Compare commits

..

4 Commits

Author SHA1 Message Date
riccardom
5740dd22e6 Remove verbose comments 2026-07-03 17:26:20 +02:00
riccardom
ec98c930cb [Recheck watcher ctx cancellation under conn.mu in onWGDisconnected
onWGDisconnected only checked conn.ctx (the engine-scoped context), never
the watcher's own context. disableWgWatcherIfNeeded cancels the wgWatcherCtx,
not conn.ctx, so a disabled watcher's timeout callback did not see the
cancellation.

handshakeCheck runs lock-free, so between the ctx check in periodicHandshakeCheck
and acquiring conn.mu a fast disconnect/reconnect can slip in: the stale watcher
then acquires the lock and tears down the *new*, healthy connection based on the
old timeout, forcing the guard into an unnecessary reconnect (flap).

Recheck watcherCtx.Err() under conn.mu so a superseded watcher exits without
touching the connection that replaced it.
2026-07-03 12:15:24 +02:00
riccardom
60104e000b Discriminate not updated from timeout handshakes 2026-07-03 12:02:50 +02:00
riccardom
d5a212349f Stick new watcher creation to actual existence of af the conn
and its removal to the removal of such same conn.
Avoid debouncing and cross lock dead locking
2026-07-03 11:37:41 +02:00
30 changed files with 265 additions and 2183 deletions

View File

@@ -195,7 +195,6 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
statusICE: worker.NewAtomicStatus(),
dumpState: dumpState,
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
metricsRecorder: services.MetricsRecorder,
}
@@ -663,11 +662,12 @@ func (conn *Conn) onGuardEvent() {
}
}
func (conn *Conn) onWGDisconnected() {
func (conn *Conn) onWGDisconnected(watcherCtx context.Context) {
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.ctx.Err() != nil {
// watcherCtx guards against a stale watcher tearing down a connection that already superseded it.
if conn.ctx.Err() != nil || watcherCtx.Err() != nil {
return
}
@@ -802,25 +802,39 @@ func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
})
}
// enableWgWatcherIfNeeded starts a fresh watcher instance per connection attempt, so its
// lifecycle stays bound to conn.mu and enable/disable can't race an old goroutine's shutdown.
// Caller must hold conn.mu.
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
if !conn.wgWatcher.PrepareInitialHandshake() {
if conn.wgWatcher != nil {
return
}
watcher := NewWGWatcher(conn.Log, conn.config.WgConfig.WgInterface, conn.config.Key, conn.dumpState)
watcher.PrepareInitialHandshake()
wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx)
conn.wgWatcher = watcher
conn.wgWatcherCancel = wgWatcherCancel
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, enabledTime, conn.onWGDisconnected, conn.onWGHandshakeSuccess)
onDisconnected := func() { conn.onWGDisconnected(wgWatcherCtx) }
watcher.EnableWgWatcher(wgWatcherCtx, enabledTime, onDisconnected, conn.onWGHandshakeSuccess)
}()
}
// disableWgWatcherIfNeeded cancels and drops the watcher once no transport is active. It never
// waits for the goroutine: the timeout path reentrantly calls back here under conn.mu, so
// blocking would deadlock. Caller must hold conn.mu.
func (conn *Conn) disableWgWatcherIfNeeded() {
if conn.currentConnPriority == conntype.None && conn.wgWatcherCancel != nil {
conn.wgWatcherCancel()
conn.wgWatcherCancel = nil
if conn.currentConnPriority != conntype.None || conn.wgWatcher == nil {
return
}
conn.wgWatcherCancel()
conn.wgWatcher = nil
conn.wgWatcherCancel = nil
}
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
@@ -843,7 +857,9 @@ func (conn *Conn) resetEndpoint() {
return
}
conn.Log.Infof("reset wg endpoint")
conn.wgWatcher.Reset()
if conn.wgWatcher != nil {
conn.wgWatcher.Reset()
}
if err := conn.endpointUpdater.RemoveEndpointAddress(); err != nil {
conn.Log.Warnf("failed to remove endpoint address before update: %v", err)
}

View File

@@ -85,11 +85,7 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
defer g.srWatcher.RemoveListener(srReconnectedChan)
ticker := g.initialTicker(ctx)
defer func() {
// If backoff.Ticker.send is blocked, context.Done will not close the Ticker goroutine.
// We have to explicitly call Stop, even if we use backoff.WithContext.
ticker.Stop()
}()
defer ticker.Stop()
tickerChannel := ticker.C

View File

@@ -1,92 +0,0 @@
package guard
import (
"context"
"runtime"
"strings"
"sync"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer/ice"
)
func newTestGuard(status connStatusFunc) *Guard {
srw := NewSRWatcher(nil, nil, nil, ice.Config{})
return NewGuard(log.WithField("test", "guard"), status, 50*time.Millisecond, srw)
}
// countBackoffTickerGoroutines returns how many goroutines are currently sitting
// in backoff/v4.(*Ticker).run (a ticker goroutine that has not exited).
func countBackoffTickerGoroutines() int {
buf := make([]byte, 1<<25) // 32MB
n := runtime.Stack(buf, true)
return strings.Count(string(buf[:n]), "backoff/v4.(*Ticker).run")
}
// TestGuard_ReconnectTicker_NoGoroutineLeakOnShutdown reproduces a observed
// leak: after a shutdown burst, ticker run/send goroutines stay parked
// forever even though every reconnect loop has exited.
func TestGuard_ReconnectTicker_NoGoroutineLeakOnShutdown(t *testing.T) {
before := countBackoffTickerGoroutines()
const peers = 6000
cancels := make([]context.CancelFunc, 0, peers)
var wg sync.WaitGroup
// A status check slower than the tick cadence. This models the real
// isConnectedOnAllWay/callback doing work: while the loop is busy in the
// handler, the ticker fires the next tick and parks in send(), because
// send() never selects on ctx.
slowStatus := func() ConnStatus {
time.Sleep(70 * time.Millisecond)
return ConnStatusConnected
}
for range peers {
g := newTestGuard(slowStatus)
ctx, cancel := context.WithCancel(context.Background())
cancels = append(cancels, cancel)
wg.Add(1)
go func() {
defer wg.Done()
g.Start(ctx, func() {})
}()
// Force the live ticker to be a newReconnectTicker.
g.SetRelayedConnDisconnected()
}
// Let the replacement tickers get past their 800ms initial interval, so
// many are parked in send() waiting on the (slow) consumer when we tear
// everything down.
time.Sleep(1500 * time.Millisecond)
// Shutdown burst: cancel every peer at once, like engine teardown.
for _, c := range cancels {
c()
}
// Every reconnect loop must return
waitCh := make(chan struct{})
go func() { wg.Wait(); close(waitCh) }()
select {
case <-waitCh:
case <-time.After(30 * time.Second):
t.Fatal("not all reconnect loops returned after ctx cancel")
}
// Give any correctly-stopped ticker goroutines time to unwind.
for range 50 {
runtime.Gosched()
time.Sleep(10 * time.Millisecond)
}
leaked := countBackoffTickerGoroutines() - before
t.Logf("backoff Ticker.run goroutines still parked after teardown of %d peers: %d", peers, leaked)
if leaked > 0 {
t.Errorf("LEAK: %d backoff ticker goroutines parked after all reconnect loops exited "+
"(defer ticker.Stop() stops the initial ticker, not the live replacement)", leaked)
}
}

View File

@@ -3,7 +3,6 @@ package peer
import (
"context"
"fmt"
"sync"
"time"
log "github.com/sirupsen/logrus"
@@ -24,14 +23,14 @@ type WGInterfaceStater interface {
GetStats() (map[string]configurer.WGStats, error)
}
// WGWatcher is single-shot: one instance per connection attempt, run once, then discarded.
// Lifecycle is owned by Conn under conn.mu, so it keeps no "enabled" state to go stale.
type WGWatcher struct {
log *log.Entry
wgIfaceStater WGInterfaceStater
peerKey string
stateDump *stateDump
enabled bool
muEnabled sync.Mutex
// initialHandshake is not thread-safe; never call PrepareInitialHandshake and EnableWgWatcher concurrently.
initialHandshake time.Time
@@ -48,25 +47,14 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
}
}
// PrepareInitialHandshake reserves the watcher and reads the peer's current WireGuard
// handshake time. It must be called before the peer is (re)configured on the WireGuard
// interface, so the captured baseline reflects the state prior to this connection attempt
// instead of racing with that configuration. Returns ok=false if the watcher is already
// running, in which case EnableWgWatcher must not be called.
func (w *WGWatcher) PrepareInitialHandshake() (ok bool) {
w.muEnabled.Lock()
if w.enabled {
w.muEnabled.Unlock()
return false
}
// PrepareInitialHandshake reads the peer's current WireGuard handshake time. It must be
// called before the peer is (re)configured on the WireGuard interface, so the captured
// baseline reflects the state prior to this connection attempt instead of racing with
// that configuration.
func (w *WGWatcher) PrepareInitialHandshake() {
w.log.Debugf("enable WireGuard watcher")
w.enabled = true
w.muEnabled.Unlock()
handshake, _ := w.wgState()
w.initialHandshake = handshake
return true
}
// EnableWgWatcher runs the WireGuard watcher loop using the handshake baseline captured by
@@ -74,10 +62,6 @@ func (w *WGWatcher) PrepareInitialHandshake() (ok bool) {
// for context lifecycle management.
func (w *WGWatcher) EnableWgWatcher(ctx context.Context, enabledTime time.Time, onDisconnectedFn func(), onHandshakeSuccessFn func(when time.Time)) {
w.periodicHandshakeCheck(ctx, onDisconnectedFn, onHandshakeSuccessFn, enabledTime, w.initialHandshake)
w.muEnabled.Lock()
w.enabled = false
w.muEnabled.Unlock()
}
// Reset signals the watcher that the WireGuard peer has been reset and a new
@@ -103,6 +87,7 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn
case <-timer.C:
handshake, ok := w.handshakeCheck(lastHandshake)
if !ok {
// early ctx cancel check return
if ctx.Err() != nil {
return
}
@@ -147,9 +132,9 @@ func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
// the current know handshake did not change
// the current known handshake did not change
if handshake.Equal(lastHandshake) {
w.log.Warnf("WireGuard handshake timed out: %v", handshake)
w.log.Warnf("WireGuard handshake not updated: %v", handshake)
return nil, false
}

View File

@@ -7,7 +7,6 @@ import (
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/configurer"
)
@@ -35,8 +34,7 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ok := watcher.PrepareInitialHandshake()
require.True(t, ok, "watcher should not be enabled yet")
watcher.PrepareInitialHandshake()
onDisconnected := make(chan struct{}, 1)
go watcher.EnableWgWatcher(ctx, time.Now(), func() {
@@ -66,8 +64,7 @@ func TestWGWatcher_ReEnable(t *testing.T) {
watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump("peer", mlog, &Status{}))
ctx, cancel := context.WithCancel(context.Background())
ok := watcher.PrepareInitialHandshake()
require.True(t, ok, "watcher should not be enabled yet")
watcher.PrepareInitialHandshake()
wg := &sync.WaitGroup{}
wg.Add(1)
@@ -83,8 +80,7 @@ func TestWGWatcher_ReEnable(t *testing.T) {
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
ok = watcher.PrepareInitialHandshake()
require.True(t, ok, "watcher should be re-enabled after the previous run stopped")
watcher.PrepareInitialHandshake()
onDisconnected := make(chan struct{}, 1)
go watcher.EnableWgWatcher(ctx, time.Now(), func() {

View File

@@ -1,151 +0,0 @@
package cmd
import (
"context"
"fmt"
"github.com/dexidp/dex/storage"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
activitystore "github.com/netbirdio/netbird/management/server/activity/store"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
// newAdminCommands creates the admin command tree with combined-specific resource openers.
func newAdminCommands() *cobra.Command {
return admincmd.NewCommands(admincmd.Openers{
Resources: withAdminResources,
Store: withAdminStoreOnly,
IDP: withAdminIDPOnly,
})
}
func newLegacyTokenCommand() *cobra.Command {
cmd := tokencmd.NewCommands(tokencmd.StoreOpener(withAdminStoreOnly))
cmd.Deprecated = "use 'admin token' instead"
return cmd
}
// withAdminResources loads the combined YAML config, initializes stores, and calls fn.
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
return withAdminConfig(cmd, func(ctx context.Context, cfg *CombinedConfig) error {
mgmtConfig, err := adminManagementConfig(cfg)
if err != nil {
return err
}
managementStore, err := openAdminStore(ctx, cfg)
if err != nil {
return err
}
defer admincmd.CloseStore(ctx, managementStore)
idpStorage, idpStorageFile, err := admincmd.OpenIDPStorage(mgmtConfig)
if err != nil {
return err
}
defer admincmd.CloseIDPStorage(idpStorage)
eventStore, esErr := openAdminEventStore(ctx, cfg, mgmtConfig)
if esErr != nil {
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Warning: audit events will not be recorded: %v\n", esErr)
}
if eventStore != nil {
defer func() {
if err := eventStore.Close(ctx); err != nil {
log.Debugf("close activity event store: %v", err)
}
}()
}
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage, IDPStorageFile: idpStorageFile, EventStore: eventStore})
})
}
// withAdminStoreOnly opens only the management store for admin subcommands that do not
// need embedded IdP storage.
func withAdminStoreOnly(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
return withAdminConfig(cmd, func(ctx context.Context, cfg *CombinedConfig) error {
managementStore, err := openAdminStore(ctx, cfg)
if err != nil {
return err
}
defer admincmd.CloseStore(ctx, managementStore)
return fn(ctx, managementStore)
})
}
func withAdminIDPOnly(cmd *cobra.Command, fn func(ctx context.Context, idpStorage storage.Storage, storageFile string) error) error {
return withAdminConfig(cmd, func(ctx context.Context, cfg *CombinedConfig) error {
mgmtConfig, err := adminManagementConfig(cfg)
if err != nil {
return err
}
idpStorage, idpStorageFile, err := admincmd.OpenIDPStorage(mgmtConfig)
if err != nil {
return err
}
defer admincmd.CloseIDPStorage(idpStorage)
return fn(ctx, idpStorage, idpStorageFile)
})
}
func withAdminConfig(cmd *cobra.Command, fn func(ctx context.Context, cfg *CombinedConfig) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
cfg, err := LoadConfig(configPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
cfg.ApplyAdminDefaults()
applyServerStoreEnv(cfg.Server.Store)
return fn(ctx, cfg)
}
func adminManagementConfig(cfg *CombinedConfig) (*nbconfig.Config, error) {
mgmtConfig, err := cfg.ToManagementConfig()
if err != nil {
return nil, fmt.Errorf("create management config: %w", err)
}
return mgmtConfig, nil
}
func openAdminStore(ctx context.Context, cfg *CombinedConfig) (store.Store, error) {
managementStore, err := store.NewStore(ctx, types.Engine(cfg.Management.Store.Engine), cfg.Management.DataDir, nil, true)
if err != nil {
return nil, fmt.Errorf("create store: %w", err)
}
return managementStore, nil
}
func openAdminEventStore(ctx context.Context, cfg *CombinedConfig, config *nbconfig.Config) (activity.Store, error) {
if config.DataStoreEncryptionKey == "" {
return nil, fmt.Errorf("data store encryption key is not configured")
}
if err := applyActivityStoreEnv(cfg.Server.ActivityStore); err != nil {
return nil, fmt.Errorf("configure activity event store: %w", err)
}
eventStore, err := activitystore.NewSqlStore(ctx, config.Datadir, config.DataStoreEncryptionKey)
if err != nil {
return nil, fmt.Errorf("open activity event store: %w", err)
}
if eventStore == nil {
return nil, fmt.Errorf("open activity event store: returned nil store")
}
return eventStore, nil
}

View File

@@ -1,47 +0,0 @@
package cmd
import (
"context"
"os"
"testing"
"github.com/stretchr/testify/require"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
)
func TestApplyAdminDefaultsCopiesServerStoreWithoutExposedAddress(t *testing.T) {
cfg := DefaultConfig()
cfg.Server.ExposedAddress = ""
cfg.Server.DataDir = "/srv/netbird"
cfg.Server.Store = StoreConfig{
Engine: "postgres",
DSN: "postgres://user:pass@example.com/netbird",
}
cfg.ApplyAdminDefaults()
require.Equal(t, "/srv/netbird", cfg.Management.DataDir)
require.Equal(t, "postgres", cfg.Management.Store.Engine)
require.Equal(t, cfg.Server.Store.DSN, cfg.Management.Store.DSN)
}
func TestOpenAdminEventStoreMissingEncryptionKeyReturnsNilInterface(t *testing.T) {
eventStore, err := openAdminEventStore(context.Background(), &CombinedConfig{}, &nbconfig.Config{})
require.Error(t, err)
require.Contains(t, err.Error(), "encryption key")
require.Nil(t, eventStore)
}
func TestApplyServerStoreEnv(t *testing.T) {
t.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", "")
t.Setenv("NB_STORE_ENGINE_MYSQL_DSN", "")
t.Setenv("NB_STORE_ENGINE_SQLITE_FILE", "")
applyServerStoreEnv(StoreConfig{Engine: "postgres", DSN: "postgres-dsn", File: "store.db"})
require.Equal(t, "postgres-dsn", os.Getenv("NB_STORE_ENGINE_POSTGRES_DSN"))
require.Equal(t, "store.db", os.Getenv("NB_STORE_ENGINE_SQLITE_FILE"))
applyServerStoreEnv(StoreConfig{Engine: "mysql", DSN: "mysql-dsn"})
require.Equal(t, "mysql-dsn", os.Getenv("NB_STORE_ENGINE_MYSQL_DSN"))
}

View File

@@ -6,7 +6,8 @@ import (
"net"
"net/netip"
"os"
filePath "path/filepath"
"path"
"path/filepath"
"strings"
"time"
@@ -298,19 +299,6 @@ func (c *CombinedConfig) ApplySimplifiedDefaults() {
c.autoConfigureClientSettings(exposedProto, exposedHost, exposedHostPort, hasExternalStuns, hasExternalRelay, hasExternalSignal)
}
// ApplyAdminDefaults applies the management settings needed by admin commands even
// when the full server config is invalid and ApplySimplifiedDefaults cannot run.
func (c *CombinedConfig) ApplyAdminDefaults() {
if c.Management.DataDir == "" || c.Management.DataDir == "/var/lib/netbird/" {
c.Management.DataDir = c.Server.DataDir
}
if c.Management.Store.Engine == "" || c.Management.Store.Engine == "sqlite" {
if c.Server.Store.Engine != "" || c.Server.Store.File != "" || c.Server.Store.DSN != "" {
c.Management.Store = c.Server.Store
}
}
}
// applyRelayDefaults configures the relay service if no external relay is configured.
func (c *CombinedConfig) applyRelayDefaults(exposedProto, exposedHostPort string, hasExternalRelay, hasExternalStuns bool) {
if hasExternalRelay {
@@ -588,11 +576,11 @@ func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.Emb
return nil, fmt.Errorf("authStore.dsn is required when authStore.engine is postgres")
}
} else {
authStorageFile = filePath.Join(mgmt.DataDir, "idp.db")
authStorageFile = path.Join(mgmt.DataDir, "idp.db")
if c.Server.AuthStore.File != "" {
authStorageFile = c.Server.AuthStore.File
if !filePath.IsAbs(authStorageFile) {
authStorageFile = filePath.Join(mgmt.DataDir, authStorageFile)
if !filepath.IsAbs(authStorageFile) {
authStorageFile = filepath.Join(mgmt.DataDir, authStorageFile)
}
}
}
@@ -739,7 +727,7 @@ func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config, mgmtPort
cfg.EmbeddedIdP.Storage.Type = "sqlite3"
}
if cfg.EmbeddedIdP.Storage.Config.File == "" && cfg.Datadir != "" {
cfg.EmbeddedIdP.Storage.Config.File = filePath.Join(cfg.Datadir, "idp.db")
cfg.EmbeddedIdP.Storage.Config.File = path.Join(cfg.Datadir, "idp.db")
}
issuer := cfg.EmbeddedIdP.Issuer

View File

@@ -64,8 +64,7 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
_ = rootCmd.MarkPersistentFlagRequired("config")
rootCmd.AddCommand(newAdminCommands())
rootCmd.AddCommand(newLegacyTokenCommand())
rootCmd.AddCommand(newTokenCommands())
}
func RootCmd() *cobra.Command {
@@ -123,37 +122,6 @@ func execute(cmd *cobra.Command, _ []string) error {
}
// initializeConfig loads and validates the configuration, then initializes logging.
func applyServerStoreEnv(storeConfig StoreConfig) {
if dsn := storeConfig.DSN; dsn != "" {
switch strings.ToLower(storeConfig.Engine) {
case "postgres":
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
case "mysql":
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
if file := storeConfig.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
}
func applyActivityStoreEnv(storeConfig StoreConfig) error {
if engine := storeConfig.Engine; engine != "" {
engineLower := strings.ToLower(engine)
if engineLower == "postgres" && storeConfig.DSN == "" {
return fmt.Errorf("activityStore.dsn is required when activityStore.engine is postgres")
}
os.Setenv("NB_ACTIVITY_EVENT_STORE_ENGINE", engineLower)
if dsn := storeConfig.DSN; dsn != "" {
os.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN", dsn)
}
}
if file := storeConfig.File; file != "" {
os.Setenv("NB_ACTIVITY_EVENT_SQLITE_FILE", file)
}
return nil
}
func initializeConfig() error {
var err error
config, err = LoadConfig(configPath)
@@ -169,10 +137,30 @@ func initializeConfig() error {
return fmt.Errorf("failed to initialize log: %w", err)
}
applyServerStoreEnv(config.Server.Store)
if dsn := config.Server.Store.DSN; dsn != "" {
switch strings.ToLower(config.Server.Store.Engine) {
case "postgres":
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
case "mysql":
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
if file := config.Server.Store.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
if err := applyActivityStoreEnv(config.Server.ActivityStore); err != nil {
return err
if engine := config.Server.ActivityStore.Engine; engine != "" {
engineLower := strings.ToLower(engine)
if engineLower == "postgres" && config.Server.ActivityStore.DSN == "" {
return fmt.Errorf("activityStore.dsn is required when activityStore.engine is postgres")
}
os.Setenv("NB_ACTIVITY_EVENT_STORE_ENGINE", engineLower)
if dsn := config.Server.ActivityStore.DSN; dsn != "" {
os.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN", dsn)
}
}
if file := config.Server.ActivityStore.File; file != "" {
os.Setenv("NB_ACTIVITY_EVENT_SQLITE_FILE", file)
}
log.Infof("Starting combined NetBird server")

63
combined/cmd/token.go Normal file
View File

@@ -0,0 +1,63 @@
package cmd
import (
"context"
"fmt"
"os"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
// newTokenCommands creates the token command tree with combined-specific store opener.
func newTokenCommands() *cobra.Command {
return tokencmd.NewCommands(withTokenStore)
}
// withTokenStore loads the combined YAML config, initializes the store, and calls fn.
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
cfg, err := LoadConfig(configPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
if dsn := cfg.Server.Store.DSN; dsn != "" {
switch strings.ToLower(cfg.Server.Store.Engine) {
case "postgres":
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
case "mysql":
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
if file := cfg.Server.Store.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
datadir := cfg.Management.DataDir
engine := types.Engine(cfg.Management.Store.Engine)
s, err := store.NewStore(ctx, engine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, s)
}

View File

@@ -41,7 +41,7 @@ type Config struct {
GRPCAddr string
}
const LocalConnectorID = "local"
const localConnectorID = "local"
// Provider wraps a Dex server
type Provider struct {
@@ -495,60 +495,18 @@ func (p *Provider) Storage() storage.Storage {
return p.storage
}
// SetClientsMFAChain updates the MFAChain field on OAuth2 clients in Dex storage.
// Pass a non-empty slice (e.g. []string{"default-totp"}) to enable MFA, or nil to disable it.
func SetClientsMFAChain(ctx context.Context, st storage.Storage, clientIDs []string, mfaChain []string) error {
previousChains := make(map[string][]string, len(clientIDs))
for _, clientID := range clientIDs {
client, err := st.GetClient(ctx, clientID)
if err != nil {
return fmt.Errorf("failed to get client %s before MFA chain update: %w", clientID, err)
}
previousChains[clientID] = cloneMFAChain(client.MFAChain)
}
updatedClientIDs := make([]string, 0, len(clientIDs))
for _, clientID := range clientIDs {
if err := st.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
old.MFAChain = cloneMFAChain(mfaChain)
return old, nil
}); err != nil {
if rollbackErr := rollbackClientsMFAChain(ctx, st, updatedClientIDs, previousChains); rollbackErr != nil {
return fmt.Errorf("failed to update MFA chain on client %s: %w (also failed to roll back previous MFA chains: %v)", clientID, err, rollbackErr)
}
return fmt.Errorf("failed to update MFA chain on client %s: %w", clientID, err)
}
updatedClientIDs = append(updatedClientIDs, clientID)
}
return nil
}
func rollbackClientsMFAChain(ctx context.Context, st storage.Storage, clientIDs []string, previousChains map[string][]string) error {
var rollbackErrs []error
for i := len(clientIDs) - 1; i >= 0; i-- {
clientID := clientIDs[i]
previousChain := cloneMFAChain(previousChains[clientID])
if err := st.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
old.MFAChain = previousChain
return old, nil
}); err != nil {
rollbackErrs = append(rollbackErrs, fmt.Errorf("client %s: %w", clientID, err))
}
}
return errors.Join(rollbackErrs...)
}
func cloneMFAChain(chain []string) []string {
if chain == nil {
return nil
}
return append([]string(nil), chain...)
}
// SetClientsMFAChain updates the MFAChain field on the dashboard and CLI OAuth2 clients.
// Pass a non-empty slice (e.g. []string{"default-totp"}) to enable MFA, or nil to disable it.
func (p *Provider) SetClientsMFAChain(ctx context.Context, clientIDs []string, mfaChain []string) error {
return SetClientsMFAChain(ctx, p.storage, clientIDs, mfaChain)
for _, clientID := range clientIDs {
if err := p.storage.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
old.MFAChain = mfaChain
return old, nil
}); err != nil {
return fmt.Errorf("failed to update MFA chain on client %s: %w", clientID, err)
}
}
return nil
}
// Handler returns the Dex server as an http.Handler for embedding in another server.
@@ -588,7 +546,7 @@ func (p *Provider) CreateUser(ctx context.Context, email, username, password str
// Encode the user ID in Dex's format: base64(protobuf{user_id, connector_id})
// This matches the format Dex uses in JWT tokens
encodedID := EncodeDexUserID(userID, LocalConnectorID)
encodedID := EncodeDexUserID(userID, localConnectorID)
return encodedID, nil
}
@@ -667,7 +625,7 @@ func DecodeDexUserID(encodedID string) (userID, connectorID string, err error) {
// local password connector.
func IsLocalUserID(encodedID string) bool {
_, connectorID, err := DecodeDexUserID(encodedID)
return err == nil && connectorID == LocalConnectorID
return err == nil && connectorID == localConnectorID
}
// GetUser returns a user by email

View File

@@ -3,8 +3,6 @@ package dex
import (
"context"
"encoding/json"
"errors"
"io"
"log/slog"
"net/http"
"net/http/httptest"
@@ -13,44 +11,11 @@ import (
"testing"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/memory"
sqllib "github.com/dexidp/dex/storage/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type updateFailingStorage struct {
storage.Storage
failClientID string
}
func (s *updateFailingStorage) UpdateClient(ctx context.Context, id string, updater func(storage.Client) (storage.Client, error)) error {
if id == s.failClientID {
return errors.New("forced update failure")
}
return s.Storage.UpdateClient(ctx, id, updater)
}
func TestSetClientsMFAChainRollsBackUpdatedClients(t *testing.T) {
ctx := context.Background()
st := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
require.NoError(t, st.CreateClient(ctx, storage.Client{ID: "client-1", MFAChain: []string{"old-1"}}))
require.NoError(t, st.CreateClient(ctx, storage.Client{ID: "client-2", MFAChain: []string{"old-2"}}))
err := SetClientsMFAChain(ctx, &updateFailingStorage{Storage: st, failClientID: "client-2"}, []string{"client-1", "client-2"}, []string{"new"})
require.Error(t, err)
require.Contains(t, err.Error(), "failed to update MFA chain on client client-2")
client1, err := st.GetClient(ctx, "client-1")
require.NoError(t, err)
require.Equal(t, []string{"old-1"}, client1.MFAChain)
client2, err := st.GetClient(ctx, "client-2")
require.NoError(t, err)
require.Equal(t, []string{"old-2"}, client2.MFAChain)
}
func TestUserCreationFlow(t *testing.T) {
ctx := context.Background()

View File

@@ -556,7 +556,7 @@ start_services_and_show_instructions() {
echo "Creating proxy access token..."
# Use docker exec with bash to run the token command directly
PROXY_TOKEN=$($DOCKER_COMPOSE_COMMAND exec -T netbird-server \
/go/bin/netbird-server admin token create --name "default-proxy" --config /etc/netbird/config.yaml 2>/dev/null | grep "^Token:" | awk '{print $2}')
/go/bin/netbird-server token create --name "default-proxy" --config /etc/netbird/config.yaml 2>/dev/null | grep "^Token:" | awk '{print $2}')
if [[ -z "$PROXY_TOKEN" ]]; then
echo "ERROR: Failed to create proxy token. Check netbird-server logs." > /dev/stderr

View File

@@ -1,177 +0,0 @@
package cmd
import (
"context"
"fmt"
"path"
"path/filepath"
"github.com/dexidp/dex/storage"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
activitystore "github.com/netbirdio/netbird/management/server/activity/store"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
)
var adminDatadir string
// newAdminCommands creates the admin command tree with management-specific resource openers.
func newAdminCommands() *cobra.Command {
cmd := admincmd.NewCommands(admincmd.Openers{
Resources: withAdminResources,
Store: withAdminStoreOnly,
IDP: withAdminIDPOnly,
})
cmd.PersistentFlags().StringVar(&adminDatadir, "datadir", "", "Override the data directory from config (used for store.db and the default idp.db)")
return cmd
}
func newLegacyTokenCommand() *cobra.Command {
cmd := tokencmd.NewCommands(tokencmd.StoreOpener(withAdminStoreOnly))
cmd.Deprecated = "use 'admin token' instead"
cmd.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
return cmd
}
// withAdminResources initializes logging, loads config, opens the management store
// and embedded IdP storage, and calls fn.
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
return withAdminConfig(cmd, true, func(ctx context.Context, config *nbconfig.Config, datadir string) error {
managementStore, err := openAdminStore(ctx, config, datadir)
if err != nil {
return err
}
defer admincmd.CloseStore(ctx, managementStore)
idpStorage, idpStorageFile, err := admincmd.OpenIDPStorage(config)
if err != nil {
return err
}
defer admincmd.CloseIDPStorage(idpStorage)
eventStore, esErr := openAdminEventStore(ctx, config, datadir)
if esErr != nil {
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Warning: audit events will not be recorded: %v\n", esErr)
}
if eventStore != nil {
defer func() {
if err := eventStore.Close(ctx); err != nil {
log.Debugf("close activity event store: %v", err)
}
}()
}
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage, IDPStorageFile: idpStorageFile, EventStore: eventStore})
})
}
// withAdminStoreOnly opens only the management store for admin subcommands that do not
// need embedded IdP storage.
func withAdminStoreOnly(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
return withAdminConfig(cmd, false, func(ctx context.Context, config *nbconfig.Config, datadir string) error {
managementStore, err := openAdminStore(ctx, config, datadir)
if err != nil {
return err
}
defer admincmd.CloseStore(ctx, managementStore)
return fn(ctx, managementStore)
})
}
func withAdminIDPOnly(cmd *cobra.Command, fn func(ctx context.Context, idpStorage storage.Storage, storageFile string) error) error {
return withAdminConfig(cmd, true, func(ctx context.Context, config *nbconfig.Config, _ string) error {
idpStorage, idpStorageFile, err := admincmd.OpenIDPStorage(config)
if err != nil {
return err
}
defer admincmd.CloseIDPStorage(idpStorage)
return fn(ctx, idpStorage, idpStorageFile)
})
}
func withAdminConfig(cmd *cobra.Command, applyIDPDefaults bool, fn func(ctx context.Context, config *nbconfig.Config, datadir string) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
config, datadir, err := loadAdminMgmtConfig(ctx, applyIDPDefaults)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
return fn(ctx, config, datadir)
}
func loadAdminMgmtConfig(ctx context.Context, applyIDPDefaults bool) (*nbconfig.Config, string, error) {
config := &nbconfig.Config{}
if _, err := util.ReadJsonWithEnvSub(nbconfig.MgmtConfigPath, config); err != nil {
return nil, "", err
}
if applyIDPDefaults {
if err := ApplyEmbeddedIdPConfig(ctx, config); err != nil {
return nil, "", err
}
}
datadir := config.Datadir
applyAdminDatadirOverride(config, &datadir)
return config, datadir, nil
}
func applyAdminDatadirOverride(config *nbconfig.Config, datadir *string) {
if adminDatadir == "" {
return
}
oldDatadir := *datadir
*datadir = adminDatadir
if config.EmbeddedIdP != nil && config.EmbeddedIdP.Storage.Type == "sqlite3" && isDefaultIDPStorageFile(config.EmbeddedIdP.Storage.Config.File, oldDatadir) {
config.EmbeddedIdP.Storage.Config.File = filepath.Join(*datadir, "idp.db")
}
}
func isDefaultIDPStorageFile(file, datadir string) bool {
if file == "" {
return true
}
defaultFile := filepath.Join(datadir, "idp.db")
legacyDefaultFile := path.Join(datadir, "idp.db")
legacySlashDefaultFile := path.Join(filepath.ToSlash(datadir), "idp.db")
return filepath.Clean(file) == filepath.Clean(defaultFile) ||
file == legacyDefaultFile ||
filepath.ToSlash(file) == legacySlashDefaultFile
}
func openAdminStore(ctx context.Context, config *nbconfig.Config, datadir string) (store.Store, error) {
managementStore, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
if err != nil {
return nil, fmt.Errorf("create store: %w", err)
}
return managementStore, nil
}
func openAdminEventStore(ctx context.Context, config *nbconfig.Config, datadir string) (activity.Store, error) {
if config.DataStoreEncryptionKey == "" {
return nil, fmt.Errorf("data store encryption key is not configured")
}
eventStore, err := activitystore.NewSqlStore(ctx, datadir, config.DataStoreEncryptionKey)
if err != nil {
return nil, fmt.Errorf("open activity event store: %w", err)
}
if eventStore == nil {
return nil, fmt.Errorf("open activity event store: returned nil store")
}
return eventStore, nil
}

View File

@@ -1,577 +0,0 @@
// Package admincmd provides reusable cobra commands for self-hosted administrator helpers.
// Both the management and combined binaries use these commands, each providing
// their own opener to handle config loading and storage initialization.
package admincmd
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"os"
"strings"
"time"
"github.com/dexidp/dex/storage"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/crypto/bcrypt"
"github.com/netbirdio/netbird/formatter/hook"
nbdex "github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/cmd/proxy"
"github.com/netbirdio/netbird/management/cmd/token"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
// Resources contains the storages required by the admin commands.
type Resources struct {
Store store.Store
IDPStorage storage.Storage
IDPStorageFile string
EventStore activity.Store
}
// Opener initializes command resources from the command context and calls fn.
type Opener func(cmd *cobra.Command, fn func(ctx context.Context, resources Resources) error) error
// StoreOpener initializes only the management store from the command context and calls fn.
type StoreOpener func(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error
// IDPOpener initializes only the embedded IdP storage from the command context and calls fn.
type IDPOpener func(cmd *cobra.Command, fn func(ctx context.Context, idpStorage storage.Storage, storageFile string) error) error
// Openers contains the resource openers needed by the admin command tree.
type Openers struct {
Resources Opener
Store StoreOpener
IDP IDPOpener
}
type userSelector struct {
email string
userID string
}
func (s userSelector) normalized() userSelector {
return userSelector{
email: strings.TrimSpace(s.email),
userID: strings.TrimSpace(s.userID),
}
}
func (s userSelector) validate() error {
s = s.normalized()
if (s.email == "") == (s.userID == "") {
return fmt.Errorf("provide exactly one of --email or --user-id")
}
return nil
}
// NewCommands creates the admin command tree with the given resource openers.
func NewCommands(openers Openers) *cobra.Command {
adminCmd := &cobra.Command{
Use: "admin",
Short: "Self-hosted administrator helpers",
Long: "Administrative helpers for self-hosted deployments using the embedded identity provider.",
}
userCmd := &cobra.Command{
Use: "user",
Short: "Manage local embedded IdP users",
}
var passwordSelector userSelector
var password string
var passwordFile string
passwordCmd := &cobra.Command{
Use: "change-password (--email email | --user-id id) (--password password | --password-file path)",
Aliases: []string{"set-password"},
Short: "Change a local user's password",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
if err := passwordSelector.validate(); err != nil {
return err
}
newPassword, err := resolvePasswordInput(cmd, password, passwordFile)
if err != nil {
return err
}
return openers.IDP(cmd, func(ctx context.Context, idpStorage storage.Storage, storageFile string) error {
return runChangePassword(ctx, idpStorage, cmd.OutOrStdout(), passwordSelector, newPassword, storageFile)
})
},
}
addUserSelectorFlags(passwordCmd, &passwordSelector)
passwordCmd.Flags().StringVar(&password, "password", "", "New password for the user")
passwordCmd.Flags().StringVar(&passwordFile, "password-file", "", "Read new password from file ('-' for stdin)")
var resetSelector userSelector
resetMFACmd := &cobra.Command{
Use: "reset-mfa (--email email | --user-id id)",
Short: "Reset a local user's MFA enrollment",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
if err := resetSelector.validate(); err != nil {
return err
}
return openers.IDP(cmd, func(ctx context.Context, idpStorage storage.Storage, storageFile string) error {
return runResetMFA(ctx, idpStorage, cmd.OutOrStdout(), resetSelector, storageFile)
})
},
}
addUserSelectorFlags(resetMFACmd, &resetSelector)
userCmd.AddCommand(passwordCmd, resetMFACmd)
mfaCmd := &cobra.Command{
Use: "mfa",
Short: "Manage local MFA for embedded IdP users",
}
enableCmd := &cobra.Command{
Use: "enable",
Short: "Enable MFA for local embedded IdP users",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return openers.Resources(cmd, func(ctx context.Context, resources Resources) error {
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), true)
})
},
}
disableCmd := &cobra.Command{
Use: "disable",
Short: "Disable MFA for local embedded IdP users",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return openers.Resources(cmd, func(ctx context.Context, resources Resources) error {
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), false)
})
},
}
statusCmd := &cobra.Command{
Use: "status",
Short: "Show local MFA status",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return openers.Resources(cmd, func(ctx context.Context, resources Resources) error {
return runMFAStatus(ctx, resources, cmd.OutOrStdout())
})
},
}
mfaCmd.AddCommand(enableCmd, disableCmd, statusCmd)
adminCmd.AddCommand(userCmd, mfaCmd)
if openers.Store != nil {
adminCmd.AddCommand(tokencmd.NewCommands(tokencmd.StoreOpener(openers.Store)))
adminCmd.AddCommand(proxycmd.NewCommands(proxycmd.StoreOpener(openers.Store)))
}
return adminCmd
}
// OpenEmbeddedIDPStorage opens the Dex storage configured for the embedded IdP.
func OpenEmbeddedIDPStorage(cfg *idp.EmbeddedIdPConfig) (storage.Storage, error) {
if cfg == nil || !cfg.Enabled {
return nil, fmt.Errorf("admin commands require the embedded IdP to be enabled")
}
yamlConfig, err := cfg.ToYAMLConfig()
if err != nil {
return nil, fmt.Errorf("build embedded IdP config: %w", err)
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
st, err := yamlConfig.Storage.OpenStorage(logger)
if err != nil {
return nil, fmt.Errorf("open embedded IdP storage: %w", err)
}
return st, nil
}
// CloseStore closes the management store and logs cleanup errors at debug level.
func CloseStore(ctx context.Context, s store.Store) {
if s == nil {
return
}
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}
// OpenIDPStorage opens embedded IdP storage and returns its sqlite file path when applicable.
func OpenIDPStorage(config *nbconfig.Config) (storage.Storage, string, error) {
if config == nil {
return nil, "", fmt.Errorf("management config is required")
}
idpStorage, err := OpenEmbeddedIDPStorage(config.EmbeddedIdP)
if err != nil {
return nil, "", err
}
return idpStorage, embeddedIDPStorageFile(config), nil
}
func embeddedIDPStorageFile(config *nbconfig.Config) string {
if config.EmbeddedIdP == nil || config.EmbeddedIdP.Storage.Type != "sqlite3" {
return ""
}
return config.EmbeddedIdP.Storage.Config.File
}
// CloseIDPStorage closes embedded IdP storage and logs cleanup errors at debug level.
func CloseIDPStorage(s storage.Storage) {
if s == nil {
return
}
if err := s.Close(); err != nil {
log.Debugf("close embedded IdP storage: %v", err)
}
}
func addUserSelectorFlags(cmd *cobra.Command, selector *userSelector) {
cmd.Flags().StringVar(&selector.email, "email", "", "User email")
cmd.Flags().StringVar(&selector.userID, "user-id", "", "User ID")
}
func resolvePasswordInput(cmd *cobra.Command, password, passwordFile string) (string, error) {
if password != "" && passwordFile != "" {
return "", fmt.Errorf("provide only one of --password or --password-file")
}
if passwordFile == "" {
return password, nil
}
var data []byte
var err error
if passwordFile == "-" {
data, err = io.ReadAll(cmd.InOrStdin())
} else {
data, err = os.ReadFile(passwordFile)
}
if err != nil {
return "", fmt.Errorf("read password: %w", err)
}
return strings.TrimRight(string(data), "\r\n"), nil
}
func runChangePassword(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector, password string, idpStorageFile string) error {
if idpStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
selector = selector.normalized()
if err := selector.validate(); err != nil {
return err
}
if password == "" {
return fmt.Errorf("password is required")
}
if err := server.ValidatePassword(password); err != nil {
return fmt.Errorf("invalid password: %w", err)
}
user, err := findLocalUser(ctx, idpStorage, selector, idpStorageFile)
if err != nil {
return err
}
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
if err := idpStorage.UpdatePassword(ctx, user.Email, func(old storage.Password) (storage.Password, error) {
old.Hash = hash
return old, nil
}); err != nil {
return fmt.Errorf("update password for %s: %w", user.Email, err)
}
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
return err
}
_, _ = fmt.Fprintf(w, "Password updated for %s.\n", user.Email)
return nil
}
func runResetMFA(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector, idpStorageFile string) error {
if idpStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
selector = selector.normalized()
if err := selector.validate(); err != nil {
return err
}
user, err := findLocalUser(ctx, idpStorage, selector, idpStorageFile)
if err != nil {
return err
}
reset := false
err = idpStorage.UpdateUserIdentity(ctx, user.UserID, idp.LocalConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
reset = reset || len(old.MFASecrets) > 0 || len(old.WebAuthnCredentials) > 0
old.MFASecrets = map[string]*storage.MFASecret{}
old.WebAuthnCredentials = map[string][]storage.WebAuthnCredential{}
return old, nil
})
if errors.Is(err, storage.ErrNotFound) {
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
return err
}
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
return nil
}
if err != nil {
return fmt.Errorf("reset MFA for %s: %w", user.Email, err)
}
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
return err
}
if reset {
_, _ = fmt.Fprintf(w, "MFA reset for %s. The user will re-enroll at next login.\n", user.Email)
} else {
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
}
return nil
}
func runSetMFAEnabled(ctx context.Context, resources Resources, w io.Writer, enabled bool) error {
if resources.Store == nil {
return fmt.Errorf("management store is required")
}
if resources.IDPStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
accountID, settings, err := getSingleAccountSettings(ctx, resources.Store)
if err != nil {
return err
}
oldEnabled := settings.LocalMfaEnabled
newSettings := settings.Copy()
newSettings.LocalMfaEnabled = enabled
if err := setIDPClientsMFA(ctx, resources.IDPStorage, enabled); err != nil {
return err
}
if err := resources.Store.SaveAccountSettings(ctx, accountID, newSettings); err != nil {
if rollbackErr := setIDPClientsMFA(ctx, resources.IDPStorage, oldEnabled); rollbackErr != nil {
return fmt.Errorf("save local MFA account setting: %w (also failed to roll back embedded IdP MFA state: %v)", err, rollbackErr)
}
return fmt.Errorf("save local MFA account setting: %w", err)
}
if err := storeMFAActivity(ctx, resources.EventStore, accountID, enabled); err != nil {
_, _ = fmt.Fprintf(w, "Warning: failed to record audit event: %v\n", err)
}
state := "disabled"
if enabled {
state = "enabled"
}
_, _ = fmt.Fprintf(w, "Local MFA %s.\n", state)
return nil
}
func runMFAStatus(ctx context.Context, resources Resources, w io.Writer) error {
if resources.Store == nil {
return fmt.Errorf("management store is required")
}
if resources.IDPStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
_, settings, err := getSingleAccountSettings(ctx, resources.Store)
if err != nil {
return err
}
accountStatus := "disabled"
if settings.LocalMfaEnabled {
accountStatus = "enabled"
}
clientStatus, err := idpClientsMFAStatus(ctx, resources.IDPStorage)
if err != nil {
return err
}
_, _ = fmt.Fprintf(w, "Account setting: %s\n", accountStatus)
_, _ = fmt.Fprintf(w, "Embedded IdP clients: %s\n", clientStatus)
return nil
}
func getSingleAccountSettings(ctx context.Context, s store.Store) (string, *types.Settings, error) {
count, err := s.GetAccountsCounter(ctx)
if err != nil {
return "", nil, fmt.Errorf("count accounts: %w", err)
}
if count != 1 {
return "", nil, fmt.Errorf("expected exactly one account, got %d; local MFA is supported only in single-account embedded IdP deployments", count)
}
accountID, err := s.GetAnyAccountID(ctx)
if err != nil {
return "", nil, fmt.Errorf("get account ID: %w", err)
}
settings, err := s.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return "", nil, fmt.Errorf("get account settings: %w", err)
}
if settings == nil {
settings = &types.Settings{}
}
return accountID, settings, nil
}
func storeMFAActivity(ctx context.Context, eventStore activity.Store, accountID string, enabled bool) error {
if eventStore == nil {
return nil
}
event := activity.AccountLocalMfaDisabled
if enabled {
event = activity.AccountLocalMfaEnabled
}
_, err := eventStore.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: event,
InitiatorID: string(hook.SystemSource),
TargetID: accountID,
AccountID: accountID,
})
if err != nil {
return fmt.Errorf("save local MFA audit event: %w", err)
}
return nil
}
func findLocalUser(ctx context.Context, idpStorage storage.Storage, selector userSelector, idpStorageFile string) (storage.Password, error) {
selector = selector.normalized()
if err := selector.validate(); err != nil {
return storage.Password{}, err
}
if selector.email != "" {
user, err := idpStorage.GetPassword(ctx, selector.email)
if errors.Is(err, storage.ErrNotFound) {
if empty, listErr := localUsersEmpty(ctx, idpStorage); listErr != nil {
return storage.Password{}, listErr
} else if empty {
return storage.Password{}, noLocalUsersError(idpStorageFile)
}
return storage.Password{}, fmt.Errorf("local user with email %q not found", selector.email)
}
if err != nil {
return storage.Password{}, fmt.Errorf("get local user by email %q: %w", selector.email, err)
}
return user, nil
}
rawUserID := selector.userID
if decodedUserID, _, err := nbdex.DecodeDexUserID(selector.userID); err == nil && decodedUserID != "" {
rawUserID = decodedUserID
}
users, err := idpStorage.ListPasswords(ctx)
if err != nil {
return storage.Password{}, fmt.Errorf("list local users: %w", err)
}
for _, user := range users {
if user.UserID == rawUserID || user.UserID == selector.userID {
return user, nil
}
}
if len(users) == 0 {
return storage.Password{}, noLocalUsersError(idpStorageFile)
}
return storage.Password{}, fmt.Errorf("local user with ID %q not found", selector.userID)
}
func localUsersEmpty(ctx context.Context, idpStorage storage.Storage) (bool, error) {
users, err := idpStorage.ListPasswords(ctx)
if err != nil {
return false, fmt.Errorf("list local users: %w", err)
}
return len(users) == 0, nil
}
func noLocalUsersError(idpStorageFile string) error {
location := ""
if idpStorageFile != "" {
location = fmt.Sprintf(" (%s)", idpStorageFile)
}
return fmt.Errorf("no local users exist in the embedded IdP storage%s; the management server may never have started with this config, or --datadir points at the wrong location", location)
}
func deleteLocalAuthSession(ctx context.Context, idpStorage storage.Storage, userID string) error {
err := idpStorage.DeleteAuthSession(ctx, userID, idp.LocalConnectorID)
if err == nil || errors.Is(err, storage.ErrNotFound) {
return nil
}
return fmt.Errorf("delete local auth session for user %s: %w", userID, err)
}
func setIDPClientsMFA(ctx context.Context, idpStorage storage.Storage, enabled bool) error {
var mfaChain []string
if enabled {
mfaChain = []string{idp.DefaultTOTPAuthenticatorID}
}
clientIDs := []string{idp.StaticClientCLI, idp.StaticClientDashboard}
if err := nbdex.SetClientsMFAChain(ctx, idpStorage, clientIDs, mfaChain); err != nil {
if errors.Is(err, storage.ErrNotFound) {
return fmt.Errorf("embedded IdP client not found; start the management server once before toggling MFA: %w", err)
}
return fmt.Errorf("update MFA chain on embedded IdP clients: %w", err)
}
return nil
}
func idpClientsMFAStatus(ctx context.Context, idpStorage storage.Storage) (string, error) {
clientIDs := []string{idp.StaticClientCLI, idp.StaticClientDashboard}
enabledCount := 0
for _, clientID := range clientIDs {
client, err := idpStorage.GetClient(ctx, clientID)
if errors.Is(err, storage.ErrNotFound) {
return "unknown", fmt.Errorf("embedded IdP client %q not found", clientID)
}
if err != nil {
return "unknown", fmt.Errorf("get embedded IdP client %q: %w", clientID, err)
}
if hasAuthenticator(client.MFAChain, idp.DefaultTOTPAuthenticatorID) {
enabledCount++
}
}
switch enabledCount {
case 0:
return "disabled", nil
case len(clientIDs):
return "enabled", nil
default:
return "partially enabled", nil
}
}
func hasAuthenticator(chain []string, authenticatorID string) bool {
for _, id := range chain {
if id == authenticatorID {
return true
}
}
return false
}

View File

@@ -1,250 +0,0 @@
package admincmd
import (
"bytes"
"context"
"io"
"log/slog"
"strings"
"testing"
"time"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/memory"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
nbdex "github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/idp"
mgmtstore "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
func newTestIDPStorage(t *testing.T) storage.Storage {
t.Helper()
st := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
hash, err := bcrypt.GenerateFromPassword([]byte("OldPass1!"), bcrypt.DefaultCost)
require.NoError(t, err)
require.NoError(t, st.CreatePassword(context.Background(), storage.Password{
Email: "user@example.com",
Username: "User",
UserID: "user-1",
Hash: hash,
}))
require.NoError(t, st.CreateUserIdentity(context.Background(), storage.UserIdentity{
UserID: "user-1",
ConnectorID: idp.LocalConnectorID,
MFASecrets: map[string]*storage.MFASecret{
idp.DefaultTOTPAuthenticatorID: {
AuthenticatorID: idp.DefaultTOTPAuthenticatorID,
Type: "TOTP",
Secret: "otpauth://totp/NetBird:user@example.com?secret=ABC",
Confirmed: true,
CreatedAt: time.Now(),
},
},
WebAuthnCredentials: map[string][]storage.WebAuthnCredential{
"webauthn": {{CredentialID: []byte("credential")}},
},
}))
require.NoError(t, st.CreateAuthSession(context.Background(), storage.AuthSession{
UserID: "user-1",
ConnectorID: idp.LocalConnectorID,
Nonce: "nonce",
}))
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: idp.StaticClientCLI, Name: "CLI"}))
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: idp.StaticClientDashboard, Name: "Dashboard"}))
return st
}
func TestRunChangePassword(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
var out bytes.Buffer
err := runChangePassword(ctx, st, &out, userSelector{email: "user@example.com"}, "NewPass1!", "")
require.NoError(t, err)
require.Contains(t, out.String(), "Password updated")
user, err := st.GetPassword(ctx, "user@example.com")
require.NoError(t, err)
require.NoError(t, bcrypt.CompareHashAndPassword(user.Hash, []byte("NewPass1!")))
_, err = st.GetAuthSession(ctx, "user-1", idp.LocalConnectorID)
require.ErrorIs(t, err, storage.ErrNotFound)
}
func TestRunChangePasswordValidatesPassword(t *testing.T) {
st := newTestIDPStorage(t)
err := runChangePassword(context.Background(), st, io.Discard, userSelector{email: "user@example.com"}, "short", "")
require.Error(t, err)
require.Contains(t, err.Error(), "invalid password")
}
func TestRunResetMFA(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
var out bytes.Buffer
encodedUserID := nbdex.EncodeDexUserID("user-1", idp.LocalConnectorID)
err := runResetMFA(ctx, st, &out, userSelector{userID: encodedUserID}, "")
require.NoError(t, err)
require.Contains(t, out.String(), "MFA reset")
identity, err := st.GetUserIdentity(ctx, "user-1", idp.LocalConnectorID)
require.NoError(t, err)
require.Empty(t, identity.MFASecrets)
require.Empty(t, identity.WebAuthnCredentials)
_, err = st.GetAuthSession(ctx, "user-1", idp.LocalConnectorID)
require.ErrorIs(t, err, storage.ErrNotFound)
}
func TestRunResetMFAWithoutEnrollment(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
require.NoError(t, st.UpdateUserIdentity(ctx, "user-1", idp.LocalConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
old.MFASecrets = nil
old.WebAuthnCredentials = nil
return old, nil
}))
var out bytes.Buffer
err := runResetMFA(ctx, st, &out, userSelector{email: "user@example.com"}, "")
require.NoError(t, err)
require.Contains(t, out.String(), "No MFA enrollment found")
}
func TestSetIDPClientsMFA(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
require.NoError(t, setIDPClientsMFA(ctx, st, true))
status, err := idpClientsMFAStatus(ctx, st)
require.NoError(t, err)
require.Equal(t, "enabled", status)
require.NoError(t, setIDPClientsMFA(ctx, st, false))
status, err = idpClientsMFAStatus(ctx, st)
require.NoError(t, err)
require.Equal(t, "disabled", status)
}
func newTestManagementStore(t *testing.T, localMFAEnabled bool) mgmtstore.Store {
t.Helper()
ctx := context.Background()
st, err := mgmtstore.NewStore(ctx, types.SqliteStoreEngine, t.TempDir(), nil, false)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, st.Close(ctx)) })
require.NoError(t, st.SaveAccount(ctx, &types.Account{
Id: "account-1",
Settings: &types.Settings{LocalMfaEnabled: localMFAEnabled},
}))
return st
}
func TestRunSetMFAEnabledDoesNotSaveWhenIDPUpdateFails(t *testing.T) {
ctx := context.Background()
managementStore := newTestManagementStore(t, false)
idpStorage := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
err := runSetMFAEnabled(ctx, Resources{Store: managementStore, IDPStorage: idpStorage}, io.Discard, true)
require.Error(t, err)
require.Contains(t, err.Error(), "embedded IdP client")
settings, err := managementStore.GetAccountSettings(ctx, mgmtstore.LockingStrengthNone, "account-1")
require.NoError(t, err)
require.False(t, settings.LocalMfaEnabled)
}
func TestRunSetMFAEnabledUpdatesSettingsAfterIDP(t *testing.T) {
ctx := context.Background()
managementStore := newTestManagementStore(t, false)
idpStorage := newTestIDPStorage(t)
err := runSetMFAEnabled(ctx, Resources{Store: managementStore, IDPStorage: idpStorage}, io.Discard, true)
require.NoError(t, err)
settings, err := managementStore.GetAccountSettings(ctx, mgmtstore.LockingStrengthNone, "account-1")
require.NoError(t, err)
require.True(t, settings.LocalMfaEnabled)
clientStatus, err := idpClientsMFAStatus(ctx, idpStorage)
require.NoError(t, err)
require.Equal(t, "enabled", clientStatus)
}
func TestRunSetMFAEnabledSucceedsWithNilEventStore(t *testing.T) {
ctx := context.Background()
managementStore := newTestManagementStore(t, false)
idpStorage := newTestIDPStorage(t)
var out bytes.Buffer
var err error
require.NotPanics(t, func() {
err = runSetMFAEnabled(ctx, Resources{Store: managementStore, IDPStorage: idpStorage, EventStore: nil}, &out, true)
})
require.NoError(t, err)
require.Contains(t, out.String(), "Local MFA enabled")
settings, err := managementStore.GetAccountSettings(ctx, mgmtstore.LockingStrengthNone, "account-1")
require.NoError(t, err)
require.True(t, settings.LocalMfaEnabled)
}
func TestUserSelectorValidate(t *testing.T) {
require.NoError(t, userSelector{email: " user@example.com "}.validate())
require.NoError(t, userSelector{userID: "user-1"}.validate())
require.Error(t, userSelector{}.validate())
require.Error(t, userSelector{email: "user@example.com", userID: "user-1"}.validate())
}
func TestFindLocalUserNotFound(t *testing.T) {
st := newTestIDPStorage(t)
_, err := findLocalUser(context.Background(), st, userSelector{email: "missing@example.com"}, "")
require.Error(t, err)
require.True(t, strings.Contains(err.Error(), "not found"))
}
func TestFindLocalUserZeroUsersIncludesStoragePath(t *testing.T) {
st := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
_, err := findLocalUser(context.Background(), st, userSelector{email: "missing@example.com"}, "/var/lib/netbird/idp.db")
require.Error(t, err)
require.Contains(t, err.Error(), "no local users exist")
require.Contains(t, err.Error(), "/var/lib/netbird/idp.db")
}
func TestUserCommandValidatesSelectorBeforeOpeningStorage(t *testing.T) {
opened := false
cmd := NewCommands(Openers{
IDP: func(cmd *cobra.Command, fn func(ctx context.Context, idpStorage storage.Storage, storageFile string) error) error {
opened = true
return nil
},
})
cmd.SetArgs([]string{"user", "change-password", "--password", "NewPass1!"})
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
err := cmd.Execute()
require.Error(t, err)
require.Contains(t, err.Error(), "provide exactly one")
require.False(t, opened)
}
func TestResolvePasswordInputFromStdin(t *testing.T) {
cmd := &cobra.Command{}
cmd.SetIn(strings.NewReader("NewPass1!\n"))
password, err := resolvePasswordInput(cmd, "", "-")
require.NoError(t, err)
require.Equal(t, "NewPass1!", password)
}
func TestResolvePasswordInputRejectsMultipleSources(t *testing.T) {
_, err := resolvePasswordInput(&cobra.Command{}, "NewPass1!", "-")
require.Error(t, err)
}

View File

@@ -1,80 +0,0 @@
package cmd
import (
"context"
"path"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/idp"
)
func TestApplyAdminDatadirOverrideRelocatesDefaultIDPStorage(t *testing.T) {
oldDatadir := filepath.Join(t.TempDir(), "old")
newDatadir := filepath.Join(t.TempDir(), "new")
for _, defaultFile := range []string{
"",
filepath.Join(oldDatadir, "idp.db"),
path.Join(oldDatadir, "idp.db"),
} {
t.Run(defaultFile, func(t *testing.T) {
cfg := &nbconfig.Config{
EmbeddedIdP: &idp.EmbeddedIdPConfig{
Enabled: true,
Storage: idp.EmbeddedStorageConfig{
Type: "sqlite3",
Config: idp.EmbeddedStorageTypeConfig{
File: defaultFile,
},
},
},
}
datadir := oldDatadir
oldAdminDatadir := adminDatadir
adminDatadir = newDatadir
t.Cleanup(func() { adminDatadir = oldAdminDatadir })
applyAdminDatadirOverride(cfg, &datadir)
require.Equal(t, newDatadir, datadir)
require.Equal(t, filepath.Join(newDatadir, "idp.db"), cfg.EmbeddedIdP.Storage.Config.File)
})
}
}
func TestOpenAdminEventStoreMissingEncryptionKeyReturnsNilInterface(t *testing.T) {
eventStore, err := openAdminEventStore(context.Background(), &nbconfig.Config{}, t.TempDir())
require.Error(t, err)
require.Contains(t, err.Error(), "encryption key")
require.Nil(t, eventStore)
}
func TestApplyAdminDatadirOverrideKeepsExplicitIDPStorage(t *testing.T) {
oldDatadir := filepath.Join(t.TempDir(), "old")
newDatadir := filepath.Join(t.TempDir(), "new")
explicitFile := filepath.Join(t.TempDir(), "custom-idp.db")
cfg := &nbconfig.Config{
EmbeddedIdP: &idp.EmbeddedIdPConfig{
Enabled: true,
Storage: idp.EmbeddedStorageConfig{
Type: "sqlite3",
Config: idp.EmbeddedStorageTypeConfig{
File: explicitFile,
},
},
},
}
datadir := oldDatadir
oldAdminDatadir := adminDatadir
adminDatadir = newDatadir
t.Cleanup(func() { adminDatadir = oldAdminDatadir })
applyAdminDatadirOverride(cfg, &datadir)
require.Equal(t, newDatadir, datadir)
require.Equal(t, explicitFile, cfg.EmbeddedIdP.Storage.Config.File)
}

View File

@@ -13,7 +13,6 @@ import (
"os"
"os/signal"
"path"
"path/filepath"
"strings"
"syscall"
@@ -210,7 +209,7 @@ func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
cfg.EmbeddedIdP.Storage.Type = "sqlite3"
}
if cfg.EmbeddedIdP.Storage.Config.File == "" && cfg.Datadir != "" {
cfg.EmbeddedIdP.Storage.Config.File = filepath.Join(cfg.Datadir, "idp.db")
cfg.EmbeddedIdP.Storage.Config.File = path.Join(cfg.Datadir, "idp.db")
}
issuer := cfg.EmbeddedIdP.Issuer

View File

@@ -1,141 +0,0 @@
// Package proxycmd provides reusable cobra commands for managing reverse proxy instances.
// Both the management and combined binaries use these commands, each providing
// their own StoreOpener to handle config loading and store initialization.
package proxycmd
import (
"bufio"
"context"
"fmt"
"io"
"strings"
"text/tabwriter"
"github.com/spf13/cobra"
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/server/store"
)
// StoreOpener initializes a store from the command context and calls fn.
type StoreOpener func(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error
const disconnectAllConfirmation = "disconnect all proxies"
// NewCommands creates the proxy command tree with the given store opener.
// Returns the parent "proxy" command with the disconnect-all subcommand.
func NewCommands(opener StoreOpener) *cobra.Command {
var dryRun bool
var force bool
proxyCmd := &cobra.Command{
Use: "proxy",
Short: "Manage reverse proxy instances",
Long: "Commands for inspecting and repairing the reverse proxy instances registered with the management server.",
}
disconnectAllCmd := &cobra.Command{
Use: "disconnect-all",
Short: "Force-mark all reverse proxy instances as disconnected",
Long: "Lists all reverse proxy instances and force-marks them as disconnected, regardless of their session state. " +
"Use this to repair stale connection state, e.g. after an unclean management server shutdown. " +
"By default, it asks for manual confirmation before changing state. Use --dry-run to preview without changing state, or --force to skip confirmation. " +
"Run during a maintenance window; affected live proxies may stay hidden until their next heartbeat or reconnect/re-register.",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return opener(cmd, func(ctx context.Context, s store.Store) error {
return runDisconnectAll(ctx, s, cmd.OutOrStdout(), cmd.InOrStdin(), dryRun, force)
})
},
}
disconnectAllCmd.Flags().BoolVar(&dryRun, "dry-run", false, "List reverse proxy instances that would be disconnected without changing state")
disconnectAllCmd.Flags().BoolVar(&force, "force", false, "Skip the confirmation prompt and apply the repair")
proxyCmd.AddCommand(disconnectAllCmd)
return proxyCmd
}
func runDisconnectAll(ctx context.Context, s store.Store, out io.Writer, in io.Reader, dryRun, force bool) error {
proxies, err := s.GetAllProxies(ctx)
if err != nil {
return fmt.Errorf("list proxies: %w", err)
}
if len(proxies) == 0 {
_, _ = fmt.Fprintln(out, "No reverse proxy instances found.")
return nil
}
toDisconnect := 0
w := tabwriter.NewWriter(out, 0, 0, 2, ' ', 0)
_, _ = fmt.Fprintln(w, "ID\tCLUSTER\tIP\tACCOUNT\tSTATUS\tLAST SEEN")
_, _ = fmt.Fprintln(w, "--\t-------\t--\t-------\t------\t---------")
for _, p := range proxies {
if p.Status != rpproxy.StatusDisconnected {
toDisconnect++
}
account := "-"
if p.AccountID != nil {
account = *p.AccountID
}
_, _ = fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
p.ID,
p.ClusterAddress,
p.IPAddress,
account,
p.Status,
p.LastSeen.Format("2006-01-02 15:04:05"),
)
}
if err := w.Flush(); err != nil {
return fmt.Errorf("write proxy list: %w", err)
}
if dryRun {
_, _ = fmt.Fprintf(out, "\nDry run: would force-mark %d of %d reverse proxy instance(s) as disconnected.\n", toDisconnect, len(proxies))
return nil
}
if !force {
confirmed, err := confirmDisconnectAll(out, in)
if err != nil {
return err
}
if !confirmed {
_, _ = fmt.Fprintln(out, "Aborted. No reverse proxy instances were changed.")
return nil
}
}
disconnected, err := s.DisconnectAllProxies(ctx)
if err != nil {
return fmt.Errorf("disconnect proxies: %w", err)
}
_, _ = fmt.Fprintf(out, "\nForce-marked %d of %d reverse proxy instance(s) as disconnected.\n", disconnected, len(proxies))
return nil
}
func confirmDisconnectAll(out io.Writer, in io.Reader) (bool, error) {
if in == nil {
in = strings.NewReader("")
}
_, _ = fmt.Fprintln(out, "\nWARNING: This command changes stored reverse proxy state for every non-disconnected instance.")
_, _ = fmt.Fprintln(out, "Run it during a maintenance window; affected live proxies may stay hidden until "+
"their next heartbeat or reconnect/re-register.")
_, _ = fmt.Fprintf(out, "Type %q to continue: ", disconnectAllConfirmation)
scanner := bufio.NewScanner(in)
if !scanner.Scan() {
if err := scanner.Err(); err != nil {
return false, fmt.Errorf("read confirmation: %w", err)
}
return false, nil
}
return strings.EqualFold(strings.TrimSpace(scanner.Text()), disconnectAllConfirmation), nil
}

View File

@@ -1,180 +0,0 @@
package proxycmd
import (
"bytes"
"context"
"strings"
"testing"
"time"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/server/store"
)
func newTestStore(t *testing.T) store.Store {
t.Helper()
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
require.NoError(t, err)
t.Cleanup(cleanup)
return s
}
func seedProxies(t *testing.T, ctx context.Context, s store.Store) {
t.Helper()
accountID := "account-1"
alreadyDisconnectedAt := time.Now().Add(-time.Hour)
seed := []*rpproxy.Proxy{
{
ID: "proxy-1",
SessionID: "session-1",
ClusterAddress: "cluster-a.example.com",
IPAddress: "10.0.0.1",
LastSeen: time.Now(),
Status: rpproxy.StatusConnected,
},
{
ID: "proxy-2",
SessionID: "session-2",
ClusterAddress: "cluster-b.example.com",
IPAddress: "10.0.0.2",
AccountID: &accountID,
LastSeen: time.Now(),
Status: rpproxy.StatusConnected,
},
{
ID: "proxy-3",
SessionID: "session-3",
ClusterAddress: "cluster-a.example.com",
IPAddress: "10.0.0.3",
LastSeen: time.Now().Add(-time.Hour),
Status: rpproxy.StatusDisconnected,
DisconnectedAt: &alreadyDisconnectedAt,
},
}
for _, p := range seed {
require.NoError(t, s.SaveProxy(ctx, p))
}
}
func proxiesByID(t *testing.T, ctx context.Context, s store.Store) map[string]*rpproxy.Proxy {
t.Helper()
proxies, err := s.GetAllProxies(ctx)
require.NoError(t, err)
require.Len(t, proxies, 3)
byID := make(map[string]*rpproxy.Proxy, len(proxies))
for _, p := range proxies {
byID[p.ID] = p
}
return byID
}
func TestRunDisconnectAllWithConfirmation(t *testing.T) {
ctx := context.Background()
s := newTestStore(t)
seedProxies(t, ctx, s)
var out bytes.Buffer
require.NoError(t, runDisconnectAll(ctx, s, &out, strings.NewReader(disconnectAllConfirmation+"\n"), false, false))
output := out.String()
require.Contains(t, output, "proxy-1")
require.Contains(t, output, "proxy-2")
require.Contains(t, output, "proxy-3")
require.Contains(t, output, "cluster-a.example.com")
require.Contains(t, output, "account-1")
require.Contains(t, output, "Type \"disconnect all proxies\" to continue")
require.Contains(t, output, "Force-marked 2 of 3 reverse proxy instance(s) as disconnected.")
for _, p := range proxiesByID(t, ctx, s) {
require.Equal(t, rpproxy.StatusDisconnected, p.Status, "proxy %s should be disconnected", p.ID)
require.NotNil(t, p.DisconnectedAt, "proxy %s should have a disconnected timestamp", p.ID)
}
}
func TestRunDisconnectAllForceSkipsConfirmation(t *testing.T) {
ctx := context.Background()
s := newTestStore(t)
seedProxies(t, ctx, s)
var out bytes.Buffer
require.NoError(t, runDisconnectAll(ctx, s, &out, strings.NewReader(""), false, true))
output := out.String()
require.NotContains(t, output, "Type \"disconnect all proxies\" to continue")
require.Contains(t, output, "Force-marked 2 of 3 reverse proxy instance(s) as disconnected.")
}
func TestRunDisconnectAllAbortLeavesProxiesUnchanged(t *testing.T) {
ctx := context.Background()
s := newTestStore(t)
seedProxies(t, ctx, s)
var out bytes.Buffer
require.NoError(t, runDisconnectAll(ctx, s, &out, strings.NewReader("no\n"), false, false))
output := out.String()
require.Contains(t, output, "Type \"disconnect all proxies\" to continue")
require.Contains(t, output, "Aborted. No reverse proxy instances were changed.")
byID := proxiesByID(t, ctx, s)
require.Equal(t, rpproxy.StatusConnected, byID["proxy-1"].Status)
require.Equal(t, rpproxy.StatusConnected, byID["proxy-2"].Status)
require.Equal(t, rpproxy.StatusDisconnected, byID["proxy-3"].Status)
}
func TestRunDisconnectAllDryRunLeavesProxiesUnchanged(t *testing.T) {
ctx := context.Background()
s := newTestStore(t)
seedProxies(t, ctx, s)
var out bytes.Buffer
require.NoError(t, runDisconnectAll(ctx, s, &out, strings.NewReader(""), true, false))
output := out.String()
require.Contains(t, output, "Dry run: would force-mark 2 of 3 reverse proxy instance(s) as disconnected.")
require.NotContains(t, output, "Type \"disconnect all proxies\" to continue")
byID := proxiesByID(t, ctx, s)
require.Equal(t, rpproxy.StatusConnected, byID["proxy-1"].Status)
require.Equal(t, rpproxy.StatusConnected, byID["proxy-2"].Status)
require.Equal(t, rpproxy.StatusDisconnected, byID["proxy-3"].Status)
}
func TestNewCommandsDisconnectAllDryRun(t *testing.T) {
ctx := context.Background()
s := newTestStore(t)
seedProxies(t, ctx, s)
opened := false
cmd := NewCommands(func(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
opened = true
return fn(cmd.Context(), s)
})
var out bytes.Buffer
cmd.SetOut(&out)
cmd.SetErr(&out)
cmd.SetIn(strings.NewReader(""))
cmd.SetArgs([]string{"disconnect-all", "--dry-run"})
require.NoError(t, cmd.ExecuteContext(ctx))
require.True(t, opened)
require.Contains(t, out.String(), "Dry run: would force-mark 2 of 3 reverse proxy instance(s) as disconnected.")
}
func TestRunDisconnectAllEmpty(t *testing.T) {
ctx := context.Background()
s := newTestStore(t)
var out bytes.Buffer
require.NoError(t, runDisconnectAll(ctx, s, &out, strings.NewReader(""), false, false))
require.Contains(t, out.String(), "No reverse proxy instances found.")
}

View File

@@ -83,8 +83,7 @@ func init() {
rootCmd.AddCommand(migrationCmd)
ac := newAdminCommands()
ac.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
rootCmd.AddCommand(ac)
rootCmd.AddCommand(newLegacyTokenCommand())
tc := newTokenCommands()
tc.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
rootCmd.AddCommand(tc)
}

55
management/cmd/token.go Normal file
View File

@@ -0,0 +1,55 @@
package cmd
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
)
var tokenDatadir string
// newTokenCommands creates the token command tree with management-specific store opener.
func newTokenCommands() *cobra.Command {
cmd := tokencmd.NewCommands(withTokenStore)
cmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
return cmd
}
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
datadir := config.Datadir
if tokenDatadir != "" {
datadir = tokenDatadir
}
s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, s)
}

View File

@@ -608,11 +608,11 @@ func (s *ProxyServiceServer) disconnectProxy(conn *proxyConnection) {
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
log.Warnf("Failed to unregister proxy %s from cluster: %v", conn.proxyID, err)
}
conn.cancel()
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
log.Warnf("Failed to mark proxy %s as disconnected: %v", conn.proxyID, err)
}
conn.cancel()
log.Infof("Proxy %s session %s disconnected", conn.proxyID, conn.sessionID)
}

View File

@@ -21,11 +21,8 @@ import (
)
const (
StaticClientDashboard = "netbird-dashboard"
StaticClientCLI = "netbird-cli"
DefaultTOTPAuthenticatorID = "default-totp"
LocalConnectorID = dex.LocalConnectorID
staticClientDashboard = "netbird-dashboard"
staticClientCLI = "netbird-cli"
defaultCLIRedirectURL1 = "http://localhost:53000/"
defaultCLIRedirectURL2 = "http://localhost:54000/"
defaultScopes = "openid profile email groups"
@@ -188,14 +185,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
EnablePasswordDB: true,
StaticClients: []storage.Client{
{
ID: StaticClientDashboard,
ID: staticClientDashboard,
Name: "NetBird Dashboard",
Public: true,
RedirectURIs: redirectURIs,
PostLogoutRedirectURIs: sanitizePostLogoutRedirectURIs(dashboardPostLogoutRedirectURIs),
},
{
ID: StaticClientCLI,
ID: staticClientCLI,
Name: "NetBird CLI",
Public: true,
RedirectURIs: redirectURIs,
@@ -257,13 +254,13 @@ func sanitizePostLogoutRedirectURIs(uris []string) []string {
func configureMFA(cfg *dex.YAMLConfig, sessionMaxLifetime, sessionIdleTimeout string, rememberMe bool, sessionCookieEncryptionKey string) error {
cfg.MFA.Authenticators = []dex.MFAAuthenticator{{
ID: DefaultTOTPAuthenticatorID,
ID: "default-totp",
// Has to be caps otherwise it will fail
Type: "TOTP",
Config: map[string]interface{}{
"issuer": "NetBird",
},
ConnectorTypes: []string{LocalConnectorID},
ConnectorTypes: []string{"local"},
}}
if sessionMaxLifetime == "" {
@@ -739,7 +736,7 @@ func (m *EmbeddedIdPManager) GetDefaultScopes() string {
// GetCLIClientID returns the client ID for CLI authentication.
func (m *EmbeddedIdPManager) GetCLIClientID() string {
return StaticClientCLI
return staticClientCLI
}
// GetCLIRedirectURLs returns the redirect URLs configured for the CLI client.
@@ -778,7 +775,7 @@ func (m *EmbeddedIdPManager) GetLocalKeysLocation() string {
// GetClientIDs returns the OAuth2 client IDs configured for this provider.
func (m *EmbeddedIdPManager) GetClientIDs() []string {
return []string{StaticClientDashboard, StaticClientCLI}
return []string{staticClientDashboard, staticClientCLI}
}
// GetUserIDClaim returns the JWT claim name used for user identification.
@@ -795,11 +792,11 @@ func (m *EmbeddedIdPManager) IsLocalAuthDisabled() bool {
func (m *EmbeddedIdPManager) SetMFAEnabled(ctx context.Context, enabled bool) error {
var mfaChain []string
if enabled {
mfaChain = []string{DefaultTOTPAuthenticatorID}
mfaChain = []string{"default-totp"}
}
if err := m.provider.SetClientsMFAChain(ctx, []string{
StaticClientCLI,
StaticClientDashboard,
staticClientCLI,
staticClientDashboard,
}, mfaChain); err != nil {
return fmt.Errorf("failed to set MFA enabled=%v: %w", enabled, err)
}

View File

@@ -331,7 +331,7 @@ func TestEmbeddedIdPConfig_ToYAMLConfig_IncludesDeviceCallbackRedirectURI(t *tes
var cliRedirectURIs []string
for _, client := range yamlConfig.StaticClients {
if client.ID == StaticClientCLI {
if client.ID == staticClientCLI {
cliRedirectURIs = client.RedirectURIs
break
}

View File

@@ -6088,37 +6088,6 @@ func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID, sessionID strin
return nil
}
// GetAllProxies returns all reverse proxy instance rows.
func (s *SqlStore) GetAllProxies(ctx context.Context) ([]*proxy.Proxy, error) {
var proxies []*proxy.Proxy
result := s.db.Order("cluster_address, id").Find(&proxies)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get proxies: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get proxies")
}
return proxies, nil
}
// DisconnectAllProxies force-marks every proxy that is not already disconnected
// as disconnected, regardless of session ID. Unlike DisconnectProxy it is not
// session-guarded: it is an administrative repair helper, not part of the
// connection lifecycle. last_seen is left untouched so the stale-proxy reaper
// keeps working off the real last heartbeat. Returns the number of proxies updated.
func (s *SqlStore) DisconnectAllProxies(ctx context.Context) (int64, error) {
result := s.db.
Model(&proxy.Proxy{}).
Where("status != ?", proxy.StatusDisconnected).
Updates(map[string]any{
"status": proxy.StatusDisconnected,
"disconnected_at": time.Now(),
})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to disconnect all proxies: %v", result.Error)
return 0, status.Errorf(status.Internal, "failed to disconnect all proxies")
}
return result.RowsAffected, nil
}
// UpdateProxyHeartbeat updates the last_seen timestamp for the proxy's current session.
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
now := time.Now()
@@ -6126,11 +6095,7 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) err
result := s.db.
Model(&proxy.Proxy{}).
Where("id = ? AND session_id = ?", p.ID, p.SessionID).
Updates(map[string]any{
"last_seen": now,
"status": proxy.StatusConnected,
"disconnected_at": nil,
})
Update("last_seen", now)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update proxy heartbeat: %v", result.Error)

View File

@@ -1,156 +0,0 @@
package store
import (
"context"
"os"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
// TestSqlStore_DisconnectAllProxies guards the administrative
// force-disconnect helper:
//
// 1. Every proxy that is not already disconnected is marked
// disconnected regardless of its session ID (unlike
// DisconnectProxy, which is session-guarded).
// 2. Rows that are already disconnected are left untouched, so their
// original disconnected_at is preserved and the returned count
// reflects only the rows that actually changed.
// 3. last_seen is not modified — the stale-proxy reaper keeps working
// off the real last heartbeat.
func TestSqlStore_DisconnectAllProxies(t *testing.T) {
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
ctx := context.Background()
lastSeenFresh := time.Now().Add(-30 * time.Second)
lastSeenStale := time.Now().Add(-30 * time.Minute)
oldDisconnectedAt := time.Now().Add(-time.Hour)
accountID := "acct-disconnect"
proxies := []*rpproxy.Proxy{
{
ID: "p-connected-fresh",
SessionID: "sess-1",
ClusterAddress: "cluster-a.example.com",
IPAddress: "10.0.0.1",
LastSeen: lastSeenFresh,
Status: rpproxy.StatusConnected,
},
{
ID: "p-connected-stale",
SessionID: "sess-2",
ClusterAddress: "cluster-b.example.com",
IPAddress: "10.0.0.2",
AccountID: &accountID,
LastSeen: lastSeenStale,
Status: rpproxy.StatusConnected,
},
{
ID: "p-already-disconnected",
SessionID: "sess-3",
ClusterAddress: "cluster-a.example.com",
IPAddress: "10.0.0.3",
LastSeen: lastSeenStale,
Status: rpproxy.StatusDisconnected,
DisconnectedAt: &oldDisconnectedAt,
},
}
for _, p := range proxies {
require.NoError(t, store.SaveProxy(ctx, p))
}
all, err := store.GetAllProxies(ctx)
require.NoError(t, err)
require.Len(t, all, 3)
disconnected, err := store.DisconnectAllProxies(ctx)
require.NoError(t, err)
assert.Equal(t, int64(2), disconnected)
all, err = store.GetAllProxies(ctx)
require.NoError(t, err)
require.Len(t, all, 3)
byID := make(map[string]*rpproxy.Proxy, len(all))
for _, p := range all {
byID[p.ID] = p
}
for id, p := range byID {
assert.Equal(t, rpproxy.StatusDisconnected, p.Status, "proxy %s should be disconnected", id)
require.NotNil(t, p.DisconnectedAt, "proxy %s should have disconnected_at set", id)
}
// force-marked rows carry a fresh disconnected_at; the untouched row keeps its original one
assert.WithinDuration(t, time.Now(), *byID["p-connected-fresh"].DisconnectedAt, 10*time.Second)
assert.WithinDuration(t, time.Now(), *byID["p-connected-stale"].DisconnectedAt, 10*time.Second)
assert.WithinDuration(t, oldDisconnectedAt, *byID["p-already-disconnected"].DisconnectedAt, time.Second)
// last_seen is preserved so the stale reaper schedule is unaffected
assert.WithinDuration(t, lastSeenFresh, byID["p-connected-fresh"].LastSeen, time.Second)
assert.WithinDuration(t, lastSeenStale, byID["p-connected-stale"].LastSeen, time.Second)
// idempotent: a second run has nothing left to update
disconnected, err = store.DisconnectAllProxies(ctx)
require.NoError(t, err)
assert.Equal(t, int64(0), disconnected)
})
}
func TestSqlStore_UpdateProxyHeartbeatRestoresDisconnectedCurrentSession(t *testing.T) {
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
ctx := context.Background()
proxy := &rpproxy.Proxy{
ID: "p-heartbeat",
SessionID: "sess-heartbeat",
ClusterAddress: "cluster-heartbeat.example.com",
IPAddress: "10.0.0.10",
LastSeen: time.Now().Add(-30 * time.Second),
Status: rpproxy.StatusConnected,
}
require.NoError(t, store.SaveProxy(ctx, proxy))
disconnected, err := store.DisconnectAllProxies(ctx)
require.NoError(t, err)
require.Equal(t, int64(1), disconnected)
require.NoError(t, store.UpdateProxyHeartbeat(ctx, &rpproxy.Proxy{ID: proxy.ID, SessionID: proxy.SessionID}))
all, err := store.GetAllProxies(ctx)
require.NoError(t, err)
require.Len(t, all, 1)
assert.Equal(t, rpproxy.StatusConnected, all[0].Status)
assert.Nil(t, all[0].DisconnectedAt)
assert.WithinDuration(t, time.Now(), all[0].LastSeen, 10*time.Second)
addresses, err := store.GetActiveProxyClusterAddresses(ctx)
require.NoError(t, err)
assert.Contains(t, addresses, proxy.ClusterAddress)
})
}
func TestSqlStore_GetAllProxies_Empty(t *testing.T) {
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
all, err := store.GetAllProxies(context.Background())
require.NoError(t, err)
assert.Empty(t, all)
})
}

View File

@@ -323,8 +323,6 @@ type Store interface {
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetAllProxies(ctx context.Context) ([]*proxy.Proxy, error)
DisconnectAllProxies(ctx context.Context) (int64, error)
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)

View File

@@ -745,21 +745,6 @@ func (mr *MockStoreMockRecorder) DeleteZoneDNSRecords(ctx, accountID, zoneID int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID)
}
// DisconnectAllProxies mocks base method.
func (m *MockStore) DisconnectAllProxies(ctx context.Context) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DisconnectAllProxies", ctx)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DisconnectAllProxies indicates an expected call of DisconnectAllProxies.
func (mr *MockStoreMockRecorder) DisconnectAllProxies(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectAllProxies", reflect.TypeOf((*MockStore)(nil).DisconnectAllProxies), ctx)
}
// DisconnectProxy mocks base method.
func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
m.ctrl.T.Helper()
@@ -1404,21 +1389,6 @@ func (mr *MockStoreMockRecorder) GetAllEphemeralPeers(ctx, lockStrength interfac
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllEphemeralPeers", reflect.TypeOf((*MockStore)(nil).GetAllEphemeralPeers), ctx, lockStrength)
}
// GetAllProxies mocks base method.
func (m *MockStore) GetAllProxies(ctx context.Context) ([]*proxy.Proxy, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAllProxies", ctx)
ret0, _ := ret[0].([]*proxy.Proxy)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAllProxies indicates an expected call of GetAllProxies.
func (mr *MockStoreMockRecorder) GetAllProxies(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllProxies", reflect.TypeOf((*MockStore)(nil).GetAllProxies), ctx)
}
// GetAllProxyAccessTokens mocks base method.
func (m *MockStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types2.ProxyAccessToken, error) {
m.ctrl.T.Helper()
@@ -1551,21 +1521,6 @@ func (mr *MockStoreMockRecorder) GetDNSRecordByID(ctx, lockStrength, accountID,
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSRecordByID", reflect.TypeOf((*MockStore)(nil).GetDNSRecordByID), ctx, lockStrength, accountID, zoneID, recordID)
}
// GetEmbeddedProxyPeerIDsByCluster mocks base method.
func (m *MockStore) GetEmbeddedProxyPeerIDsByCluster(ctx context.Context, accountID string) (map[string][]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetEmbeddedProxyPeerIDsByCluster", ctx, accountID)
ret0, _ := ret[0].(map[string][]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetEmbeddedProxyPeerIDsByCluster indicates an expected call of GetEmbeddedProxyPeerIDsByCluster.
func (mr *MockStoreMockRecorder) GetEmbeddedProxyPeerIDsByCluster(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEmbeddedProxyPeerIDsByCluster", reflect.TypeOf((*MockStore)(nil).GetEmbeddedProxyPeerIDsByCluster), ctx, accountID)
}
// GetExpiredEphemeralServices mocks base method.
func (m *MockStore) GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*service.Service, error) {
m.ctrl.T.Helper()
@@ -1611,21 +1566,6 @@ func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, accountID, gr
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, accountID, groupName)
}
// GetGroupIDsByPeerIDs mocks base method.
func (m *MockStore) GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupIDsByPeerIDs", ctx, accountID, peerIDs)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupIDsByPeerIDs indicates an expected call of GetGroupIDsByPeerIDs.
func (mr *MockStoreMockRecorder) GetGroupIDsByPeerIDs(ctx, accountID, peerIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupIDsByPeerIDs", reflect.TypeOf((*MockStore)(nil).GetGroupIDsByPeerIDs), ctx, accountID, peerIDs)
}
// GetGroupsByIDs mocks base method.
func (m *MockStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types2.Group, error) {
m.ctrl.T.Helper()
@@ -1910,21 +1850,6 @@ func (mr *MockStoreMockRecorder) GetPeerIDByKey(ctx, lockStrength, key interface
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerIDByKey", reflect.TypeOf((*MockStore)(nil).GetPeerIDByKey), ctx, lockStrength, key)
}
// GetPeerIDsByGroups mocks base method.
func (m *MockStore) GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeerIDsByGroups", ctx, accountID, groupIDs)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPeerIDsByGroups indicates an expected call of GetPeerIDsByGroups.
func (mr *MockStoreMockRecorder) GetPeerIDsByGroups(ctx, accountID, groupIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerIDsByGroups", reflect.TypeOf((*MockStore)(nil).GetPeerIDsByGroups), ctx, accountID, groupIDs)
}
// GetPeerIdByLabel mocks base method.
func (m *MockStore) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID, hostname string) (string, error) {
m.ctrl.T.Helper()
@@ -2000,6 +1925,51 @@ func (mr *MockStoreMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupIDs int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockStore)(nil).GetPeersByGroupIDs), ctx, accountID, groupIDs)
}
// GetPeerIDsByGroups mocks base method.
func (m *MockStore) GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeerIDsByGroups", ctx, accountID, groupIDs)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPeerIDsByGroups indicates an expected call of GetPeerIDsByGroups.
func (mr *MockStoreMockRecorder) GetPeerIDsByGroups(ctx, accountID, groupIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerIDsByGroups", reflect.TypeOf((*MockStore)(nil).GetPeerIDsByGroups), ctx, accountID, groupIDs)
}
// GetGroupIDsByPeerIDs mocks base method.
func (m *MockStore) GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupIDsByPeerIDs", ctx, accountID, peerIDs)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupIDsByPeerIDs indicates an expected call of GetGroupIDsByPeerIDs.
func (mr *MockStoreMockRecorder) GetGroupIDsByPeerIDs(ctx, accountID, peerIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupIDsByPeerIDs", reflect.TypeOf((*MockStore)(nil).GetGroupIDsByPeerIDs), ctx, accountID, peerIDs)
}
// GetEmbeddedProxyPeerIDsByCluster mocks base method.
func (m *MockStore) GetEmbeddedProxyPeerIDsByCluster(ctx context.Context, accountID string) (map[string][]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetEmbeddedProxyPeerIDsByCluster", ctx, accountID)
ret0, _ := ret[0].(map[string][]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetEmbeddedProxyPeerIDsByCluster indicates an expected call of GetEmbeddedProxyPeerIDsByCluster.
func (mr *MockStoreMockRecorder) GetEmbeddedProxyPeerIDsByCluster(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEmbeddedProxyPeerIDsByCluster", reflect.TypeOf((*MockStore)(nil).GetEmbeddedProxyPeerIDsByCluster), ctx, accountID)
}
// GetPeersByIDs mocks base method.
func (m *MockStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*peer.Peer, error) {
m.ctrl.T.Helper()

View File

@@ -1849,17 +1849,12 @@ func (am *DefaultAccountManager) DeleteUserInvite(ctx context.Context, accountID
const minPasswordLength = 8
// validatePassword checks password strength requirements.
func validatePassword(password string) error {
return ValidatePassword(password)
}
// ValidatePassword checks password strength requirements:
// validatePassword checks password strength requirements:
// - Minimum 8 characters
// - At least 1 digit
// - At least 1 uppercase letter
// - At least 1 special character
func ValidatePassword(password string) error {
func validatePassword(password string) error {
if len(password) < minPasswordLength {
return errors.New("password must be at least 8 characters long")
}