mirror of
https://github.com/netbirdio/netbird.git
synced 2026-07-02 20:59:56 +00:00
Compare commits
12 Commits
backport/s
...
feat/admin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe3c14413c | ||
|
|
8e3b284f4b | ||
|
|
21aa933584 | ||
|
|
1dfa85a917 | ||
|
|
859fe19fff | ||
|
|
e40cb294f6 | ||
|
|
e203e0f42a | ||
|
|
167be3a30f | ||
|
|
1d8b5f6e5c | ||
|
|
520370a8b0 | ||
|
|
b5a16a1898 | ||
|
|
449b5cbb80 |
2
.github/workflows/golang-test-linux.yml
vendored
2
.github/workflows/golang-test-linux.yml
vendored
@@ -158,7 +158,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags 'devcert privileged' -exec 'sudo --preserve-env=CI,CGO_ENABLED' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/testutil/privileged)
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
|
||||
5
.github/workflows/release.yml
vendored
5
.github/workflows/release.yml
vendored
@@ -293,8 +293,11 @@ jobs:
|
||||
${{ steps.goreleaser.outputs.artifacts }}
|
||||
JSON
|
||||
|
||||
# dockers_v2 artifacts have no top-level goarch field, so match the
|
||||
# per-platform -amd64 tag suffix instead; it works for both the old
|
||||
# dockers and the new dockers_v2 image naming.
|
||||
mapfile -t src_images < <(
|
||||
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name | select(startswith("ghcr.io/"))' /tmp/goreleaser-artifacts.json
|
||||
jq -r '.[] | select(.type == "Docker Image") | .name | select(startswith("ghcr.io/") and endswith("-amd64"))' /tmp/goreleaser-artifacts.json
|
||||
)
|
||||
|
||||
for src in "${src_images[@]}"; do
|
||||
|
||||
@@ -10,7 +10,7 @@ var (
|
||||
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
|
||||
|
||||
// EnvKeyNBLazyConn Exported for Android java client to configure lazy connection
|
||||
EnvKeyNBLazyConn = lazyconn.EnvEnableLazyConn
|
||||
EnvKeyNBLazyConn = lazyconn.EnvLazyConn
|
||||
|
||||
// EnvKeyNBInactivityThreshold Exported for Android java client to configure connection inactivity threshold
|
||||
EnvKeyNBInactivityThreshold = lazyconn.EnvInactivityThreshold
|
||||
|
||||
@@ -71,12 +71,14 @@ var (
|
||||
extraIFaceBlackList []string
|
||||
anonymizeFlag bool
|
||||
dnsRouteInterval time.Duration
|
||||
lazyConnEnabled bool
|
||||
mtu uint16
|
||||
profilesDisabled bool
|
||||
updateSettingsDisabled bool
|
||||
captureEnabled bool
|
||||
networksDisabled bool
|
||||
// lazyConnEnabled is the parse target for the deprecated --enable-lazy-connection
|
||||
// flag. The flag is inert; the value is no longer read (use NB_LAZY_CONN instead).
|
||||
lazyConnEnabled bool
|
||||
mtu uint16
|
||||
profilesDisabled bool
|
||||
updateSettingsDisabled bool
|
||||
captureEnabled bool
|
||||
networksDisabled bool
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
@@ -210,7 +212,8 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")
|
||||
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "Deprecated: no longer used. Lazy connections are controlled by the server and the NB_LAZY_CONN environment variable.")
|
||||
_ = upCmd.PersistentFlags().MarkDeprecated(enableLazyConnectionFlag, "no longer used; lazy connections are controlled by the server and the NB_LAZY_CONN environment variable")
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -479,10 +479,6 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
req.DisableIpv6 = &disableIPv6
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
req.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
|
||||
return &req
|
||||
}
|
||||
|
||||
@@ -600,9 +596,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
ic.DisableIPv6 = &disableIPv6
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
ic.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
return &ic, nil
|
||||
}
|
||||
|
||||
@@ -718,9 +711,6 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
loginRequest.DisableIpv6 = &disableIPv6
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
return &loginRequest, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -322,7 +322,6 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||
a.config.BlockLANAccess,
|
||||
a.config.BlockInbound,
|
||||
a.config.DisableIPv6,
|
||||
a.config.LazyConnectionEnabled,
|
||||
a.config.EnableSSHRoot,
|
||||
a.config.EnableSSHSFTP,
|
||||
a.config.EnableSSHLocalPortForwarding,
|
||||
|
||||
@@ -16,6 +16,16 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// lazyForce is the resolved local decision for lazy connections, layered above the
|
||||
// management feature flag. lazyForceNone defers to management.
|
||||
type lazyForce int
|
||||
|
||||
const (
|
||||
lazyForceNone lazyForce = iota
|
||||
lazyForceOn
|
||||
lazyForceOff
|
||||
)
|
||||
|
||||
// ConnMgr coordinates both lazy connections (established on-demand) and permanent peer connections.
|
||||
//
|
||||
// The connection manager is responsible for:
|
||||
@@ -28,7 +38,7 @@ type ConnMgr struct {
|
||||
peerStore *peerstore.Store
|
||||
statusRecorder *peer.Status
|
||||
iface lazyconn.WGIface
|
||||
enabledLocally bool
|
||||
force lazyForce
|
||||
rosenpassEnabled bool
|
||||
|
||||
lazyConnMgr *manager.Manager
|
||||
@@ -43,28 +53,34 @@ func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerSto
|
||||
peerStore: peerStore,
|
||||
statusRecorder: statusRecorder,
|
||||
iface: iface,
|
||||
force: resolveLazyForce(engineConfig.LazyConnection),
|
||||
rosenpassEnabled: engineConfig.RosenpassEnabled,
|
||||
}
|
||||
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
|
||||
e.enabledLocally = true
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// Start initializes the connection manager and starts the lazy connection manager if enabled by env var or cmd line option.
|
||||
// Start initializes the connection manager. It starts the lazy connection manager when a
|
||||
// local override forces it on; with no local override it waits for the management feature flag.
|
||||
func (e *ConnMgr) Start(ctx context.Context) {
|
||||
if e.lazyConnMgr != nil {
|
||||
log.Errorf("lazy connection manager is already started")
|
||||
return
|
||||
}
|
||||
|
||||
if !e.enabledLocally {
|
||||
log.Infof("lazy connection manager is disabled")
|
||||
switch e.force {
|
||||
case lazyForceOff:
|
||||
log.Infof("lazy connection manager is disabled by local override (%s or MDM policy)", lazyconn.EnvLazyConn)
|
||||
e.statusRecorder.UpdateLazyConnection(false)
|
||||
return
|
||||
case lazyForceNone:
|
||||
log.Infof("lazy connection manager is managed by the management feature flag")
|
||||
e.statusRecorder.UpdateLazyConnection(false)
|
||||
return
|
||||
}
|
||||
|
||||
if e.rosenpassEnabled {
|
||||
log.Warnf("rosenpass connection manager is enabled, lazy connection manager will not be started")
|
||||
e.statusRecorder.UpdateLazyConnection(false)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -76,8 +92,8 @@ func (e *ConnMgr) Start(ctx context.Context) {
|
||||
// If enabled, it initializes the lazy connection manager and start it. Do not need to call Start() again.
|
||||
// If disabled, then it closes the lazy connection manager and open the connections to all peers.
|
||||
func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) error {
|
||||
// do not disable lazy connection manager if it was enabled by env var
|
||||
if e.enabledLocally {
|
||||
// a local override (NB_LAZY_CONN or local config) takes precedence over management
|
||||
if e.force != lazyForceNone {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -89,6 +105,7 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er
|
||||
|
||||
if e.rosenpassEnabled {
|
||||
log.Infof("rosenpass connection manager is enabled, lazy connection manager will not be started")
|
||||
e.statusRecorder.UpdateLazyConnection(false)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -98,6 +115,7 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er
|
||||
return e.addPeersToLazyConnManager()
|
||||
} else {
|
||||
if e.lazyConnMgr == nil {
|
||||
e.statusRecorder.UpdateLazyConnection(false)
|
||||
return nil
|
||||
}
|
||||
log.Infof("lazy connection manager is disabled by management feature flag")
|
||||
@@ -309,6 +327,25 @@ func (e *ConnMgr) isStartedWithLazyMgr() bool {
|
||||
return e.lazyConnMgr != nil && e.lazyCtxCancel != nil
|
||||
}
|
||||
|
||||
// resolveLazyForce determines the local override. NB_LAZY_CONN takes precedence; when it
|
||||
// is unset the MDM policy override (mdmState) applies. Either wins in both directions over
|
||||
// the management feature flag; StateUnset for both defers to management.
|
||||
func resolveLazyForce(mdmState lazyconn.State) lazyForce {
|
||||
state := lazyconn.EnvState()
|
||||
if state == lazyconn.StateUnset {
|
||||
state = mdmState
|
||||
}
|
||||
|
||||
switch state {
|
||||
case lazyconn.StateOn:
|
||||
return lazyForceOn
|
||||
case lazyconn.StateOff:
|
||||
return lazyForceOff
|
||||
default:
|
||||
return lazyForceNone
|
||||
}
|
||||
}
|
||||
|
||||
func inactivityThresholdEnv() *time.Duration {
|
||||
envValue := os.Getenv(lazyconn.EnvInactivityThreshold)
|
||||
if envValue == "" {
|
||||
|
||||
40
client/internal/conn_mgr_test.go
Normal file
40
client/internal/conn_mgr_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
)
|
||||
|
||||
func TestResolveLazyForce(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
env string
|
||||
envSet bool
|
||||
mdm lazyconn.State
|
||||
want lazyForce
|
||||
}{
|
||||
{name: "env unset, mdm unset -> defer to management", mdm: lazyconn.StateUnset, want: lazyForceNone},
|
||||
{name: "env on -> force on", env: "on", envSet: true, mdm: lazyconn.StateUnset, want: lazyForceOn},
|
||||
{name: "env off -> force off", env: "off", envSet: true, mdm: lazyconn.StateUnset, want: lazyForceOff},
|
||||
{name: "env unset, mdm on -> force on", mdm: lazyconn.StateOn, want: lazyForceOn},
|
||||
{name: "env unset, mdm off -> force off", mdm: lazyconn.StateOff, want: lazyForceOff},
|
||||
{name: "env on beats mdm off", env: "on", envSet: true, mdm: lazyconn.StateOff, want: lazyForceOn},
|
||||
{name: "env off beats mdm on", env: "off", envSet: true, mdm: lazyconn.StateOn, want: lazyForceOff},
|
||||
{name: "unrecognized env, mdm on -> mdm wins", env: "auto", envSet: true, mdm: lazyconn.StateOn, want: lazyForceOn},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Setenv(lazyconn.EnvLazyConn, tt.env)
|
||||
if !tt.envSet {
|
||||
os.Unsetenv(lazyconn.EnvLazyConn)
|
||||
}
|
||||
|
||||
if got := resolveLazyForce(tt.mdm); got != tt.want {
|
||||
t.Fatalf("resolveLazyForce(%v) = %v, want %v", tt.mdm, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/metrics"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
@@ -601,7 +602,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
BlockInbound: config.BlockInbound,
|
||||
DisableIPv6: config.DisableIPv6,
|
||||
|
||||
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
||||
LazyConnection: lazyconn.ParseState(config.LazyConnection),
|
||||
|
||||
MTU: selectMTU(config.MTU, peerConfig.Mtu),
|
||||
LogPath: logPath,
|
||||
@@ -675,7 +676,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.DisableIPv6,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
|
||||
@@ -681,7 +681,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
configContent.WriteString(fmt.Sprintf("ClientCertKeyPath: %s\n", g.internalConfig.ClientCertKeyPath))
|
||||
}
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
|
||||
configContent.WriteString(fmt.Sprintf("LazyConnection: %q\n", g.internalConfig.LazyConnection))
|
||||
configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU))
|
||||
}
|
||||
|
||||
|
||||
@@ -885,7 +885,7 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
||||
DNSRouteInterval: 5 * time.Second,
|
||||
ClientCertPath: "/tmp/cert",
|
||||
ClientCertKeyPath: "/tmp/key",
|
||||
LazyConnectionEnabled: true,
|
||||
LazyConnection: "on",
|
||||
MTU: 1280,
|
||||
}
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/client/internal/metrics"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
@@ -147,7 +148,9 @@ type EngineConfig struct {
|
||||
BlockInbound bool
|
||||
DisableIPv6 bool
|
||||
|
||||
LazyConnectionEnabled bool
|
||||
// LazyConnection is the MDM-sourced lazy-connection override; StateUnset defers to
|
||||
// the env var and management feature flag.
|
||||
LazyConnection lazyconn.State
|
||||
|
||||
MTU uint16
|
||||
|
||||
@@ -1130,7 +1133,6 @@ func (e *Engine) applyInfoFlags(info *system.Info) {
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.DisableIPv6,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
@@ -1999,7 +2001,6 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.DisableIPv6,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
|
||||
@@ -3,24 +3,57 @@ package lazyconn
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
EnvEnableLazyConn = "NB_ENABLE_EXPERIMENTAL_LAZY_CONN"
|
||||
EnvLazyConn = "NB_LAZY_CONN"
|
||||
EnvInactivityThreshold = "NB_LAZY_CONN_INACTIVITY_THRESHOLD"
|
||||
)
|
||||
|
||||
func IsLazyConnEnabledByEnv() bool {
|
||||
val := os.Getenv(EnvEnableLazyConn)
|
||||
if val == "" {
|
||||
return false
|
||||
}
|
||||
enabled, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvEnableLazyConn, err)
|
||||
return false
|
||||
}
|
||||
return enabled
|
||||
// State is the tri-state local override for lazy connections read from the environment.
|
||||
type State int
|
||||
|
||||
const (
|
||||
// StateUnset means no local override; defer to the management feature flag.
|
||||
StateUnset State = iota
|
||||
// StateOn forces lazy connections on, overriding management.
|
||||
StateOn
|
||||
// StateOff forces lazy connections off, overriding management.
|
||||
StateOff
|
||||
)
|
||||
|
||||
// EnvState reads NB_LAZY_CONN and returns the local override state.
|
||||
func EnvState() State {
|
||||
return ParseState(os.Getenv(EnvLazyConn))
|
||||
}
|
||||
|
||||
// ParseState interprets a lazy-connection override value (from the environment or an MDM
|
||||
// policy). It accepts the on/off aliases plus any value strconv.ParseBool understands
|
||||
// (true/false/1/0). An empty or unrecognized value returns StateUnset so that the
|
||||
// management feature flag remains in control.
|
||||
func ParseState(raw string) State {
|
||||
if raw == "" {
|
||||
return StateUnset
|
||||
}
|
||||
|
||||
normalized := strings.ToLower(strings.TrimSpace(raw))
|
||||
switch normalized {
|
||||
case "on":
|
||||
return StateOn
|
||||
case "off":
|
||||
return StateOff
|
||||
}
|
||||
|
||||
enabled, err := strconv.ParseBool(normalized)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse lazy connection value %q (from %s env or MDM policy): %v", raw, EnvLazyConn, err)
|
||||
return StateUnset
|
||||
}
|
||||
if enabled {
|
||||
return StateOn
|
||||
}
|
||||
return StateOff
|
||||
}
|
||||
|
||||
45
client/internal/lazyconn/env_test.go
Normal file
45
client/internal/lazyconn/env_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package lazyconn
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEnvState(t *testing.T) {
|
||||
tests := []struct {
|
||||
value string
|
||||
set bool
|
||||
want State
|
||||
}{
|
||||
{set: false, want: StateUnset},
|
||||
{value: "", set: true, want: StateUnset},
|
||||
{value: "on", set: true, want: StateOn},
|
||||
{value: "ON", set: true, want: StateOn},
|
||||
{value: "true", set: true, want: StateOn},
|
||||
{value: "1", set: true, want: StateOn},
|
||||
{value: " on ", set: true, want: StateOn},
|
||||
{value: "off", set: true, want: StateOff},
|
||||
{value: "OFF", set: true, want: StateOff},
|
||||
{value: "false", set: true, want: StateOff},
|
||||
{value: "0", set: true, want: StateOff},
|
||||
{value: "auto", set: true, want: StateUnset},
|
||||
{value: "garbage", set: true, want: StateUnset},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
name := tt.value
|
||||
if !tt.set {
|
||||
name = "unset"
|
||||
}
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Setenv(EnvLazyConn, tt.value)
|
||||
if !tt.set {
|
||||
os.Unsetenv(EnvLazyConn)
|
||||
}
|
||||
|
||||
if got := EnvState(); got != tt.want {
|
||||
t.Fatalf("EnvState() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -101,8 +101,6 @@ type ConfigInput struct {
|
||||
|
||||
DNSLabels domain.List
|
||||
|
||||
LazyConnectionEnabled *bool
|
||||
|
||||
MTU *uint16
|
||||
}
|
||||
|
||||
@@ -180,7 +178,9 @@ type Config struct {
|
||||
|
||||
ClientCertKeyPair *tls.Certificate `json:"-"`
|
||||
|
||||
LazyConnectionEnabled bool
|
||||
// LazyConnection is the MDM-managed lazy-connection override ("on"/"off"/"").
|
||||
// Runtime-only: re-derived from MDM policy on each load, never persisted.
|
||||
LazyConnection string `json:"-"`
|
||||
|
||||
MTU uint16
|
||||
|
||||
@@ -632,12 +632,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.LazyConnectionEnabled != nil && *input.LazyConnectionEnabled != config.LazyConnectionEnabled {
|
||||
log.Infof("switching lazy connection to %t", *input.LazyConnectionEnabled)
|
||||
config.LazyConnectionEnabled = *input.LazyConnectionEnabled
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.MTU != nil && *input.MTU != config.MTU {
|
||||
log.Infof("updating MTU to %d (old value %d)", *input.MTU, config.MTU)
|
||||
config.MTU = *input.MTU
|
||||
@@ -728,6 +722,15 @@ func (config *Config) applyMDMPolicy(policy *mdm.Policy) {
|
||||
log.Warnf("MDM wireguard port %d out of range [1,65535]; keeping previous value", v)
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := policy.GetBool(mdm.KeyLazyConnection); ok {
|
||||
state := "off"
|
||||
if v {
|
||||
state = "on"
|
||||
}
|
||||
config.LazyConnection = state
|
||||
logApplied(mdm.KeyLazyConnection, state)
|
||||
}
|
||||
}
|
||||
|
||||
// parseURL parses and validates the URL for the named service. The URL
|
||||
|
||||
@@ -130,6 +130,37 @@ func TestApply_MDMBoolKeysOverrideOnDiskValue(t *testing.T) {
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyRosenpassEnabled))
|
||||
}
|
||||
|
||||
func TestApply_MDMLazyConnection(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
raw any
|
||||
want string
|
||||
}{
|
||||
{"native true", true, "on"},
|
||||
{"native false", false, "off"},
|
||||
{"string on", "on", "on"},
|
||||
{"string off", "off", "off"},
|
||||
{"string yes", "yes", "on"},
|
||||
{"string no", "no", "off"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyLazyConnection: c.raw,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.Equal(t, c.want, cfg.LazyConnection)
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyLazyConnection))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApply_MDMPreSharedKeyRedactionSentinelRejected(t *testing.T) {
|
||||
const maskSentinel = "**********"
|
||||
|
||||
|
||||
@@ -1,191 +0,0 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func newExitNodeTestManager() *DefaultManager {
|
||||
return &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
}
|
||||
|
||||
func exitRoute(netID, peer string, skipAutoApply bool) *route.Route {
|
||||
return &route.Route{
|
||||
NetID: route.NetID(netID),
|
||||
Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
Peer: peer,
|
||||
SkipAutoApply: skipAutoApply,
|
||||
}
|
||||
}
|
||||
|
||||
func TestPickPreferredExitNode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
info exitNodeInfo
|
||||
want route.NetID
|
||||
}{
|
||||
{
|
||||
name: "persisted user selection wins over management",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"a", "b", "c"},
|
||||
userSelected: []route.NetID{"b"},
|
||||
selectedByManagement: []route.NetID{"a"},
|
||||
},
|
||||
want: "b",
|
||||
},
|
||||
{
|
||||
name: "multiple user-selected self-heal to deterministic min",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"a", "b", "c"},
|
||||
userSelected: []route.NetID{"c", "a"},
|
||||
},
|
||||
want: "a",
|
||||
},
|
||||
{
|
||||
name: "explicit opt-out keeps none",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"a", "b"},
|
||||
userDeselected: []route.NetID{"a", "b"},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "fresh defaults to management auto-apply pick",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"a", "b", "c"},
|
||||
selectedByManagement: []route.NetID{"b"},
|
||||
},
|
||||
want: "b",
|
||||
},
|
||||
{
|
||||
name: "no user pick and no management auto-apply selects none",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"c", "a", "b"},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "user-deselect does not block a management auto-apply sibling",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"a", "b"},
|
||||
userDeselected: []route.NetID{"a"},
|
||||
selectedByManagement: []route.NetID{"b"},
|
||||
},
|
||||
want: "b",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, pickPreferredExitNode(tt.info), "preferred exit node")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceSingleExitNode(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
all := []route.NetID{"a", "b", "c"}
|
||||
|
||||
m.enforceSingleExitNode("b", all)
|
||||
assert.False(t, m.routeSelector.IsSelected("a"), "a should be deselected")
|
||||
assert.True(t, m.routeSelector.IsSelected("b"), "b should be the only selected exit node")
|
||||
assert.False(t, m.routeSelector.IsSelected("c"), "c should be deselected")
|
||||
|
||||
// Switching the preferred node moves the single selection.
|
||||
m.enforceSingleExitNode("c", all)
|
||||
assert.False(t, m.routeSelector.IsSelected("a"), "a stays deselected")
|
||||
assert.False(t, m.routeSelector.IsSelected("b"), "b should now be deselected")
|
||||
assert.True(t, m.routeSelector.IsSelected("c"), "c should now be selected")
|
||||
|
||||
// Empty preferred turns every exit node off.
|
||||
m.enforceSingleExitNode("", all)
|
||||
for _, id := range all {
|
||||
assert.False(t, m.routeSelector.IsSelected(id), "no exit node should be selected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceSingleExitNode_RespectsDeselectAll(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
m.routeSelector.DeselectAllRoutes()
|
||||
|
||||
m.enforceSingleExitNode("b", []route.NetID{"a", "b"})
|
||||
|
||||
assert.True(t, m.routeSelector.IsDeselectAll(), "global deselect-all must stay in effect")
|
||||
assert.False(t, m.routeSelector.IsSelected("b"), "no exit node should be forced on while deselect-all is set")
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement_FreshSelectsOne(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
routes := route.HAMap{
|
||||
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
|
||||
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
|
||||
"lan|192.168.1.0/24": {{NetID: "lan", Network: netip.MustParsePrefix("192.168.1.0/24"), Peer: "p3"}},
|
||||
"exitC|0.0.0.0/0": {exitRoute("exitC", "p4", false)},
|
||||
}
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
// Exactly one exit node (the deterministic first) is selected.
|
||||
assert.True(t, m.routeSelector.IsSelected("exitA"), "exitA is the deterministic default")
|
||||
assert.False(t, m.routeSelector.IsSelected("exitB"), "exitB must not also be selected")
|
||||
assert.False(t, m.routeSelector.IsSelected("exitC"), "exitC must not also be selected")
|
||||
// Non-exit routes are left at their default-on state.
|
||||
assert.True(t, m.routeSelector.IsSelected("lan"), "non-exit route selection is untouched")
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement_HonorsPersistedPick(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
routes := route.HAMap{
|
||||
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
|
||||
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
|
||||
}
|
||||
all := []route.NetID{"exitA", "exitB"}
|
||||
|
||||
// Simulate the state the runtime select path leaves behind: exactly one
|
||||
// exit node explicitly selected, its sibling deselected.
|
||||
require.NoError(t, m.routeSelector.SelectRoutes([]route.NetID{"exitB"}, true, all))
|
||||
require.NoError(t, m.routeSelector.DeselectRoutes([]route.NetID{"exitA"}, all))
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
assert.True(t, m.routeSelector.IsSelected("exitB"), "persisted pick must stay selected")
|
||||
assert.False(t, m.routeSelector.IsSelected("exitA"), "the other exit node stays deselected")
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement_OptOutKeepsNone(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
routes := route.HAMap{
|
||||
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
|
||||
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
|
||||
}
|
||||
all := []route.NetID{"exitA", "exitB"}
|
||||
|
||||
// User deselected exit nodes and selected none.
|
||||
require.NoError(t, m.routeSelector.DeselectRoutes(all, all))
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
assert.False(t, m.routeSelector.IsSelected("exitA"), "opt-out keeps exitA off")
|
||||
assert.False(t, m.routeSelector.IsSelected("exitB"), "opt-out keeps exitB off")
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement_NoAutoApplySelectsNone(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
// SkipAutoApply=true: management offers the exit nodes but doesn't request
|
||||
// auto-activation, so none should be selected until the user picks one.
|
||||
routes := route.HAMap{
|
||||
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", true)},
|
||||
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", true)},
|
||||
}
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
assert.False(t, m.routeSelector.IsSelected("exitA"), "no auto-apply keeps exitA off")
|
||||
assert.False(t, m.routeSelector.IsSelected("exitB"), "no auto-apply keeps exitB off")
|
||||
}
|
||||
@@ -701,13 +701,7 @@ func resolveURLsToIPs(urls []string) []net.IP {
|
||||
return ips
|
||||
}
|
||||
|
||||
// updateRouteSelectorFromManagement reconciles exit-node selection on every
|
||||
// network map: it keeps at most one exit node selected — the user's persisted
|
||||
// pick, else whatever management marks for auto-apply (SkipAutoApply=false),
|
||||
// else none. We never auto-activate an exit node the map doesn't request; it
|
||||
// stays off until the user picks it. Exit nodes are mutually exclusive, but the
|
||||
// RouteSelector stores routes with default-on semantics, so without this every
|
||||
// available exit node would report selected at once.
|
||||
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
|
||||
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
|
||||
m.mirrorV6ExitPairSelections(clientRoutes)
|
||||
|
||||
@@ -718,14 +712,13 @@ func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HA
|
||||
return
|
||||
}
|
||||
|
||||
info := m.collectExitNodeInfo(clientRoutes)
|
||||
if len(info.allIDs) == 0 {
|
||||
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
|
||||
if len(exitNodeInfo.allIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
preferred := pickPreferredExitNode(info)
|
||||
m.enforceSingleExitNode(preferred, info.allIDs)
|
||||
m.logExitNodeUpdate(info, preferred)
|
||||
m.updateExitNodeSelections(exitNodeInfo)
|
||||
m.logExitNodeUpdate(exitNodeInfo)
|
||||
}
|
||||
|
||||
// mirrorV6ExitPairSelections keeps every synthesized "-v6" exit route's selection
|
||||
@@ -753,10 +746,6 @@ type exitNodeInfo struct {
|
||||
userDeselected []route.NetID
|
||||
}
|
||||
|
||||
// collectExitNodeInfo categorises the available exit nodes by their persisted
|
||||
// selection state. It keys on the base (v4) NetID and skips the synthesized
|
||||
// "-v6" partner, which inherits its base's selection through the RouteSelector
|
||||
// — counting it separately would double-count the pair.
|
||||
func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeInfo {
|
||||
var info exitNodeInfo
|
||||
|
||||
@@ -766,9 +755,6 @@ func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeI
|
||||
}
|
||||
|
||||
netID := haID.NetID()
|
||||
if strings.HasSuffix(string(netID), route.V6ExitSuffix) {
|
||||
continue
|
||||
}
|
||||
info.allIDs = append(info.allIDs, netID)
|
||||
|
||||
if m.routeSelector.HasUserSelectionForRoute(netID) {
|
||||
@@ -805,52 +791,45 @@ func (m *DefaultManager) checkManagementSelection(routes []*route.Route, netID r
|
||||
}
|
||||
}
|
||||
|
||||
// pickPreferredExitNode chooses the single exit node to keep selected. In order:
|
||||
// - a persisted user selection wins (deterministic if several survive from
|
||||
// legacy state, so the set self-heals down to one);
|
||||
// - otherwise activate only what management marks for auto-apply
|
||||
// (SkipAutoApply=false); the lexicographically first if it marks several.
|
||||
//
|
||||
// Returns "" when neither holds — we never force an arbitrary exit node on. A
|
||||
// route the map doesn't auto-apply stays off until the user selects it.
|
||||
// info.userDeselected is informational only: an explicit deselect simply keeps
|
||||
// that route out of both lists above, so it can't be picked.
|
||||
func pickPreferredExitNode(info exitNodeInfo) route.NetID {
|
||||
if len(info.userSelected) > 0 {
|
||||
return minNetID(info.userSelected)
|
||||
}
|
||||
if len(info.selectedByManagement) > 0 {
|
||||
return minNetID(info.selectedByManagement)
|
||||
}
|
||||
return ""
|
||||
func (m *DefaultManager) updateExitNodeSelections(info exitNodeInfo) {
|
||||
routesToDeselect := m.getRoutesToDeselect(info.allIDs)
|
||||
m.deselectExitNodes(routesToDeselect)
|
||||
m.selectExitNodesByManagement(info.selectedByManagement, info.allIDs)
|
||||
}
|
||||
|
||||
// enforceSingleExitNode makes preferred the only selected exit node: every other
|
||||
// available exit node is deselected and preferred (if any) is selected, without
|
||||
// disturbing non-exit route selections. The whole reconciliation runs under a
|
||||
// single RouteSelector lock (SetExclusiveExitNode) so a concurrent deselect-all
|
||||
// cannot interleave and get undone; a global deselect-all is left untouched so
|
||||
// the user's "all off" stays in effect.
|
||||
func (m *DefaultManager) enforceSingleExitNode(preferred route.NetID, allIDs []route.NetID) {
|
||||
m.routeSelector.SetExclusiveExitNode(preferred, allIDs)
|
||||
}
|
||||
|
||||
func (m *DefaultManager) logExitNodeUpdate(info exitNodeInfo, preferred route.NetID) {
|
||||
log.Debugf("Exit node selection: %d available, preferred=%q (%d user-selected, %d user-deselected, %d management-selected)",
|
||||
len(info.allIDs), preferred, len(info.userSelected), len(info.userDeselected), len(info.selectedByManagement))
|
||||
}
|
||||
|
||||
// minNetID returns the lexicographically smallest NetID, for a deterministic
|
||||
// default pick that stays stable across restarts.
|
||||
func minNetID(ids []route.NetID) route.NetID {
|
||||
if len(ids) == 0 {
|
||||
return ""
|
||||
}
|
||||
best := ids[0]
|
||||
for _, id := range ids[1:] {
|
||||
if id < best {
|
||||
best = id
|
||||
func (m *DefaultManager) getRoutesToDeselect(allIDs []route.NetID) []route.NetID {
|
||||
var routesToDeselect []route.NetID
|
||||
for _, netID := range allIDs {
|
||||
if !m.routeSelector.HasUserSelectionForRoute(netID) {
|
||||
routesToDeselect = append(routesToDeselect, netID)
|
||||
}
|
||||
}
|
||||
return best
|
||||
return routesToDeselect
|
||||
}
|
||||
|
||||
func (m *DefaultManager) deselectExitNodes(routesToDeselect []route.NetID) {
|
||||
if len(routesToDeselect) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err := m.routeSelector.DeselectRoutes(routesToDeselect, routesToDeselect)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to deselect exit nodes: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) selectExitNodesByManagement(selectedByManagement []route.NetID, allIDs []route.NetID) {
|
||||
if len(selectedByManagement) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err := m.routeSelector.SelectRoutes(selectedByManagement, true, allIDs)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to select exit nodes: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) logExitNodeUpdate(info exitNodeInfo) {
|
||||
log.Debugf("Updated route selector: %d exit nodes available, %d selected by management, %d user-selected, %d user-deselected",
|
||||
len(info.allIDs), len(info.selectedByManagement), len(info.userSelected), len(info.userDeselected))
|
||||
}
|
||||
|
||||
@@ -115,38 +115,7 @@ func (rs *RouteSelector) DeselectAllRoutes() {
|
||||
clear(rs.selectedRoutes)
|
||||
}
|
||||
|
||||
// SetExclusiveExitNode atomically makes preferred the only selected exit node
|
||||
// among exitIDs: every other ID in exitIDs is deselected and preferred (when
|
||||
// non-empty) is selected, all under a single lock. Holding the lock across the
|
||||
// whole reconciliation prevents a concurrent DeselectAllRoutes from interleaving
|
||||
// between the deselect and select steps and being silently undone. A global
|
||||
// deselect-all is left untouched so the user's "all off" stays in effect;
|
||||
// non-exit routes are never referenced, so their selection is preserved.
|
||||
func (rs *RouteSelector) SetExclusiveExitNode(preferred route.NetID, exitIDs []route.NetID) {
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
|
||||
if rs.deselectAll {
|
||||
return
|
||||
}
|
||||
|
||||
for _, id := range exitIDs {
|
||||
if id == preferred {
|
||||
continue
|
||||
}
|
||||
rs.deselectedRoutes[id] = struct{}{}
|
||||
delete(rs.selectedRoutes, id)
|
||||
}
|
||||
|
||||
if preferred != "" {
|
||||
delete(rs.deselectedRoutes, preferred)
|
||||
rs.selectedRoutes[preferred] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// IsDeselectAll reports whether the global "deselect all" flag is set, i.e. the
|
||||
// user explicitly disabled every route. Callers enforcing per-route invariants
|
||||
// (e.g. single exit node) should leave the selection untouched when it is.
|
||||
// IsDeselectAll reports whether the user has explicitly deselected all routes.
|
||||
func (rs *RouteSelector) IsDeselectAll() bool {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
@@ -38,7 +38,7 @@ func GetEnvKeyNBForceRelay() string {
|
||||
|
||||
// GetEnvKeyNBLazyConn Exports the environment variable for the iOS client
|
||||
func GetEnvKeyNBLazyConn() string {
|
||||
return lazyconn.EnvEnableLazyConn
|
||||
return lazyconn.EnvLazyConn
|
||||
}
|
||||
|
||||
// GetEnvKeyNBInactivityThreshold Exports the environment variable for the iOS client
|
||||
|
||||
@@ -27,6 +27,7 @@ var allKeys = []string{
|
||||
KeyWireguardPort,
|
||||
KeySplitTunnelMode,
|
||||
KeySplitTunnelApps,
|
||||
KeyLazyConnection,
|
||||
}
|
||||
|
||||
// canonicalKey maps the lowercase form of a managed-config value name to
|
||||
|
||||
@@ -11,6 +11,7 @@ package mdm
|
||||
import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -41,6 +42,11 @@ const (
|
||||
// construction — only one mode can be set at a time.
|
||||
KeySplitTunnelMode = "splitTunnelMode"
|
||||
KeySplitTunnelApps = "splitTunnelApps"
|
||||
|
||||
// KeyLazyConnection forces the lazy-connection feature on or off, overriding
|
||||
// the management feature flag. Read as a bool (native bool, or on/off,
|
||||
// true/false, 1/0, yes/no); absent = defer to management.
|
||||
KeyLazyConnection = "lazyConnection"
|
||||
)
|
||||
|
||||
// Split-tunnel mode literals (KeySplitTunnelMode values).
|
||||
@@ -62,12 +68,13 @@ var boolStringLiterals = map[string]bool{
|
||||
"true": true,
|
||||
"1": true,
|
||||
"yes": true,
|
||||
"on": true,
|
||||
"false": false,
|
||||
"0": false,
|
||||
"no": false,
|
||||
"off": false,
|
||||
}
|
||||
|
||||
|
||||
// Policy holds MDM-managed settings read from the platform source. A nil or
|
||||
// empty Policy means no enforcement is active.
|
||||
type Policy struct {
|
||||
@@ -150,7 +157,8 @@ func (p *Policy) GetString(key string) (string, bool) {
|
||||
}
|
||||
|
||||
// GetBool returns the managed value for key coerced to bool, and whether the
|
||||
// key was set. Accepts native bool and string literals "true"/"false"/"1"/"0".
|
||||
// key was set. Accepts native bool and string literals (true/false, 1/0,
|
||||
// yes/no, on/off), case-insensitively and trimmed of surrounding whitespace.
|
||||
func (p *Policy) GetBool(key string) (bool, bool) {
|
||||
if p == nil {
|
||||
return false, false
|
||||
@@ -163,7 +171,7 @@ func (p *Policy) GetBool(key string) (bool, bool) {
|
||||
case bool:
|
||||
return t, true
|
||||
case string:
|
||||
b, known := boolStringLiterals[t]
|
||||
b, known := boolStringLiterals[strings.ToLower(strings.TrimSpace(t))]
|
||||
return b, known
|
||||
case int:
|
||||
return t != 0, true
|
||||
|
||||
@@ -31,8 +31,8 @@ func TestPolicy_Empty(t *testing.T) {
|
||||
|
||||
func TestPolicy_HasKey(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true,
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true,
|
||||
})
|
||||
assert.False(t, p.IsEmpty())
|
||||
assert.True(t, p.HasKey(KeyManagementURL))
|
||||
@@ -53,8 +53,8 @@ func TestPolicy_ManagedKeysSorted(t *testing.T) {
|
||||
func TestPolicy_GetString(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true, // wrong type for GetString
|
||||
KeyPreSharedKey: "", // empty rejected
|
||||
KeyDisableProfiles: true, // wrong type for GetString
|
||||
KeyPreSharedKey: "", // empty rejected
|
||||
})
|
||||
v, ok := p.GetString(KeyManagementURL)
|
||||
assert.True(t, ok)
|
||||
@@ -85,6 +85,11 @@ func TestPolicy_GetBool(t *testing.T) {
|
||||
{"string 0", "0", false, true},
|
||||
{"string yes", "yes", true, true},
|
||||
{"string no", "no", false, true},
|
||||
{"string on", "on", true, true},
|
||||
{"string off", "off", false, true},
|
||||
{"mixed case On", "On", true, true},
|
||||
{"upper TRUE", "TRUE", true, true},
|
||||
{"padded yes", " yes ", true, true},
|
||||
{"int nonzero", 1, true, true},
|
||||
{"int zero", 0, false, true},
|
||||
{"int64 nonzero", int64(2), true, true},
|
||||
|
||||
@@ -152,7 +152,6 @@ func (s *Server) restartEngineForMDMLocked() error {
|
||||
s.config = config
|
||||
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
|
||||
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
|
||||
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
s.actCancel = cancel
|
||||
@@ -305,7 +304,6 @@ func setConfigRequestHasConfigOverrides(msg *proto.SetConfigRequest) bool {
|
||||
msg.DisableFirewall != nil ||
|
||||
msg.BlockLanAccess != nil ||
|
||||
msg.DisableNotifications != nil ||
|
||||
msg.LazyConnectionEnabled != nil ||
|
||||
msg.BlockInbound != nil ||
|
||||
msg.DisableIpv6 != nil ||
|
||||
msg.EnableSSHRoot != nil ||
|
||||
@@ -348,7 +346,6 @@ func loginRequestHasConfigOverrides(msg *proto.LoginRequest) bool {
|
||||
msg.BlockLanAccess != nil ||
|
||||
msg.DisableNotifications != nil ||
|
||||
len(msg.DnsLabels) > 0 || msg.CleanDNSLabels ||
|
||||
msg.LazyConnectionEnabled != nil ||
|
||||
msg.BlockInbound != nil
|
||||
}
|
||||
|
||||
|
||||
@@ -214,7 +214,6 @@ func (s *Server) Start() error {
|
||||
|
||||
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
|
||||
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
|
||||
|
||||
if s.sessionWatcher == nil {
|
||||
s.sessionWatcher = internal.NewSessionWatcher(s.rootCtx, s.statusRecorder)
|
||||
@@ -463,7 +462,6 @@ func (s *Server) setConfigInputFromRequest(msg *proto.SetConfigRequest) (profile
|
||||
config.DisableFirewall = msg.DisableFirewall
|
||||
config.BlockLANAccess = msg.BlockLanAccess
|
||||
config.DisableNotifications = msg.DisableNotifications
|
||||
config.LazyConnectionEnabled = msg.LazyConnectionEnabled
|
||||
config.BlockInbound = msg.BlockInbound
|
||||
config.DisableIPv6 = msg.DisableIpv6
|
||||
config.EnableSSHRoot = msg.EnableSSHRoot
|
||||
@@ -1647,7 +1645,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||
RosenpassPermissive: cfg.RosenpassPermissive,
|
||||
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
||||
BlockInbound: cfg.BlockInbound,
|
||||
DisableNotifications: disableNotifications,
|
||||
NetworkMonitor: networkMonitor,
|
||||
|
||||
@@ -69,43 +69,41 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
disableFirewall := true
|
||||
blockLANAccess := true
|
||||
disableNotifications := true
|
||||
lazyConnectionEnabled := true
|
||||
blockInbound := true
|
||||
disableIPv6 := true
|
||||
mtu := int64(1280)
|
||||
sshJWTCacheTTL := int32(300)
|
||||
|
||||
req := &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: currUser.Username,
|
||||
ManagementUrl: "https://new-api.netbird.io:443",
|
||||
AdminURL: "https://new-admin.netbird.io",
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
RosenpassPermissive: &rosenpassPermissive,
|
||||
ServerSSHAllowed: &serverSSHAllowed,
|
||||
InterfaceName: &interfaceName,
|
||||
WireguardPort: &wireguardPort,
|
||||
OptionalPreSharedKey: &preSharedKey,
|
||||
DisableAutoConnect: &disableAutoConnect,
|
||||
NetworkMonitor: &networkMonitor,
|
||||
DisableClientRoutes: &disableClientRoutes,
|
||||
DisableServerRoutes: &disableServerRoutes,
|
||||
DisableDns: &disableDNS,
|
||||
DisableFirewall: &disableFirewall,
|
||||
BlockLanAccess: &blockLANAccess,
|
||||
DisableNotifications: &disableNotifications,
|
||||
LazyConnectionEnabled: &lazyConnectionEnabled,
|
||||
BlockInbound: &blockInbound,
|
||||
DisableIpv6: &disableIPv6,
|
||||
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
|
||||
CleanNATExternalIPs: false,
|
||||
CustomDNSAddress: []byte("1.1.1.1:53"),
|
||||
ExtraIFaceBlacklist: []string{"eth1", "eth2"},
|
||||
DnsLabels: []string{"label1", "label2"},
|
||||
CleanDNSLabels: false,
|
||||
DnsRouteInterval: durationpb.New(2 * time.Minute),
|
||||
Mtu: &mtu,
|
||||
SshJWTCacheTTL: &sshJWTCacheTTL,
|
||||
ProfileName: profName,
|
||||
Username: currUser.Username,
|
||||
ManagementUrl: "https://new-api.netbird.io:443",
|
||||
AdminURL: "https://new-admin.netbird.io",
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
RosenpassPermissive: &rosenpassPermissive,
|
||||
ServerSSHAllowed: &serverSSHAllowed,
|
||||
InterfaceName: &interfaceName,
|
||||
WireguardPort: &wireguardPort,
|
||||
OptionalPreSharedKey: &preSharedKey,
|
||||
DisableAutoConnect: &disableAutoConnect,
|
||||
NetworkMonitor: &networkMonitor,
|
||||
DisableClientRoutes: &disableClientRoutes,
|
||||
DisableServerRoutes: &disableServerRoutes,
|
||||
DisableDns: &disableDNS,
|
||||
DisableFirewall: &disableFirewall,
|
||||
BlockLanAccess: &blockLANAccess,
|
||||
DisableNotifications: &disableNotifications,
|
||||
BlockInbound: &blockInbound,
|
||||
DisableIpv6: &disableIPv6,
|
||||
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
|
||||
CleanNATExternalIPs: false,
|
||||
CustomDNSAddress: []byte("1.1.1.1:53"),
|
||||
ExtraIFaceBlacklist: []string{"eth1", "eth2"},
|
||||
DnsLabels: []string{"label1", "label2"},
|
||||
CleanDNSLabels: false,
|
||||
DnsRouteInterval: durationpb.New(2 * time.Minute),
|
||||
Mtu: &mtu,
|
||||
SshJWTCacheTTL: &sshJWTCacheTTL,
|
||||
}
|
||||
|
||||
_, err = s.SetConfig(ctx, req)
|
||||
@@ -140,7 +138,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
require.Equal(t, blockLANAccess, cfg.BlockLANAccess)
|
||||
require.NotNil(t, cfg.DisableNotifications)
|
||||
require.Equal(t, disableNotifications, *cfg.DisableNotifications)
|
||||
require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled)
|
||||
require.Equal(t, blockInbound, cfg.BlockInbound)
|
||||
require.Equal(t, disableIPv6, cfg.DisableIPv6)
|
||||
require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs)
|
||||
@@ -164,13 +161,14 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
|
||||
t.Helper()
|
||||
|
||||
metadataFields := map[string]bool{
|
||||
"state": true, // protobuf internal
|
||||
"sizeCache": true, // protobuf internal
|
||||
"unknownFields": true, // protobuf internal
|
||||
"Username": true, // metadata
|
||||
"ProfileName": true, // metadata
|
||||
"CleanNATExternalIPs": true, // control flag for clearing
|
||||
"CleanDNSLabels": true, // control flag for clearing
|
||||
"state": true, // protobuf internal
|
||||
"sizeCache": true, // protobuf internal
|
||||
"unknownFields": true, // protobuf internal
|
||||
"Username": true, // metadata
|
||||
"ProfileName": true, // metadata
|
||||
"CleanNATExternalIPs": true, // control flag for clearing
|
||||
"CleanDNSLabels": true, // control flag for clearing
|
||||
"LazyConnectionEnabled": true, // deprecated: proto field retained for compat, no longer applied
|
||||
}
|
||||
|
||||
expectedFields := map[string]bool{
|
||||
@@ -190,7 +188,6 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
|
||||
"DisableFirewall": true,
|
||||
"BlockLanAccess": true,
|
||||
"DisableNotifications": true,
|
||||
"LazyConnectionEnabled": true,
|
||||
"BlockInbound": true,
|
||||
"DisableIpv6": true,
|
||||
"NatExternalIPs": true,
|
||||
@@ -252,7 +249,6 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
|
||||
"block-lan-access": "BlockLanAccess",
|
||||
"block-inbound": "BlockInbound",
|
||||
"disable-ipv6": "DisableIpv6",
|
||||
"enable-lazy-connection": "LazyConnectionEnabled",
|
||||
"external-ip-map": "NatExternalIPs",
|
||||
"dns-resolver-address": "CustomDNSAddress",
|
||||
"extra-iface-blacklist": "ExtraIFaceBlacklist",
|
||||
@@ -269,7 +265,8 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
|
||||
|
||||
// SetConfigRequest fields that don't have CLI flags (settable only via UI or other means).
|
||||
fieldsWithoutCLIFlags := map[string]bool{
|
||||
"DisableNotifications": true, // Only settable via UI
|
||||
"DisableNotifications": true, // Only settable via UI
|
||||
"LazyConnectionEnabled": true, // deprecated: no longer settable (managed by server + NB_LAZY_CONN)
|
||||
}
|
||||
|
||||
// Get all SetConfigRequest fields to verify our map is complete.
|
||||
|
||||
@@ -74,8 +74,6 @@ type Info struct {
|
||||
BlockInbound bool
|
||||
DisableIPv6 bool
|
||||
|
||||
LazyConnectionEnabled bool
|
||||
|
||||
EnableSSHRoot bool
|
||||
EnableSSHSFTP bool
|
||||
EnableSSHLocalPortForwarding bool
|
||||
@@ -87,7 +85,7 @@ func (i *Info) SetFlags(
|
||||
rosenpassEnabled, rosenpassPermissive bool,
|
||||
serverSSHAllowed *bool,
|
||||
disableClientRoutes, disableServerRoutes,
|
||||
disableDNS, disableFirewall, blockLANAccess, blockInbound, disableIPv6, lazyConnectionEnabled bool,
|
||||
disableDNS, disableFirewall, blockLANAccess, blockInbound, disableIPv6 bool,
|
||||
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
|
||||
disableSSHAuth *bool,
|
||||
) {
|
||||
@@ -105,8 +103,6 @@ func (i *Info) SetFlags(
|
||||
i.BlockInbound = blockInbound
|
||||
i.DisableIPv6 = disableIPv6
|
||||
|
||||
i.LazyConnectionEnabled = lazyConnectionEnabled
|
||||
|
||||
if enableSSHRoot != nil {
|
||||
i.EnableSSHRoot = *enableSSHRoot
|
||||
}
|
||||
|
||||
@@ -266,7 +266,6 @@ type serviceClient struct {
|
||||
mAllowSSH *systray.MenuItem
|
||||
mAutoConnect *systray.MenuItem
|
||||
mEnableRosenpass *systray.MenuItem
|
||||
mLazyConnEnabled *systray.MenuItem
|
||||
mBlockInbound *systray.MenuItem
|
||||
mNotifications *systray.MenuItem
|
||||
mAdvancedSettings *systray.MenuItem
|
||||
@@ -336,11 +335,11 @@ type serviceClient struct {
|
||||
// mNetworks + mExitNode submenu items. Combines features.DisableNetworks
|
||||
// AND s.connected — both must be true for the menus to be active.
|
||||
// Zero value (false) matches the Disable() call at AddMenuItem time.
|
||||
networksMenuEnabled bool
|
||||
showNetworks bool
|
||||
wNetworks fyne.Window
|
||||
wProfiles fyne.Window
|
||||
wQuickActions fyne.Window
|
||||
networksMenuEnabled bool
|
||||
showNetworks bool
|
||||
wNetworks fyne.Window
|
||||
wProfiles fyne.Window
|
||||
wQuickActions fyne.Window
|
||||
|
||||
eventManager *event.Manager
|
||||
|
||||
@@ -1094,7 +1093,6 @@ func (s *serviceClient) onTrayReady() {
|
||||
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false)
|
||||
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false)
|
||||
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false)
|
||||
s.mLazyConnEnabled = s.mSettings.AddSubMenuItemCheckbox("Enable Lazy Connections", lazyConnMenuDescr, false)
|
||||
s.mBlockInbound = s.mSettings.AddSubMenuItemCheckbox("Block Inbound Connections", blockInboundMenuDescr, false)
|
||||
s.mNotifications = s.mSettings.AddSubMenuItemCheckbox("Notifications", notificationsMenuDescr, false)
|
||||
s.mSettings.AddSeparator()
|
||||
@@ -1578,7 +1576,6 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
|
||||
config.RosenpassEnabled = cfg.RosenpassEnabled
|
||||
config.RosenpassPermissive = cfg.RosenpassPermissive
|
||||
config.DisableNotifications = &cfg.DisableNotifications
|
||||
config.LazyConnectionEnabled = cfg.LazyConnectionEnabled
|
||||
config.BlockInbound = cfg.BlockInbound
|
||||
config.NetworkMonitor = &cfg.NetworkMonitor
|
||||
config.DisableDNS = cfg.DisableDns
|
||||
@@ -1682,12 +1679,6 @@ func (s *serviceClient) loadSettings() {
|
||||
s.mEnableRosenpass.Uncheck()
|
||||
}
|
||||
|
||||
if cfg.LazyConnectionEnabled {
|
||||
s.mLazyConnEnabled.Check()
|
||||
} else {
|
||||
s.mLazyConnEnabled.Uncheck()
|
||||
}
|
||||
|
||||
if cfg.BlockInbound {
|
||||
s.mBlockInbound.Check()
|
||||
} else {
|
||||
@@ -1833,7 +1824,6 @@ func (s *serviceClient) updateConfig() error {
|
||||
disableAutoStart := !s.mAutoConnect.Checked()
|
||||
sshAllowed := s.mAllowSSH.Checked()
|
||||
rosenpassEnabled := s.mEnableRosenpass.Checked()
|
||||
lazyConnectionEnabled := s.mLazyConnEnabled.Checked()
|
||||
blockInbound := s.mBlockInbound.Checked()
|
||||
notificationsDisabled := !s.mNotifications.Checked()
|
||||
|
||||
@@ -1856,14 +1846,13 @@ func (s *serviceClient) updateConfig() error {
|
||||
}
|
||||
|
||||
req := proto.SetConfigRequest{
|
||||
ProfileName: activeProf.ID.String(),
|
||||
Username: currUser.Username,
|
||||
DisableAutoConnect: &disableAutoStart,
|
||||
ServerSSHAllowed: &sshAllowed,
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
LazyConnectionEnabled: &lazyConnectionEnabled,
|
||||
BlockInbound: &blockInbound,
|
||||
DisableNotifications: ¬ificationsDisabled,
|
||||
ProfileName: activeProf.ID.String(),
|
||||
Username: currUser.Username,
|
||||
DisableAutoConnect: &disableAutoStart,
|
||||
ServerSSHAllowed: &sshAllowed,
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
BlockInbound: &blockInbound,
|
||||
DisableNotifications: ¬ificationsDisabled,
|
||||
}
|
||||
|
||||
if _, err := conn.SetConfig(s.ctx, &req); err != nil {
|
||||
|
||||
@@ -4,7 +4,6 @@ const (
|
||||
allowSSHMenuDescr = "Allow SSH connections"
|
||||
autoConnectMenuDescr = "Connect automatically when the service starts"
|
||||
quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass"
|
||||
lazyConnMenuDescr = "[Experimental] Enable lazy connections"
|
||||
blockInboundMenuDescr = "Block inbound connections to the local machine and routed networks"
|
||||
notificationsMenuDescr = "Enable notifications"
|
||||
advancedSettingsMenuDescr = "Advanced settings of the application"
|
||||
|
||||
@@ -43,8 +43,6 @@ func (h *eventHandler) listen(ctx context.Context) {
|
||||
h.handleAutoConnectClick()
|
||||
case <-h.client.mEnableRosenpass.ClickedCh:
|
||||
h.handleRosenpassClick()
|
||||
case <-h.client.mLazyConnEnabled.ClickedCh:
|
||||
h.handleLazyConnectionClick()
|
||||
case <-h.client.mBlockInbound.ClickedCh:
|
||||
h.handleBlockInboundClick()
|
||||
case <-h.client.mAdvancedSettings.ClickedCh:
|
||||
@@ -152,15 +150,6 @@ func (h *eventHandler) handleRosenpassClick() {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *eventHandler) handleLazyConnectionClick() {
|
||||
h.toggleCheckbox(h.client.mLazyConnEnabled)
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
h.toggleCheckbox(h.client.mLazyConnEnabled) // revert checkbox state on error
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
h.client.notifier.Send("Error", "Failed to update lazy connection settings")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *eventHandler) handleBlockInboundClick() {
|
||||
h.toggleCheckbox(h.client.mBlockInbound)
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
|
||||
91
combined/cmd/admin.go
Normal file
91
combined/cmd/admin.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
|
||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// newAdminCommands creates the admin command tree with combined-specific resource openers.
|
||||
func newAdminCommands() *cobra.Command {
|
||||
cmd := admincmd.NewCommands(withAdminResources)
|
||||
cmd.AddCommand(tokencmd.NewCommands(withAdminTokenStore))
|
||||
return cmd
|
||||
}
|
||||
|
||||
// withAdminResources loads the combined YAML config, initializes stores, and calls fn.
|
||||
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
|
||||
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, cfg *CombinedConfig) error {
|
||||
mgmtConfig, err := cfg.ToManagementConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create management config: %w", err)
|
||||
}
|
||||
|
||||
idpStorage, err := admincmd.OpenEmbeddedIDPStorage(mgmtConfig.EmbeddedIdP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := idpStorage.Close(); err != nil {
|
||||
log.Debugf("close embedded IdP storage: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage})
|
||||
})
|
||||
}
|
||||
|
||||
// withAdminTokenStore opens only the management store for admin token commands.
|
||||
func withAdminTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, _ *CombinedConfig) error {
|
||||
return fn(ctx, managementStore)
|
||||
})
|
||||
}
|
||||
|
||||
func withAdminStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store, cfg *CombinedConfig) error) error {
|
||||
if err := util.InitLog("error", "console"); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
if dsn := cfg.Server.Store.DSN; dsn != "" {
|
||||
switch strings.ToLower(cfg.Server.Store.Engine) {
|
||||
case "postgres":
|
||||
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
|
||||
case "mysql":
|
||||
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||
}
|
||||
}
|
||||
if file := cfg.Server.Store.File; file != "" {
|
||||
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
|
||||
}
|
||||
|
||||
managementStore, err := store.NewStore(ctx, types.Engine(cfg.Management.Store.Engine), cfg.Management.DataDir, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create store: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := managementStore.Close(ctx); err != nil {
|
||||
log.Debugf("close store: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, managementStore, cfg)
|
||||
}
|
||||
@@ -64,7 +64,7 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
|
||||
_ = rootCmd.MarkPersistentFlagRequired("config")
|
||||
|
||||
rootCmd.AddCommand(newTokenCommands())
|
||||
rootCmd.AddCommand(newAdminCommands())
|
||||
}
|
||||
|
||||
func RootCmd() *cobra.Command {
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// newTokenCommands creates the token command tree with combined-specific store opener.
|
||||
func newTokenCommands() *cobra.Command {
|
||||
return tokencmd.NewCommands(withTokenStore)
|
||||
}
|
||||
|
||||
// withTokenStore loads the combined YAML config, initializes the store, and calls fn.
|
||||
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||
if err := util.InitLog("error", "console"); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
if dsn := cfg.Server.Store.DSN; dsn != "" {
|
||||
switch strings.ToLower(cfg.Server.Store.Engine) {
|
||||
case "postgres":
|
||||
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
|
||||
case "mysql":
|
||||
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||
}
|
||||
}
|
||||
if file := cfg.Server.Store.File; file != "" {
|
||||
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
|
||||
}
|
||||
|
||||
datadir := cfg.Management.DataDir
|
||||
engine := types.Engine(cfg.Management.Store.Engine)
|
||||
|
||||
s, err := store.NewStore(ctx, engine, datadir, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create store: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := s.Close(ctx); err != nil {
|
||||
log.Debugf("close store: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, s)
|
||||
}
|
||||
171
e2e/agentnetwork/vllm_test.go
Normal file
171
e2e/agentnetwork/vllm_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
//go:build e2e
|
||||
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/e2e/harness"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// TestVLLMProvider proves the proxy supports a self-hosted vLLM backend. vLLM is
|
||||
// OpenAI-compatible, so it uses the "vllm" catalog entry (KindCustom) and is
|
||||
// reached over plain HTTP — no TLS anywhere on the path:
|
||||
//
|
||||
// client --tunnel--> netbird proxy --http--> vllm (:8000, OpenAI-compatible)
|
||||
//
|
||||
// The mock vLLM server answers /v1/chat/completions with an OpenAI-shaped
|
||||
// completion carrying a non-zero usage block. The test asserts the chat returns
|
||||
// 200 with the completion, that the request is recorded in the access log by its
|
||||
// session id, and that vLLM's usage block is metered into a consumption row —
|
||||
// which together prove request routing, response parsing, and token accounting
|
||||
// all work for a self-hosted OpenAI-compatible provider.
|
||||
//
|
||||
// It needs no external credentials (the mock ignores auth), so it always runs.
|
||||
func TestVLLMProvider(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
vllm, err := harness.StartVLLM(ctx, srv)
|
||||
require.NoError(t, err, "start mock vLLM server")
|
||||
t.Cleanup(func() { _ = vllm.Terminate(context.Background()) })
|
||||
|
||||
grp, err := srv.API().Groups.Create(ctx, api.PostApiGroupsJSONRequestBody{Name: "e2e-vllm"})
|
||||
require.NoError(t, err, "create group")
|
||||
t.Cleanup(func() { _ = srv.API().Groups.Delete(context.Background(), grp.Id) })
|
||||
|
||||
ephemeral := false
|
||||
sk, err := srv.API().SetupKeys.Create(ctx, api.PostApiSetupKeysJSONRequestBody{
|
||||
Name: "e2e-vllm-client",
|
||||
Type: "reusable",
|
||||
ExpiresIn: 86400,
|
||||
UsageLimit: 0,
|
||||
AutoGroups: []string{grp.Id},
|
||||
Ephemeral: &ephemeral,
|
||||
})
|
||||
require.NoError(t, err, "mint setup key")
|
||||
require.NotEmpty(t, sk.Key, "setup key plaintext")
|
||||
|
||||
// vLLM provider pointed at the mock over plain HTTP. The mock ignores auth,
|
||||
// so a dummy key satisfies the "Bearer ${API_KEY}" template. The served model
|
||||
// is enumerated so the router dispatches this model string to this provider.
|
||||
dummyKey := "sk-vllm-e2e"
|
||||
prov, err := srv.CreateProvider(ctx, api.AgentNetworkProviderRequest{
|
||||
Name: "vllm",
|
||||
ProviderId: "vllm",
|
||||
UpstreamUrl: vllm.URL,
|
||||
ApiKey: &dummyKey,
|
||||
Enabled: ptr(true),
|
||||
BootstrapCluster: ptr(harness.AgentNetworkCluster),
|
||||
Models: &[]api.AgentNetworkProviderModel{
|
||||
{Id: harness.VLLMModel, InputPer1k: 0.001, OutputPer1k: 0.002},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err, "create vllm provider")
|
||||
t.Cleanup(func() { _ = srv.DeleteProvider(context.Background(), prov.Id) })
|
||||
|
||||
// Token limit far above the handful of tokens this test drives, so it never
|
||||
// blocks but switches on usage metering — the switch that makes consumption
|
||||
// rows get recorded.
|
||||
enabled := true
|
||||
pol, err := srv.CreatePolicy(ctx, api.AgentNetworkPolicyRequest{
|
||||
Name: "e2e-vllm-allow",
|
||||
Enabled: &enabled,
|
||||
SourceGroups: []string{grp.Id},
|
||||
DestinationProviderIds: []string{prov.Id},
|
||||
Limits: &api.AgentNetworkPolicyLimits{
|
||||
TokenLimit: api.AgentNetworkPolicyTokenLimit{
|
||||
Enabled: true,
|
||||
GroupCap: 10_000_000,
|
||||
UserCap: 10_000_000,
|
||||
WindowSeconds: 60,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err, "create policy")
|
||||
t.Cleanup(func() { _ = srv.DeletePolicy(context.Background(), pol.Id) })
|
||||
|
||||
settings, err := srv.GetSettings(ctx)
|
||||
require.NoError(t, err, "read settings")
|
||||
require.NotEmpty(t, settings.Endpoint, "endpoint must be assigned")
|
||||
|
||||
proxyToken, err := srv.CreateProxyTokenCLI(ctx, "e2e-vllm-proxy")
|
||||
require.NoError(t, err, "mint proxy token")
|
||||
px, err := harness.StartProxy(ctx, srv, proxyToken)
|
||||
require.NoError(t, err, "start proxy")
|
||||
t.Cleanup(func() { _ = px.Terminate(context.Background()) })
|
||||
|
||||
cl, err := harness.StartClient(ctx, srv, sk.Key)
|
||||
require.NoError(t, err, "start client")
|
||||
t.Cleanup(func() { _ = cl.Terminate(context.Background()) })
|
||||
|
||||
require.NoError(t, cl.WaitConnected(ctx, 90*time.Second), "client must connect to management")
|
||||
if err := cl.WaitProxyPeer(ctx, 180*time.Second); err != nil {
|
||||
t.Fatalf("client did not see the proxy peer: %v\n=== proxy logs ===\n%s", err, px.Logs(context.Background()))
|
||||
}
|
||||
proxyIP, err := cl.ResolveProxyIP(ctx, settings.Endpoint)
|
||||
require.NoError(t, err, "resolve endpoint to proxy IP")
|
||||
|
||||
before, _ := srv.ListAccessLogs(ctx)
|
||||
sessionID := "e2e-session-vllm"
|
||||
|
||||
// Retry to absorb tunnel/DNS jitter on the first call.
|
||||
var code int
|
||||
var body string
|
||||
deadline := time.Now().Add(90 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
c, b, cerr := cl.Chat(ctx, settings.Endpoint, proxyIP, harness.WireChat, harness.VLLMModel, "Reply with exactly: pong", sessionID)
|
||||
if cerr == nil {
|
||||
code, body = c, b
|
||||
if code == 200 {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
require.Equal(t, 200, code,
|
||||
"chat through the vLLM provider must return 200; body: %s\n=== vllm logs ===\n%s\n=== proxy logs ===\n%s",
|
||||
body, vllm.Logs(context.Background()), px.Logs(context.Background()))
|
||||
require.True(t, strings.Contains(body, "chat.completion"),
|
||||
"body should be an OpenAI-compatible chat completion; got: %s", body)
|
||||
|
||||
// The request must surface as an access-log row carrying our session id.
|
||||
require.Eventually(t, func() bool {
|
||||
logs, lerr := srv.ListAccessLogs(ctx)
|
||||
return lerr == nil && logs.TotalRecords > before.TotalRecords
|
||||
}, 30*time.Second, 2*time.Second, "an access-log row should be ingested for the vLLM provider")
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
logs, lerr := srv.ListAccessLogs(ctx)
|
||||
if lerr != nil {
|
||||
return false
|
||||
}
|
||||
for _, r := range logs.Data {
|
||||
if r.SessionId != nil && *r.SessionId == sessionID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, 30*time.Second, 2*time.Second, "session id %q must be recorded in an access-log row", sessionID)
|
||||
|
||||
// vLLM's usage block (prompt_tokens=11, completion_tokens=2) must be parsed
|
||||
// and metered into a consumption row with positive token counts.
|
||||
require.Eventually(t, func() bool {
|
||||
rows, lerr := srv.ListConsumption(ctx)
|
||||
if lerr != nil {
|
||||
return false
|
||||
}
|
||||
for _, r := range rows {
|
||||
if r.TokensInput > 0 && r.TokensOutput > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, 60*time.Second, 3*time.Second, "vLLM usage must be metered into a consumption row")
|
||||
}
|
||||
113
e2e/harness/vllm.go
Normal file
113
e2e/harness/vllm.go
Normal file
@@ -0,0 +1,113 @@
|
||||
//go:build e2e
|
||||
|
||||
package harness
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
)
|
||||
|
||||
const (
|
||||
vllmImage = "nginx:alpine"
|
||||
vllmAlias = "vllm"
|
||||
vllmPort = "8000/tcp"
|
||||
// VLLMModel is the served model id the mock advertises and echoes back. It
|
||||
// matches a real small model commonly served by vLLM so the provider's
|
||||
// enumerated model and the client's request line up.
|
||||
VLLMModel = "Qwen/Qwen2.5-0.5B-Instruct"
|
||||
)
|
||||
|
||||
// vllmNginxConf emulates a vLLM OpenAI-compatible server over plain HTTP (vLLM's
|
||||
// default: no TLS, port 8000). It answers /v1/models with a one-model list and
|
||||
// any chat/completions path with a canned OpenAI-shaped chat completion carrying
|
||||
// a non-zero usage block, so the proxy's OpenAI parser records real token
|
||||
// consumption. Running actual vLLM in CI is infeasible (GPU + multi-GB model
|
||||
// download), so this stands in for the wire contract the proxy depends on.
|
||||
const vllmNginxConf = `pid /tmp/nginx.pid;
|
||||
events {}
|
||||
http {
|
||||
server {
|
||||
listen 8000;
|
||||
location = /v1/models {
|
||||
default_type application/json;
|
||||
return 200 '{"object":"list","data":[{"id":"Qwen/Qwen2.5-0.5B-Instruct","object":"model","owned_by":"vllm"}]}';
|
||||
}
|
||||
location / {
|
||||
default_type application/json;
|
||||
return 200 '{"id":"chatcmpl-e2e-vllm","object":"chat.completion","created":1700000000,"model":"Qwen/Qwen2.5-0.5B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"pong"},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"completion_tokens":2,"total_tokens":13}}';
|
||||
}
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
// VLLM is a mock vLLM OpenAI-compatible server on the combined server's network,
|
||||
// reachable at http://vllm:8000. A "vllm" provider points at it to exercise the
|
||||
// proxy's support for self-hosted OpenAI-compatible backends.
|
||||
type VLLM struct {
|
||||
container testcontainers.Container
|
||||
workDir string
|
||||
// URL is the upstream URL the vllm provider points at (http://<alias>:8000).
|
||||
URL string
|
||||
}
|
||||
|
||||
// StartVLLM runs the mock vLLM server on the shared network over plain HTTP.
|
||||
func StartVLLM(ctx context.Context, c *Combined) (*VLLM, error) {
|
||||
workDir, err := os.MkdirTemp("/tmp", "nb-e2e-vllm-*")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create vllm work dir: %w", err)
|
||||
}
|
||||
// Widen so the (non-root worker) nginx container can traverse the bind mount.
|
||||
if err := os.Chmod(workDir, 0o755); err != nil { //nolint:gosec // throwaway e2e config dir
|
||||
return nil, fmt.Errorf("chmod vllm dir: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(workDir, "nginx.conf"), []byte(vllmNginxConf), 0o644); err != nil { //nolint:gosec // non-secret e2e config
|
||||
return nil, fmt.Errorf("write nginx conf: %w", err)
|
||||
}
|
||||
|
||||
req := testcontainers.ContainerRequest{
|
||||
Image: vllmImage,
|
||||
ExposedPorts: []string{vllmPort},
|
||||
Networks: []string{c.network.Name},
|
||||
NetworkAliases: map[string][]string{c.network.Name: {vllmAlias}},
|
||||
Cmd: []string{"nginx", "-c", "/conf/nginx.conf", "-g", "daemon off;"},
|
||||
HostConfigModifier: func(hc *container.HostConfig) {
|
||||
hc.Binds = append(hc.Binds, workDir+":/conf:ro")
|
||||
},
|
||||
WaitingFor: wait.ForListeningPort(vllmPort).WithStartupTimeout(60 * time.Second),
|
||||
}
|
||||
|
||||
ctr, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
|
||||
ContainerRequest: req,
|
||||
Started: true,
|
||||
})
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(workDir)
|
||||
return nil, fmt.Errorf("start vllm container: %w", err)
|
||||
}
|
||||
|
||||
return &VLLM{container: ctr, workDir: workDir, URL: "http://" + vllmAlias + ":8000"}, nil
|
||||
}
|
||||
|
||||
// Logs returns the vLLM container logs, for diagnostics on failure.
|
||||
func (v *VLLM) Logs(ctx context.Context) string {
|
||||
return containerLogs(ctx, v.container)
|
||||
}
|
||||
|
||||
// Terminate stops the vLLM container and cleans its work dir.
|
||||
func (v *VLLM) Terminate(ctx context.Context) error {
|
||||
var err error
|
||||
if v.container != nil {
|
||||
err = v.container.Terminate(ctx)
|
||||
}
|
||||
if v.workDir != "" {
|
||||
_ = os.RemoveAll(v.workDir)
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -351,11 +351,6 @@ initialize_default_values() {
|
||||
NETBIRD_STUN_PORT=3478
|
||||
|
||||
# Docker images
|
||||
# Record whether the operator explicitly pinned the server/proxy images via
|
||||
# env vars, so the agent-network preset can pick its own defaults without
|
||||
# clobbering an explicit override.
|
||||
NETBIRD_SERVER_IMAGE_EXPLICIT=${NETBIRD_SERVER_IMAGE:+true}
|
||||
NETBIRD_PROXY_IMAGE_EXPLICIT=${NETBIRD_PROXY_IMAGE:+true}
|
||||
DASHBOARD_IMAGE=${DASHBOARD_IMAGE:-"netbirdio/dashboard:latest"}
|
||||
# Combined server replaces separate signal, relay, and management containers
|
||||
NETBIRD_SERVER_IMAGE=${NETBIRD_SERVER_IMAGE:-"netbirdio/netbird-server:latest"}
|
||||
@@ -415,15 +410,6 @@ apply_agent_network_preset() {
|
||||
ENABLE_PROXY="true"
|
||||
ENABLE_CROWDSEC="false"
|
||||
|
||||
# Agent-network ships dedicated server/proxy images. Honor an explicit
|
||||
# env override; otherwise pin the agent-network builds.
|
||||
if [[ "${NETBIRD_SERVER_IMAGE_EXPLICIT}" != "true" ]]; then
|
||||
NETBIRD_SERVER_IMAGE="netbirdio/netbird-server:0.74.0-rc.2"
|
||||
fi
|
||||
if [[ "${NETBIRD_PROXY_IMAGE_EXPLICIT}" != "true" ]]; then
|
||||
NETBIRD_PROXY_IMAGE="netbirdio/reverse-proxy:0.74.0-rc.2"
|
||||
fi
|
||||
|
||||
if [[ -n "${NETBIRD_LETSENCRYPT_EMAIL}" ]]; then
|
||||
TRAEFIK_ACME_EMAIL="${NETBIRD_LETSENCRYPT_EMAIL}"
|
||||
else
|
||||
|
||||
89
management/cmd/admin.go
Normal file
89
management/cmd/admin.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
|
||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var adminDatadir string
|
||||
|
||||
// newAdminCommands creates the admin command tree with management-specific resource openers.
|
||||
func newAdminCommands() *cobra.Command {
|
||||
cmd := admincmd.NewCommands(withAdminResources)
|
||||
cmd.PersistentFlags().StringVar(&adminDatadir, "datadir", "", "Override the data directory from config (used for store.db and the default idp.db)")
|
||||
cmd.AddCommand(tokencmd.NewCommands(withAdminTokenStore))
|
||||
return cmd
|
||||
}
|
||||
|
||||
// withAdminResources initializes logging, loads config, opens the management store
|
||||
// and embedded IdP storage, and calls fn.
|
||||
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
|
||||
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, config *nbconfig.Config) error {
|
||||
idpStorage, err := admincmd.OpenEmbeddedIDPStorage(config.EmbeddedIdP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := idpStorage.Close(); err != nil {
|
||||
log.Debugf("close embedded IdP storage: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage})
|
||||
})
|
||||
}
|
||||
|
||||
// withAdminTokenStore opens only the management store for admin token commands.
|
||||
func withAdminTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, _ *nbconfig.Config) error {
|
||||
return fn(ctx, managementStore)
|
||||
})
|
||||
}
|
||||
|
||||
func withAdminStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store, config *nbconfig.Config) error) error {
|
||||
if err := util.InitLog("error", "console"); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||
|
||||
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
datadir := config.Datadir
|
||||
if adminDatadir != "" {
|
||||
oldDatadir := datadir
|
||||
datadir = adminDatadir
|
||||
if config.EmbeddedIdP != nil && config.EmbeddedIdP.Storage.Type == "sqlite3" {
|
||||
defaultIDPFile := filepath.Join(oldDatadir, "idp.db")
|
||||
if config.EmbeddedIdP.Storage.Config.File == "" || config.EmbeddedIdP.Storage.Config.File == defaultIDPFile {
|
||||
config.EmbeddedIdP.Storage.Config.File = filepath.Join(datadir, "idp.db")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
managementStore, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create store: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := managementStore.Close(ctx); err != nil {
|
||||
log.Debugf("close store: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, managementStore, config)
|
||||
}
|
||||
441
management/cmd/admin/admin.go
Normal file
441
management/cmd/admin/admin.go
Normal file
@@ -0,0 +1,441 @@
|
||||
// Package admincmd provides reusable cobra commands for self-hosted administrator helpers.
|
||||
// Both the management and combined binaries use these commands, each providing
|
||||
// their own opener to handle config loading and storage initialization.
|
||||
package admincmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
nbdex "github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
localConnectorID = "local"
|
||||
dashboardClientID = "netbird-dashboard"
|
||||
cliClientID = "netbird-cli"
|
||||
defaultTOTPAuthenticatorID = "default-totp"
|
||||
)
|
||||
|
||||
// Resources contains the storages required by the admin commands.
|
||||
type Resources struct {
|
||||
Store store.Store
|
||||
IDPStorage storage.Storage
|
||||
}
|
||||
|
||||
// Opener initializes command resources from the command context and calls fn.
|
||||
type Opener func(cmd *cobra.Command, fn func(ctx context.Context, resources Resources) error) error
|
||||
|
||||
type userSelector struct {
|
||||
email string
|
||||
userID string
|
||||
}
|
||||
|
||||
func (s userSelector) normalized() userSelector {
|
||||
return userSelector{
|
||||
email: strings.TrimSpace(s.email),
|
||||
userID: strings.TrimSpace(s.userID),
|
||||
}
|
||||
}
|
||||
|
||||
func (s userSelector) validate() error {
|
||||
s = s.normalized()
|
||||
if (s.email == "") == (s.userID == "") {
|
||||
return fmt.Errorf("provide exactly one of --email or --user-id")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewCommands creates the admin command tree with the given resource opener.
|
||||
func NewCommands(opener Opener) *cobra.Command {
|
||||
adminCmd := &cobra.Command{
|
||||
Use: "admin",
|
||||
Short: "Self-hosted administrator helpers",
|
||||
Long: "Administrative helpers for self-hosted deployments using the embedded identity provider.",
|
||||
}
|
||||
|
||||
userCmd := &cobra.Command{
|
||||
Use: "user",
|
||||
Short: "Manage local embedded IdP users",
|
||||
}
|
||||
|
||||
var passwordSelector userSelector
|
||||
var password string
|
||||
var passwordFile string
|
||||
passwordCmd := &cobra.Command{
|
||||
Use: "change-password (--email email | --user-id id) (--password password | --password-file path)",
|
||||
Aliases: []string{"set-password"},
|
||||
Short: "Change a local user's password",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
newPassword, err := resolvePasswordInput(cmd, password, passwordFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return opener(cmd, func(ctx context.Context, resources Resources) error {
|
||||
return runChangePassword(ctx, resources.IDPStorage, cmd.OutOrStdout(), passwordSelector, newPassword)
|
||||
})
|
||||
},
|
||||
}
|
||||
addUserSelectorFlags(passwordCmd, &passwordSelector)
|
||||
passwordCmd.Flags().StringVar(&password, "password", "", "New password for the user")
|
||||
passwordCmd.Flags().StringVar(&passwordFile, "password-file", "", "Read new password from file ('-' for stdin)")
|
||||
|
||||
var resetSelector userSelector
|
||||
resetMFACmd := &cobra.Command{
|
||||
Use: "reset-mfa (--email email | --user-id id)",
|
||||
Short: "Reset a local user's MFA enrollment",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return opener(cmd, func(ctx context.Context, resources Resources) error {
|
||||
return runResetMFA(ctx, resources.IDPStorage, cmd.OutOrStdout(), resetSelector)
|
||||
})
|
||||
},
|
||||
}
|
||||
addUserSelectorFlags(resetMFACmd, &resetSelector)
|
||||
|
||||
userCmd.AddCommand(passwordCmd, resetMFACmd)
|
||||
|
||||
mfaCmd := &cobra.Command{
|
||||
Use: "mfa",
|
||||
Short: "Manage local MFA for embedded IdP users",
|
||||
}
|
||||
|
||||
enableCmd := &cobra.Command{
|
||||
Use: "enable",
|
||||
Short: "Enable MFA for local embedded IdP users",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return opener(cmd, func(ctx context.Context, resources Resources) error {
|
||||
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), true)
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
disableCmd := &cobra.Command{
|
||||
Use: "disable",
|
||||
Short: "Disable MFA for local embedded IdP users",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return opener(cmd, func(ctx context.Context, resources Resources) error {
|
||||
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), false)
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
statusCmd := &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "Show local MFA status",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return opener(cmd, func(ctx context.Context, resources Resources) error {
|
||||
return runMFAStatus(ctx, resources, cmd.OutOrStdout())
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
mfaCmd.AddCommand(enableCmd, disableCmd, statusCmd)
|
||||
adminCmd.AddCommand(userCmd, mfaCmd)
|
||||
return adminCmd
|
||||
}
|
||||
|
||||
// OpenEmbeddedIDPStorage opens the Dex storage configured for the embedded IdP.
|
||||
func OpenEmbeddedIDPStorage(cfg *idp.EmbeddedIdPConfig) (storage.Storage, error) {
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
return nil, fmt.Errorf("admin commands require the embedded IdP to be enabled")
|
||||
}
|
||||
|
||||
yamlConfig, err := cfg.ToYAMLConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build embedded IdP config: %w", err)
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
st, err := yamlConfig.Storage.OpenStorage(logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open embedded IdP storage: %w", err)
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
func addUserSelectorFlags(cmd *cobra.Command, selector *userSelector) {
|
||||
cmd.Flags().StringVar(&selector.email, "email", "", "User email")
|
||||
cmd.Flags().StringVar(&selector.userID, "user-id", "", "User ID")
|
||||
}
|
||||
|
||||
func resolvePasswordInput(cmd *cobra.Command, password, passwordFile string) (string, error) {
|
||||
if password != "" && passwordFile != "" {
|
||||
return "", fmt.Errorf("provide only one of --password or --password-file")
|
||||
}
|
||||
if passwordFile == "" {
|
||||
return password, nil
|
||||
}
|
||||
|
||||
var data []byte
|
||||
var err error
|
||||
if passwordFile == "-" {
|
||||
data, err = io.ReadAll(cmd.InOrStdin())
|
||||
} else {
|
||||
data, err = os.ReadFile(passwordFile)
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read password: %w", err)
|
||||
}
|
||||
return strings.TrimRight(string(data), "\r\n"), nil
|
||||
}
|
||||
|
||||
func runChangePassword(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector, password string) error {
|
||||
if idpStorage == nil {
|
||||
return fmt.Errorf("embedded IdP storage is required")
|
||||
}
|
||||
selector = selector.normalized()
|
||||
if err := selector.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
if password == "" {
|
||||
return fmt.Errorf("password is required")
|
||||
}
|
||||
if err := server.ValidatePassword(password); err != nil {
|
||||
return fmt.Errorf("invalid password: %w", err)
|
||||
}
|
||||
|
||||
user, err := findLocalUser(ctx, idpStorage, selector)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
if err := idpStorage.UpdatePassword(ctx, user.Email, func(old storage.Password) (storage.Password, error) {
|
||||
old.Hash = hash
|
||||
return old, nil
|
||||
}); err != nil {
|
||||
return fmt.Errorf("update password for %s: %w", user.Email, err)
|
||||
}
|
||||
|
||||
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(w, "Password updated for %s.\n", user.Email)
|
||||
return nil
|
||||
}
|
||||
|
||||
func runResetMFA(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector) error {
|
||||
if idpStorage == nil {
|
||||
return fmt.Errorf("embedded IdP storage is required")
|
||||
}
|
||||
selector = selector.normalized()
|
||||
if err := selector.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := findLocalUser(ctx, idpStorage, selector)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reset := false
|
||||
err = idpStorage.UpdateUserIdentity(ctx, user.UserID, localConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
|
||||
reset = reset || len(old.MFASecrets) > 0 || len(old.WebAuthnCredentials) > 0
|
||||
old.MFASecrets = map[string]*storage.MFASecret{}
|
||||
old.WebAuthnCredentials = map[string][]storage.WebAuthnCredential{}
|
||||
return old, nil
|
||||
})
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
|
||||
return err
|
||||
}
|
||||
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("reset MFA for %s: %w", user.Email, err)
|
||||
}
|
||||
|
||||
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reset {
|
||||
_, _ = fmt.Fprintf(w, "MFA reset for %s. The user will re-enroll at next login.\n", user.Email)
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSetMFAEnabled(ctx context.Context, resources Resources, w io.Writer, enabled bool) error {
|
||||
if resources.Store == nil {
|
||||
return fmt.Errorf("management store is required")
|
||||
}
|
||||
if resources.IDPStorage == nil {
|
||||
return fmt.Errorf("embedded IdP storage is required")
|
||||
}
|
||||
|
||||
accounts := resources.Store.GetAllAccounts(ctx)
|
||||
if len(accounts) != 1 {
|
||||
return fmt.Errorf("expected exactly one account, got %d; local MFA is supported only in single-account embedded IdP deployments", len(accounts))
|
||||
}
|
||||
|
||||
settings := &types.Settings{}
|
||||
if accounts[0].Settings != nil {
|
||||
settings = accounts[0].Settings.Copy()
|
||||
}
|
||||
settings.LocalMfaEnabled = enabled
|
||||
if err := resources.Store.SaveAccountSettings(ctx, accounts[0].Id, settings); err != nil {
|
||||
return fmt.Errorf("save local MFA account setting: %w", err)
|
||||
}
|
||||
|
||||
if err := setIDPClientsMFA(ctx, resources.IDPStorage, enabled); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
state := "disabled"
|
||||
if enabled {
|
||||
state = "enabled"
|
||||
}
|
||||
_, _ = fmt.Fprintf(w, "Local MFA %s.\n", state)
|
||||
return nil
|
||||
}
|
||||
|
||||
func runMFAStatus(ctx context.Context, resources Resources, w io.Writer) error {
|
||||
if resources.Store == nil {
|
||||
return fmt.Errorf("management store is required")
|
||||
}
|
||||
if resources.IDPStorage == nil {
|
||||
return fmt.Errorf("embedded IdP storage is required")
|
||||
}
|
||||
|
||||
accounts := resources.Store.GetAllAccounts(ctx)
|
||||
accountStatus := "unknown"
|
||||
if len(accounts) == 1 && accounts[0].Settings != nil {
|
||||
accountStatus = "disabled"
|
||||
if accounts[0].Settings.LocalMfaEnabled {
|
||||
accountStatus = "enabled"
|
||||
}
|
||||
}
|
||||
|
||||
clientStatus, err := idpClientsMFAStatus(ctx, resources.IDPStorage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(w, "Account setting: %s\n", accountStatus)
|
||||
_, _ = fmt.Fprintf(w, "Embedded IdP clients: %s\n", clientStatus)
|
||||
return nil
|
||||
}
|
||||
|
||||
func findLocalUser(ctx context.Context, idpStorage storage.Storage, selector userSelector) (storage.Password, error) {
|
||||
selector = selector.normalized()
|
||||
if err := selector.validate(); err != nil {
|
||||
return storage.Password{}, err
|
||||
}
|
||||
|
||||
if selector.email != "" {
|
||||
user, err := idpStorage.GetPassword(ctx, selector.email)
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return storage.Password{}, fmt.Errorf("local user with email %q not found", selector.email)
|
||||
}
|
||||
if err != nil {
|
||||
return storage.Password{}, fmt.Errorf("get local user by email %q: %w", selector.email, err)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
rawUserID := selector.userID
|
||||
if decodedUserID, _, err := nbdex.DecodeDexUserID(selector.userID); err == nil && decodedUserID != "" {
|
||||
rawUserID = decodedUserID
|
||||
}
|
||||
|
||||
users, err := idpStorage.ListPasswords(ctx)
|
||||
if err != nil {
|
||||
return storage.Password{}, fmt.Errorf("list local users: %w", err)
|
||||
}
|
||||
for _, user := range users {
|
||||
if user.UserID == rawUserID || user.UserID == selector.userID {
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
|
||||
return storage.Password{}, fmt.Errorf("local user with ID %q not found", selector.userID)
|
||||
}
|
||||
|
||||
func deleteLocalAuthSession(ctx context.Context, idpStorage storage.Storage, userID string) error {
|
||||
err := idpStorage.DeleteAuthSession(ctx, userID, localConnectorID)
|
||||
if err == nil || errors.Is(err, storage.ErrNotFound) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("delete local auth session for user %s: %w", userID, err)
|
||||
}
|
||||
|
||||
func setIDPClientsMFA(ctx context.Context, idpStorage storage.Storage, enabled bool) error {
|
||||
var mfaChain []string
|
||||
if enabled {
|
||||
mfaChain = []string{defaultTOTPAuthenticatorID}
|
||||
}
|
||||
|
||||
for _, clientID := range []string{cliClientID, dashboardClientID} {
|
||||
if err := idpStorage.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
|
||||
old.MFAChain = mfaChain
|
||||
return old, nil
|
||||
}); err != nil {
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return fmt.Errorf("embedded IdP client %q not found; start the management server once before toggling MFA", clientID)
|
||||
}
|
||||
return fmt.Errorf("update MFA chain on embedded IdP client %q: %w", clientID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func idpClientsMFAStatus(ctx context.Context, idpStorage storage.Storage) (string, error) {
|
||||
clientIDs := []string{cliClientID, dashboardClientID}
|
||||
enabledCount := 0
|
||||
for _, clientID := range clientIDs {
|
||||
client, err := idpStorage.GetClient(ctx, clientID)
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return "unknown", fmt.Errorf("embedded IdP client %q not found", clientID)
|
||||
}
|
||||
if err != nil {
|
||||
return "unknown", fmt.Errorf("get embedded IdP client %q: %w", clientID, err)
|
||||
}
|
||||
if hasAuthenticator(client.MFAChain, defaultTOTPAuthenticatorID) {
|
||||
enabledCount++
|
||||
}
|
||||
}
|
||||
|
||||
switch enabledCount {
|
||||
case 0:
|
||||
return "disabled", nil
|
||||
case len(clientIDs):
|
||||
return "enabled", nil
|
||||
default:
|
||||
return "partially enabled", nil
|
||||
}
|
||||
}
|
||||
|
||||
func hasAuthenticator(chain []string, authenticatorID string) bool {
|
||||
for _, id := range chain {
|
||||
if id == authenticatorID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
160
management/cmd/admin/admin_test.go
Normal file
160
management/cmd/admin/admin_test.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package admincmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/memory"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
nbdex "github.com/netbirdio/netbird/idp/dex"
|
||||
)
|
||||
|
||||
func newTestIDPStorage(t *testing.T) storage.Storage {
|
||||
t.Helper()
|
||||
|
||||
st := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte("OldPass1!"), bcrypt.DefaultCost)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, st.CreatePassword(context.Background(), storage.Password{
|
||||
Email: "user@example.com",
|
||||
Username: "User",
|
||||
UserID: "user-1",
|
||||
Hash: hash,
|
||||
}))
|
||||
require.NoError(t, st.CreateUserIdentity(context.Background(), storage.UserIdentity{
|
||||
UserID: "user-1",
|
||||
ConnectorID: localConnectorID,
|
||||
MFASecrets: map[string]*storage.MFASecret{
|
||||
defaultTOTPAuthenticatorID: {
|
||||
AuthenticatorID: defaultTOTPAuthenticatorID,
|
||||
Type: "TOTP",
|
||||
Secret: "otpauth://totp/NetBird:user@example.com?secret=ABC",
|
||||
Confirmed: true,
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
},
|
||||
WebAuthnCredentials: map[string][]storage.WebAuthnCredential{
|
||||
"webauthn": {{CredentialID: []byte("credential")}},
|
||||
},
|
||||
}))
|
||||
require.NoError(t, st.CreateAuthSession(context.Background(), storage.AuthSession{
|
||||
UserID: "user-1",
|
||||
ConnectorID: localConnectorID,
|
||||
Nonce: "nonce",
|
||||
}))
|
||||
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: cliClientID, Name: "CLI"}))
|
||||
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: dashboardClientID, Name: "Dashboard"}))
|
||||
|
||||
return st
|
||||
}
|
||||
|
||||
func TestRunChangePassword(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
st := newTestIDPStorage(t)
|
||||
var out bytes.Buffer
|
||||
|
||||
err := runChangePassword(ctx, st, &out, userSelector{email: "user@example.com"}, "NewPass1!")
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, out.String(), "Password updated")
|
||||
|
||||
user, err := st.GetPassword(ctx, "user@example.com")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, bcrypt.CompareHashAndPassword(user.Hash, []byte("NewPass1!")))
|
||||
|
||||
_, err = st.GetAuthSession(ctx, "user-1", localConnectorID)
|
||||
require.ErrorIs(t, err, storage.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestRunChangePasswordValidatesPassword(t *testing.T) {
|
||||
st := newTestIDPStorage(t)
|
||||
err := runChangePassword(context.Background(), st, io.Discard, userSelector{email: "user@example.com"}, "short")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid password")
|
||||
}
|
||||
|
||||
func TestRunResetMFA(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
st := newTestIDPStorage(t)
|
||||
var out bytes.Buffer
|
||||
|
||||
encodedUserID := nbdex.EncodeDexUserID("user-1", localConnectorID)
|
||||
err := runResetMFA(ctx, st, &out, userSelector{userID: encodedUserID})
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, out.String(), "MFA reset")
|
||||
|
||||
identity, err := st.GetUserIdentity(ctx, "user-1", localConnectorID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, identity.MFASecrets)
|
||||
require.Empty(t, identity.WebAuthnCredentials)
|
||||
|
||||
_, err = st.GetAuthSession(ctx, "user-1", localConnectorID)
|
||||
require.ErrorIs(t, err, storage.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestRunResetMFAWithoutEnrollment(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
st := newTestIDPStorage(t)
|
||||
require.NoError(t, st.UpdateUserIdentity(ctx, "user-1", localConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
|
||||
old.MFASecrets = nil
|
||||
old.WebAuthnCredentials = nil
|
||||
return old, nil
|
||||
}))
|
||||
|
||||
var out bytes.Buffer
|
||||
err := runResetMFA(ctx, st, &out, userSelector{email: "user@example.com"})
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, out.String(), "No MFA enrollment found")
|
||||
}
|
||||
|
||||
func TestSetIDPClientsMFA(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
st := newTestIDPStorage(t)
|
||||
|
||||
require.NoError(t, setIDPClientsMFA(ctx, st, true))
|
||||
status, err := idpClientsMFAStatus(ctx, st)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "enabled", status)
|
||||
|
||||
require.NoError(t, setIDPClientsMFA(ctx, st, false))
|
||||
status, err = idpClientsMFAStatus(ctx, st)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "disabled", status)
|
||||
}
|
||||
|
||||
func TestUserSelectorValidate(t *testing.T) {
|
||||
require.NoError(t, userSelector{email: " user@example.com "}.validate())
|
||||
require.NoError(t, userSelector{userID: "user-1"}.validate())
|
||||
require.Error(t, userSelector{}.validate())
|
||||
require.Error(t, userSelector{email: "user@example.com", userID: "user-1"}.validate())
|
||||
}
|
||||
|
||||
func TestFindLocalUserNotFound(t *testing.T) {
|
||||
st := newTestIDPStorage(t)
|
||||
_, err := findLocalUser(context.Background(), st, userSelector{email: "missing@example.com"})
|
||||
require.Error(t, err)
|
||||
require.True(t, strings.Contains(err.Error(), "not found"))
|
||||
}
|
||||
|
||||
func TestResolvePasswordInputFromStdin(t *testing.T) {
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetIn(strings.NewReader("NewPass1!\n"))
|
||||
|
||||
password, err := resolvePasswordInput(cmd, "", "-")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "NewPass1!", password)
|
||||
}
|
||||
|
||||
func TestResolvePasswordInputRejectsMultipleSources(t *testing.T) {
|
||||
_, err := resolvePasswordInput(&cobra.Command{}, "NewPass1!", "-")
|
||||
require.Error(t, err)
|
||||
}
|
||||
@@ -83,7 +83,7 @@ func init() {
|
||||
|
||||
rootCmd.AddCommand(migrationCmd)
|
||||
|
||||
tc := newTokenCommands()
|
||||
tc.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
|
||||
rootCmd.AddCommand(tc)
|
||||
ac := newAdminCommands()
|
||||
ac.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
|
||||
rootCmd.AddCommand(ac)
|
||||
}
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var tokenDatadir string
|
||||
|
||||
// newTokenCommands creates the token command tree with management-specific store opener.
|
||||
func newTokenCommands() *cobra.Command {
|
||||
cmd := tokencmd.NewCommands(withTokenStore)
|
||||
cmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
|
||||
return cmd
|
||||
}
|
||||
|
||||
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
|
||||
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||
if err := util.InitLog("error", "console"); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||
|
||||
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
datadir := config.Datadir
|
||||
if tokenDatadir != "" {
|
||||
datadir = tokenDatadir
|
||||
}
|
||||
|
||||
s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create store: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := s.Close(ctx); err != nil {
|
||||
log.Debugf("close store: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, s)
|
||||
}
|
||||
@@ -627,6 +627,21 @@ var providers = []Provider{
|
||||
},
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
// vLLM is an OpenAI-compatible self-hosted server. It behaves like
|
||||
// the generic custom entry; it gets its own catalog id purely so it
|
||||
// surfaces as a named "vLLM" choice in the provider picker.
|
||||
ID: "vllm",
|
||||
Kind: KindCustom,
|
||||
Name: "vLLM",
|
||||
Description: "Self-hosted vLLM (OpenAI-compatible)",
|
||||
DefaultHost: "",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#30A2FF",
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
ID: "custom",
|
||||
Kind: KindCustom,
|
||||
|
||||
@@ -47,16 +47,13 @@ func init() {
|
||||
precomputedDeprecatedRemotePeersConstraint = constraint
|
||||
}
|
||||
|
||||
// toNetbirdConfig converts the server configuration to the wire representation. It returns
|
||||
// nil when no server config is set (the fan-out network-map path) because clients treat any
|
||||
// non-nil config as authoritative: a config without a relay section is interpreted as relay
|
||||
// disabled and wipes the clients' relay URLs.
|
||||
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings, settings *types.Settings) *proto.NetbirdConfig {
|
||||
if config == nil {
|
||||
if settings == nil {
|
||||
return nil
|
||||
}
|
||||
return &proto.NetbirdConfig{
|
||||
Metrics: &proto.MetricsConfig{
|
||||
Enabled: settings.MetricsPushEnabled,
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var stuns []*proto.HostConfig
|
||||
|
||||
@@ -8,11 +8,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
@@ -263,3 +265,39 @@ func TestEncodeSessionExpiresAt(t *testing.T) {
|
||||
assert.True(t, got.AsTime().Equal(deadline))
|
||||
})
|
||||
}
|
||||
|
||||
// TestToNetbirdConfig_RelayInvariant guards against the v0.74.0 relay-wipe regression.
|
||||
// Clients treat any non-nil NetbirdConfig as authoritative and interpret a missing relay
|
||||
// section as relay disabled, wiping their relay URLs. toNetbirdConfig must therefore
|
||||
// return nil when no server config is set (the fan-out network-map path) instead of a
|
||||
// partial config, and a result built from a relay-enabled config must carry the relay
|
||||
// section.
|
||||
func TestToNetbirdConfig_RelayInvariant(t *testing.T) {
|
||||
settings := &types.Settings{MetricsPushEnabled: true}
|
||||
|
||||
t.Run("nil server config returns nil config", func(t *testing.T) {
|
||||
nbCfg := toNetbirdConfig(nil, nil, nil, nil, settings)
|
||||
assert.Nil(t, nbCfg, "fan-out updates must not carry a partial NetbirdConfig even when settings are present")
|
||||
})
|
||||
|
||||
t.Run("relay-enabled config carries relay section", func(t *testing.T) {
|
||||
cfg := &nbconfig.Config{
|
||||
Stuns: []*nbconfig.Host{{Proto: nbconfig.UDP, URI: "stun:stun.example.com:3478"}},
|
||||
TURNConfig: &nbconfig.TURNConfig{
|
||||
Turns: []*nbconfig.Host{{Proto: nbconfig.UDP, URI: "turn:turn.example.com:3478", Username: "user", Password: "pass"}},
|
||||
},
|
||||
Relay: &nbconfig.Relay{Addresses: []string{"rels://relay.example.com:443"}},
|
||||
Signal: &nbconfig.Host{Proto: nbconfig.HTTP, URI: "signal.example.com:10000"},
|
||||
}
|
||||
relayToken := &Token{Payload: "token-payload", Signature: "token-signature"}
|
||||
|
||||
nbCfg := toNetbirdConfig(cfg, nil, relayToken, nil, settings)
|
||||
require.NotNil(t, nbCfg)
|
||||
require.NotNil(t, nbCfg.Relay, "non-nil NetbirdConfig must include the relay section")
|
||||
assert.Equal(t, cfg.Relay.Addresses, nbCfg.Relay.Urls, "relay URLs should match the server config")
|
||||
assert.Equal(t, relayToken.Payload, nbCfg.Relay.TokenPayload, "relay token payload should be set")
|
||||
assert.Equal(t, relayToken.Signature, nbCfg.Relay.TokenSignature, "relay token signature should be set")
|
||||
require.NotNil(t, nbCfg.Metrics)
|
||||
assert.True(t, nbCfg.Metrics.Enabled, "metrics flag should carry the settings value")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1048,11 +1048,7 @@ func testUpdateAccountPeers(t *testing.T) {
|
||||
|
||||
for _, channel := range peerChannels {
|
||||
update := <-channel
|
||||
assert.NotNil(t, update.Update.NetbirdConfig)
|
||||
assert.Nil(t, update.Update.NetbirdConfig.Stuns)
|
||||
assert.Nil(t, update.Update.NetbirdConfig.Turns)
|
||||
assert.Nil(t, update.Update.NetbirdConfig.Signal)
|
||||
assert.Nil(t, update.Update.NetbirdConfig.Relay)
|
||||
assert.Nil(t, update.Update.NetbirdConfig, "fan-out updates must not carry a NetbirdConfig; clients treat a config without relay as relay disabled and wipe their relay URLs")
|
||||
assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers))
|
||||
assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules))
|
||||
}
|
||||
|
||||
@@ -1849,12 +1849,17 @@ func (am *DefaultAccountManager) DeleteUserInvite(ctx context.Context, accountID
|
||||
|
||||
const minPasswordLength = 8
|
||||
|
||||
// validatePassword checks password strength requirements:
|
||||
// validatePassword checks password strength requirements.
|
||||
func validatePassword(password string) error {
|
||||
return ValidatePassword(password)
|
||||
}
|
||||
|
||||
// ValidatePassword checks password strength requirements:
|
||||
// - Minimum 8 characters
|
||||
// - At least 1 digit
|
||||
// - At least 1 uppercase letter
|
||||
// - At least 1 special character
|
||||
func validatePassword(password string) error {
|
||||
func ValidatePassword(password string) error {
|
||||
if len(password) < minPasswordLength {
|
||||
return errors.New("password must be at least 8 characters long")
|
||||
}
|
||||
|
||||
@@ -33,10 +33,15 @@ const ConnectTimeout = 10 * time.Second
|
||||
const healthCheckTimeout = 5 * time.Second
|
||||
|
||||
const (
|
||||
// EnvMaxRecvMsgSize overrides the default gRPC max receive message size (4 MB)
|
||||
// EnvMaxRecvMsgSize overrides the default gRPC max receive message size
|
||||
// for the management client connection. Value is in bytes.
|
||||
EnvMaxRecvMsgSize = "NB_MANAGEMENT_GRPC_MAX_MSG_SIZE"
|
||||
|
||||
// defaultMaxRecvMsgSize is the max gRPC receive message size used for the
|
||||
// management client connection when EnvMaxRecvMsgSize is unset or invalid.
|
||||
// It overrides the gRPC library default of 4 MB.
|
||||
defaultMaxRecvMsgSize = 1024 * 1024 * 16
|
||||
|
||||
errMsgMgmtPublicKey = "failed getting Management Service public key: %s"
|
||||
errMsgNoMgmtConnection = "no connection to management"
|
||||
)
|
||||
@@ -84,22 +89,22 @@ type ExposeResponse struct {
|
||||
}
|
||||
|
||||
// MaxRecvMsgSize returns the configured max gRPC receive message size from
|
||||
// the environment, or 0 if unset (which uses the gRPC default of 4 MB).
|
||||
// the environment, or defaultMaxRecvMsgSize (16 MB) if unset or invalid.
|
||||
func MaxRecvMsgSize() int {
|
||||
val := os.Getenv(EnvMaxRecvMsgSize)
|
||||
if val == "" {
|
||||
return 0
|
||||
return defaultMaxRecvMsgSize
|
||||
}
|
||||
|
||||
size, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
log.Warnf("invalid %s value %q, using default: %v", EnvMaxRecvMsgSize, val, err)
|
||||
return 0
|
||||
return defaultMaxRecvMsgSize
|
||||
}
|
||||
|
||||
if size <= 0 {
|
||||
log.Warnf("invalid %s value %d, must be positive, using default", EnvMaxRecvMsgSize, size)
|
||||
return 0
|
||||
return defaultMaxRecvMsgSize
|
||||
}
|
||||
|
||||
return size
|
||||
@@ -1030,8 +1035,6 @@ func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
|
||||
BlockLANAccess: info.BlockLANAccess,
|
||||
BlockInbound: info.BlockInbound,
|
||||
DisableIPv6: info.DisableIPv6,
|
||||
|
||||
LazyConnectionEnabled: info.LazyConnectionEnabled,
|
||||
},
|
||||
|
||||
Capabilities: peerCapabilities(*info),
|
||||
|
||||
@@ -21,11 +21,11 @@ func TestMaxRecvMsgSize(t *testing.T) {
|
||||
envValue string
|
||||
expected int
|
||||
}{
|
||||
{name: "unset returns 0", envValue: "", expected: 0},
|
||||
{name: "unset returns default", envValue: "", expected: defaultMaxRecvMsgSize},
|
||||
{name: "valid value", envValue: "10485760", expected: 10485760},
|
||||
{name: "non-numeric returns 0", envValue: "abc", expected: 0},
|
||||
{name: "negative returns 0", envValue: "-1", expected: 0},
|
||||
{name: "zero returns 0", envValue: "0", expected: 0},
|
||||
{name: "non-numeric returns default", envValue: "abc", expected: defaultMaxRecvMsgSize},
|
||||
{name: "negative returns default", envValue: "-1", expected: defaultMaxRecvMsgSize},
|
||||
{name: "zero returns default", envValue: "0", expected: defaultMaxRecvMsgSize},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
Reference in New Issue
Block a user