Compare commits

...

18 Commits

Author SHA1 Message Date
jnfrati
f25011f9ca Merge branch 'main' of github.com:netbirdio/netbird into feat/admin-cli 2026-07-03 15:26:55 +02:00
jnfrati
1d5d8ef1b7 Merge branch 'feat/admin-cli' of github.com:netbirdio/netbird into feat/admin-cli 2026-07-03 15:26:41 +02:00
jnfrati
1428970a24 reliability fixes and disconnect-all command 2026-07-03 15:26:13 +02:00
Theodor Midtlien
3aa6c02b93 [client] Fix backoff.Ticker goroutine leak in reconnect guard 2026-07-03 12:23:11 +02:00
Zoltan Papp
f6900fb07c [client] backport enforce a single selected exit node (#6640)
* routemanager: enforce a single selected exit node

Backport of the exit-node exclusivity reconcile from the 0.75.0 line
(upstream commit 966fbec11) onto v0.74.0. Exit nodes are mutually
exclusive, but the RouteSelector stores routes with default-on semantics,
so every available exit node reported as selected at once.

Reconcile exit-node selection on each network map: keep at most one
selected -- the user's persisted pick, else whatever management marks for
auto-apply (SkipAutoApply=false), else none. Never auto-activate an exit
node the map does not request.

Carries over only the manager/routeselector logic and its test; the
desktop-only client/server changes and the BumpNetworksRevision UI-push
feature from the original commit are intentionally excluded.

* routeselector: make exit-node reconciliation atomic

enforceSingleExitNode took the RouteSelector lock three separate times
(IsDeselectAll, then DeselectRoutes, then SelectRoutes), so a concurrent
DeselectAllRoutes could interleave and be silently undone: SelectRoutes on
its deselectAll branch clears the flag and re-selects the preferred exit
node, overriding the user's "all off".

Move the whole reconciliation into a single locked RouteSelector method
(SetExclusiveExitNode) that checks deselectAll inside the critical section,
so a deselect-all either fully precedes the reconcile (left untouched) or
fully follows it (honoured). No interleaving is possible.
2026-07-03 10:31:06 +02:00
Zoltan Papp
4b3dd9103d [client] Fix slow wg operations (#6633)
* [iface] Drop redundant device dump in kernel configure()

wgctrl.ConfigureDevice already returns an error when the interface is
missing, so the preceding wg.Device() existence check is redundant. That
check dumps the entire device (all peers) on every configure() call,
making it O(peers) per call and turning bulk peer insertion into
O(peers^2): inserting N peers one by one re-parsed the whole growing peer
list N times. Removing it keeps each peer write constant-time regardless
of how many peers are already configured.

* [iface] Cache WireGuard stats to collapse per-peer device dumps

Each peer runs a WGWatcher that polls GetStats(), and every call dumps
the whole device, so with N peers the watchers perform O(N) full dumps
per poll cycle (O(N^2) work) while each keeps only its own peer's entry.

Wrap the kernel and userspace configurer GetStats() in a short-TTL cache
with singleflight: the staggered per-peer calls share a single device
dump per window and concurrent misses collapse into one dump. The kernel
and userspace WireGuard APIs have no per-peer stats query (a get always
returns the whole device), so a shared cached snapshot avoids the
repeated full dumps.

* Ignore .claude directory
2026-07-02 20:42:43 +02:00
jnfrati
fe3c14413c Merge branch 'main' of github.com:netbirdio/netbird into feat/admin-cli 2026-07-02 18:18:37 +02:00
Riccardo Manfrin
8e3b284f4b [client] Increase mgmt grpc buff size to 16MB (#6641) 2026-07-02 17:50:18 +02:00
Maycon Santos
21aa933584 [misc] Fix GHCR image push after dockers_v2 migration (#6653) 2026-07-02 17:21:06 +02:00
Misha Bragin
1dfa85a917 [management] Add vLLM e2e test (#6649)
* Add vLLM to Agent Network

* Add vllm e2e test
2026-07-02 15:36:51 +02:00
Maycon Santos
859fe19fff [management] return nil when config is not set (#6642)
* [management] return nil when config is not set

* [management] add relay invariant test and enforce config behavior
2026-07-02 14:55:55 +02:00
Misha Bragin
e40cb294f6 [management] Add vLLM to Agent Network (#6643) 2026-07-02 14:45:24 +02:00
Maycon Santos
e203e0f42a [self-hosted] Remove unused server/proxy image override logic in getting-started.sh (#6636) 2026-07-02 14:20:23 +02:00
Zoltan Papp
167be3a30f [ci] Run privileged client tests natively with sudo on Linux (#6635)
Restore the pre-split native, sudo-based run for the Linux Client / Unit
job: build with the privileged tag and run under sudo, matching the darwin
job. Excludes the dockertest harness (client/testutil/privileged) so it does
not recurse into a container spawn. The Docker privileged job is kept as-is.
2026-07-02 12:15:57 +02:00
Viktor Liu
1d8b5f6e5c [client] Make lazy connections opt-out via NB_LAZY_CONN (#6617) 2026-07-02 10:58:16 +02:00
jnfrati
520370a8b0 Merge branch 'main' of github.com:netbirdio/netbird into feat/admin-cli 2026-06-22 17:17:47 +02:00
jnfrati
b5a16a1898 chore: move token commands under admin CLI 2026-06-04 12:49:48 +02:00
jnfrati
449b5cbb80 feat: add self-hosted admin CLI 2026-06-04 11:41:57 +02:00
70 changed files with 3246 additions and 474 deletions

View File

@@ -158,7 +158,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags 'devcert privileged' -exec 'sudo --preserve-env=CI,CGO_ENABLED' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/testutil/privileged)
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'

View File

@@ -293,8 +293,11 @@ jobs:
${{ steps.goreleaser.outputs.artifacts }}
JSON
# dockers_v2 artifacts have no top-level goarch field, so match the
# per-platform -amd64 tag suffix instead; it works for both the old
# dockers and the new dockers_v2 image naming.
mapfile -t src_images < <(
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name | select(startswith("ghcr.io/"))' /tmp/goreleaser-artifacts.json
jq -r '.[] | select(.type == "Docker Image") | .name | select(startswith("ghcr.io/") and endswith("-amd64"))' /tmp/goreleaser-artifacts.json
)
for src in "${src_images[@]}"; do

1
.gitignore vendored
View File

@@ -1,3 +1,4 @@
.claude
.idea
.run
*.iml

View File

@@ -10,7 +10,7 @@ var (
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
// EnvKeyNBLazyConn Exported for Android java client to configure lazy connection
EnvKeyNBLazyConn = lazyconn.EnvEnableLazyConn
EnvKeyNBLazyConn = lazyconn.EnvLazyConn
// EnvKeyNBInactivityThreshold Exported for Android java client to configure connection inactivity threshold
EnvKeyNBInactivityThreshold = lazyconn.EnvInactivityThreshold

View File

@@ -71,12 +71,14 @@ var (
extraIFaceBlackList []string
anonymizeFlag bool
dnsRouteInterval time.Duration
lazyConnEnabled bool
mtu uint16
profilesDisabled bool
updateSettingsDisabled bool
captureEnabled bool
networksDisabled bool
// lazyConnEnabled is the parse target for the deprecated --enable-lazy-connection
// flag. The flag is inert; the value is no longer read (use NB_LAZY_CONN instead).
lazyConnEnabled bool
mtu uint16
profilesDisabled bool
updateSettingsDisabled bool
captureEnabled bool
networksDisabled bool
rootCmd = &cobra.Command{
Use: "netbird",
@@ -210,7 +212,8 @@ func init() {
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "Deprecated: no longer used. Lazy connections are controlled by the server and the NB_LAZY_CONN environment variable.")
_ = upCmd.PersistentFlags().MarkDeprecated(enableLazyConnectionFlag, "no longer used; lazy connections are controlled by the server and the NB_LAZY_CONN environment variable")
}

View File

@@ -479,10 +479,6 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
req.DisableIpv6 = &disableIPv6
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
req.LazyConnectionEnabled = &lazyConnEnabled
}
return &req
}
@@ -600,9 +596,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.DisableIPv6 = &disableIPv6
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
ic.LazyConnectionEnabled = &lazyConnEnabled
}
return &ic, nil
}
@@ -718,9 +711,6 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.DisableIpv6 = &disableIPv6
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
}
return &loginRequest, nil
}

View File

@@ -17,12 +17,15 @@ import (
type KernelConfigurer struct {
deviceName string
statsCache *statsCache
}
func NewKernelConfigurer(deviceName string) *KernelConfigurer {
return &KernelConfigurer{
c := &KernelConfigurer{
deviceName: deviceName,
}
c.statsCache = newStatsCache(statsCacheTTL, c.fetchStats)
return c
}
func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error {
@@ -246,12 +249,6 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
}
}()
// validate if device with name exists
_, err = wg.Device(c.deviceName)
if err != nil {
return err
}
return wg.ConfigureDevice(c.deviceName, config)
}
@@ -300,6 +297,14 @@ func (c *KernelConfigurer) FullStats() (*Stats, error) {
}
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
return c.statsCache.get()
}
func (c *KernelConfigurer) LastActivities() map[string]monotime.Time {
return nil
}
func (c *KernelConfigurer) fetchStats() (map[string]WGStats, error) {
stats := make(map[string]WGStats)
wg, err := wgctrl.New()
if err != nil {
@@ -326,7 +331,3 @@ func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
}
return stats, nil
}
func (c *KernelConfigurer) LastActivities() map[string]monotime.Time {
return nil
}

View File

@@ -0,0 +1,52 @@
package configurer
import (
"sync"
"time"
"golang.org/x/sync/singleflight"
)
const statsCacheTTL = 1 * time.Second
type statsCache struct {
ttl time.Duration
fetch func() (map[string]WGStats, error)
mu sync.RWMutex
value map[string]WGStats
expireAt time.Time
sf singleflight.Group
}
func newStatsCache(ttl time.Duration, fetch func() (map[string]WGStats, error)) *statsCache {
return &statsCache{ttl: ttl, fetch: fetch}
}
func (c *statsCache) get() (map[string]WGStats, error) {
c.mu.RLock()
if c.value != nil && time.Now().Before(c.expireAt) {
value := c.value
c.mu.RUnlock()
return value, nil
}
c.mu.RUnlock()
value, err, _ := c.sf.Do("stats", func() (interface{}, error) {
res, err := c.fetch()
if err != nil {
return nil, err
}
c.mu.Lock()
c.value = res
c.expireAt = time.Now().Add(c.ttl)
c.mu.Unlock()
return res, nil
})
if err != nil {
return nil, err
}
return value.(map[string]WGStats), nil
}

View File

@@ -0,0 +1,70 @@
package configurer
import (
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestStatsCache_CachesWithinTTL(t *testing.T) {
var calls atomic.Int64
c := newStatsCache(50*time.Millisecond, func() (map[string]WGStats, error) {
calls.Add(1)
return map[string]WGStats{"p": {}}, nil
})
for i := 0; i < 10; i++ {
_, err := c.get()
require.NoError(t, err)
}
require.Equal(t, int64(1), calls.Load(), "within TTL only one underlying fetch")
time.Sleep(60 * time.Millisecond)
_, err := c.get()
require.NoError(t, err)
require.Equal(t, int64(2), calls.Load(), "after TTL expiry a fresh fetch happens")
}
func TestStatsCache_SingleFlight(t *testing.T) {
var calls atomic.Int64
release := make(chan struct{})
c := newStatsCache(time.Minute, func() (map[string]WGStats, error) {
calls.Add(1)
<-release
return map[string]WGStats{}, nil
})
const n = 50
var wg sync.WaitGroup
wg.Add(n)
for i := 0; i < n; i++ {
go func() {
defer wg.Done()
_, _ = c.get()
}()
}
time.Sleep(20 * time.Millisecond)
close(release)
wg.Wait()
require.Equal(t, int64(1), calls.Load(), "concurrent misses collapse into one fetch")
}
func TestStatsCache_ErrorNotCached(t *testing.T) {
var calls atomic.Int64
wantErr := errors.New("dump failed")
c := newStatsCache(time.Minute, func() (map[string]WGStats, error) {
calls.Add(1)
return nil, wantErr
})
_, err := c.get()
require.ErrorIs(t, err, wantErr)
_, err = c.get()
require.ErrorIs(t, err, wantErr)
require.Equal(t, int64(2), calls.Load(), "errors are not cached; each call retries")
}

View File

@@ -40,6 +40,7 @@ type WGUSPConfigurer struct {
device *device.Device
deviceName string
activityRecorder *bind.ActivityRecorder
statsCache *statsCache
uapiListener net.Listener
}
@@ -50,16 +51,19 @@ func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder
deviceName: deviceName,
activityRecorder: activityRecorder,
}
wgCfg.statsCache = newStatsCache(statsCacheTTL, wgCfg.fetchStats)
wgCfg.startUAPI()
return wgCfg
}
func NewUSPConfigurerNoUAPI(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
return &WGUSPConfigurer{
wgCfg := &WGUSPConfigurer{
device: device,
deviceName: deviceName,
activityRecorder: activityRecorder,
}
wgCfg.statsCache = newStatsCache(statsCacheTTL, wgCfg.fetchStats)
return wgCfg
}
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
@@ -348,6 +352,10 @@ func (t *WGUSPConfigurer) Close() {
}
func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) {
return t.statsCache.get()
}
func (t *WGUSPConfigurer) fetchStats() (map[string]WGStats, error) {
ipc, err := t.device.IpcGet()
if err != nil {
return nil, fmt.Errorf("ipc get: %w", err)

View File

@@ -322,7 +322,6 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
a.config.BlockLANAccess,
a.config.BlockInbound,
a.config.DisableIPv6,
a.config.LazyConnectionEnabled,
a.config.EnableSSHRoot,
a.config.EnableSSHSFTP,
a.config.EnableSSHLocalPortForwarding,

View File

@@ -16,6 +16,16 @@ import (
"github.com/netbirdio/netbird/route"
)
// lazyForce is the resolved local decision for lazy connections, layered above the
// management feature flag. lazyForceNone defers to management.
type lazyForce int
const (
lazyForceNone lazyForce = iota
lazyForceOn
lazyForceOff
)
// ConnMgr coordinates both lazy connections (established on-demand) and permanent peer connections.
//
// The connection manager is responsible for:
@@ -28,7 +38,7 @@ type ConnMgr struct {
peerStore *peerstore.Store
statusRecorder *peer.Status
iface lazyconn.WGIface
enabledLocally bool
force lazyForce
rosenpassEnabled bool
lazyConnMgr *manager.Manager
@@ -43,28 +53,34 @@ func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerSto
peerStore: peerStore,
statusRecorder: statusRecorder,
iface: iface,
force: resolveLazyForce(engineConfig.LazyConnection),
rosenpassEnabled: engineConfig.RosenpassEnabled,
}
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
e.enabledLocally = true
}
return e
}
// Start initializes the connection manager and starts the lazy connection manager if enabled by env var or cmd line option.
// Start initializes the connection manager. It starts the lazy connection manager when a
// local override forces it on; with no local override it waits for the management feature flag.
func (e *ConnMgr) Start(ctx context.Context) {
if e.lazyConnMgr != nil {
log.Errorf("lazy connection manager is already started")
return
}
if !e.enabledLocally {
log.Infof("lazy connection manager is disabled")
switch e.force {
case lazyForceOff:
log.Infof("lazy connection manager is disabled by local override (%s or MDM policy)", lazyconn.EnvLazyConn)
e.statusRecorder.UpdateLazyConnection(false)
return
case lazyForceNone:
log.Infof("lazy connection manager is managed by the management feature flag")
e.statusRecorder.UpdateLazyConnection(false)
return
}
if e.rosenpassEnabled {
log.Warnf("rosenpass connection manager is enabled, lazy connection manager will not be started")
e.statusRecorder.UpdateLazyConnection(false)
return
}
@@ -76,8 +92,8 @@ func (e *ConnMgr) Start(ctx context.Context) {
// If enabled, it initializes the lazy connection manager and start it. Do not need to call Start() again.
// If disabled, then it closes the lazy connection manager and open the connections to all peers.
func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) error {
// do not disable lazy connection manager if it was enabled by env var
if e.enabledLocally {
// a local override (NB_LAZY_CONN or local config) takes precedence over management
if e.force != lazyForceNone {
return nil
}
@@ -89,6 +105,7 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er
if e.rosenpassEnabled {
log.Infof("rosenpass connection manager is enabled, lazy connection manager will not be started")
e.statusRecorder.UpdateLazyConnection(false)
return nil
}
@@ -98,6 +115,7 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er
return e.addPeersToLazyConnManager()
} else {
if e.lazyConnMgr == nil {
e.statusRecorder.UpdateLazyConnection(false)
return nil
}
log.Infof("lazy connection manager is disabled by management feature flag")
@@ -309,6 +327,25 @@ func (e *ConnMgr) isStartedWithLazyMgr() bool {
return e.lazyConnMgr != nil && e.lazyCtxCancel != nil
}
// resolveLazyForce determines the local override. NB_LAZY_CONN takes precedence; when it
// is unset the MDM policy override (mdmState) applies. Either wins in both directions over
// the management feature flag; StateUnset for both defers to management.
func resolveLazyForce(mdmState lazyconn.State) lazyForce {
state := lazyconn.EnvState()
if state == lazyconn.StateUnset {
state = mdmState
}
switch state {
case lazyconn.StateOn:
return lazyForceOn
case lazyconn.StateOff:
return lazyForceOff
default:
return lazyForceNone
}
}
func inactivityThresholdEnv() *time.Duration {
envValue := os.Getenv(lazyconn.EnvInactivityThreshold)
if envValue == "" {

View File

@@ -0,0 +1,40 @@
package internal
import (
"os"
"testing"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
func TestResolveLazyForce(t *testing.T) {
tests := []struct {
name string
env string
envSet bool
mdm lazyconn.State
want lazyForce
}{
{name: "env unset, mdm unset -> defer to management", mdm: lazyconn.StateUnset, want: lazyForceNone},
{name: "env on -> force on", env: "on", envSet: true, mdm: lazyconn.StateUnset, want: lazyForceOn},
{name: "env off -> force off", env: "off", envSet: true, mdm: lazyconn.StateUnset, want: lazyForceOff},
{name: "env unset, mdm on -> force on", mdm: lazyconn.StateOn, want: lazyForceOn},
{name: "env unset, mdm off -> force off", mdm: lazyconn.StateOff, want: lazyForceOff},
{name: "env on beats mdm off", env: "on", envSet: true, mdm: lazyconn.StateOff, want: lazyForceOn},
{name: "env off beats mdm on", env: "off", envSet: true, mdm: lazyconn.StateOn, want: lazyForceOff},
{name: "unrecognized env, mdm on -> mdm wins", env: "auto", envSet: true, mdm: lazyconn.StateOn, want: lazyForceOn},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv(lazyconn.EnvLazyConn, tt.env)
if !tt.envSet {
os.Unsetenv(lazyconn.EnvLazyConn)
}
if got := resolveLazyForce(tt.mdm); got != tt.want {
t.Fatalf("resolveLazyForce(%v) = %v, want %v", tt.mdm, got, tt.want)
}
})
}
}

View File

@@ -27,6 +27,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/metrics"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -601,7 +602,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
BlockInbound: config.BlockInbound,
DisableIPv6: config.DisableIPv6,
LazyConnectionEnabled: config.LazyConnectionEnabled,
LazyConnection: lazyconn.ParseState(config.LazyConnection),
MTU: selectMTU(config.MTU, peerConfig.Mtu),
LogPath: logPath,
@@ -675,7 +676,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.BlockLANAccess,
config.BlockInbound,
config.DisableIPv6,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,

View File

@@ -681,7 +681,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
configContent.WriteString(fmt.Sprintf("ClientCertKeyPath: %s\n", g.internalConfig.ClientCertKeyPath))
}
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
configContent.WriteString(fmt.Sprintf("LazyConnection: %q\n", g.internalConfig.LazyConnection))
configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU))
}

View File

@@ -885,7 +885,7 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
DNSRouteInterval: 5 * time.Second,
ClientCertPath: "/tmp/cert",
ClientCertKeyPath: "/tmp/key",
LazyConnectionEnabled: true,
LazyConnection: "on",
MTU: 1280,
}

View File

@@ -40,6 +40,7 @@ import (
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/metrics"
"github.com/netbirdio/netbird/client/internal/netflow"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
@@ -147,7 +148,9 @@ type EngineConfig struct {
BlockInbound bool
DisableIPv6 bool
LazyConnectionEnabled bool
// LazyConnection is the MDM-sourced lazy-connection override; StateUnset defers to
// the env var and management feature flag.
LazyConnection lazyconn.State
MTU uint16
@@ -1130,7 +1133,6 @@ func (e *Engine) applyInfoFlags(info *system.Info) {
e.config.BlockLANAccess,
e.config.BlockInbound,
e.config.DisableIPv6,
e.config.LazyConnectionEnabled,
e.config.EnableSSHRoot,
e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding,
@@ -1999,7 +2001,6 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
e.config.BlockLANAccess,
e.config.BlockInbound,
e.config.DisableIPv6,
e.config.LazyConnectionEnabled,
e.config.EnableSSHRoot,
e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding,

View File

@@ -3,24 +3,57 @@ package lazyconn
import (
"os"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
)
const (
EnvEnableLazyConn = "NB_ENABLE_EXPERIMENTAL_LAZY_CONN"
EnvLazyConn = "NB_LAZY_CONN"
EnvInactivityThreshold = "NB_LAZY_CONN_INACTIVITY_THRESHOLD"
)
func IsLazyConnEnabledByEnv() bool {
val := os.Getenv(EnvEnableLazyConn)
if val == "" {
return false
}
enabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableLazyConn, err)
return false
}
return enabled
// State is the tri-state local override for lazy connections read from the environment.
type State int
const (
// StateUnset means no local override; defer to the management feature flag.
StateUnset State = iota
// StateOn forces lazy connections on, overriding management.
StateOn
// StateOff forces lazy connections off, overriding management.
StateOff
)
// EnvState reads NB_LAZY_CONN and returns the local override state.
func EnvState() State {
return ParseState(os.Getenv(EnvLazyConn))
}
// ParseState interprets a lazy-connection override value (from the environment or an MDM
// policy). It accepts the on/off aliases plus any value strconv.ParseBool understands
// (true/false/1/0). An empty or unrecognized value returns StateUnset so that the
// management feature flag remains in control.
func ParseState(raw string) State {
if raw == "" {
return StateUnset
}
normalized := strings.ToLower(strings.TrimSpace(raw))
switch normalized {
case "on":
return StateOn
case "off":
return StateOff
}
enabled, err := strconv.ParseBool(normalized)
if err != nil {
log.Warnf("failed to parse lazy connection value %q (from %s env or MDM policy): %v", raw, EnvLazyConn, err)
return StateUnset
}
if enabled {
return StateOn
}
return StateOff
}

View File

@@ -0,0 +1,45 @@
package lazyconn
import (
"os"
"testing"
)
func TestEnvState(t *testing.T) {
tests := []struct {
value string
set bool
want State
}{
{set: false, want: StateUnset},
{value: "", set: true, want: StateUnset},
{value: "on", set: true, want: StateOn},
{value: "ON", set: true, want: StateOn},
{value: "true", set: true, want: StateOn},
{value: "1", set: true, want: StateOn},
{value: " on ", set: true, want: StateOn},
{value: "off", set: true, want: StateOff},
{value: "OFF", set: true, want: StateOff},
{value: "false", set: true, want: StateOff},
{value: "0", set: true, want: StateOff},
{value: "auto", set: true, want: StateUnset},
{value: "garbage", set: true, want: StateUnset},
}
for _, tt := range tests {
name := tt.value
if !tt.set {
name = "unset"
}
t.Run(name, func(t *testing.T) {
t.Setenv(EnvLazyConn, tt.value)
if !tt.set {
os.Unsetenv(EnvLazyConn)
}
if got := EnvState(); got != tt.want {
t.Fatalf("EnvState() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -85,7 +85,11 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
defer g.srWatcher.RemoveListener(srReconnectedChan)
ticker := g.initialTicker(ctx)
defer ticker.Stop()
defer func() {
// If backoff.Ticker.send is blocked, context.Done will not close the Ticker goroutine.
// We have to explicitly call Stop, even if we use backoff.WithContext.
ticker.Stop()
}()
tickerChannel := ticker.C

View File

@@ -0,0 +1,92 @@
package guard
import (
"context"
"runtime"
"strings"
"sync"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer/ice"
)
func newTestGuard(status connStatusFunc) *Guard {
srw := NewSRWatcher(nil, nil, nil, ice.Config{})
return NewGuard(log.WithField("test", "guard"), status, 50*time.Millisecond, srw)
}
// countBackoffTickerGoroutines returns how many goroutines are currently sitting
// in backoff/v4.(*Ticker).run (a ticker goroutine that has not exited).
func countBackoffTickerGoroutines() int {
buf := make([]byte, 1<<25) // 32MB
n := runtime.Stack(buf, true)
return strings.Count(string(buf[:n]), "backoff/v4.(*Ticker).run")
}
// TestGuard_ReconnectTicker_NoGoroutineLeakOnShutdown reproduces a observed
// leak: after a shutdown burst, ticker run/send goroutines stay parked
// forever even though every reconnect loop has exited.
func TestGuard_ReconnectTicker_NoGoroutineLeakOnShutdown(t *testing.T) {
before := countBackoffTickerGoroutines()
const peers = 6000
cancels := make([]context.CancelFunc, 0, peers)
var wg sync.WaitGroup
// A status check slower than the tick cadence. This models the real
// isConnectedOnAllWay/callback doing work: while the loop is busy in the
// handler, the ticker fires the next tick and parks in send(), because
// send() never selects on ctx.
slowStatus := func() ConnStatus {
time.Sleep(70 * time.Millisecond)
return ConnStatusConnected
}
for range peers {
g := newTestGuard(slowStatus)
ctx, cancel := context.WithCancel(context.Background())
cancels = append(cancels, cancel)
wg.Add(1)
go func() {
defer wg.Done()
g.Start(ctx, func() {})
}()
// Force the live ticker to be a newReconnectTicker.
g.SetRelayedConnDisconnected()
}
// Let the replacement tickers get past their 800ms initial interval, so
// many are parked in send() waiting on the (slow) consumer when we tear
// everything down.
time.Sleep(1500 * time.Millisecond)
// Shutdown burst: cancel every peer at once, like engine teardown.
for _, c := range cancels {
c()
}
// Every reconnect loop must return
waitCh := make(chan struct{})
go func() { wg.Wait(); close(waitCh) }()
select {
case <-waitCh:
case <-time.After(30 * time.Second):
t.Fatal("not all reconnect loops returned after ctx cancel")
}
// Give any correctly-stopped ticker goroutines time to unwind.
for range 50 {
runtime.Gosched()
time.Sleep(10 * time.Millisecond)
}
leaked := countBackoffTickerGoroutines() - before
t.Logf("backoff Ticker.run goroutines still parked after teardown of %d peers: %d", peers, leaked)
if leaked > 0 {
t.Errorf("LEAK: %d backoff ticker goroutines parked after all reconnect loops exited "+
"(defer ticker.Stop() stops the initial ticker, not the live replacement)", leaked)
}
}

View File

@@ -101,8 +101,6 @@ type ConfigInput struct {
DNSLabels domain.List
LazyConnectionEnabled *bool
MTU *uint16
}
@@ -180,7 +178,9 @@ type Config struct {
ClientCertKeyPair *tls.Certificate `json:"-"`
LazyConnectionEnabled bool
// LazyConnection is the MDM-managed lazy-connection override ("on"/"off"/"").
// Runtime-only: re-derived from MDM policy on each load, never persisted.
LazyConnection string `json:"-"`
MTU uint16
@@ -632,12 +632,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.LazyConnectionEnabled != nil && *input.LazyConnectionEnabled != config.LazyConnectionEnabled {
log.Infof("switching lazy connection to %t", *input.LazyConnectionEnabled)
config.LazyConnectionEnabled = *input.LazyConnectionEnabled
updated = true
}
if input.MTU != nil && *input.MTU != config.MTU {
log.Infof("updating MTU to %d (old value %d)", *input.MTU, config.MTU)
config.MTU = *input.MTU
@@ -728,6 +722,15 @@ func (config *Config) applyMDMPolicy(policy *mdm.Policy) {
log.Warnf("MDM wireguard port %d out of range [1,65535]; keeping previous value", v)
}
}
if v, ok := policy.GetBool(mdm.KeyLazyConnection); ok {
state := "off"
if v {
state = "on"
}
config.LazyConnection = state
logApplied(mdm.KeyLazyConnection, state)
}
}
// parseURL parses and validates the URL for the named service. The URL

View File

@@ -130,6 +130,37 @@ func TestApply_MDMBoolKeysOverrideOnDiskValue(t *testing.T) {
assert.True(t, cfg.Policy().HasKey(mdm.KeyRosenpassEnabled))
}
func TestApply_MDMLazyConnection(t *testing.T) {
cases := []struct {
name string
raw any
want string
}{
{"native true", true, "on"},
{"native false", false, "off"},
{"string on", "on", "on"},
{"string off", "off", "off"},
{"string yes", "yes", "on"},
{"string no", "no", "off"},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyLazyConnection: c.raw,
}))
cfg, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
})
require.NoError(t, err)
require.NotNil(t, cfg)
assert.Equal(t, c.want, cfg.LazyConnection)
assert.True(t, cfg.Policy().HasKey(mdm.KeyLazyConnection))
})
}
}
func TestApply_MDMPreSharedKeyRedactionSentinelRejected(t *testing.T) {
const maskSentinel = "**********"

View File

@@ -0,0 +1,191 @@
package routemanager
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/route"
)
func newExitNodeTestManager() *DefaultManager {
return &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
}
func exitRoute(netID, peer string, skipAutoApply bool) *route.Route {
return &route.Route{
NetID: route.NetID(netID),
Network: netip.MustParsePrefix("0.0.0.0/0"),
Peer: peer,
SkipAutoApply: skipAutoApply,
}
}
func TestPickPreferredExitNode(t *testing.T) {
tests := []struct {
name string
info exitNodeInfo
want route.NetID
}{
{
name: "persisted user selection wins over management",
info: exitNodeInfo{
allIDs: []route.NetID{"a", "b", "c"},
userSelected: []route.NetID{"b"},
selectedByManagement: []route.NetID{"a"},
},
want: "b",
},
{
name: "multiple user-selected self-heal to deterministic min",
info: exitNodeInfo{
allIDs: []route.NetID{"a", "b", "c"},
userSelected: []route.NetID{"c", "a"},
},
want: "a",
},
{
name: "explicit opt-out keeps none",
info: exitNodeInfo{
allIDs: []route.NetID{"a", "b"},
userDeselected: []route.NetID{"a", "b"},
},
want: "",
},
{
name: "fresh defaults to management auto-apply pick",
info: exitNodeInfo{
allIDs: []route.NetID{"a", "b", "c"},
selectedByManagement: []route.NetID{"b"},
},
want: "b",
},
{
name: "no user pick and no management auto-apply selects none",
info: exitNodeInfo{
allIDs: []route.NetID{"c", "a", "b"},
},
want: "",
},
{
name: "user-deselect does not block a management auto-apply sibling",
info: exitNodeInfo{
allIDs: []route.NetID{"a", "b"},
userDeselected: []route.NetID{"a"},
selectedByManagement: []route.NetID{"b"},
},
want: "b",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, pickPreferredExitNode(tt.info), "preferred exit node")
})
}
}
func TestEnforceSingleExitNode(t *testing.T) {
m := newExitNodeTestManager()
all := []route.NetID{"a", "b", "c"}
m.enforceSingleExitNode("b", all)
assert.False(t, m.routeSelector.IsSelected("a"), "a should be deselected")
assert.True(t, m.routeSelector.IsSelected("b"), "b should be the only selected exit node")
assert.False(t, m.routeSelector.IsSelected("c"), "c should be deselected")
// Switching the preferred node moves the single selection.
m.enforceSingleExitNode("c", all)
assert.False(t, m.routeSelector.IsSelected("a"), "a stays deselected")
assert.False(t, m.routeSelector.IsSelected("b"), "b should now be deselected")
assert.True(t, m.routeSelector.IsSelected("c"), "c should now be selected")
// Empty preferred turns every exit node off.
m.enforceSingleExitNode("", all)
for _, id := range all {
assert.False(t, m.routeSelector.IsSelected(id), "no exit node should be selected")
}
}
func TestEnforceSingleExitNode_RespectsDeselectAll(t *testing.T) {
m := newExitNodeTestManager()
m.routeSelector.DeselectAllRoutes()
m.enforceSingleExitNode("b", []route.NetID{"a", "b"})
assert.True(t, m.routeSelector.IsDeselectAll(), "global deselect-all must stay in effect")
assert.False(t, m.routeSelector.IsSelected("b"), "no exit node should be forced on while deselect-all is set")
}
func TestUpdateRouteSelectorFromManagement_FreshSelectsOne(t *testing.T) {
m := newExitNodeTestManager()
routes := route.HAMap{
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
"lan|192.168.1.0/24": {{NetID: "lan", Network: netip.MustParsePrefix("192.168.1.0/24"), Peer: "p3"}},
"exitC|0.0.0.0/0": {exitRoute("exitC", "p4", false)},
}
m.updateRouteSelectorFromManagement(routes)
// Exactly one exit node (the deterministic first) is selected.
assert.True(t, m.routeSelector.IsSelected("exitA"), "exitA is the deterministic default")
assert.False(t, m.routeSelector.IsSelected("exitB"), "exitB must not also be selected")
assert.False(t, m.routeSelector.IsSelected("exitC"), "exitC must not also be selected")
// Non-exit routes are left at their default-on state.
assert.True(t, m.routeSelector.IsSelected("lan"), "non-exit route selection is untouched")
}
func TestUpdateRouteSelectorFromManagement_HonorsPersistedPick(t *testing.T) {
m := newExitNodeTestManager()
routes := route.HAMap{
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
}
all := []route.NetID{"exitA", "exitB"}
// Simulate the state the runtime select path leaves behind: exactly one
// exit node explicitly selected, its sibling deselected.
require.NoError(t, m.routeSelector.SelectRoutes([]route.NetID{"exitB"}, true, all))
require.NoError(t, m.routeSelector.DeselectRoutes([]route.NetID{"exitA"}, all))
m.updateRouteSelectorFromManagement(routes)
assert.True(t, m.routeSelector.IsSelected("exitB"), "persisted pick must stay selected")
assert.False(t, m.routeSelector.IsSelected("exitA"), "the other exit node stays deselected")
}
func TestUpdateRouteSelectorFromManagement_OptOutKeepsNone(t *testing.T) {
m := newExitNodeTestManager()
routes := route.HAMap{
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
}
all := []route.NetID{"exitA", "exitB"}
// User deselected exit nodes and selected none.
require.NoError(t, m.routeSelector.DeselectRoutes(all, all))
m.updateRouteSelectorFromManagement(routes)
assert.False(t, m.routeSelector.IsSelected("exitA"), "opt-out keeps exitA off")
assert.False(t, m.routeSelector.IsSelected("exitB"), "opt-out keeps exitB off")
}
func TestUpdateRouteSelectorFromManagement_NoAutoApplySelectsNone(t *testing.T) {
m := newExitNodeTestManager()
// SkipAutoApply=true: management offers the exit nodes but doesn't request
// auto-activation, so none should be selected until the user picks one.
routes := route.HAMap{
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", true)},
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", true)},
}
m.updateRouteSelectorFromManagement(routes)
assert.False(t, m.routeSelector.IsSelected("exitA"), "no auto-apply keeps exitA off")
assert.False(t, m.routeSelector.IsSelected("exitB"), "no auto-apply keeps exitB off")
}

View File

@@ -701,7 +701,13 @@ func resolveURLsToIPs(urls []string) []net.IP {
return ips
}
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
// updateRouteSelectorFromManagement reconciles exit-node selection on every
// network map: it keeps at most one exit node selected — the user's persisted
// pick, else whatever management marks for auto-apply (SkipAutoApply=false),
// else none. We never auto-activate an exit node the map doesn't request; it
// stays off until the user picks it. Exit nodes are mutually exclusive, but the
// RouteSelector stores routes with default-on semantics, so without this every
// available exit node would report selected at once.
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
m.mirrorV6ExitPairSelections(clientRoutes)
@@ -712,13 +718,14 @@ func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HA
return
}
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
if len(exitNodeInfo.allIDs) == 0 {
info := m.collectExitNodeInfo(clientRoutes)
if len(info.allIDs) == 0 {
return
}
m.updateExitNodeSelections(exitNodeInfo)
m.logExitNodeUpdate(exitNodeInfo)
preferred := pickPreferredExitNode(info)
m.enforceSingleExitNode(preferred, info.allIDs)
m.logExitNodeUpdate(info, preferred)
}
// mirrorV6ExitPairSelections keeps every synthesized "-v6" exit route's selection
@@ -746,6 +753,10 @@ type exitNodeInfo struct {
userDeselected []route.NetID
}
// collectExitNodeInfo categorises the available exit nodes by their persisted
// selection state. It keys on the base (v4) NetID and skips the synthesized
// "-v6" partner, which inherits its base's selection through the RouteSelector
// — counting it separately would double-count the pair.
func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeInfo {
var info exitNodeInfo
@@ -755,6 +766,9 @@ func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeI
}
netID := haID.NetID()
if strings.HasSuffix(string(netID), route.V6ExitSuffix) {
continue
}
info.allIDs = append(info.allIDs, netID)
if m.routeSelector.HasUserSelectionForRoute(netID) {
@@ -791,45 +805,52 @@ func (m *DefaultManager) checkManagementSelection(routes []*route.Route, netID r
}
}
func (m *DefaultManager) updateExitNodeSelections(info exitNodeInfo) {
routesToDeselect := m.getRoutesToDeselect(info.allIDs)
m.deselectExitNodes(routesToDeselect)
m.selectExitNodesByManagement(info.selectedByManagement, info.allIDs)
// pickPreferredExitNode chooses the single exit node to keep selected. In order:
// - a persisted user selection wins (deterministic if several survive from
// legacy state, so the set self-heals down to one);
// - otherwise activate only what management marks for auto-apply
// (SkipAutoApply=false); the lexicographically first if it marks several.
//
// Returns "" when neither holds — we never force an arbitrary exit node on. A
// route the map doesn't auto-apply stays off until the user selects it.
// info.userDeselected is informational only: an explicit deselect simply keeps
// that route out of both lists above, so it can't be picked.
func pickPreferredExitNode(info exitNodeInfo) route.NetID {
if len(info.userSelected) > 0 {
return minNetID(info.userSelected)
}
if len(info.selectedByManagement) > 0 {
return minNetID(info.selectedByManagement)
}
return ""
}
func (m *DefaultManager) getRoutesToDeselect(allIDs []route.NetID) []route.NetID {
var routesToDeselect []route.NetID
for _, netID := range allIDs {
if !m.routeSelector.HasUserSelectionForRoute(netID) {
routesToDeselect = append(routesToDeselect, netID)
// enforceSingleExitNode makes preferred the only selected exit node: every other
// available exit node is deselected and preferred (if any) is selected, without
// disturbing non-exit route selections. The whole reconciliation runs under a
// single RouteSelector lock (SetExclusiveExitNode) so a concurrent deselect-all
// cannot interleave and get undone; a global deselect-all is left untouched so
// the user's "all off" stays in effect.
func (m *DefaultManager) enforceSingleExitNode(preferred route.NetID, allIDs []route.NetID) {
m.routeSelector.SetExclusiveExitNode(preferred, allIDs)
}
func (m *DefaultManager) logExitNodeUpdate(info exitNodeInfo, preferred route.NetID) {
log.Debugf("Exit node selection: %d available, preferred=%q (%d user-selected, %d user-deselected, %d management-selected)",
len(info.allIDs), preferred, len(info.userSelected), len(info.userDeselected), len(info.selectedByManagement))
}
// minNetID returns the lexicographically smallest NetID, for a deterministic
// default pick that stays stable across restarts.
func minNetID(ids []route.NetID) route.NetID {
if len(ids) == 0 {
return ""
}
best := ids[0]
for _, id := range ids[1:] {
if id < best {
best = id
}
}
return routesToDeselect
}
func (m *DefaultManager) deselectExitNodes(routesToDeselect []route.NetID) {
if len(routesToDeselect) == 0 {
return
}
err := m.routeSelector.DeselectRoutes(routesToDeselect, routesToDeselect)
if err != nil {
log.Warnf("Failed to deselect exit nodes: %v", err)
}
}
func (m *DefaultManager) selectExitNodesByManagement(selectedByManagement []route.NetID, allIDs []route.NetID) {
if len(selectedByManagement) == 0 {
return
}
err := m.routeSelector.SelectRoutes(selectedByManagement, true, allIDs)
if err != nil {
log.Warnf("Failed to select exit nodes: %v", err)
}
}
func (m *DefaultManager) logExitNodeUpdate(info exitNodeInfo) {
log.Debugf("Updated route selector: %d exit nodes available, %d selected by management, %d user-selected, %d user-deselected",
len(info.allIDs), len(info.selectedByManagement), len(info.userSelected), len(info.userDeselected))
return best
}

View File

@@ -115,7 +115,38 @@ func (rs *RouteSelector) DeselectAllRoutes() {
clear(rs.selectedRoutes)
}
// IsDeselectAll reports whether the user has explicitly deselected all routes.
// SetExclusiveExitNode atomically makes preferred the only selected exit node
// among exitIDs: every other ID in exitIDs is deselected and preferred (when
// non-empty) is selected, all under a single lock. Holding the lock across the
// whole reconciliation prevents a concurrent DeselectAllRoutes from interleaving
// between the deselect and select steps and being silently undone. A global
// deselect-all is left untouched so the user's "all off" stays in effect;
// non-exit routes are never referenced, so their selection is preserved.
func (rs *RouteSelector) SetExclusiveExitNode(preferred route.NetID, exitIDs []route.NetID) {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.deselectAll {
return
}
for _, id := range exitIDs {
if id == preferred {
continue
}
rs.deselectedRoutes[id] = struct{}{}
delete(rs.selectedRoutes, id)
}
if preferred != "" {
delete(rs.deselectedRoutes, preferred)
rs.selectedRoutes[preferred] = struct{}{}
}
}
// IsDeselectAll reports whether the global "deselect all" flag is set, i.e. the
// user explicitly disabled every route. Callers enforcing per-route invariants
// (e.g. single exit node) should leave the selection untouched when it is.
func (rs *RouteSelector) IsDeselectAll() bool {
rs.mu.RLock()
defer rs.mu.RUnlock()

View File

@@ -38,7 +38,7 @@ func GetEnvKeyNBForceRelay() string {
// GetEnvKeyNBLazyConn Exports the environment variable for the iOS client
func GetEnvKeyNBLazyConn() string {
return lazyconn.EnvEnableLazyConn
return lazyconn.EnvLazyConn
}
// GetEnvKeyNBInactivityThreshold Exports the environment variable for the iOS client

View File

@@ -27,6 +27,7 @@ var allKeys = []string{
KeyWireguardPort,
KeySplitTunnelMode,
KeySplitTunnelApps,
KeyLazyConnection,
}
// canonicalKey maps the lowercase form of a managed-config value name to

View File

@@ -11,6 +11,7 @@ package mdm
import (
"sort"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
)
@@ -41,6 +42,11 @@ const (
// construction — only one mode can be set at a time.
KeySplitTunnelMode = "splitTunnelMode"
KeySplitTunnelApps = "splitTunnelApps"
// KeyLazyConnection forces the lazy-connection feature on or off, overriding
// the management feature flag. Read as a bool (native bool, or on/off,
// true/false, 1/0, yes/no); absent = defer to management.
KeyLazyConnection = "lazyConnection"
)
// Split-tunnel mode literals (KeySplitTunnelMode values).
@@ -62,12 +68,13 @@ var boolStringLiterals = map[string]bool{
"true": true,
"1": true,
"yes": true,
"on": true,
"false": false,
"0": false,
"no": false,
"off": false,
}
// Policy holds MDM-managed settings read from the platform source. A nil or
// empty Policy means no enforcement is active.
type Policy struct {
@@ -150,7 +157,8 @@ func (p *Policy) GetString(key string) (string, bool) {
}
// GetBool returns the managed value for key coerced to bool, and whether the
// key was set. Accepts native bool and string literals "true"/"false"/"1"/"0".
// key was set. Accepts native bool and string literals (true/false, 1/0,
// yes/no, on/off), case-insensitively and trimmed of surrounding whitespace.
func (p *Policy) GetBool(key string) (bool, bool) {
if p == nil {
return false, false
@@ -163,7 +171,7 @@ func (p *Policy) GetBool(key string) (bool, bool) {
case bool:
return t, true
case string:
b, known := boolStringLiterals[t]
b, known := boolStringLiterals[strings.ToLower(strings.TrimSpace(t))]
return b, known
case int:
return t != 0, true

View File

@@ -31,8 +31,8 @@ func TestPolicy_Empty(t *testing.T) {
func TestPolicy_HasKey(t *testing.T) {
p := NewPolicy(map[string]any{
KeyManagementURL: "https://corp.example.com",
KeyDisableProfiles: true,
KeyManagementURL: "https://corp.example.com",
KeyDisableProfiles: true,
})
assert.False(t, p.IsEmpty())
assert.True(t, p.HasKey(KeyManagementURL))
@@ -53,8 +53,8 @@ func TestPolicy_ManagedKeysSorted(t *testing.T) {
func TestPolicy_GetString(t *testing.T) {
p := NewPolicy(map[string]any{
KeyManagementURL: "https://corp.example.com",
KeyDisableProfiles: true, // wrong type for GetString
KeyPreSharedKey: "", // empty rejected
KeyDisableProfiles: true, // wrong type for GetString
KeyPreSharedKey: "", // empty rejected
})
v, ok := p.GetString(KeyManagementURL)
assert.True(t, ok)
@@ -85,6 +85,11 @@ func TestPolicy_GetBool(t *testing.T) {
{"string 0", "0", false, true},
{"string yes", "yes", true, true},
{"string no", "no", false, true},
{"string on", "on", true, true},
{"string off", "off", false, true},
{"mixed case On", "On", true, true},
{"upper TRUE", "TRUE", true, true},
{"padded yes", " yes ", true, true},
{"int nonzero", 1, true, true},
{"int zero", 0, false, true},
{"int64 nonzero", int64(2), true, true},

View File

@@ -152,7 +152,6 @@ func (s *Server) restartEngineForMDMLocked() error {
s.config = config
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
ctx, cancel := context.WithCancel(s.rootCtx)
s.actCancel = cancel
@@ -305,7 +304,6 @@ func setConfigRequestHasConfigOverrides(msg *proto.SetConfigRequest) bool {
msg.DisableFirewall != nil ||
msg.BlockLanAccess != nil ||
msg.DisableNotifications != nil ||
msg.LazyConnectionEnabled != nil ||
msg.BlockInbound != nil ||
msg.DisableIpv6 != nil ||
msg.EnableSSHRoot != nil ||
@@ -348,7 +346,6 @@ func loginRequestHasConfigOverrides(msg *proto.LoginRequest) bool {
msg.BlockLanAccess != nil ||
msg.DisableNotifications != nil ||
len(msg.DnsLabels) > 0 || msg.CleanDNSLabels ||
msg.LazyConnectionEnabled != nil ||
msg.BlockInbound != nil
}

View File

@@ -214,7 +214,6 @@ func (s *Server) Start() error {
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
if s.sessionWatcher == nil {
s.sessionWatcher = internal.NewSessionWatcher(s.rootCtx, s.statusRecorder)
@@ -463,7 +462,6 @@ func (s *Server) setConfigInputFromRequest(msg *proto.SetConfigRequest) (profile
config.DisableFirewall = msg.DisableFirewall
config.BlockLANAccess = msg.BlockLanAccess
config.DisableNotifications = msg.DisableNotifications
config.LazyConnectionEnabled = msg.LazyConnectionEnabled
config.BlockInbound = msg.BlockInbound
config.DisableIPv6 = msg.DisableIpv6
config.EnableSSHRoot = msg.EnableSSHRoot
@@ -1647,7 +1645,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
ServerSSHAllowed: *cfg.ServerSSHAllowed,
RosenpassEnabled: cfg.RosenpassEnabled,
RosenpassPermissive: cfg.RosenpassPermissive,
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
BlockInbound: cfg.BlockInbound,
DisableNotifications: disableNotifications,
NetworkMonitor: networkMonitor,

View File

@@ -69,43 +69,41 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
disableFirewall := true
blockLANAccess := true
disableNotifications := true
lazyConnectionEnabled := true
blockInbound := true
disableIPv6 := true
mtu := int64(1280)
sshJWTCacheTTL := int32(300)
req := &proto.SetConfigRequest{
ProfileName: profName,
Username: currUser.Username,
ManagementUrl: "https://new-api.netbird.io:443",
AdminURL: "https://new-admin.netbird.io",
RosenpassEnabled: &rosenpassEnabled,
RosenpassPermissive: &rosenpassPermissive,
ServerSSHAllowed: &serverSSHAllowed,
InterfaceName: &interfaceName,
WireguardPort: &wireguardPort,
OptionalPreSharedKey: &preSharedKey,
DisableAutoConnect: &disableAutoConnect,
NetworkMonitor: &networkMonitor,
DisableClientRoutes: &disableClientRoutes,
DisableServerRoutes: &disableServerRoutes,
DisableDns: &disableDNS,
DisableFirewall: &disableFirewall,
BlockLanAccess: &blockLANAccess,
DisableNotifications: &disableNotifications,
LazyConnectionEnabled: &lazyConnectionEnabled,
BlockInbound: &blockInbound,
DisableIpv6: &disableIPv6,
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
CleanNATExternalIPs: false,
CustomDNSAddress: []byte("1.1.1.1:53"),
ExtraIFaceBlacklist: []string{"eth1", "eth2"},
DnsLabels: []string{"label1", "label2"},
CleanDNSLabels: false,
DnsRouteInterval: durationpb.New(2 * time.Minute),
Mtu: &mtu,
SshJWTCacheTTL: &sshJWTCacheTTL,
ProfileName: profName,
Username: currUser.Username,
ManagementUrl: "https://new-api.netbird.io:443",
AdminURL: "https://new-admin.netbird.io",
RosenpassEnabled: &rosenpassEnabled,
RosenpassPermissive: &rosenpassPermissive,
ServerSSHAllowed: &serverSSHAllowed,
InterfaceName: &interfaceName,
WireguardPort: &wireguardPort,
OptionalPreSharedKey: &preSharedKey,
DisableAutoConnect: &disableAutoConnect,
NetworkMonitor: &networkMonitor,
DisableClientRoutes: &disableClientRoutes,
DisableServerRoutes: &disableServerRoutes,
DisableDns: &disableDNS,
DisableFirewall: &disableFirewall,
BlockLanAccess: &blockLANAccess,
DisableNotifications: &disableNotifications,
BlockInbound: &blockInbound,
DisableIpv6: &disableIPv6,
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
CleanNATExternalIPs: false,
CustomDNSAddress: []byte("1.1.1.1:53"),
ExtraIFaceBlacklist: []string{"eth1", "eth2"},
DnsLabels: []string{"label1", "label2"},
CleanDNSLabels: false,
DnsRouteInterval: durationpb.New(2 * time.Minute),
Mtu: &mtu,
SshJWTCacheTTL: &sshJWTCacheTTL,
}
_, err = s.SetConfig(ctx, req)
@@ -140,7 +138,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.Equal(t, blockLANAccess, cfg.BlockLANAccess)
require.NotNil(t, cfg.DisableNotifications)
require.Equal(t, disableNotifications, *cfg.DisableNotifications)
require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled)
require.Equal(t, blockInbound, cfg.BlockInbound)
require.Equal(t, disableIPv6, cfg.DisableIPv6)
require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs)
@@ -164,13 +161,14 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
t.Helper()
metadataFields := map[string]bool{
"state": true, // protobuf internal
"sizeCache": true, // protobuf internal
"unknownFields": true, // protobuf internal
"Username": true, // metadata
"ProfileName": true, // metadata
"CleanNATExternalIPs": true, // control flag for clearing
"CleanDNSLabels": true, // control flag for clearing
"state": true, // protobuf internal
"sizeCache": true, // protobuf internal
"unknownFields": true, // protobuf internal
"Username": true, // metadata
"ProfileName": true, // metadata
"CleanNATExternalIPs": true, // control flag for clearing
"CleanDNSLabels": true, // control flag for clearing
"LazyConnectionEnabled": true, // deprecated: proto field retained for compat, no longer applied
}
expectedFields := map[string]bool{
@@ -190,7 +188,6 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
"DisableFirewall": true,
"BlockLanAccess": true,
"DisableNotifications": true,
"LazyConnectionEnabled": true,
"BlockInbound": true,
"DisableIpv6": true,
"NatExternalIPs": true,
@@ -252,7 +249,6 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
"block-lan-access": "BlockLanAccess",
"block-inbound": "BlockInbound",
"disable-ipv6": "DisableIpv6",
"enable-lazy-connection": "LazyConnectionEnabled",
"external-ip-map": "NatExternalIPs",
"dns-resolver-address": "CustomDNSAddress",
"extra-iface-blacklist": "ExtraIFaceBlacklist",
@@ -269,7 +265,8 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
// SetConfigRequest fields that don't have CLI flags (settable only via UI or other means).
fieldsWithoutCLIFlags := map[string]bool{
"DisableNotifications": true, // Only settable via UI
"DisableNotifications": true, // Only settable via UI
"LazyConnectionEnabled": true, // deprecated: no longer settable (managed by server + NB_LAZY_CONN)
}
// Get all SetConfigRequest fields to verify our map is complete.

View File

@@ -74,8 +74,6 @@ type Info struct {
BlockInbound bool
DisableIPv6 bool
LazyConnectionEnabled bool
EnableSSHRoot bool
EnableSSHSFTP bool
EnableSSHLocalPortForwarding bool
@@ -87,7 +85,7 @@ func (i *Info) SetFlags(
rosenpassEnabled, rosenpassPermissive bool,
serverSSHAllowed *bool,
disableClientRoutes, disableServerRoutes,
disableDNS, disableFirewall, blockLANAccess, blockInbound, disableIPv6, lazyConnectionEnabled bool,
disableDNS, disableFirewall, blockLANAccess, blockInbound, disableIPv6 bool,
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
disableSSHAuth *bool,
) {
@@ -105,8 +103,6 @@ func (i *Info) SetFlags(
i.BlockInbound = blockInbound
i.DisableIPv6 = disableIPv6
i.LazyConnectionEnabled = lazyConnectionEnabled
if enableSSHRoot != nil {
i.EnableSSHRoot = *enableSSHRoot
}

View File

@@ -266,7 +266,6 @@ type serviceClient struct {
mAllowSSH *systray.MenuItem
mAutoConnect *systray.MenuItem
mEnableRosenpass *systray.MenuItem
mLazyConnEnabled *systray.MenuItem
mBlockInbound *systray.MenuItem
mNotifications *systray.MenuItem
mAdvancedSettings *systray.MenuItem
@@ -336,11 +335,11 @@ type serviceClient struct {
// mNetworks + mExitNode submenu items. Combines features.DisableNetworks
// AND s.connected — both must be true for the menus to be active.
// Zero value (false) matches the Disable() call at AddMenuItem time.
networksMenuEnabled bool
showNetworks bool
wNetworks fyne.Window
wProfiles fyne.Window
wQuickActions fyne.Window
networksMenuEnabled bool
showNetworks bool
wNetworks fyne.Window
wProfiles fyne.Window
wQuickActions fyne.Window
eventManager *event.Manager
@@ -1094,7 +1093,6 @@ func (s *serviceClient) onTrayReady() {
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false)
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false)
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false)
s.mLazyConnEnabled = s.mSettings.AddSubMenuItemCheckbox("Enable Lazy Connections", lazyConnMenuDescr, false)
s.mBlockInbound = s.mSettings.AddSubMenuItemCheckbox("Block Inbound Connections", blockInboundMenuDescr, false)
s.mNotifications = s.mSettings.AddSubMenuItemCheckbox("Notifications", notificationsMenuDescr, false)
s.mSettings.AddSeparator()
@@ -1578,7 +1576,6 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
config.RosenpassEnabled = cfg.RosenpassEnabled
config.RosenpassPermissive = cfg.RosenpassPermissive
config.DisableNotifications = &cfg.DisableNotifications
config.LazyConnectionEnabled = cfg.LazyConnectionEnabled
config.BlockInbound = cfg.BlockInbound
config.NetworkMonitor = &cfg.NetworkMonitor
config.DisableDNS = cfg.DisableDns
@@ -1682,12 +1679,6 @@ func (s *serviceClient) loadSettings() {
s.mEnableRosenpass.Uncheck()
}
if cfg.LazyConnectionEnabled {
s.mLazyConnEnabled.Check()
} else {
s.mLazyConnEnabled.Uncheck()
}
if cfg.BlockInbound {
s.mBlockInbound.Check()
} else {
@@ -1833,7 +1824,6 @@ func (s *serviceClient) updateConfig() error {
disableAutoStart := !s.mAutoConnect.Checked()
sshAllowed := s.mAllowSSH.Checked()
rosenpassEnabled := s.mEnableRosenpass.Checked()
lazyConnectionEnabled := s.mLazyConnEnabled.Checked()
blockInbound := s.mBlockInbound.Checked()
notificationsDisabled := !s.mNotifications.Checked()
@@ -1856,14 +1846,13 @@ func (s *serviceClient) updateConfig() error {
}
req := proto.SetConfigRequest{
ProfileName: activeProf.ID.String(),
Username: currUser.Username,
DisableAutoConnect: &disableAutoStart,
ServerSSHAllowed: &sshAllowed,
RosenpassEnabled: &rosenpassEnabled,
LazyConnectionEnabled: &lazyConnectionEnabled,
BlockInbound: &blockInbound,
DisableNotifications: &notificationsDisabled,
ProfileName: activeProf.ID.String(),
Username: currUser.Username,
DisableAutoConnect: &disableAutoStart,
ServerSSHAllowed: &sshAllowed,
RosenpassEnabled: &rosenpassEnabled,
BlockInbound: &blockInbound,
DisableNotifications: &notificationsDisabled,
}
if _, err := conn.SetConfig(s.ctx, &req); err != nil {

View File

@@ -4,7 +4,6 @@ const (
allowSSHMenuDescr = "Allow SSH connections"
autoConnectMenuDescr = "Connect automatically when the service starts"
quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass"
lazyConnMenuDescr = "[Experimental] Enable lazy connections"
blockInboundMenuDescr = "Block inbound connections to the local machine and routed networks"
notificationsMenuDescr = "Enable notifications"
advancedSettingsMenuDescr = "Advanced settings of the application"

View File

@@ -43,8 +43,6 @@ func (h *eventHandler) listen(ctx context.Context) {
h.handleAutoConnectClick()
case <-h.client.mEnableRosenpass.ClickedCh:
h.handleRosenpassClick()
case <-h.client.mLazyConnEnabled.ClickedCh:
h.handleLazyConnectionClick()
case <-h.client.mBlockInbound.ClickedCh:
h.handleBlockInboundClick()
case <-h.client.mAdvancedSettings.ClickedCh:
@@ -152,15 +150,6 @@ func (h *eventHandler) handleRosenpassClick() {
}
}
func (h *eventHandler) handleLazyConnectionClick() {
h.toggleCheckbox(h.client.mLazyConnEnabled)
if err := h.updateConfigWithErr(); err != nil {
h.toggleCheckbox(h.client.mLazyConnEnabled) // revert checkbox state on error
log.Errorf("failed to update config: %v", err)
h.client.notifier.Send("Error", "Failed to update lazy connection settings")
}
}
func (h *eventHandler) handleBlockInboundClick() {
h.toggleCheckbox(h.client.mBlockInbound)
if err := h.updateConfigWithErr(); err != nil {

151
combined/cmd/admin.go Normal file
View File

@@ -0,0 +1,151 @@
package cmd
import (
"context"
"fmt"
"github.com/dexidp/dex/storage"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
activitystore "github.com/netbirdio/netbird/management/server/activity/store"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
// newAdminCommands creates the admin command tree with combined-specific resource openers.
func newAdminCommands() *cobra.Command {
return admincmd.NewCommands(admincmd.Openers{
Resources: withAdminResources,
Store: withAdminStoreOnly,
IDP: withAdminIDPOnly,
})
}
func newLegacyTokenCommand() *cobra.Command {
cmd := tokencmd.NewCommands(tokencmd.StoreOpener(withAdminStoreOnly))
cmd.Deprecated = "use 'admin token' instead"
return cmd
}
// withAdminResources loads the combined YAML config, initializes stores, and calls fn.
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
return withAdminConfig(cmd, func(ctx context.Context, cfg *CombinedConfig) error {
mgmtConfig, err := adminManagementConfig(cfg)
if err != nil {
return err
}
managementStore, err := openAdminStore(ctx, cfg)
if err != nil {
return err
}
defer admincmd.CloseStore(ctx, managementStore)
idpStorage, idpStorageFile, err := admincmd.OpenIDPStorage(mgmtConfig)
if err != nil {
return err
}
defer admincmd.CloseIDPStorage(idpStorage)
eventStore, esErr := openAdminEventStore(ctx, cfg, mgmtConfig)
if esErr != nil {
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Warning: audit events will not be recorded: %v\n", esErr)
}
if eventStore != nil {
defer func() {
if err := eventStore.Close(ctx); err != nil {
log.Debugf("close activity event store: %v", err)
}
}()
}
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage, IDPStorageFile: idpStorageFile, EventStore: eventStore})
})
}
// withAdminStoreOnly opens only the management store for admin subcommands that do not
// need embedded IdP storage.
func withAdminStoreOnly(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
return withAdminConfig(cmd, func(ctx context.Context, cfg *CombinedConfig) error {
managementStore, err := openAdminStore(ctx, cfg)
if err != nil {
return err
}
defer admincmd.CloseStore(ctx, managementStore)
return fn(ctx, managementStore)
})
}
func withAdminIDPOnly(cmd *cobra.Command, fn func(ctx context.Context, idpStorage storage.Storage, storageFile string) error) error {
return withAdminConfig(cmd, func(ctx context.Context, cfg *CombinedConfig) error {
mgmtConfig, err := adminManagementConfig(cfg)
if err != nil {
return err
}
idpStorage, idpStorageFile, err := admincmd.OpenIDPStorage(mgmtConfig)
if err != nil {
return err
}
defer admincmd.CloseIDPStorage(idpStorage)
return fn(ctx, idpStorage, idpStorageFile)
})
}
func withAdminConfig(cmd *cobra.Command, fn func(ctx context.Context, cfg *CombinedConfig) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
cfg, err := LoadConfig(configPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
cfg.ApplyAdminDefaults()
applyServerStoreEnv(cfg.Server.Store)
return fn(ctx, cfg)
}
func adminManagementConfig(cfg *CombinedConfig) (*nbconfig.Config, error) {
mgmtConfig, err := cfg.ToManagementConfig()
if err != nil {
return nil, fmt.Errorf("create management config: %w", err)
}
return mgmtConfig, nil
}
func openAdminStore(ctx context.Context, cfg *CombinedConfig) (store.Store, error) {
managementStore, err := store.NewStore(ctx, types.Engine(cfg.Management.Store.Engine), cfg.Management.DataDir, nil, true)
if err != nil {
return nil, fmt.Errorf("create store: %w", err)
}
return managementStore, nil
}
func openAdminEventStore(ctx context.Context, cfg *CombinedConfig, config *nbconfig.Config) (activity.Store, error) {
if config.DataStoreEncryptionKey == "" {
return nil, fmt.Errorf("data store encryption key is not configured")
}
if err := applyActivityStoreEnv(cfg.Server.ActivityStore); err != nil {
return nil, fmt.Errorf("configure activity event store: %w", err)
}
eventStore, err := activitystore.NewSqlStore(ctx, config.Datadir, config.DataStoreEncryptionKey)
if err != nil {
return nil, fmt.Errorf("open activity event store: %w", err)
}
if eventStore == nil {
return nil, fmt.Errorf("open activity event store: returned nil store")
}
return eventStore, nil
}

View File

@@ -0,0 +1,47 @@
package cmd
import (
"context"
"os"
"testing"
"github.com/stretchr/testify/require"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
)
func TestApplyAdminDefaultsCopiesServerStoreWithoutExposedAddress(t *testing.T) {
cfg := DefaultConfig()
cfg.Server.ExposedAddress = ""
cfg.Server.DataDir = "/srv/netbird"
cfg.Server.Store = StoreConfig{
Engine: "postgres",
DSN: "postgres://user:pass@example.com/netbird",
}
cfg.ApplyAdminDefaults()
require.Equal(t, "/srv/netbird", cfg.Management.DataDir)
require.Equal(t, "postgres", cfg.Management.Store.Engine)
require.Equal(t, cfg.Server.Store.DSN, cfg.Management.Store.DSN)
}
func TestOpenAdminEventStoreMissingEncryptionKeyReturnsNilInterface(t *testing.T) {
eventStore, err := openAdminEventStore(context.Background(), &CombinedConfig{}, &nbconfig.Config{})
require.Error(t, err)
require.Contains(t, err.Error(), "encryption key")
require.Nil(t, eventStore)
}
func TestApplyServerStoreEnv(t *testing.T) {
t.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", "")
t.Setenv("NB_STORE_ENGINE_MYSQL_DSN", "")
t.Setenv("NB_STORE_ENGINE_SQLITE_FILE", "")
applyServerStoreEnv(StoreConfig{Engine: "postgres", DSN: "postgres-dsn", File: "store.db"})
require.Equal(t, "postgres-dsn", os.Getenv("NB_STORE_ENGINE_POSTGRES_DSN"))
require.Equal(t, "store.db", os.Getenv("NB_STORE_ENGINE_SQLITE_FILE"))
applyServerStoreEnv(StoreConfig{Engine: "mysql", DSN: "mysql-dsn"})
require.Equal(t, "mysql-dsn", os.Getenv("NB_STORE_ENGINE_MYSQL_DSN"))
}

View File

@@ -6,8 +6,7 @@ import (
"net"
"net/netip"
"os"
"path"
"path/filepath"
filePath "path/filepath"
"strings"
"time"
@@ -299,6 +298,19 @@ func (c *CombinedConfig) ApplySimplifiedDefaults() {
c.autoConfigureClientSettings(exposedProto, exposedHost, exposedHostPort, hasExternalStuns, hasExternalRelay, hasExternalSignal)
}
// ApplyAdminDefaults applies the management settings needed by admin commands even
// when the full server config is invalid and ApplySimplifiedDefaults cannot run.
func (c *CombinedConfig) ApplyAdminDefaults() {
if c.Management.DataDir == "" || c.Management.DataDir == "/var/lib/netbird/" {
c.Management.DataDir = c.Server.DataDir
}
if c.Management.Store.Engine == "" || c.Management.Store.Engine == "sqlite" {
if c.Server.Store.Engine != "" || c.Server.Store.File != "" || c.Server.Store.DSN != "" {
c.Management.Store = c.Server.Store
}
}
}
// applyRelayDefaults configures the relay service if no external relay is configured.
func (c *CombinedConfig) applyRelayDefaults(exposedProto, exposedHostPort string, hasExternalRelay, hasExternalStuns bool) {
if hasExternalRelay {
@@ -576,11 +588,11 @@ func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.Emb
return nil, fmt.Errorf("authStore.dsn is required when authStore.engine is postgres")
}
} else {
authStorageFile = path.Join(mgmt.DataDir, "idp.db")
authStorageFile = filePath.Join(mgmt.DataDir, "idp.db")
if c.Server.AuthStore.File != "" {
authStorageFile = c.Server.AuthStore.File
if !filepath.IsAbs(authStorageFile) {
authStorageFile = filepath.Join(mgmt.DataDir, authStorageFile)
if !filePath.IsAbs(authStorageFile) {
authStorageFile = filePath.Join(mgmt.DataDir, authStorageFile)
}
}
}
@@ -727,7 +739,7 @@ func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config, mgmtPort
cfg.EmbeddedIdP.Storage.Type = "sqlite3"
}
if cfg.EmbeddedIdP.Storage.Config.File == "" && cfg.Datadir != "" {
cfg.EmbeddedIdP.Storage.Config.File = path.Join(cfg.Datadir, "idp.db")
cfg.EmbeddedIdP.Storage.Config.File = filePath.Join(cfg.Datadir, "idp.db")
}
issuer := cfg.EmbeddedIdP.Issuer

View File

@@ -64,7 +64,8 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
_ = rootCmd.MarkPersistentFlagRequired("config")
rootCmd.AddCommand(newTokenCommands())
rootCmd.AddCommand(newAdminCommands())
rootCmd.AddCommand(newLegacyTokenCommand())
}
func RootCmd() *cobra.Command {
@@ -122,6 +123,37 @@ func execute(cmd *cobra.Command, _ []string) error {
}
// initializeConfig loads and validates the configuration, then initializes logging.
func applyServerStoreEnv(storeConfig StoreConfig) {
if dsn := storeConfig.DSN; dsn != "" {
switch strings.ToLower(storeConfig.Engine) {
case "postgres":
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
case "mysql":
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
if file := storeConfig.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
}
func applyActivityStoreEnv(storeConfig StoreConfig) error {
if engine := storeConfig.Engine; engine != "" {
engineLower := strings.ToLower(engine)
if engineLower == "postgres" && storeConfig.DSN == "" {
return fmt.Errorf("activityStore.dsn is required when activityStore.engine is postgres")
}
os.Setenv("NB_ACTIVITY_EVENT_STORE_ENGINE", engineLower)
if dsn := storeConfig.DSN; dsn != "" {
os.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN", dsn)
}
}
if file := storeConfig.File; file != "" {
os.Setenv("NB_ACTIVITY_EVENT_SQLITE_FILE", file)
}
return nil
}
func initializeConfig() error {
var err error
config, err = LoadConfig(configPath)
@@ -137,30 +169,10 @@ func initializeConfig() error {
return fmt.Errorf("failed to initialize log: %w", err)
}
if dsn := config.Server.Store.DSN; dsn != "" {
switch strings.ToLower(config.Server.Store.Engine) {
case "postgres":
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
case "mysql":
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
if file := config.Server.Store.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
applyServerStoreEnv(config.Server.Store)
if engine := config.Server.ActivityStore.Engine; engine != "" {
engineLower := strings.ToLower(engine)
if engineLower == "postgres" && config.Server.ActivityStore.DSN == "" {
return fmt.Errorf("activityStore.dsn is required when activityStore.engine is postgres")
}
os.Setenv("NB_ACTIVITY_EVENT_STORE_ENGINE", engineLower)
if dsn := config.Server.ActivityStore.DSN; dsn != "" {
os.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN", dsn)
}
}
if file := config.Server.ActivityStore.File; file != "" {
os.Setenv("NB_ACTIVITY_EVENT_SQLITE_FILE", file)
if err := applyActivityStoreEnv(config.Server.ActivityStore); err != nil {
return err
}
log.Infof("Starting combined NetBird server")

View File

@@ -1,63 +0,0 @@
package cmd
import (
"context"
"fmt"
"os"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
// newTokenCommands creates the token command tree with combined-specific store opener.
func newTokenCommands() *cobra.Command {
return tokencmd.NewCommands(withTokenStore)
}
// withTokenStore loads the combined YAML config, initializes the store, and calls fn.
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
cfg, err := LoadConfig(configPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
if dsn := cfg.Server.Store.DSN; dsn != "" {
switch strings.ToLower(cfg.Server.Store.Engine) {
case "postgres":
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
case "mysql":
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
if file := cfg.Server.Store.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
datadir := cfg.Management.DataDir
engine := types.Engine(cfg.Management.Store.Engine)
s, err := store.NewStore(ctx, engine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, s)
}

View File

@@ -0,0 +1,171 @@
//go:build e2e
package agentnetwork
import (
"context"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/e2e/harness"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// TestVLLMProvider proves the proxy supports a self-hosted vLLM backend. vLLM is
// OpenAI-compatible, so it uses the "vllm" catalog entry (KindCustom) and is
// reached over plain HTTP — no TLS anywhere on the path:
//
// client --tunnel--> netbird proxy --http--> vllm (:8000, OpenAI-compatible)
//
// The mock vLLM server answers /v1/chat/completions with an OpenAI-shaped
// completion carrying a non-zero usage block. The test asserts the chat returns
// 200 with the completion, that the request is recorded in the access log by its
// session id, and that vLLM's usage block is metered into a consumption row —
// which together prove request routing, response parsing, and token accounting
// all work for a self-hosted OpenAI-compatible provider.
//
// It needs no external credentials (the mock ignores auth), so it always runs.
func TestVLLMProvider(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
defer cancel()
vllm, err := harness.StartVLLM(ctx, srv)
require.NoError(t, err, "start mock vLLM server")
t.Cleanup(func() { _ = vllm.Terminate(context.Background()) })
grp, err := srv.API().Groups.Create(ctx, api.PostApiGroupsJSONRequestBody{Name: "e2e-vllm"})
require.NoError(t, err, "create group")
t.Cleanup(func() { _ = srv.API().Groups.Delete(context.Background(), grp.Id) })
ephemeral := false
sk, err := srv.API().SetupKeys.Create(ctx, api.PostApiSetupKeysJSONRequestBody{
Name: "e2e-vllm-client",
Type: "reusable",
ExpiresIn: 86400,
UsageLimit: 0,
AutoGroups: []string{grp.Id},
Ephemeral: &ephemeral,
})
require.NoError(t, err, "mint setup key")
require.NotEmpty(t, sk.Key, "setup key plaintext")
// vLLM provider pointed at the mock over plain HTTP. The mock ignores auth,
// so a dummy key satisfies the "Bearer ${API_KEY}" template. The served model
// is enumerated so the router dispatches this model string to this provider.
dummyKey := "sk-vllm-e2e"
prov, err := srv.CreateProvider(ctx, api.AgentNetworkProviderRequest{
Name: "vllm",
ProviderId: "vllm",
UpstreamUrl: vllm.URL,
ApiKey: &dummyKey,
Enabled: ptr(true),
BootstrapCluster: ptr(harness.AgentNetworkCluster),
Models: &[]api.AgentNetworkProviderModel{
{Id: harness.VLLMModel, InputPer1k: 0.001, OutputPer1k: 0.002},
},
})
require.NoError(t, err, "create vllm provider")
t.Cleanup(func() { _ = srv.DeleteProvider(context.Background(), prov.Id) })
// Token limit far above the handful of tokens this test drives, so it never
// blocks but switches on usage metering — the switch that makes consumption
// rows get recorded.
enabled := true
pol, err := srv.CreatePolicy(ctx, api.AgentNetworkPolicyRequest{
Name: "e2e-vllm-allow",
Enabled: &enabled,
SourceGroups: []string{grp.Id},
DestinationProviderIds: []string{prov.Id},
Limits: &api.AgentNetworkPolicyLimits{
TokenLimit: api.AgentNetworkPolicyTokenLimit{
Enabled: true,
GroupCap: 10_000_000,
UserCap: 10_000_000,
WindowSeconds: 60,
},
},
})
require.NoError(t, err, "create policy")
t.Cleanup(func() { _ = srv.DeletePolicy(context.Background(), pol.Id) })
settings, err := srv.GetSettings(ctx)
require.NoError(t, err, "read settings")
require.NotEmpty(t, settings.Endpoint, "endpoint must be assigned")
proxyToken, err := srv.CreateProxyTokenCLI(ctx, "e2e-vllm-proxy")
require.NoError(t, err, "mint proxy token")
px, err := harness.StartProxy(ctx, srv, proxyToken)
require.NoError(t, err, "start proxy")
t.Cleanup(func() { _ = px.Terminate(context.Background()) })
cl, err := harness.StartClient(ctx, srv, sk.Key)
require.NoError(t, err, "start client")
t.Cleanup(func() { _ = cl.Terminate(context.Background()) })
require.NoError(t, cl.WaitConnected(ctx, 90*time.Second), "client must connect to management")
if err := cl.WaitProxyPeer(ctx, 180*time.Second); err != nil {
t.Fatalf("client did not see the proxy peer: %v\n=== proxy logs ===\n%s", err, px.Logs(context.Background()))
}
proxyIP, err := cl.ResolveProxyIP(ctx, settings.Endpoint)
require.NoError(t, err, "resolve endpoint to proxy IP")
before, _ := srv.ListAccessLogs(ctx)
sessionID := "e2e-session-vllm"
// Retry to absorb tunnel/DNS jitter on the first call.
var code int
var body string
deadline := time.Now().Add(90 * time.Second)
for time.Now().Before(deadline) {
c, b, cerr := cl.Chat(ctx, settings.Endpoint, proxyIP, harness.WireChat, harness.VLLMModel, "Reply with exactly: pong", sessionID)
if cerr == nil {
code, body = c, b
if code == 200 {
break
}
}
time.Sleep(5 * time.Second)
}
require.Equal(t, 200, code,
"chat through the vLLM provider must return 200; body: %s\n=== vllm logs ===\n%s\n=== proxy logs ===\n%s",
body, vllm.Logs(context.Background()), px.Logs(context.Background()))
require.True(t, strings.Contains(body, "chat.completion"),
"body should be an OpenAI-compatible chat completion; got: %s", body)
// The request must surface as an access-log row carrying our session id.
require.Eventually(t, func() bool {
logs, lerr := srv.ListAccessLogs(ctx)
return lerr == nil && logs.TotalRecords > before.TotalRecords
}, 30*time.Second, 2*time.Second, "an access-log row should be ingested for the vLLM provider")
require.Eventually(t, func() bool {
logs, lerr := srv.ListAccessLogs(ctx)
if lerr != nil {
return false
}
for _, r := range logs.Data {
if r.SessionId != nil && *r.SessionId == sessionID {
return true
}
}
return false
}, 30*time.Second, 2*time.Second, "session id %q must be recorded in an access-log row", sessionID)
// vLLM's usage block (prompt_tokens=11, completion_tokens=2) must be parsed
// and metered into a consumption row with positive token counts.
require.Eventually(t, func() bool {
rows, lerr := srv.ListConsumption(ctx)
if lerr != nil {
return false
}
for _, r := range rows {
if r.TokensInput > 0 && r.TokensOutput > 0 {
return true
}
}
return false
}, 60*time.Second, 3*time.Second, "vLLM usage must be metered into a consumption row")
}

113
e2e/harness/vllm.go Normal file
View File

@@ -0,0 +1,113 @@
//go:build e2e
package harness
import (
"context"
"fmt"
"os"
"path/filepath"
"time"
"github.com/docker/docker/api/types/container"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
)
const (
vllmImage = "nginx:alpine"
vllmAlias = "vllm"
vllmPort = "8000/tcp"
// VLLMModel is the served model id the mock advertises and echoes back. It
// matches a real small model commonly served by vLLM so the provider's
// enumerated model and the client's request line up.
VLLMModel = "Qwen/Qwen2.5-0.5B-Instruct"
)
// vllmNginxConf emulates a vLLM OpenAI-compatible server over plain HTTP (vLLM's
// default: no TLS, port 8000). It answers /v1/models with a one-model list and
// any chat/completions path with a canned OpenAI-shaped chat completion carrying
// a non-zero usage block, so the proxy's OpenAI parser records real token
// consumption. Running actual vLLM in CI is infeasible (GPU + multi-GB model
// download), so this stands in for the wire contract the proxy depends on.
const vllmNginxConf = `pid /tmp/nginx.pid;
events {}
http {
server {
listen 8000;
location = /v1/models {
default_type application/json;
return 200 '{"object":"list","data":[{"id":"Qwen/Qwen2.5-0.5B-Instruct","object":"model","owned_by":"vllm"}]}';
}
location / {
default_type application/json;
return 200 '{"id":"chatcmpl-e2e-vllm","object":"chat.completion","created":1700000000,"model":"Qwen/Qwen2.5-0.5B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"pong"},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"completion_tokens":2,"total_tokens":13}}';
}
}
}
`
// VLLM is a mock vLLM OpenAI-compatible server on the combined server's network,
// reachable at http://vllm:8000. A "vllm" provider points at it to exercise the
// proxy's support for self-hosted OpenAI-compatible backends.
type VLLM struct {
container testcontainers.Container
workDir string
// URL is the upstream URL the vllm provider points at (http://<alias>:8000).
URL string
}
// StartVLLM runs the mock vLLM server on the shared network over plain HTTP.
func StartVLLM(ctx context.Context, c *Combined) (*VLLM, error) {
workDir, err := os.MkdirTemp("/tmp", "nb-e2e-vllm-*")
if err != nil {
return nil, fmt.Errorf("create vllm work dir: %w", err)
}
// Widen so the (non-root worker) nginx container can traverse the bind mount.
if err := os.Chmod(workDir, 0o755); err != nil { //nolint:gosec // throwaway e2e config dir
return nil, fmt.Errorf("chmod vllm dir: %w", err)
}
if err := os.WriteFile(filepath.Join(workDir, "nginx.conf"), []byte(vllmNginxConf), 0o644); err != nil { //nolint:gosec // non-secret e2e config
return nil, fmt.Errorf("write nginx conf: %w", err)
}
req := testcontainers.ContainerRequest{
Image: vllmImage,
ExposedPorts: []string{vllmPort},
Networks: []string{c.network.Name},
NetworkAliases: map[string][]string{c.network.Name: {vllmAlias}},
Cmd: []string{"nginx", "-c", "/conf/nginx.conf", "-g", "daemon off;"},
HostConfigModifier: func(hc *container.HostConfig) {
hc.Binds = append(hc.Binds, workDir+":/conf:ro")
},
WaitingFor: wait.ForListeningPort(vllmPort).WithStartupTimeout(60 * time.Second),
}
ctr, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
})
if err != nil {
_ = os.RemoveAll(workDir)
return nil, fmt.Errorf("start vllm container: %w", err)
}
return &VLLM{container: ctr, workDir: workDir, URL: "http://" + vllmAlias + ":8000"}, nil
}
// Logs returns the vLLM container logs, for diagnostics on failure.
func (v *VLLM) Logs(ctx context.Context) string {
return containerLogs(ctx, v.container)
}
// Terminate stops the vLLM container and cleans its work dir.
func (v *VLLM) Terminate(ctx context.Context) error {
var err error
if v.container != nil {
err = v.container.Terminate(ctx)
}
if v.workDir != "" {
_ = os.RemoveAll(v.workDir)
}
return err
}

View File

@@ -41,7 +41,7 @@ type Config struct {
GRPCAddr string
}
const localConnectorID = "local"
const LocalConnectorID = "local"
// Provider wraps a Dex server
type Provider struct {
@@ -495,18 +495,60 @@ func (p *Provider) Storage() storage.Storage {
return p.storage
}
// SetClientsMFAChain updates the MFAChain field on OAuth2 clients in Dex storage.
// Pass a non-empty slice (e.g. []string{"default-totp"}) to enable MFA, or nil to disable it.
func SetClientsMFAChain(ctx context.Context, st storage.Storage, clientIDs []string, mfaChain []string) error {
previousChains := make(map[string][]string, len(clientIDs))
for _, clientID := range clientIDs {
client, err := st.GetClient(ctx, clientID)
if err != nil {
return fmt.Errorf("failed to get client %s before MFA chain update: %w", clientID, err)
}
previousChains[clientID] = cloneMFAChain(client.MFAChain)
}
updatedClientIDs := make([]string, 0, len(clientIDs))
for _, clientID := range clientIDs {
if err := st.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
old.MFAChain = cloneMFAChain(mfaChain)
return old, nil
}); err != nil {
if rollbackErr := rollbackClientsMFAChain(ctx, st, updatedClientIDs, previousChains); rollbackErr != nil {
return fmt.Errorf("failed to update MFA chain on client %s: %w (also failed to roll back previous MFA chains: %v)", clientID, err, rollbackErr)
}
return fmt.Errorf("failed to update MFA chain on client %s: %w", clientID, err)
}
updatedClientIDs = append(updatedClientIDs, clientID)
}
return nil
}
func rollbackClientsMFAChain(ctx context.Context, st storage.Storage, clientIDs []string, previousChains map[string][]string) error {
var rollbackErrs []error
for i := len(clientIDs) - 1; i >= 0; i-- {
clientID := clientIDs[i]
previousChain := cloneMFAChain(previousChains[clientID])
if err := st.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
old.MFAChain = previousChain
return old, nil
}); err != nil {
rollbackErrs = append(rollbackErrs, fmt.Errorf("client %s: %w", clientID, err))
}
}
return errors.Join(rollbackErrs...)
}
func cloneMFAChain(chain []string) []string {
if chain == nil {
return nil
}
return append([]string(nil), chain...)
}
// SetClientsMFAChain updates the MFAChain field on the dashboard and CLI OAuth2 clients.
// Pass a non-empty slice (e.g. []string{"default-totp"}) to enable MFA, or nil to disable it.
func (p *Provider) SetClientsMFAChain(ctx context.Context, clientIDs []string, mfaChain []string) error {
for _, clientID := range clientIDs {
if err := p.storage.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
old.MFAChain = mfaChain
return old, nil
}); err != nil {
return fmt.Errorf("failed to update MFA chain on client %s: %w", clientID, err)
}
}
return nil
return SetClientsMFAChain(ctx, p.storage, clientIDs, mfaChain)
}
// Handler returns the Dex server as an http.Handler for embedding in another server.
@@ -546,7 +588,7 @@ func (p *Provider) CreateUser(ctx context.Context, email, username, password str
// Encode the user ID in Dex's format: base64(protobuf{user_id, connector_id})
// This matches the format Dex uses in JWT tokens
encodedID := EncodeDexUserID(userID, localConnectorID)
encodedID := EncodeDexUserID(userID, LocalConnectorID)
return encodedID, nil
}
@@ -625,7 +667,7 @@ func DecodeDexUserID(encodedID string) (userID, connectorID string, err error) {
// local password connector.
func IsLocalUserID(encodedID string) bool {
_, connectorID, err := DecodeDexUserID(encodedID)
return err == nil && connectorID == localConnectorID
return err == nil && connectorID == LocalConnectorID
}
// GetUser returns a user by email

View File

@@ -3,6 +3,8 @@ package dex
import (
"context"
"encoding/json"
"errors"
"io"
"log/slog"
"net/http"
"net/http/httptest"
@@ -11,11 +13,44 @@ import (
"testing"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/memory"
sqllib "github.com/dexidp/dex/storage/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type updateFailingStorage struct {
storage.Storage
failClientID string
}
func (s *updateFailingStorage) UpdateClient(ctx context.Context, id string, updater func(storage.Client) (storage.Client, error)) error {
if id == s.failClientID {
return errors.New("forced update failure")
}
return s.Storage.UpdateClient(ctx, id, updater)
}
func TestSetClientsMFAChainRollsBackUpdatedClients(t *testing.T) {
ctx := context.Background()
st := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
require.NoError(t, st.CreateClient(ctx, storage.Client{ID: "client-1", MFAChain: []string{"old-1"}}))
require.NoError(t, st.CreateClient(ctx, storage.Client{ID: "client-2", MFAChain: []string{"old-2"}}))
err := SetClientsMFAChain(ctx, &updateFailingStorage{Storage: st, failClientID: "client-2"}, []string{"client-1", "client-2"}, []string{"new"})
require.Error(t, err)
require.Contains(t, err.Error(), "failed to update MFA chain on client client-2")
client1, err := st.GetClient(ctx, "client-1")
require.NoError(t, err)
require.Equal(t, []string{"old-1"}, client1.MFAChain)
client2, err := st.GetClient(ctx, "client-2")
require.NoError(t, err)
require.Equal(t, []string{"old-2"}, client2.MFAChain)
}
func TestUserCreationFlow(t *testing.T) {
ctx := context.Background()

View File

@@ -351,11 +351,6 @@ initialize_default_values() {
NETBIRD_STUN_PORT=3478
# Docker images
# Record whether the operator explicitly pinned the server/proxy images via
# env vars, so the agent-network preset can pick its own defaults without
# clobbering an explicit override.
NETBIRD_SERVER_IMAGE_EXPLICIT=${NETBIRD_SERVER_IMAGE:+true}
NETBIRD_PROXY_IMAGE_EXPLICIT=${NETBIRD_PROXY_IMAGE:+true}
DASHBOARD_IMAGE=${DASHBOARD_IMAGE:-"netbirdio/dashboard:latest"}
# Combined server replaces separate signal, relay, and management containers
NETBIRD_SERVER_IMAGE=${NETBIRD_SERVER_IMAGE:-"netbirdio/netbird-server:latest"}
@@ -415,15 +410,6 @@ apply_agent_network_preset() {
ENABLE_PROXY="true"
ENABLE_CROWDSEC="false"
# Agent-network ships dedicated server/proxy images. Honor an explicit
# env override; otherwise pin the agent-network builds.
if [[ "${NETBIRD_SERVER_IMAGE_EXPLICIT}" != "true" ]]; then
NETBIRD_SERVER_IMAGE="netbirdio/netbird-server:0.74.0-rc.2"
fi
if [[ "${NETBIRD_PROXY_IMAGE_EXPLICIT}" != "true" ]]; then
NETBIRD_PROXY_IMAGE="netbirdio/reverse-proxy:0.74.0-rc.2"
fi
if [[ -n "${NETBIRD_LETSENCRYPT_EMAIL}" ]]; then
TRAEFIK_ACME_EMAIL="${NETBIRD_LETSENCRYPT_EMAIL}"
else
@@ -570,7 +556,7 @@ start_services_and_show_instructions() {
echo "Creating proxy access token..."
# Use docker exec with bash to run the token command directly
PROXY_TOKEN=$($DOCKER_COMPOSE_COMMAND exec -T netbird-server \
/go/bin/netbird-server token create --name "default-proxy" --config /etc/netbird/config.yaml 2>/dev/null | grep "^Token:" | awk '{print $2}')
/go/bin/netbird-server admin token create --name "default-proxy" --config /etc/netbird/config.yaml 2>/dev/null | grep "^Token:" | awk '{print $2}')
if [[ -z "$PROXY_TOKEN" ]]; then
echo "ERROR: Failed to create proxy token. Check netbird-server logs." > /dev/stderr

177
management/cmd/admin.go Normal file
View File

@@ -0,0 +1,177 @@
package cmd
import (
"context"
"fmt"
"path"
"path/filepath"
"github.com/dexidp/dex/storage"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
activitystore "github.com/netbirdio/netbird/management/server/activity/store"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
)
var adminDatadir string
// newAdminCommands creates the admin command tree with management-specific resource openers.
func newAdminCommands() *cobra.Command {
cmd := admincmd.NewCommands(admincmd.Openers{
Resources: withAdminResources,
Store: withAdminStoreOnly,
IDP: withAdminIDPOnly,
})
cmd.PersistentFlags().StringVar(&adminDatadir, "datadir", "", "Override the data directory from config (used for store.db and the default idp.db)")
return cmd
}
func newLegacyTokenCommand() *cobra.Command {
cmd := tokencmd.NewCommands(tokencmd.StoreOpener(withAdminStoreOnly))
cmd.Deprecated = "use 'admin token' instead"
cmd.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
return cmd
}
// withAdminResources initializes logging, loads config, opens the management store
// and embedded IdP storage, and calls fn.
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
return withAdminConfig(cmd, true, func(ctx context.Context, config *nbconfig.Config, datadir string) error {
managementStore, err := openAdminStore(ctx, config, datadir)
if err != nil {
return err
}
defer admincmd.CloseStore(ctx, managementStore)
idpStorage, idpStorageFile, err := admincmd.OpenIDPStorage(config)
if err != nil {
return err
}
defer admincmd.CloseIDPStorage(idpStorage)
eventStore, esErr := openAdminEventStore(ctx, config, datadir)
if esErr != nil {
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Warning: audit events will not be recorded: %v\n", esErr)
}
if eventStore != nil {
defer func() {
if err := eventStore.Close(ctx); err != nil {
log.Debugf("close activity event store: %v", err)
}
}()
}
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage, IDPStorageFile: idpStorageFile, EventStore: eventStore})
})
}
// withAdminStoreOnly opens only the management store for admin subcommands that do not
// need embedded IdP storage.
func withAdminStoreOnly(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
return withAdminConfig(cmd, false, func(ctx context.Context, config *nbconfig.Config, datadir string) error {
managementStore, err := openAdminStore(ctx, config, datadir)
if err != nil {
return err
}
defer admincmd.CloseStore(ctx, managementStore)
return fn(ctx, managementStore)
})
}
func withAdminIDPOnly(cmd *cobra.Command, fn func(ctx context.Context, idpStorage storage.Storage, storageFile string) error) error {
return withAdminConfig(cmd, true, func(ctx context.Context, config *nbconfig.Config, _ string) error {
idpStorage, idpStorageFile, err := admincmd.OpenIDPStorage(config)
if err != nil {
return err
}
defer admincmd.CloseIDPStorage(idpStorage)
return fn(ctx, idpStorage, idpStorageFile)
})
}
func withAdminConfig(cmd *cobra.Command, applyIDPDefaults bool, fn func(ctx context.Context, config *nbconfig.Config, datadir string) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
config, datadir, err := loadAdminMgmtConfig(ctx, applyIDPDefaults)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
return fn(ctx, config, datadir)
}
func loadAdminMgmtConfig(ctx context.Context, applyIDPDefaults bool) (*nbconfig.Config, string, error) {
config := &nbconfig.Config{}
if _, err := util.ReadJsonWithEnvSub(nbconfig.MgmtConfigPath, config); err != nil {
return nil, "", err
}
if applyIDPDefaults {
if err := ApplyEmbeddedIdPConfig(ctx, config); err != nil {
return nil, "", err
}
}
datadir := config.Datadir
applyAdminDatadirOverride(config, &datadir)
return config, datadir, nil
}
func applyAdminDatadirOverride(config *nbconfig.Config, datadir *string) {
if adminDatadir == "" {
return
}
oldDatadir := *datadir
*datadir = adminDatadir
if config.EmbeddedIdP != nil && config.EmbeddedIdP.Storage.Type == "sqlite3" && isDefaultIDPStorageFile(config.EmbeddedIdP.Storage.Config.File, oldDatadir) {
config.EmbeddedIdP.Storage.Config.File = filepath.Join(*datadir, "idp.db")
}
}
func isDefaultIDPStorageFile(file, datadir string) bool {
if file == "" {
return true
}
defaultFile := filepath.Join(datadir, "idp.db")
legacyDefaultFile := path.Join(datadir, "idp.db")
legacySlashDefaultFile := path.Join(filepath.ToSlash(datadir), "idp.db")
return filepath.Clean(file) == filepath.Clean(defaultFile) ||
file == legacyDefaultFile ||
filepath.ToSlash(file) == legacySlashDefaultFile
}
func openAdminStore(ctx context.Context, config *nbconfig.Config, datadir string) (store.Store, error) {
managementStore, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
if err != nil {
return nil, fmt.Errorf("create store: %w", err)
}
return managementStore, nil
}
func openAdminEventStore(ctx context.Context, config *nbconfig.Config, datadir string) (activity.Store, error) {
if config.DataStoreEncryptionKey == "" {
return nil, fmt.Errorf("data store encryption key is not configured")
}
eventStore, err := activitystore.NewSqlStore(ctx, datadir, config.DataStoreEncryptionKey)
if err != nil {
return nil, fmt.Errorf("open activity event store: %w", err)
}
if eventStore == nil {
return nil, fmt.Errorf("open activity event store: returned nil store")
}
return eventStore, nil
}

View File

@@ -0,0 +1,577 @@
// Package admincmd provides reusable cobra commands for self-hosted administrator helpers.
// Both the management and combined binaries use these commands, each providing
// their own opener to handle config loading and storage initialization.
package admincmd
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"os"
"strings"
"time"
"github.com/dexidp/dex/storage"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/crypto/bcrypt"
"github.com/netbirdio/netbird/formatter/hook"
nbdex "github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/cmd/proxy"
"github.com/netbirdio/netbird/management/cmd/token"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
// Resources contains the storages required by the admin commands.
type Resources struct {
Store store.Store
IDPStorage storage.Storage
IDPStorageFile string
EventStore activity.Store
}
// Opener initializes command resources from the command context and calls fn.
type Opener func(cmd *cobra.Command, fn func(ctx context.Context, resources Resources) error) error
// StoreOpener initializes only the management store from the command context and calls fn.
type StoreOpener func(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error
// IDPOpener initializes only the embedded IdP storage from the command context and calls fn.
type IDPOpener func(cmd *cobra.Command, fn func(ctx context.Context, idpStorage storage.Storage, storageFile string) error) error
// Openers contains the resource openers needed by the admin command tree.
type Openers struct {
Resources Opener
Store StoreOpener
IDP IDPOpener
}
type userSelector struct {
email string
userID string
}
func (s userSelector) normalized() userSelector {
return userSelector{
email: strings.TrimSpace(s.email),
userID: strings.TrimSpace(s.userID),
}
}
func (s userSelector) validate() error {
s = s.normalized()
if (s.email == "") == (s.userID == "") {
return fmt.Errorf("provide exactly one of --email or --user-id")
}
return nil
}
// NewCommands creates the admin command tree with the given resource openers.
func NewCommands(openers Openers) *cobra.Command {
adminCmd := &cobra.Command{
Use: "admin",
Short: "Self-hosted administrator helpers",
Long: "Administrative helpers for self-hosted deployments using the embedded identity provider.",
}
userCmd := &cobra.Command{
Use: "user",
Short: "Manage local embedded IdP users",
}
var passwordSelector userSelector
var password string
var passwordFile string
passwordCmd := &cobra.Command{
Use: "change-password (--email email | --user-id id) (--password password | --password-file path)",
Aliases: []string{"set-password"},
Short: "Change a local user's password",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
if err := passwordSelector.validate(); err != nil {
return err
}
newPassword, err := resolvePasswordInput(cmd, password, passwordFile)
if err != nil {
return err
}
return openers.IDP(cmd, func(ctx context.Context, idpStorage storage.Storage, storageFile string) error {
return runChangePassword(ctx, idpStorage, cmd.OutOrStdout(), passwordSelector, newPassword, storageFile)
})
},
}
addUserSelectorFlags(passwordCmd, &passwordSelector)
passwordCmd.Flags().StringVar(&password, "password", "", "New password for the user")
passwordCmd.Flags().StringVar(&passwordFile, "password-file", "", "Read new password from file ('-' for stdin)")
var resetSelector userSelector
resetMFACmd := &cobra.Command{
Use: "reset-mfa (--email email | --user-id id)",
Short: "Reset a local user's MFA enrollment",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
if err := resetSelector.validate(); err != nil {
return err
}
return openers.IDP(cmd, func(ctx context.Context, idpStorage storage.Storage, storageFile string) error {
return runResetMFA(ctx, idpStorage, cmd.OutOrStdout(), resetSelector, storageFile)
})
},
}
addUserSelectorFlags(resetMFACmd, &resetSelector)
userCmd.AddCommand(passwordCmd, resetMFACmd)
mfaCmd := &cobra.Command{
Use: "mfa",
Short: "Manage local MFA for embedded IdP users",
}
enableCmd := &cobra.Command{
Use: "enable",
Short: "Enable MFA for local embedded IdP users",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return openers.Resources(cmd, func(ctx context.Context, resources Resources) error {
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), true)
})
},
}
disableCmd := &cobra.Command{
Use: "disable",
Short: "Disable MFA for local embedded IdP users",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return openers.Resources(cmd, func(ctx context.Context, resources Resources) error {
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), false)
})
},
}
statusCmd := &cobra.Command{
Use: "status",
Short: "Show local MFA status",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return openers.Resources(cmd, func(ctx context.Context, resources Resources) error {
return runMFAStatus(ctx, resources, cmd.OutOrStdout())
})
},
}
mfaCmd.AddCommand(enableCmd, disableCmd, statusCmd)
adminCmd.AddCommand(userCmd, mfaCmd)
if openers.Store != nil {
adminCmd.AddCommand(tokencmd.NewCommands(tokencmd.StoreOpener(openers.Store)))
adminCmd.AddCommand(proxycmd.NewCommands(proxycmd.StoreOpener(openers.Store)))
}
return adminCmd
}
// OpenEmbeddedIDPStorage opens the Dex storage configured for the embedded IdP.
func OpenEmbeddedIDPStorage(cfg *idp.EmbeddedIdPConfig) (storage.Storage, error) {
if cfg == nil || !cfg.Enabled {
return nil, fmt.Errorf("admin commands require the embedded IdP to be enabled")
}
yamlConfig, err := cfg.ToYAMLConfig()
if err != nil {
return nil, fmt.Errorf("build embedded IdP config: %w", err)
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
st, err := yamlConfig.Storage.OpenStorage(logger)
if err != nil {
return nil, fmt.Errorf("open embedded IdP storage: %w", err)
}
return st, nil
}
// CloseStore closes the management store and logs cleanup errors at debug level.
func CloseStore(ctx context.Context, s store.Store) {
if s == nil {
return
}
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}
// OpenIDPStorage opens embedded IdP storage and returns its sqlite file path when applicable.
func OpenIDPStorage(config *nbconfig.Config) (storage.Storage, string, error) {
if config == nil {
return nil, "", fmt.Errorf("management config is required")
}
idpStorage, err := OpenEmbeddedIDPStorage(config.EmbeddedIdP)
if err != nil {
return nil, "", err
}
return idpStorage, embeddedIDPStorageFile(config), nil
}
func embeddedIDPStorageFile(config *nbconfig.Config) string {
if config.EmbeddedIdP == nil || config.EmbeddedIdP.Storage.Type != "sqlite3" {
return ""
}
return config.EmbeddedIdP.Storage.Config.File
}
// CloseIDPStorage closes embedded IdP storage and logs cleanup errors at debug level.
func CloseIDPStorage(s storage.Storage) {
if s == nil {
return
}
if err := s.Close(); err != nil {
log.Debugf("close embedded IdP storage: %v", err)
}
}
func addUserSelectorFlags(cmd *cobra.Command, selector *userSelector) {
cmd.Flags().StringVar(&selector.email, "email", "", "User email")
cmd.Flags().StringVar(&selector.userID, "user-id", "", "User ID")
}
func resolvePasswordInput(cmd *cobra.Command, password, passwordFile string) (string, error) {
if password != "" && passwordFile != "" {
return "", fmt.Errorf("provide only one of --password or --password-file")
}
if passwordFile == "" {
return password, nil
}
var data []byte
var err error
if passwordFile == "-" {
data, err = io.ReadAll(cmd.InOrStdin())
} else {
data, err = os.ReadFile(passwordFile)
}
if err != nil {
return "", fmt.Errorf("read password: %w", err)
}
return strings.TrimRight(string(data), "\r\n"), nil
}
func runChangePassword(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector, password string, idpStorageFile string) error {
if idpStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
selector = selector.normalized()
if err := selector.validate(); err != nil {
return err
}
if password == "" {
return fmt.Errorf("password is required")
}
if err := server.ValidatePassword(password); err != nil {
return fmt.Errorf("invalid password: %w", err)
}
user, err := findLocalUser(ctx, idpStorage, selector, idpStorageFile)
if err != nil {
return err
}
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
if err := idpStorage.UpdatePassword(ctx, user.Email, func(old storage.Password) (storage.Password, error) {
old.Hash = hash
return old, nil
}); err != nil {
return fmt.Errorf("update password for %s: %w", user.Email, err)
}
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
return err
}
_, _ = fmt.Fprintf(w, "Password updated for %s.\n", user.Email)
return nil
}
func runResetMFA(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector, idpStorageFile string) error {
if idpStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
selector = selector.normalized()
if err := selector.validate(); err != nil {
return err
}
user, err := findLocalUser(ctx, idpStorage, selector, idpStorageFile)
if err != nil {
return err
}
reset := false
err = idpStorage.UpdateUserIdentity(ctx, user.UserID, idp.LocalConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
reset = reset || len(old.MFASecrets) > 0 || len(old.WebAuthnCredentials) > 0
old.MFASecrets = map[string]*storage.MFASecret{}
old.WebAuthnCredentials = map[string][]storage.WebAuthnCredential{}
return old, nil
})
if errors.Is(err, storage.ErrNotFound) {
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
return err
}
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
return nil
}
if err != nil {
return fmt.Errorf("reset MFA for %s: %w", user.Email, err)
}
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
return err
}
if reset {
_, _ = fmt.Fprintf(w, "MFA reset for %s. The user will re-enroll at next login.\n", user.Email)
} else {
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
}
return nil
}
func runSetMFAEnabled(ctx context.Context, resources Resources, w io.Writer, enabled bool) error {
if resources.Store == nil {
return fmt.Errorf("management store is required")
}
if resources.IDPStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
accountID, settings, err := getSingleAccountSettings(ctx, resources.Store)
if err != nil {
return err
}
oldEnabled := settings.LocalMfaEnabled
newSettings := settings.Copy()
newSettings.LocalMfaEnabled = enabled
if err := setIDPClientsMFA(ctx, resources.IDPStorage, enabled); err != nil {
return err
}
if err := resources.Store.SaveAccountSettings(ctx, accountID, newSettings); err != nil {
if rollbackErr := setIDPClientsMFA(ctx, resources.IDPStorage, oldEnabled); rollbackErr != nil {
return fmt.Errorf("save local MFA account setting: %w (also failed to roll back embedded IdP MFA state: %v)", err, rollbackErr)
}
return fmt.Errorf("save local MFA account setting: %w", err)
}
if err := storeMFAActivity(ctx, resources.EventStore, accountID, enabled); err != nil {
_, _ = fmt.Fprintf(w, "Warning: failed to record audit event: %v\n", err)
}
state := "disabled"
if enabled {
state = "enabled"
}
_, _ = fmt.Fprintf(w, "Local MFA %s.\n", state)
return nil
}
func runMFAStatus(ctx context.Context, resources Resources, w io.Writer) error {
if resources.Store == nil {
return fmt.Errorf("management store is required")
}
if resources.IDPStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
_, settings, err := getSingleAccountSettings(ctx, resources.Store)
if err != nil {
return err
}
accountStatus := "disabled"
if settings.LocalMfaEnabled {
accountStatus = "enabled"
}
clientStatus, err := idpClientsMFAStatus(ctx, resources.IDPStorage)
if err != nil {
return err
}
_, _ = fmt.Fprintf(w, "Account setting: %s\n", accountStatus)
_, _ = fmt.Fprintf(w, "Embedded IdP clients: %s\n", clientStatus)
return nil
}
func getSingleAccountSettings(ctx context.Context, s store.Store) (string, *types.Settings, error) {
count, err := s.GetAccountsCounter(ctx)
if err != nil {
return "", nil, fmt.Errorf("count accounts: %w", err)
}
if count != 1 {
return "", nil, fmt.Errorf("expected exactly one account, got %d; local MFA is supported only in single-account embedded IdP deployments", count)
}
accountID, err := s.GetAnyAccountID(ctx)
if err != nil {
return "", nil, fmt.Errorf("get account ID: %w", err)
}
settings, err := s.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return "", nil, fmt.Errorf("get account settings: %w", err)
}
if settings == nil {
settings = &types.Settings{}
}
return accountID, settings, nil
}
func storeMFAActivity(ctx context.Context, eventStore activity.Store, accountID string, enabled bool) error {
if eventStore == nil {
return nil
}
event := activity.AccountLocalMfaDisabled
if enabled {
event = activity.AccountLocalMfaEnabled
}
_, err := eventStore.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: event,
InitiatorID: string(hook.SystemSource),
TargetID: accountID,
AccountID: accountID,
})
if err != nil {
return fmt.Errorf("save local MFA audit event: %w", err)
}
return nil
}
func findLocalUser(ctx context.Context, idpStorage storage.Storage, selector userSelector, idpStorageFile string) (storage.Password, error) {
selector = selector.normalized()
if err := selector.validate(); err != nil {
return storage.Password{}, err
}
if selector.email != "" {
user, err := idpStorage.GetPassword(ctx, selector.email)
if errors.Is(err, storage.ErrNotFound) {
if empty, listErr := localUsersEmpty(ctx, idpStorage); listErr != nil {
return storage.Password{}, listErr
} else if empty {
return storage.Password{}, noLocalUsersError(idpStorageFile)
}
return storage.Password{}, fmt.Errorf("local user with email %q not found", selector.email)
}
if err != nil {
return storage.Password{}, fmt.Errorf("get local user by email %q: %w", selector.email, err)
}
return user, nil
}
rawUserID := selector.userID
if decodedUserID, _, err := nbdex.DecodeDexUserID(selector.userID); err == nil && decodedUserID != "" {
rawUserID = decodedUserID
}
users, err := idpStorage.ListPasswords(ctx)
if err != nil {
return storage.Password{}, fmt.Errorf("list local users: %w", err)
}
for _, user := range users {
if user.UserID == rawUserID || user.UserID == selector.userID {
return user, nil
}
}
if len(users) == 0 {
return storage.Password{}, noLocalUsersError(idpStorageFile)
}
return storage.Password{}, fmt.Errorf("local user with ID %q not found", selector.userID)
}
func localUsersEmpty(ctx context.Context, idpStorage storage.Storage) (bool, error) {
users, err := idpStorage.ListPasswords(ctx)
if err != nil {
return false, fmt.Errorf("list local users: %w", err)
}
return len(users) == 0, nil
}
func noLocalUsersError(idpStorageFile string) error {
location := ""
if idpStorageFile != "" {
location = fmt.Sprintf(" (%s)", idpStorageFile)
}
return fmt.Errorf("no local users exist in the embedded IdP storage%s; the management server may never have started with this config, or --datadir points at the wrong location", location)
}
func deleteLocalAuthSession(ctx context.Context, idpStorage storage.Storage, userID string) error {
err := idpStorage.DeleteAuthSession(ctx, userID, idp.LocalConnectorID)
if err == nil || errors.Is(err, storage.ErrNotFound) {
return nil
}
return fmt.Errorf("delete local auth session for user %s: %w", userID, err)
}
func setIDPClientsMFA(ctx context.Context, idpStorage storage.Storage, enabled bool) error {
var mfaChain []string
if enabled {
mfaChain = []string{idp.DefaultTOTPAuthenticatorID}
}
clientIDs := []string{idp.StaticClientCLI, idp.StaticClientDashboard}
if err := nbdex.SetClientsMFAChain(ctx, idpStorage, clientIDs, mfaChain); err != nil {
if errors.Is(err, storage.ErrNotFound) {
return fmt.Errorf("embedded IdP client not found; start the management server once before toggling MFA: %w", err)
}
return fmt.Errorf("update MFA chain on embedded IdP clients: %w", err)
}
return nil
}
func idpClientsMFAStatus(ctx context.Context, idpStorage storage.Storage) (string, error) {
clientIDs := []string{idp.StaticClientCLI, idp.StaticClientDashboard}
enabledCount := 0
for _, clientID := range clientIDs {
client, err := idpStorage.GetClient(ctx, clientID)
if errors.Is(err, storage.ErrNotFound) {
return "unknown", fmt.Errorf("embedded IdP client %q not found", clientID)
}
if err != nil {
return "unknown", fmt.Errorf("get embedded IdP client %q: %w", clientID, err)
}
if hasAuthenticator(client.MFAChain, idp.DefaultTOTPAuthenticatorID) {
enabledCount++
}
}
switch enabledCount {
case 0:
return "disabled", nil
case len(clientIDs):
return "enabled", nil
default:
return "partially enabled", nil
}
}
func hasAuthenticator(chain []string, authenticatorID string) bool {
for _, id := range chain {
if id == authenticatorID {
return true
}
}
return false
}

View File

@@ -0,0 +1,250 @@
package admincmd
import (
"bytes"
"context"
"io"
"log/slog"
"strings"
"testing"
"time"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/memory"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
nbdex "github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/idp"
mgmtstore "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
func newTestIDPStorage(t *testing.T) storage.Storage {
t.Helper()
st := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
hash, err := bcrypt.GenerateFromPassword([]byte("OldPass1!"), bcrypt.DefaultCost)
require.NoError(t, err)
require.NoError(t, st.CreatePassword(context.Background(), storage.Password{
Email: "user@example.com",
Username: "User",
UserID: "user-1",
Hash: hash,
}))
require.NoError(t, st.CreateUserIdentity(context.Background(), storage.UserIdentity{
UserID: "user-1",
ConnectorID: idp.LocalConnectorID,
MFASecrets: map[string]*storage.MFASecret{
idp.DefaultTOTPAuthenticatorID: {
AuthenticatorID: idp.DefaultTOTPAuthenticatorID,
Type: "TOTP",
Secret: "otpauth://totp/NetBird:user@example.com?secret=ABC",
Confirmed: true,
CreatedAt: time.Now(),
},
},
WebAuthnCredentials: map[string][]storage.WebAuthnCredential{
"webauthn": {{CredentialID: []byte("credential")}},
},
}))
require.NoError(t, st.CreateAuthSession(context.Background(), storage.AuthSession{
UserID: "user-1",
ConnectorID: idp.LocalConnectorID,
Nonce: "nonce",
}))
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: idp.StaticClientCLI, Name: "CLI"}))
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: idp.StaticClientDashboard, Name: "Dashboard"}))
return st
}
func TestRunChangePassword(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
var out bytes.Buffer
err := runChangePassword(ctx, st, &out, userSelector{email: "user@example.com"}, "NewPass1!", "")
require.NoError(t, err)
require.Contains(t, out.String(), "Password updated")
user, err := st.GetPassword(ctx, "user@example.com")
require.NoError(t, err)
require.NoError(t, bcrypt.CompareHashAndPassword(user.Hash, []byte("NewPass1!")))
_, err = st.GetAuthSession(ctx, "user-1", idp.LocalConnectorID)
require.ErrorIs(t, err, storage.ErrNotFound)
}
func TestRunChangePasswordValidatesPassword(t *testing.T) {
st := newTestIDPStorage(t)
err := runChangePassword(context.Background(), st, io.Discard, userSelector{email: "user@example.com"}, "short", "")
require.Error(t, err)
require.Contains(t, err.Error(), "invalid password")
}
func TestRunResetMFA(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
var out bytes.Buffer
encodedUserID := nbdex.EncodeDexUserID("user-1", idp.LocalConnectorID)
err := runResetMFA(ctx, st, &out, userSelector{userID: encodedUserID}, "")
require.NoError(t, err)
require.Contains(t, out.String(), "MFA reset")
identity, err := st.GetUserIdentity(ctx, "user-1", idp.LocalConnectorID)
require.NoError(t, err)
require.Empty(t, identity.MFASecrets)
require.Empty(t, identity.WebAuthnCredentials)
_, err = st.GetAuthSession(ctx, "user-1", idp.LocalConnectorID)
require.ErrorIs(t, err, storage.ErrNotFound)
}
func TestRunResetMFAWithoutEnrollment(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
require.NoError(t, st.UpdateUserIdentity(ctx, "user-1", idp.LocalConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
old.MFASecrets = nil
old.WebAuthnCredentials = nil
return old, nil
}))
var out bytes.Buffer
err := runResetMFA(ctx, st, &out, userSelector{email: "user@example.com"}, "")
require.NoError(t, err)
require.Contains(t, out.String(), "No MFA enrollment found")
}
func TestSetIDPClientsMFA(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
require.NoError(t, setIDPClientsMFA(ctx, st, true))
status, err := idpClientsMFAStatus(ctx, st)
require.NoError(t, err)
require.Equal(t, "enabled", status)
require.NoError(t, setIDPClientsMFA(ctx, st, false))
status, err = idpClientsMFAStatus(ctx, st)
require.NoError(t, err)
require.Equal(t, "disabled", status)
}
func newTestManagementStore(t *testing.T, localMFAEnabled bool) mgmtstore.Store {
t.Helper()
ctx := context.Background()
st, err := mgmtstore.NewStore(ctx, types.SqliteStoreEngine, t.TempDir(), nil, false)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, st.Close(ctx)) })
require.NoError(t, st.SaveAccount(ctx, &types.Account{
Id: "account-1",
Settings: &types.Settings{LocalMfaEnabled: localMFAEnabled},
}))
return st
}
func TestRunSetMFAEnabledDoesNotSaveWhenIDPUpdateFails(t *testing.T) {
ctx := context.Background()
managementStore := newTestManagementStore(t, false)
idpStorage := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
err := runSetMFAEnabled(ctx, Resources{Store: managementStore, IDPStorage: idpStorage}, io.Discard, true)
require.Error(t, err)
require.Contains(t, err.Error(), "embedded IdP client")
settings, err := managementStore.GetAccountSettings(ctx, mgmtstore.LockingStrengthNone, "account-1")
require.NoError(t, err)
require.False(t, settings.LocalMfaEnabled)
}
func TestRunSetMFAEnabledUpdatesSettingsAfterIDP(t *testing.T) {
ctx := context.Background()
managementStore := newTestManagementStore(t, false)
idpStorage := newTestIDPStorage(t)
err := runSetMFAEnabled(ctx, Resources{Store: managementStore, IDPStorage: idpStorage}, io.Discard, true)
require.NoError(t, err)
settings, err := managementStore.GetAccountSettings(ctx, mgmtstore.LockingStrengthNone, "account-1")
require.NoError(t, err)
require.True(t, settings.LocalMfaEnabled)
clientStatus, err := idpClientsMFAStatus(ctx, idpStorage)
require.NoError(t, err)
require.Equal(t, "enabled", clientStatus)
}
func TestRunSetMFAEnabledSucceedsWithNilEventStore(t *testing.T) {
ctx := context.Background()
managementStore := newTestManagementStore(t, false)
idpStorage := newTestIDPStorage(t)
var out bytes.Buffer
var err error
require.NotPanics(t, func() {
err = runSetMFAEnabled(ctx, Resources{Store: managementStore, IDPStorage: idpStorage, EventStore: nil}, &out, true)
})
require.NoError(t, err)
require.Contains(t, out.String(), "Local MFA enabled")
settings, err := managementStore.GetAccountSettings(ctx, mgmtstore.LockingStrengthNone, "account-1")
require.NoError(t, err)
require.True(t, settings.LocalMfaEnabled)
}
func TestUserSelectorValidate(t *testing.T) {
require.NoError(t, userSelector{email: " user@example.com "}.validate())
require.NoError(t, userSelector{userID: "user-1"}.validate())
require.Error(t, userSelector{}.validate())
require.Error(t, userSelector{email: "user@example.com", userID: "user-1"}.validate())
}
func TestFindLocalUserNotFound(t *testing.T) {
st := newTestIDPStorage(t)
_, err := findLocalUser(context.Background(), st, userSelector{email: "missing@example.com"}, "")
require.Error(t, err)
require.True(t, strings.Contains(err.Error(), "not found"))
}
func TestFindLocalUserZeroUsersIncludesStoragePath(t *testing.T) {
st := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
_, err := findLocalUser(context.Background(), st, userSelector{email: "missing@example.com"}, "/var/lib/netbird/idp.db")
require.Error(t, err)
require.Contains(t, err.Error(), "no local users exist")
require.Contains(t, err.Error(), "/var/lib/netbird/idp.db")
}
func TestUserCommandValidatesSelectorBeforeOpeningStorage(t *testing.T) {
opened := false
cmd := NewCommands(Openers{
IDP: func(cmd *cobra.Command, fn func(ctx context.Context, idpStorage storage.Storage, storageFile string) error) error {
opened = true
return nil
},
})
cmd.SetArgs([]string{"user", "change-password", "--password", "NewPass1!"})
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
err := cmd.Execute()
require.Error(t, err)
require.Contains(t, err.Error(), "provide exactly one")
require.False(t, opened)
}
func TestResolvePasswordInputFromStdin(t *testing.T) {
cmd := &cobra.Command{}
cmd.SetIn(strings.NewReader("NewPass1!\n"))
password, err := resolvePasswordInput(cmd, "", "-")
require.NoError(t, err)
require.Equal(t, "NewPass1!", password)
}
func TestResolvePasswordInputRejectsMultipleSources(t *testing.T) {
_, err := resolvePasswordInput(&cobra.Command{}, "NewPass1!", "-")
require.Error(t, err)
}

View File

@@ -0,0 +1,80 @@
package cmd
import (
"context"
"path"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/idp"
)
func TestApplyAdminDatadirOverrideRelocatesDefaultIDPStorage(t *testing.T) {
oldDatadir := filepath.Join(t.TempDir(), "old")
newDatadir := filepath.Join(t.TempDir(), "new")
for _, defaultFile := range []string{
"",
filepath.Join(oldDatadir, "idp.db"),
path.Join(oldDatadir, "idp.db"),
} {
t.Run(defaultFile, func(t *testing.T) {
cfg := &nbconfig.Config{
EmbeddedIdP: &idp.EmbeddedIdPConfig{
Enabled: true,
Storage: idp.EmbeddedStorageConfig{
Type: "sqlite3",
Config: idp.EmbeddedStorageTypeConfig{
File: defaultFile,
},
},
},
}
datadir := oldDatadir
oldAdminDatadir := adminDatadir
adminDatadir = newDatadir
t.Cleanup(func() { adminDatadir = oldAdminDatadir })
applyAdminDatadirOverride(cfg, &datadir)
require.Equal(t, newDatadir, datadir)
require.Equal(t, filepath.Join(newDatadir, "idp.db"), cfg.EmbeddedIdP.Storage.Config.File)
})
}
}
func TestOpenAdminEventStoreMissingEncryptionKeyReturnsNilInterface(t *testing.T) {
eventStore, err := openAdminEventStore(context.Background(), &nbconfig.Config{}, t.TempDir())
require.Error(t, err)
require.Contains(t, err.Error(), "encryption key")
require.Nil(t, eventStore)
}
func TestApplyAdminDatadirOverrideKeepsExplicitIDPStorage(t *testing.T) {
oldDatadir := filepath.Join(t.TempDir(), "old")
newDatadir := filepath.Join(t.TempDir(), "new")
explicitFile := filepath.Join(t.TempDir(), "custom-idp.db")
cfg := &nbconfig.Config{
EmbeddedIdP: &idp.EmbeddedIdPConfig{
Enabled: true,
Storage: idp.EmbeddedStorageConfig{
Type: "sqlite3",
Config: idp.EmbeddedStorageTypeConfig{
File: explicitFile,
},
},
},
}
datadir := oldDatadir
oldAdminDatadir := adminDatadir
adminDatadir = newDatadir
t.Cleanup(func() { adminDatadir = oldAdminDatadir })
applyAdminDatadirOverride(cfg, &datadir)
require.Equal(t, newDatadir, datadir)
require.Equal(t, explicitFile, cfg.EmbeddedIdP.Storage.Config.File)
}

View File

@@ -13,6 +13,7 @@ import (
"os"
"os/signal"
"path"
"path/filepath"
"strings"
"syscall"
@@ -209,7 +210,7 @@ func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
cfg.EmbeddedIdP.Storage.Type = "sqlite3"
}
if cfg.EmbeddedIdP.Storage.Config.File == "" && cfg.Datadir != "" {
cfg.EmbeddedIdP.Storage.Config.File = path.Join(cfg.Datadir, "idp.db")
cfg.EmbeddedIdP.Storage.Config.File = filepath.Join(cfg.Datadir, "idp.db")
}
issuer := cfg.EmbeddedIdP.Issuer

View File

@@ -0,0 +1,141 @@
// Package proxycmd provides reusable cobra commands for managing reverse proxy instances.
// Both the management and combined binaries use these commands, each providing
// their own StoreOpener to handle config loading and store initialization.
package proxycmd
import (
"bufio"
"context"
"fmt"
"io"
"strings"
"text/tabwriter"
"github.com/spf13/cobra"
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/server/store"
)
// StoreOpener initializes a store from the command context and calls fn.
type StoreOpener func(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error
const disconnectAllConfirmation = "disconnect all proxies"
// NewCommands creates the proxy command tree with the given store opener.
// Returns the parent "proxy" command with the disconnect-all subcommand.
func NewCommands(opener StoreOpener) *cobra.Command {
var dryRun bool
var force bool
proxyCmd := &cobra.Command{
Use: "proxy",
Short: "Manage reverse proxy instances",
Long: "Commands for inspecting and repairing the reverse proxy instances registered with the management server.",
}
disconnectAllCmd := &cobra.Command{
Use: "disconnect-all",
Short: "Force-mark all reverse proxy instances as disconnected",
Long: "Lists all reverse proxy instances and force-marks them as disconnected, regardless of their session state. " +
"Use this to repair stale connection state, e.g. after an unclean management server shutdown. " +
"By default, it asks for manual confirmation before changing state. Use --dry-run to preview without changing state, or --force to skip confirmation. " +
"Run during a maintenance window; affected live proxies may stay hidden until their next heartbeat or reconnect/re-register.",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return opener(cmd, func(ctx context.Context, s store.Store) error {
return runDisconnectAll(ctx, s, cmd.OutOrStdout(), cmd.InOrStdin(), dryRun, force)
})
},
}
disconnectAllCmd.Flags().BoolVar(&dryRun, "dry-run", false, "List reverse proxy instances that would be disconnected without changing state")
disconnectAllCmd.Flags().BoolVar(&force, "force", false, "Skip the confirmation prompt and apply the repair")
proxyCmd.AddCommand(disconnectAllCmd)
return proxyCmd
}
func runDisconnectAll(ctx context.Context, s store.Store, out io.Writer, in io.Reader, dryRun, force bool) error {
proxies, err := s.GetAllProxies(ctx)
if err != nil {
return fmt.Errorf("list proxies: %w", err)
}
if len(proxies) == 0 {
_, _ = fmt.Fprintln(out, "No reverse proxy instances found.")
return nil
}
toDisconnect := 0
w := tabwriter.NewWriter(out, 0, 0, 2, ' ', 0)
_, _ = fmt.Fprintln(w, "ID\tCLUSTER\tIP\tACCOUNT\tSTATUS\tLAST SEEN")
_, _ = fmt.Fprintln(w, "--\t-------\t--\t-------\t------\t---------")
for _, p := range proxies {
if p.Status != rpproxy.StatusDisconnected {
toDisconnect++
}
account := "-"
if p.AccountID != nil {
account = *p.AccountID
}
_, _ = fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
p.ID,
p.ClusterAddress,
p.IPAddress,
account,
p.Status,
p.LastSeen.Format("2006-01-02 15:04:05"),
)
}
if err := w.Flush(); err != nil {
return fmt.Errorf("write proxy list: %w", err)
}
if dryRun {
_, _ = fmt.Fprintf(out, "\nDry run: would force-mark %d of %d reverse proxy instance(s) as disconnected.\n", toDisconnect, len(proxies))
return nil
}
if !force {
confirmed, err := confirmDisconnectAll(out, in)
if err != nil {
return err
}
if !confirmed {
_, _ = fmt.Fprintln(out, "Aborted. No reverse proxy instances were changed.")
return nil
}
}
disconnected, err := s.DisconnectAllProxies(ctx)
if err != nil {
return fmt.Errorf("disconnect proxies: %w", err)
}
_, _ = fmt.Fprintf(out, "\nForce-marked %d of %d reverse proxy instance(s) as disconnected.\n", disconnected, len(proxies))
return nil
}
func confirmDisconnectAll(out io.Writer, in io.Reader) (bool, error) {
if in == nil {
in = strings.NewReader("")
}
_, _ = fmt.Fprintln(out, "\nWARNING: This command changes stored reverse proxy state for every non-disconnected instance.")
_, _ = fmt.Fprintln(out, "Run it during a maintenance window; affected live proxies may stay hidden until "+
"their next heartbeat or reconnect/re-register.")
_, _ = fmt.Fprintf(out, "Type %q to continue: ", disconnectAllConfirmation)
scanner := bufio.NewScanner(in)
if !scanner.Scan() {
if err := scanner.Err(); err != nil {
return false, fmt.Errorf("read confirmation: %w", err)
}
return false, nil
}
return strings.EqualFold(strings.TrimSpace(scanner.Text()), disconnectAllConfirmation), nil
}

View File

@@ -0,0 +1,180 @@
package proxycmd
import (
"bytes"
"context"
"strings"
"testing"
"time"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/server/store"
)
func newTestStore(t *testing.T) store.Store {
t.Helper()
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
require.NoError(t, err)
t.Cleanup(cleanup)
return s
}
func seedProxies(t *testing.T, ctx context.Context, s store.Store) {
t.Helper()
accountID := "account-1"
alreadyDisconnectedAt := time.Now().Add(-time.Hour)
seed := []*rpproxy.Proxy{
{
ID: "proxy-1",
SessionID: "session-1",
ClusterAddress: "cluster-a.example.com",
IPAddress: "10.0.0.1",
LastSeen: time.Now(),
Status: rpproxy.StatusConnected,
},
{
ID: "proxy-2",
SessionID: "session-2",
ClusterAddress: "cluster-b.example.com",
IPAddress: "10.0.0.2",
AccountID: &accountID,
LastSeen: time.Now(),
Status: rpproxy.StatusConnected,
},
{
ID: "proxy-3",
SessionID: "session-3",
ClusterAddress: "cluster-a.example.com",
IPAddress: "10.0.0.3",
LastSeen: time.Now().Add(-time.Hour),
Status: rpproxy.StatusDisconnected,
DisconnectedAt: &alreadyDisconnectedAt,
},
}
for _, p := range seed {
require.NoError(t, s.SaveProxy(ctx, p))
}
}
func proxiesByID(t *testing.T, ctx context.Context, s store.Store) map[string]*rpproxy.Proxy {
t.Helper()
proxies, err := s.GetAllProxies(ctx)
require.NoError(t, err)
require.Len(t, proxies, 3)
byID := make(map[string]*rpproxy.Proxy, len(proxies))
for _, p := range proxies {
byID[p.ID] = p
}
return byID
}
func TestRunDisconnectAllWithConfirmation(t *testing.T) {
ctx := context.Background()
s := newTestStore(t)
seedProxies(t, ctx, s)
var out bytes.Buffer
require.NoError(t, runDisconnectAll(ctx, s, &out, strings.NewReader(disconnectAllConfirmation+"\n"), false, false))
output := out.String()
require.Contains(t, output, "proxy-1")
require.Contains(t, output, "proxy-2")
require.Contains(t, output, "proxy-3")
require.Contains(t, output, "cluster-a.example.com")
require.Contains(t, output, "account-1")
require.Contains(t, output, "Type \"disconnect all proxies\" to continue")
require.Contains(t, output, "Force-marked 2 of 3 reverse proxy instance(s) as disconnected.")
for _, p := range proxiesByID(t, ctx, s) {
require.Equal(t, rpproxy.StatusDisconnected, p.Status, "proxy %s should be disconnected", p.ID)
require.NotNil(t, p.DisconnectedAt, "proxy %s should have a disconnected timestamp", p.ID)
}
}
func TestRunDisconnectAllForceSkipsConfirmation(t *testing.T) {
ctx := context.Background()
s := newTestStore(t)
seedProxies(t, ctx, s)
var out bytes.Buffer
require.NoError(t, runDisconnectAll(ctx, s, &out, strings.NewReader(""), false, true))
output := out.String()
require.NotContains(t, output, "Type \"disconnect all proxies\" to continue")
require.Contains(t, output, "Force-marked 2 of 3 reverse proxy instance(s) as disconnected.")
}
func TestRunDisconnectAllAbortLeavesProxiesUnchanged(t *testing.T) {
ctx := context.Background()
s := newTestStore(t)
seedProxies(t, ctx, s)
var out bytes.Buffer
require.NoError(t, runDisconnectAll(ctx, s, &out, strings.NewReader("no\n"), false, false))
output := out.String()
require.Contains(t, output, "Type \"disconnect all proxies\" to continue")
require.Contains(t, output, "Aborted. No reverse proxy instances were changed.")
byID := proxiesByID(t, ctx, s)
require.Equal(t, rpproxy.StatusConnected, byID["proxy-1"].Status)
require.Equal(t, rpproxy.StatusConnected, byID["proxy-2"].Status)
require.Equal(t, rpproxy.StatusDisconnected, byID["proxy-3"].Status)
}
func TestRunDisconnectAllDryRunLeavesProxiesUnchanged(t *testing.T) {
ctx := context.Background()
s := newTestStore(t)
seedProxies(t, ctx, s)
var out bytes.Buffer
require.NoError(t, runDisconnectAll(ctx, s, &out, strings.NewReader(""), true, false))
output := out.String()
require.Contains(t, output, "Dry run: would force-mark 2 of 3 reverse proxy instance(s) as disconnected.")
require.NotContains(t, output, "Type \"disconnect all proxies\" to continue")
byID := proxiesByID(t, ctx, s)
require.Equal(t, rpproxy.StatusConnected, byID["proxy-1"].Status)
require.Equal(t, rpproxy.StatusConnected, byID["proxy-2"].Status)
require.Equal(t, rpproxy.StatusDisconnected, byID["proxy-3"].Status)
}
func TestNewCommandsDisconnectAllDryRun(t *testing.T) {
ctx := context.Background()
s := newTestStore(t)
seedProxies(t, ctx, s)
opened := false
cmd := NewCommands(func(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
opened = true
return fn(cmd.Context(), s)
})
var out bytes.Buffer
cmd.SetOut(&out)
cmd.SetErr(&out)
cmd.SetIn(strings.NewReader(""))
cmd.SetArgs([]string{"disconnect-all", "--dry-run"})
require.NoError(t, cmd.ExecuteContext(ctx))
require.True(t, opened)
require.Contains(t, out.String(), "Dry run: would force-mark 2 of 3 reverse proxy instance(s) as disconnected.")
}
func TestRunDisconnectAllEmpty(t *testing.T) {
ctx := context.Background()
s := newTestStore(t)
var out bytes.Buffer
require.NoError(t, runDisconnectAll(ctx, s, &out, strings.NewReader(""), false, false))
require.Contains(t, out.String(), "No reverse proxy instances found.")
}

View File

@@ -83,7 +83,8 @@ func init() {
rootCmd.AddCommand(migrationCmd)
tc := newTokenCommands()
tc.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
rootCmd.AddCommand(tc)
ac := newAdminCommands()
ac.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
rootCmd.AddCommand(ac)
rootCmd.AddCommand(newLegacyTokenCommand())
}

View File

@@ -1,55 +0,0 @@
package cmd
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
)
var tokenDatadir string
// newTokenCommands creates the token command tree with management-specific store opener.
func newTokenCommands() *cobra.Command {
cmd := tokencmd.NewCommands(withTokenStore)
cmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
return cmd
}
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
datadir := config.Datadir
if tokenDatadir != "" {
datadir = tokenDatadir
}
s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, s)
}

View File

@@ -627,6 +627,21 @@ var providers = []Provider{
},
Models: []Model{},
},
{
// vLLM is an OpenAI-compatible self-hosted server. It behaves like
// the generic custom entry; it gets its own catalog id purely so it
// surfaces as a named "vLLM" choice in the provider picker.
ID: "vllm",
Kind: KindCustom,
Name: "vLLM",
Description: "Self-hosted vLLM (OpenAI-compatible)",
DefaultHost: "",
AuthHeaderName: "Authorization",
AuthHeaderTemplate: "Bearer ${API_KEY}",
DefaultContentType: "application/json",
BrandColor: "#30A2FF",
Models: []Model{},
},
{
ID: "custom",
Kind: KindCustom,

View File

@@ -47,16 +47,13 @@ func init() {
precomputedDeprecatedRemotePeersConstraint = constraint
}
// toNetbirdConfig converts the server configuration to the wire representation. It returns
// nil when no server config is set (the fan-out network-map path) because clients treat any
// non-nil config as authoritative: a config without a relay section is interpreted as relay
// disabled and wipes the clients' relay URLs.
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings, settings *types.Settings) *proto.NetbirdConfig {
if config == nil {
if settings == nil {
return nil
}
return &proto.NetbirdConfig{
Metrics: &proto.MetricsConfig{
Enabled: settings.MetricsPushEnabled,
},
}
return nil
}
var stuns []*proto.HostConfig

View File

@@ -8,11 +8,13 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/types"
)
func TestToProtocolDNSConfigWithCache(t *testing.T) {
@@ -263,3 +265,39 @@ func TestEncodeSessionExpiresAt(t *testing.T) {
assert.True(t, got.AsTime().Equal(deadline))
})
}
// TestToNetbirdConfig_RelayInvariant guards against the v0.74.0 relay-wipe regression.
// Clients treat any non-nil NetbirdConfig as authoritative and interpret a missing relay
// section as relay disabled, wiping their relay URLs. toNetbirdConfig must therefore
// return nil when no server config is set (the fan-out network-map path) instead of a
// partial config, and a result built from a relay-enabled config must carry the relay
// section.
func TestToNetbirdConfig_RelayInvariant(t *testing.T) {
settings := &types.Settings{MetricsPushEnabled: true}
t.Run("nil server config returns nil config", func(t *testing.T) {
nbCfg := toNetbirdConfig(nil, nil, nil, nil, settings)
assert.Nil(t, nbCfg, "fan-out updates must not carry a partial NetbirdConfig even when settings are present")
})
t.Run("relay-enabled config carries relay section", func(t *testing.T) {
cfg := &nbconfig.Config{
Stuns: []*nbconfig.Host{{Proto: nbconfig.UDP, URI: "stun:stun.example.com:3478"}},
TURNConfig: &nbconfig.TURNConfig{
Turns: []*nbconfig.Host{{Proto: nbconfig.UDP, URI: "turn:turn.example.com:3478", Username: "user", Password: "pass"}},
},
Relay: &nbconfig.Relay{Addresses: []string{"rels://relay.example.com:443"}},
Signal: &nbconfig.Host{Proto: nbconfig.HTTP, URI: "signal.example.com:10000"},
}
relayToken := &Token{Payload: "token-payload", Signature: "token-signature"}
nbCfg := toNetbirdConfig(cfg, nil, relayToken, nil, settings)
require.NotNil(t, nbCfg)
require.NotNil(t, nbCfg.Relay, "non-nil NetbirdConfig must include the relay section")
assert.Equal(t, cfg.Relay.Addresses, nbCfg.Relay.Urls, "relay URLs should match the server config")
assert.Equal(t, relayToken.Payload, nbCfg.Relay.TokenPayload, "relay token payload should be set")
assert.Equal(t, relayToken.Signature, nbCfg.Relay.TokenSignature, "relay token signature should be set")
require.NotNil(t, nbCfg.Metrics)
assert.True(t, nbCfg.Metrics.Enabled, "metrics flag should carry the settings value")
})
}

View File

@@ -608,11 +608,11 @@ func (s *ProxyServiceServer) disconnectProxy(conn *proxyConnection) {
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
log.Warnf("Failed to unregister proxy %s from cluster: %v", conn.proxyID, err)
}
conn.cancel()
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
log.Warnf("Failed to mark proxy %s as disconnected: %v", conn.proxyID, err)
}
conn.cancel()
log.Infof("Proxy %s session %s disconnected", conn.proxyID, conn.sessionID)
}

View File

@@ -21,8 +21,11 @@ import (
)
const (
staticClientDashboard = "netbird-dashboard"
staticClientCLI = "netbird-cli"
StaticClientDashboard = "netbird-dashboard"
StaticClientCLI = "netbird-cli"
DefaultTOTPAuthenticatorID = "default-totp"
LocalConnectorID = dex.LocalConnectorID
defaultCLIRedirectURL1 = "http://localhost:53000/"
defaultCLIRedirectURL2 = "http://localhost:54000/"
defaultScopes = "openid profile email groups"
@@ -185,14 +188,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
EnablePasswordDB: true,
StaticClients: []storage.Client{
{
ID: staticClientDashboard,
ID: StaticClientDashboard,
Name: "NetBird Dashboard",
Public: true,
RedirectURIs: redirectURIs,
PostLogoutRedirectURIs: sanitizePostLogoutRedirectURIs(dashboardPostLogoutRedirectURIs),
},
{
ID: staticClientCLI,
ID: StaticClientCLI,
Name: "NetBird CLI",
Public: true,
RedirectURIs: redirectURIs,
@@ -254,13 +257,13 @@ func sanitizePostLogoutRedirectURIs(uris []string) []string {
func configureMFA(cfg *dex.YAMLConfig, sessionMaxLifetime, sessionIdleTimeout string, rememberMe bool, sessionCookieEncryptionKey string) error {
cfg.MFA.Authenticators = []dex.MFAAuthenticator{{
ID: "default-totp",
ID: DefaultTOTPAuthenticatorID,
// Has to be caps otherwise it will fail
Type: "TOTP",
Config: map[string]interface{}{
"issuer": "NetBird",
},
ConnectorTypes: []string{"local"},
ConnectorTypes: []string{LocalConnectorID},
}}
if sessionMaxLifetime == "" {
@@ -736,7 +739,7 @@ func (m *EmbeddedIdPManager) GetDefaultScopes() string {
// GetCLIClientID returns the client ID for CLI authentication.
func (m *EmbeddedIdPManager) GetCLIClientID() string {
return staticClientCLI
return StaticClientCLI
}
// GetCLIRedirectURLs returns the redirect URLs configured for the CLI client.
@@ -775,7 +778,7 @@ func (m *EmbeddedIdPManager) GetLocalKeysLocation() string {
// GetClientIDs returns the OAuth2 client IDs configured for this provider.
func (m *EmbeddedIdPManager) GetClientIDs() []string {
return []string{staticClientDashboard, staticClientCLI}
return []string{StaticClientDashboard, StaticClientCLI}
}
// GetUserIDClaim returns the JWT claim name used for user identification.
@@ -792,11 +795,11 @@ func (m *EmbeddedIdPManager) IsLocalAuthDisabled() bool {
func (m *EmbeddedIdPManager) SetMFAEnabled(ctx context.Context, enabled bool) error {
var mfaChain []string
if enabled {
mfaChain = []string{"default-totp"}
mfaChain = []string{DefaultTOTPAuthenticatorID}
}
if err := m.provider.SetClientsMFAChain(ctx, []string{
staticClientCLI,
staticClientDashboard,
StaticClientCLI,
StaticClientDashboard,
}, mfaChain); err != nil {
return fmt.Errorf("failed to set MFA enabled=%v: %w", enabled, err)
}

View File

@@ -331,7 +331,7 @@ func TestEmbeddedIdPConfig_ToYAMLConfig_IncludesDeviceCallbackRedirectURI(t *tes
var cliRedirectURIs []string
for _, client := range yamlConfig.StaticClients {
if client.ID == staticClientCLI {
if client.ID == StaticClientCLI {
cliRedirectURIs = client.RedirectURIs
break
}

View File

@@ -1048,11 +1048,7 @@ func testUpdateAccountPeers(t *testing.T) {
for _, channel := range peerChannels {
update := <-channel
assert.NotNil(t, update.Update.NetbirdConfig)
assert.Nil(t, update.Update.NetbirdConfig.Stuns)
assert.Nil(t, update.Update.NetbirdConfig.Turns)
assert.Nil(t, update.Update.NetbirdConfig.Signal)
assert.Nil(t, update.Update.NetbirdConfig.Relay)
assert.Nil(t, update.Update.NetbirdConfig, "fan-out updates must not carry a NetbirdConfig; clients treat a config without relay as relay disabled and wipe their relay URLs")
assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers))
assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules))
}

View File

@@ -6088,6 +6088,37 @@ func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID, sessionID strin
return nil
}
// GetAllProxies returns all reverse proxy instance rows.
func (s *SqlStore) GetAllProxies(ctx context.Context) ([]*proxy.Proxy, error) {
var proxies []*proxy.Proxy
result := s.db.Order("cluster_address, id").Find(&proxies)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get proxies: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get proxies")
}
return proxies, nil
}
// DisconnectAllProxies force-marks every proxy that is not already disconnected
// as disconnected, regardless of session ID. Unlike DisconnectProxy it is not
// session-guarded: it is an administrative repair helper, not part of the
// connection lifecycle. last_seen is left untouched so the stale-proxy reaper
// keeps working off the real last heartbeat. Returns the number of proxies updated.
func (s *SqlStore) DisconnectAllProxies(ctx context.Context) (int64, error) {
result := s.db.
Model(&proxy.Proxy{}).
Where("status != ?", proxy.StatusDisconnected).
Updates(map[string]any{
"status": proxy.StatusDisconnected,
"disconnected_at": time.Now(),
})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to disconnect all proxies: %v", result.Error)
return 0, status.Errorf(status.Internal, "failed to disconnect all proxies")
}
return result.RowsAffected, nil
}
// UpdateProxyHeartbeat updates the last_seen timestamp for the proxy's current session.
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
now := time.Now()
@@ -6095,7 +6126,11 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) err
result := s.db.
Model(&proxy.Proxy{}).
Where("id = ? AND session_id = ?", p.ID, p.SessionID).
Update("last_seen", now)
Updates(map[string]any{
"last_seen": now,
"status": proxy.StatusConnected,
"disconnected_at": nil,
})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update proxy heartbeat: %v", result.Error)

View File

@@ -0,0 +1,156 @@
package store
import (
"context"
"os"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
// TestSqlStore_DisconnectAllProxies guards the administrative
// force-disconnect helper:
//
// 1. Every proxy that is not already disconnected is marked
// disconnected regardless of its session ID (unlike
// DisconnectProxy, which is session-guarded).
// 2. Rows that are already disconnected are left untouched, so their
// original disconnected_at is preserved and the returned count
// reflects only the rows that actually changed.
// 3. last_seen is not modified — the stale-proxy reaper keeps working
// off the real last heartbeat.
func TestSqlStore_DisconnectAllProxies(t *testing.T) {
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
ctx := context.Background()
lastSeenFresh := time.Now().Add(-30 * time.Second)
lastSeenStale := time.Now().Add(-30 * time.Minute)
oldDisconnectedAt := time.Now().Add(-time.Hour)
accountID := "acct-disconnect"
proxies := []*rpproxy.Proxy{
{
ID: "p-connected-fresh",
SessionID: "sess-1",
ClusterAddress: "cluster-a.example.com",
IPAddress: "10.0.0.1",
LastSeen: lastSeenFresh,
Status: rpproxy.StatusConnected,
},
{
ID: "p-connected-stale",
SessionID: "sess-2",
ClusterAddress: "cluster-b.example.com",
IPAddress: "10.0.0.2",
AccountID: &accountID,
LastSeen: lastSeenStale,
Status: rpproxy.StatusConnected,
},
{
ID: "p-already-disconnected",
SessionID: "sess-3",
ClusterAddress: "cluster-a.example.com",
IPAddress: "10.0.0.3",
LastSeen: lastSeenStale,
Status: rpproxy.StatusDisconnected,
DisconnectedAt: &oldDisconnectedAt,
},
}
for _, p := range proxies {
require.NoError(t, store.SaveProxy(ctx, p))
}
all, err := store.GetAllProxies(ctx)
require.NoError(t, err)
require.Len(t, all, 3)
disconnected, err := store.DisconnectAllProxies(ctx)
require.NoError(t, err)
assert.Equal(t, int64(2), disconnected)
all, err = store.GetAllProxies(ctx)
require.NoError(t, err)
require.Len(t, all, 3)
byID := make(map[string]*rpproxy.Proxy, len(all))
for _, p := range all {
byID[p.ID] = p
}
for id, p := range byID {
assert.Equal(t, rpproxy.StatusDisconnected, p.Status, "proxy %s should be disconnected", id)
require.NotNil(t, p.DisconnectedAt, "proxy %s should have disconnected_at set", id)
}
// force-marked rows carry a fresh disconnected_at; the untouched row keeps its original one
assert.WithinDuration(t, time.Now(), *byID["p-connected-fresh"].DisconnectedAt, 10*time.Second)
assert.WithinDuration(t, time.Now(), *byID["p-connected-stale"].DisconnectedAt, 10*time.Second)
assert.WithinDuration(t, oldDisconnectedAt, *byID["p-already-disconnected"].DisconnectedAt, time.Second)
// last_seen is preserved so the stale reaper schedule is unaffected
assert.WithinDuration(t, lastSeenFresh, byID["p-connected-fresh"].LastSeen, time.Second)
assert.WithinDuration(t, lastSeenStale, byID["p-connected-stale"].LastSeen, time.Second)
// idempotent: a second run has nothing left to update
disconnected, err = store.DisconnectAllProxies(ctx)
require.NoError(t, err)
assert.Equal(t, int64(0), disconnected)
})
}
func TestSqlStore_UpdateProxyHeartbeatRestoresDisconnectedCurrentSession(t *testing.T) {
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
ctx := context.Background()
proxy := &rpproxy.Proxy{
ID: "p-heartbeat",
SessionID: "sess-heartbeat",
ClusterAddress: "cluster-heartbeat.example.com",
IPAddress: "10.0.0.10",
LastSeen: time.Now().Add(-30 * time.Second),
Status: rpproxy.StatusConnected,
}
require.NoError(t, store.SaveProxy(ctx, proxy))
disconnected, err := store.DisconnectAllProxies(ctx)
require.NoError(t, err)
require.Equal(t, int64(1), disconnected)
require.NoError(t, store.UpdateProxyHeartbeat(ctx, &rpproxy.Proxy{ID: proxy.ID, SessionID: proxy.SessionID}))
all, err := store.GetAllProxies(ctx)
require.NoError(t, err)
require.Len(t, all, 1)
assert.Equal(t, rpproxy.StatusConnected, all[0].Status)
assert.Nil(t, all[0].DisconnectedAt)
assert.WithinDuration(t, time.Now(), all[0].LastSeen, 10*time.Second)
addresses, err := store.GetActiveProxyClusterAddresses(ctx)
require.NoError(t, err)
assert.Contains(t, addresses, proxy.ClusterAddress)
})
}
func TestSqlStore_GetAllProxies_Empty(t *testing.T) {
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
all, err := store.GetAllProxies(context.Background())
require.NoError(t, err)
assert.Empty(t, all)
})
}

View File

@@ -323,6 +323,8 @@ type Store interface {
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetAllProxies(ctx context.Context) ([]*proxy.Proxy, error)
DisconnectAllProxies(ctx context.Context) (int64, error)
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)

View File

@@ -745,6 +745,21 @@ func (mr *MockStoreMockRecorder) DeleteZoneDNSRecords(ctx, accountID, zoneID int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID)
}
// DisconnectAllProxies mocks base method.
func (m *MockStore) DisconnectAllProxies(ctx context.Context) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DisconnectAllProxies", ctx)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DisconnectAllProxies indicates an expected call of DisconnectAllProxies.
func (mr *MockStoreMockRecorder) DisconnectAllProxies(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectAllProxies", reflect.TypeOf((*MockStore)(nil).DisconnectAllProxies), ctx)
}
// DisconnectProxy mocks base method.
func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
m.ctrl.T.Helper()
@@ -1389,6 +1404,21 @@ func (mr *MockStoreMockRecorder) GetAllEphemeralPeers(ctx, lockStrength interfac
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllEphemeralPeers", reflect.TypeOf((*MockStore)(nil).GetAllEphemeralPeers), ctx, lockStrength)
}
// GetAllProxies mocks base method.
func (m *MockStore) GetAllProxies(ctx context.Context) ([]*proxy.Proxy, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAllProxies", ctx)
ret0, _ := ret[0].([]*proxy.Proxy)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAllProxies indicates an expected call of GetAllProxies.
func (mr *MockStoreMockRecorder) GetAllProxies(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllProxies", reflect.TypeOf((*MockStore)(nil).GetAllProxies), ctx)
}
// GetAllProxyAccessTokens mocks base method.
func (m *MockStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types2.ProxyAccessToken, error) {
m.ctrl.T.Helper()
@@ -1521,6 +1551,21 @@ func (mr *MockStoreMockRecorder) GetDNSRecordByID(ctx, lockStrength, accountID,
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSRecordByID", reflect.TypeOf((*MockStore)(nil).GetDNSRecordByID), ctx, lockStrength, accountID, zoneID, recordID)
}
// GetEmbeddedProxyPeerIDsByCluster mocks base method.
func (m *MockStore) GetEmbeddedProxyPeerIDsByCluster(ctx context.Context, accountID string) (map[string][]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetEmbeddedProxyPeerIDsByCluster", ctx, accountID)
ret0, _ := ret[0].(map[string][]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetEmbeddedProxyPeerIDsByCluster indicates an expected call of GetEmbeddedProxyPeerIDsByCluster.
func (mr *MockStoreMockRecorder) GetEmbeddedProxyPeerIDsByCluster(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEmbeddedProxyPeerIDsByCluster", reflect.TypeOf((*MockStore)(nil).GetEmbeddedProxyPeerIDsByCluster), ctx, accountID)
}
// GetExpiredEphemeralServices mocks base method.
func (m *MockStore) GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*service.Service, error) {
m.ctrl.T.Helper()
@@ -1566,6 +1611,21 @@ func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, accountID, gr
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, accountID, groupName)
}
// GetGroupIDsByPeerIDs mocks base method.
func (m *MockStore) GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupIDsByPeerIDs", ctx, accountID, peerIDs)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupIDsByPeerIDs indicates an expected call of GetGroupIDsByPeerIDs.
func (mr *MockStoreMockRecorder) GetGroupIDsByPeerIDs(ctx, accountID, peerIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupIDsByPeerIDs", reflect.TypeOf((*MockStore)(nil).GetGroupIDsByPeerIDs), ctx, accountID, peerIDs)
}
// GetGroupsByIDs mocks base method.
func (m *MockStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types2.Group, error) {
m.ctrl.T.Helper()
@@ -1850,6 +1910,21 @@ func (mr *MockStoreMockRecorder) GetPeerIDByKey(ctx, lockStrength, key interface
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerIDByKey", reflect.TypeOf((*MockStore)(nil).GetPeerIDByKey), ctx, lockStrength, key)
}
// GetPeerIDsByGroups mocks base method.
func (m *MockStore) GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeerIDsByGroups", ctx, accountID, groupIDs)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPeerIDsByGroups indicates an expected call of GetPeerIDsByGroups.
func (mr *MockStoreMockRecorder) GetPeerIDsByGroups(ctx, accountID, groupIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerIDsByGroups", reflect.TypeOf((*MockStore)(nil).GetPeerIDsByGroups), ctx, accountID, groupIDs)
}
// GetPeerIdByLabel mocks base method.
func (m *MockStore) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID, hostname string) (string, error) {
m.ctrl.T.Helper()
@@ -1925,51 +2000,6 @@ func (mr *MockStoreMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupIDs int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockStore)(nil).GetPeersByGroupIDs), ctx, accountID, groupIDs)
}
// GetPeerIDsByGroups mocks base method.
func (m *MockStore) GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeerIDsByGroups", ctx, accountID, groupIDs)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPeerIDsByGroups indicates an expected call of GetPeerIDsByGroups.
func (mr *MockStoreMockRecorder) GetPeerIDsByGroups(ctx, accountID, groupIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerIDsByGroups", reflect.TypeOf((*MockStore)(nil).GetPeerIDsByGroups), ctx, accountID, groupIDs)
}
// GetGroupIDsByPeerIDs mocks base method.
func (m *MockStore) GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupIDsByPeerIDs", ctx, accountID, peerIDs)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupIDsByPeerIDs indicates an expected call of GetGroupIDsByPeerIDs.
func (mr *MockStoreMockRecorder) GetGroupIDsByPeerIDs(ctx, accountID, peerIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupIDsByPeerIDs", reflect.TypeOf((*MockStore)(nil).GetGroupIDsByPeerIDs), ctx, accountID, peerIDs)
}
// GetEmbeddedProxyPeerIDsByCluster mocks base method.
func (m *MockStore) GetEmbeddedProxyPeerIDsByCluster(ctx context.Context, accountID string) (map[string][]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetEmbeddedProxyPeerIDsByCluster", ctx, accountID)
ret0, _ := ret[0].(map[string][]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetEmbeddedProxyPeerIDsByCluster indicates an expected call of GetEmbeddedProxyPeerIDsByCluster.
func (mr *MockStoreMockRecorder) GetEmbeddedProxyPeerIDsByCluster(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEmbeddedProxyPeerIDsByCluster", reflect.TypeOf((*MockStore)(nil).GetEmbeddedProxyPeerIDsByCluster), ctx, accountID)
}
// GetPeersByIDs mocks base method.
func (m *MockStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*peer.Peer, error) {
m.ctrl.T.Helper()

View File

@@ -1849,12 +1849,17 @@ func (am *DefaultAccountManager) DeleteUserInvite(ctx context.Context, accountID
const minPasswordLength = 8
// validatePassword checks password strength requirements:
// validatePassword checks password strength requirements.
func validatePassword(password string) error {
return ValidatePassword(password)
}
// ValidatePassword checks password strength requirements:
// - Minimum 8 characters
// - At least 1 digit
// - At least 1 uppercase letter
// - At least 1 special character
func validatePassword(password string) error {
func ValidatePassword(password string) error {
if len(password) < minPasswordLength {
return errors.New("password must be at least 8 characters long")
}

View File

@@ -33,10 +33,15 @@ const ConnectTimeout = 10 * time.Second
const healthCheckTimeout = 5 * time.Second
const (
// EnvMaxRecvMsgSize overrides the default gRPC max receive message size (4 MB)
// EnvMaxRecvMsgSize overrides the default gRPC max receive message size
// for the management client connection. Value is in bytes.
EnvMaxRecvMsgSize = "NB_MANAGEMENT_GRPC_MAX_MSG_SIZE"
// defaultMaxRecvMsgSize is the max gRPC receive message size used for the
// management client connection when EnvMaxRecvMsgSize is unset or invalid.
// It overrides the gRPC library default of 4 MB.
defaultMaxRecvMsgSize = 1024 * 1024 * 16
errMsgMgmtPublicKey = "failed getting Management Service public key: %s"
errMsgNoMgmtConnection = "no connection to management"
)
@@ -84,22 +89,22 @@ type ExposeResponse struct {
}
// MaxRecvMsgSize returns the configured max gRPC receive message size from
// the environment, or 0 if unset (which uses the gRPC default of 4 MB).
// the environment, or defaultMaxRecvMsgSize (16 MB) if unset or invalid.
func MaxRecvMsgSize() int {
val := os.Getenv(EnvMaxRecvMsgSize)
if val == "" {
return 0
return defaultMaxRecvMsgSize
}
size, err := strconv.Atoi(val)
if err != nil {
log.Warnf("invalid %s value %q, using default: %v", EnvMaxRecvMsgSize, val, err)
return 0
return defaultMaxRecvMsgSize
}
if size <= 0 {
log.Warnf("invalid %s value %d, must be positive, using default", EnvMaxRecvMsgSize, size)
return 0
return defaultMaxRecvMsgSize
}
return size
@@ -1030,8 +1035,6 @@ func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
BlockLANAccess: info.BlockLANAccess,
BlockInbound: info.BlockInbound,
DisableIPv6: info.DisableIPv6,
LazyConnectionEnabled: info.LazyConnectionEnabled,
},
Capabilities: peerCapabilities(*info),

View File

@@ -21,11 +21,11 @@ func TestMaxRecvMsgSize(t *testing.T) {
envValue string
expected int
}{
{name: "unset returns 0", envValue: "", expected: 0},
{name: "unset returns default", envValue: "", expected: defaultMaxRecvMsgSize},
{name: "valid value", envValue: "10485760", expected: 10485760},
{name: "non-numeric returns 0", envValue: "abc", expected: 0},
{name: "negative returns 0", envValue: "-1", expected: 0},
{name: "zero returns 0", envValue: "0", expected: 0},
{name: "non-numeric returns default", envValue: "abc", expected: defaultMaxRecvMsgSize},
{name: "negative returns default", envValue: "-1", expected: defaultMaxRecvMsgSize},
{name: "zero returns default", envValue: "0", expected: defaultMaxRecvMsgSize},
}
for _, tt := range tests {