Compare commits

..

7 Commits

25 changed files with 2045 additions and 226 deletions

151
combined/cmd/admin.go Normal file
View File

@@ -0,0 +1,151 @@
package cmd
import (
"context"
"fmt"
"github.com/dexidp/dex/storage"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
activitystore "github.com/netbirdio/netbird/management/server/activity/store"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
// newAdminCommands creates the admin command tree with combined-specific resource openers.
func newAdminCommands() *cobra.Command {
return admincmd.NewCommands(admincmd.Openers{
Resources: withAdminResources,
Store: withAdminStoreOnly,
IDP: withAdminIDPOnly,
})
}
func newLegacyTokenCommand() *cobra.Command {
cmd := tokencmd.NewCommands(tokencmd.StoreOpener(withAdminStoreOnly))
cmd.Deprecated = "use 'admin token' instead"
return cmd
}
// withAdminResources loads the combined YAML config, initializes stores, and calls fn.
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
return withAdminConfig(cmd, func(ctx context.Context, cfg *CombinedConfig) error {
mgmtConfig, err := adminManagementConfig(cfg)
if err != nil {
return err
}
managementStore, err := openAdminStore(ctx, cfg)
if err != nil {
return err
}
defer admincmd.CloseStore(ctx, managementStore)
idpStorage, idpStorageFile, err := admincmd.OpenIDPStorage(mgmtConfig)
if err != nil {
return err
}
defer admincmd.CloseIDPStorage(idpStorage)
eventStore, esErr := openAdminEventStore(ctx, cfg, mgmtConfig)
if esErr != nil {
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Warning: audit events will not be recorded: %v\n", esErr)
}
if eventStore != nil {
defer func() {
if err := eventStore.Close(ctx); err != nil {
log.Debugf("close activity event store: %v", err)
}
}()
}
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage, IDPStorageFile: idpStorageFile, EventStore: eventStore})
})
}
// withAdminStoreOnly opens only the management store for admin subcommands that do not
// need embedded IdP storage.
func withAdminStoreOnly(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
return withAdminConfig(cmd, func(ctx context.Context, cfg *CombinedConfig) error {
managementStore, err := openAdminStore(ctx, cfg)
if err != nil {
return err
}
defer admincmd.CloseStore(ctx, managementStore)
return fn(ctx, managementStore)
})
}
func withAdminIDPOnly(cmd *cobra.Command, fn func(ctx context.Context, idpStorage storage.Storage, storageFile string) error) error {
return withAdminConfig(cmd, func(ctx context.Context, cfg *CombinedConfig) error {
mgmtConfig, err := adminManagementConfig(cfg)
if err != nil {
return err
}
idpStorage, idpStorageFile, err := admincmd.OpenIDPStorage(mgmtConfig)
if err != nil {
return err
}
defer admincmd.CloseIDPStorage(idpStorage)
return fn(ctx, idpStorage, idpStorageFile)
})
}
func withAdminConfig(cmd *cobra.Command, fn func(ctx context.Context, cfg *CombinedConfig) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
cfg, err := LoadConfig(configPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
cfg.ApplyAdminDefaults()
applyServerStoreEnv(cfg.Server.Store)
return fn(ctx, cfg)
}
func adminManagementConfig(cfg *CombinedConfig) (*nbconfig.Config, error) {
mgmtConfig, err := cfg.ToManagementConfig()
if err != nil {
return nil, fmt.Errorf("create management config: %w", err)
}
return mgmtConfig, nil
}
func openAdminStore(ctx context.Context, cfg *CombinedConfig) (store.Store, error) {
managementStore, err := store.NewStore(ctx, types.Engine(cfg.Management.Store.Engine), cfg.Management.DataDir, nil, true)
if err != nil {
return nil, fmt.Errorf("create store: %w", err)
}
return managementStore, nil
}
func openAdminEventStore(ctx context.Context, cfg *CombinedConfig, config *nbconfig.Config) (activity.Store, error) {
if config.DataStoreEncryptionKey == "" {
return nil, fmt.Errorf("data store encryption key is not configured")
}
if err := applyActivityStoreEnv(cfg.Server.ActivityStore); err != nil {
return nil, fmt.Errorf("configure activity event store: %w", err)
}
eventStore, err := activitystore.NewSqlStore(ctx, config.Datadir, config.DataStoreEncryptionKey)
if err != nil {
return nil, fmt.Errorf("open activity event store: %w", err)
}
if eventStore == nil {
return nil, fmt.Errorf("open activity event store: returned nil store")
}
return eventStore, nil
}

View File

@@ -0,0 +1,47 @@
package cmd
import (
"context"
"os"
"testing"
"github.com/stretchr/testify/require"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
)
func TestApplyAdminDefaultsCopiesServerStoreWithoutExposedAddress(t *testing.T) {
cfg := DefaultConfig()
cfg.Server.ExposedAddress = ""
cfg.Server.DataDir = "/srv/netbird"
cfg.Server.Store = StoreConfig{
Engine: "postgres",
DSN: "postgres://user:pass@example.com/netbird",
}
cfg.ApplyAdminDefaults()
require.Equal(t, "/srv/netbird", cfg.Management.DataDir)
require.Equal(t, "postgres", cfg.Management.Store.Engine)
require.Equal(t, cfg.Server.Store.DSN, cfg.Management.Store.DSN)
}
func TestOpenAdminEventStoreMissingEncryptionKeyReturnsNilInterface(t *testing.T) {
eventStore, err := openAdminEventStore(context.Background(), &CombinedConfig{}, &nbconfig.Config{})
require.Error(t, err)
require.Contains(t, err.Error(), "encryption key")
require.Nil(t, eventStore)
}
func TestApplyServerStoreEnv(t *testing.T) {
t.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", "")
t.Setenv("NB_STORE_ENGINE_MYSQL_DSN", "")
t.Setenv("NB_STORE_ENGINE_SQLITE_FILE", "")
applyServerStoreEnv(StoreConfig{Engine: "postgres", DSN: "postgres-dsn", File: "store.db"})
require.Equal(t, "postgres-dsn", os.Getenv("NB_STORE_ENGINE_POSTGRES_DSN"))
require.Equal(t, "store.db", os.Getenv("NB_STORE_ENGINE_SQLITE_FILE"))
applyServerStoreEnv(StoreConfig{Engine: "mysql", DSN: "mysql-dsn"})
require.Equal(t, "mysql-dsn", os.Getenv("NB_STORE_ENGINE_MYSQL_DSN"))
}

View File

@@ -6,8 +6,7 @@ import (
"net"
"net/netip"
"os"
"path"
"path/filepath"
filePath "path/filepath"
"strings"
"time"
@@ -299,6 +298,19 @@ func (c *CombinedConfig) ApplySimplifiedDefaults() {
c.autoConfigureClientSettings(exposedProto, exposedHost, exposedHostPort, hasExternalStuns, hasExternalRelay, hasExternalSignal)
}
// ApplyAdminDefaults applies the management settings needed by admin commands even
// when the full server config is invalid and ApplySimplifiedDefaults cannot run.
func (c *CombinedConfig) ApplyAdminDefaults() {
if c.Management.DataDir == "" || c.Management.DataDir == "/var/lib/netbird/" {
c.Management.DataDir = c.Server.DataDir
}
if c.Management.Store.Engine == "" || c.Management.Store.Engine == "sqlite" {
if c.Server.Store.Engine != "" || c.Server.Store.File != "" || c.Server.Store.DSN != "" {
c.Management.Store = c.Server.Store
}
}
}
// applyRelayDefaults configures the relay service if no external relay is configured.
func (c *CombinedConfig) applyRelayDefaults(exposedProto, exposedHostPort string, hasExternalRelay, hasExternalStuns bool) {
if hasExternalRelay {
@@ -576,11 +588,11 @@ func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.Emb
return nil, fmt.Errorf("authStore.dsn is required when authStore.engine is postgres")
}
} else {
authStorageFile = path.Join(mgmt.DataDir, "idp.db")
authStorageFile = filePath.Join(mgmt.DataDir, "idp.db")
if c.Server.AuthStore.File != "" {
authStorageFile = c.Server.AuthStore.File
if !filepath.IsAbs(authStorageFile) {
authStorageFile = filepath.Join(mgmt.DataDir, authStorageFile)
if !filePath.IsAbs(authStorageFile) {
authStorageFile = filePath.Join(mgmt.DataDir, authStorageFile)
}
}
}
@@ -727,7 +739,7 @@ func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config, mgmtPort
cfg.EmbeddedIdP.Storage.Type = "sqlite3"
}
if cfg.EmbeddedIdP.Storage.Config.File == "" && cfg.Datadir != "" {
cfg.EmbeddedIdP.Storage.Config.File = path.Join(cfg.Datadir, "idp.db")
cfg.EmbeddedIdP.Storage.Config.File = filePath.Join(cfg.Datadir, "idp.db")
}
issuer := cfg.EmbeddedIdP.Issuer

View File

@@ -64,7 +64,8 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
_ = rootCmd.MarkPersistentFlagRequired("config")
rootCmd.AddCommand(newTokenCommands())
rootCmd.AddCommand(newAdminCommands())
rootCmd.AddCommand(newLegacyTokenCommand())
}
func RootCmd() *cobra.Command {
@@ -122,6 +123,37 @@ func execute(cmd *cobra.Command, _ []string) error {
}
// initializeConfig loads and validates the configuration, then initializes logging.
func applyServerStoreEnv(storeConfig StoreConfig) {
if dsn := storeConfig.DSN; dsn != "" {
switch strings.ToLower(storeConfig.Engine) {
case "postgres":
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
case "mysql":
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
if file := storeConfig.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
}
func applyActivityStoreEnv(storeConfig StoreConfig) error {
if engine := storeConfig.Engine; engine != "" {
engineLower := strings.ToLower(engine)
if engineLower == "postgres" && storeConfig.DSN == "" {
return fmt.Errorf("activityStore.dsn is required when activityStore.engine is postgres")
}
os.Setenv("NB_ACTIVITY_EVENT_STORE_ENGINE", engineLower)
if dsn := storeConfig.DSN; dsn != "" {
os.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN", dsn)
}
}
if file := storeConfig.File; file != "" {
os.Setenv("NB_ACTIVITY_EVENT_SQLITE_FILE", file)
}
return nil
}
func initializeConfig() error {
var err error
config, err = LoadConfig(configPath)
@@ -137,30 +169,10 @@ func initializeConfig() error {
return fmt.Errorf("failed to initialize log: %w", err)
}
if dsn := config.Server.Store.DSN; dsn != "" {
switch strings.ToLower(config.Server.Store.Engine) {
case "postgres":
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
case "mysql":
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
if file := config.Server.Store.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
applyServerStoreEnv(config.Server.Store)
if engine := config.Server.ActivityStore.Engine; engine != "" {
engineLower := strings.ToLower(engine)
if engineLower == "postgres" && config.Server.ActivityStore.DSN == "" {
return fmt.Errorf("activityStore.dsn is required when activityStore.engine is postgres")
}
os.Setenv("NB_ACTIVITY_EVENT_STORE_ENGINE", engineLower)
if dsn := config.Server.ActivityStore.DSN; dsn != "" {
os.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN", dsn)
}
}
if file := config.Server.ActivityStore.File; file != "" {
os.Setenv("NB_ACTIVITY_EVENT_SQLITE_FILE", file)
if err := applyActivityStoreEnv(config.Server.ActivityStore); err != nil {
return err
}
log.Infof("Starting combined NetBird server")

View File

@@ -1,63 +0,0 @@
package cmd
import (
"context"
"fmt"
"os"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
// newTokenCommands creates the token command tree with combined-specific store opener.
func newTokenCommands() *cobra.Command {
return tokencmd.NewCommands(withTokenStore)
}
// withTokenStore loads the combined YAML config, initializes the store, and calls fn.
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
cfg, err := LoadConfig(configPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
if dsn := cfg.Server.Store.DSN; dsn != "" {
switch strings.ToLower(cfg.Server.Store.Engine) {
case "postgres":
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
case "mysql":
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
if file := cfg.Server.Store.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
datadir := cfg.Management.DataDir
engine := types.Engine(cfg.Management.Store.Engine)
s, err := store.NewStore(ctx, engine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, s)
}

View File

@@ -41,7 +41,7 @@ type Config struct {
GRPCAddr string
}
const localConnectorID = "local"
const LocalConnectorID = "local"
// Provider wraps a Dex server
type Provider struct {
@@ -495,18 +495,60 @@ func (p *Provider) Storage() storage.Storage {
return p.storage
}
// SetClientsMFAChain updates the MFAChain field on OAuth2 clients in Dex storage.
// Pass a non-empty slice (e.g. []string{"default-totp"}) to enable MFA, or nil to disable it.
func SetClientsMFAChain(ctx context.Context, st storage.Storage, clientIDs []string, mfaChain []string) error {
previousChains := make(map[string][]string, len(clientIDs))
for _, clientID := range clientIDs {
client, err := st.GetClient(ctx, clientID)
if err != nil {
return fmt.Errorf("failed to get client %s before MFA chain update: %w", clientID, err)
}
previousChains[clientID] = cloneMFAChain(client.MFAChain)
}
updatedClientIDs := make([]string, 0, len(clientIDs))
for _, clientID := range clientIDs {
if err := st.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
old.MFAChain = cloneMFAChain(mfaChain)
return old, nil
}); err != nil {
if rollbackErr := rollbackClientsMFAChain(ctx, st, updatedClientIDs, previousChains); rollbackErr != nil {
return fmt.Errorf("failed to update MFA chain on client %s: %w (also failed to roll back previous MFA chains: %v)", clientID, err, rollbackErr)
}
return fmt.Errorf("failed to update MFA chain on client %s: %w", clientID, err)
}
updatedClientIDs = append(updatedClientIDs, clientID)
}
return nil
}
func rollbackClientsMFAChain(ctx context.Context, st storage.Storage, clientIDs []string, previousChains map[string][]string) error {
var rollbackErrs []error
for i := len(clientIDs) - 1; i >= 0; i-- {
clientID := clientIDs[i]
previousChain := cloneMFAChain(previousChains[clientID])
if err := st.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
old.MFAChain = previousChain
return old, nil
}); err != nil {
rollbackErrs = append(rollbackErrs, fmt.Errorf("client %s: %w", clientID, err))
}
}
return errors.Join(rollbackErrs...)
}
func cloneMFAChain(chain []string) []string {
if chain == nil {
return nil
}
return append([]string(nil), chain...)
}
// SetClientsMFAChain updates the MFAChain field on the dashboard and CLI OAuth2 clients.
// Pass a non-empty slice (e.g. []string{"default-totp"}) to enable MFA, or nil to disable it.
func (p *Provider) SetClientsMFAChain(ctx context.Context, clientIDs []string, mfaChain []string) error {
for _, clientID := range clientIDs {
if err := p.storage.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
old.MFAChain = mfaChain
return old, nil
}); err != nil {
return fmt.Errorf("failed to update MFA chain on client %s: %w", clientID, err)
}
}
return nil
return SetClientsMFAChain(ctx, p.storage, clientIDs, mfaChain)
}
// Handler returns the Dex server as an http.Handler for embedding in another server.
@@ -546,7 +588,7 @@ func (p *Provider) CreateUser(ctx context.Context, email, username, password str
// Encode the user ID in Dex's format: base64(protobuf{user_id, connector_id})
// This matches the format Dex uses in JWT tokens
encodedID := EncodeDexUserID(userID, localConnectorID)
encodedID := EncodeDexUserID(userID, LocalConnectorID)
return encodedID, nil
}
@@ -625,7 +667,7 @@ func DecodeDexUserID(encodedID string) (userID, connectorID string, err error) {
// local password connector.
func IsLocalUserID(encodedID string) bool {
_, connectorID, err := DecodeDexUserID(encodedID)
return err == nil && connectorID == localConnectorID
return err == nil && connectorID == LocalConnectorID
}
// GetUser returns a user by email

View File

@@ -3,6 +3,8 @@ package dex
import (
"context"
"encoding/json"
"errors"
"io"
"log/slog"
"net/http"
"net/http/httptest"
@@ -11,11 +13,44 @@ import (
"testing"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/memory"
sqllib "github.com/dexidp/dex/storage/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type updateFailingStorage struct {
storage.Storage
failClientID string
}
func (s *updateFailingStorage) UpdateClient(ctx context.Context, id string, updater func(storage.Client) (storage.Client, error)) error {
if id == s.failClientID {
return errors.New("forced update failure")
}
return s.Storage.UpdateClient(ctx, id, updater)
}
func TestSetClientsMFAChainRollsBackUpdatedClients(t *testing.T) {
ctx := context.Background()
st := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
require.NoError(t, st.CreateClient(ctx, storage.Client{ID: "client-1", MFAChain: []string{"old-1"}}))
require.NoError(t, st.CreateClient(ctx, storage.Client{ID: "client-2", MFAChain: []string{"old-2"}}))
err := SetClientsMFAChain(ctx, &updateFailingStorage{Storage: st, failClientID: "client-2"}, []string{"client-1", "client-2"}, []string{"new"})
require.Error(t, err)
require.Contains(t, err.Error(), "failed to update MFA chain on client client-2")
client1, err := st.GetClient(ctx, "client-1")
require.NoError(t, err)
require.Equal(t, []string{"old-1"}, client1.MFAChain)
client2, err := st.GetClient(ctx, "client-2")
require.NoError(t, err)
require.Equal(t, []string{"old-2"}, client2.MFAChain)
}
func TestUserCreationFlow(t *testing.T) {
ctx := context.Background()

View File

@@ -556,7 +556,7 @@ start_services_and_show_instructions() {
echo "Creating proxy access token..."
# Use docker exec with bash to run the token command directly
PROXY_TOKEN=$($DOCKER_COMPOSE_COMMAND exec -T netbird-server \
/go/bin/netbird-server token create --name "default-proxy" --config /etc/netbird/config.yaml 2>/dev/null | grep "^Token:" | awk '{print $2}')
/go/bin/netbird-server admin token create --name "default-proxy" --config /etc/netbird/config.yaml 2>/dev/null | grep "^Token:" | awk '{print $2}')
if [[ -z "$PROXY_TOKEN" ]]; then
echo "ERROR: Failed to create proxy token. Check netbird-server logs." > /dev/stderr

177
management/cmd/admin.go Normal file
View File

@@ -0,0 +1,177 @@
package cmd
import (
"context"
"fmt"
"path"
"path/filepath"
"github.com/dexidp/dex/storage"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
activitystore "github.com/netbirdio/netbird/management/server/activity/store"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
)
var adminDatadir string
// newAdminCommands creates the admin command tree with management-specific resource openers.
func newAdminCommands() *cobra.Command {
cmd := admincmd.NewCommands(admincmd.Openers{
Resources: withAdminResources,
Store: withAdminStoreOnly,
IDP: withAdminIDPOnly,
})
cmd.PersistentFlags().StringVar(&adminDatadir, "datadir", "", "Override the data directory from config (used for store.db and the default idp.db)")
return cmd
}
func newLegacyTokenCommand() *cobra.Command {
cmd := tokencmd.NewCommands(tokencmd.StoreOpener(withAdminStoreOnly))
cmd.Deprecated = "use 'admin token' instead"
cmd.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
return cmd
}
// withAdminResources initializes logging, loads config, opens the management store
// and embedded IdP storage, and calls fn.
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
return withAdminConfig(cmd, true, func(ctx context.Context, config *nbconfig.Config, datadir string) error {
managementStore, err := openAdminStore(ctx, config, datadir)
if err != nil {
return err
}
defer admincmd.CloseStore(ctx, managementStore)
idpStorage, idpStorageFile, err := admincmd.OpenIDPStorage(config)
if err != nil {
return err
}
defer admincmd.CloseIDPStorage(idpStorage)
eventStore, esErr := openAdminEventStore(ctx, config, datadir)
if esErr != nil {
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Warning: audit events will not be recorded: %v\n", esErr)
}
if eventStore != nil {
defer func() {
if err := eventStore.Close(ctx); err != nil {
log.Debugf("close activity event store: %v", err)
}
}()
}
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage, IDPStorageFile: idpStorageFile, EventStore: eventStore})
})
}
// withAdminStoreOnly opens only the management store for admin subcommands that do not
// need embedded IdP storage.
func withAdminStoreOnly(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
return withAdminConfig(cmd, false, func(ctx context.Context, config *nbconfig.Config, datadir string) error {
managementStore, err := openAdminStore(ctx, config, datadir)
if err != nil {
return err
}
defer admincmd.CloseStore(ctx, managementStore)
return fn(ctx, managementStore)
})
}
func withAdminIDPOnly(cmd *cobra.Command, fn func(ctx context.Context, idpStorage storage.Storage, storageFile string) error) error {
return withAdminConfig(cmd, true, func(ctx context.Context, config *nbconfig.Config, _ string) error {
idpStorage, idpStorageFile, err := admincmd.OpenIDPStorage(config)
if err != nil {
return err
}
defer admincmd.CloseIDPStorage(idpStorage)
return fn(ctx, idpStorage, idpStorageFile)
})
}
func withAdminConfig(cmd *cobra.Command, applyIDPDefaults bool, fn func(ctx context.Context, config *nbconfig.Config, datadir string) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
config, datadir, err := loadAdminMgmtConfig(ctx, applyIDPDefaults)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
return fn(ctx, config, datadir)
}
func loadAdminMgmtConfig(ctx context.Context, applyIDPDefaults bool) (*nbconfig.Config, string, error) {
config := &nbconfig.Config{}
if _, err := util.ReadJsonWithEnvSub(nbconfig.MgmtConfigPath, config); err != nil {
return nil, "", err
}
if applyIDPDefaults {
if err := ApplyEmbeddedIdPConfig(ctx, config); err != nil {
return nil, "", err
}
}
datadir := config.Datadir
applyAdminDatadirOverride(config, &datadir)
return config, datadir, nil
}
func applyAdminDatadirOverride(config *nbconfig.Config, datadir *string) {
if adminDatadir == "" {
return
}
oldDatadir := *datadir
*datadir = adminDatadir
if config.EmbeddedIdP != nil && config.EmbeddedIdP.Storage.Type == "sqlite3" && isDefaultIDPStorageFile(config.EmbeddedIdP.Storage.Config.File, oldDatadir) {
config.EmbeddedIdP.Storage.Config.File = filepath.Join(*datadir, "idp.db")
}
}
func isDefaultIDPStorageFile(file, datadir string) bool {
if file == "" {
return true
}
defaultFile := filepath.Join(datadir, "idp.db")
legacyDefaultFile := path.Join(datadir, "idp.db")
legacySlashDefaultFile := path.Join(filepath.ToSlash(datadir), "idp.db")
return filepath.Clean(file) == filepath.Clean(defaultFile) ||
file == legacyDefaultFile ||
filepath.ToSlash(file) == legacySlashDefaultFile
}
func openAdminStore(ctx context.Context, config *nbconfig.Config, datadir string) (store.Store, error) {
managementStore, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
if err != nil {
return nil, fmt.Errorf("create store: %w", err)
}
return managementStore, nil
}
func openAdminEventStore(ctx context.Context, config *nbconfig.Config, datadir string) (activity.Store, error) {
if config.DataStoreEncryptionKey == "" {
return nil, fmt.Errorf("data store encryption key is not configured")
}
eventStore, err := activitystore.NewSqlStore(ctx, datadir, config.DataStoreEncryptionKey)
if err != nil {
return nil, fmt.Errorf("open activity event store: %w", err)
}
if eventStore == nil {
return nil, fmt.Errorf("open activity event store: returned nil store")
}
return eventStore, nil
}

View File

@@ -0,0 +1,577 @@
// Package admincmd provides reusable cobra commands for self-hosted administrator helpers.
// Both the management and combined binaries use these commands, each providing
// their own opener to handle config loading and storage initialization.
package admincmd
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"os"
"strings"
"time"
"github.com/dexidp/dex/storage"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/crypto/bcrypt"
"github.com/netbirdio/netbird/formatter/hook"
nbdex "github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/cmd/proxy"
"github.com/netbirdio/netbird/management/cmd/token"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
// Resources contains the storages required by the admin commands.
type Resources struct {
Store store.Store
IDPStorage storage.Storage
IDPStorageFile string
EventStore activity.Store
}
// Opener initializes command resources from the command context and calls fn.
type Opener func(cmd *cobra.Command, fn func(ctx context.Context, resources Resources) error) error
// StoreOpener initializes only the management store from the command context and calls fn.
type StoreOpener func(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error
// IDPOpener initializes only the embedded IdP storage from the command context and calls fn.
type IDPOpener func(cmd *cobra.Command, fn func(ctx context.Context, idpStorage storage.Storage, storageFile string) error) error
// Openers contains the resource openers needed by the admin command tree.
type Openers struct {
Resources Opener
Store StoreOpener
IDP IDPOpener
}
type userSelector struct {
email string
userID string
}
func (s userSelector) normalized() userSelector {
return userSelector{
email: strings.TrimSpace(s.email),
userID: strings.TrimSpace(s.userID),
}
}
func (s userSelector) validate() error {
s = s.normalized()
if (s.email == "") == (s.userID == "") {
return fmt.Errorf("provide exactly one of --email or --user-id")
}
return nil
}
// NewCommands creates the admin command tree with the given resource openers.
func NewCommands(openers Openers) *cobra.Command {
adminCmd := &cobra.Command{
Use: "admin",
Short: "Self-hosted administrator helpers",
Long: "Administrative helpers for self-hosted deployments using the embedded identity provider.",
}
userCmd := &cobra.Command{
Use: "user",
Short: "Manage local embedded IdP users",
}
var passwordSelector userSelector
var password string
var passwordFile string
passwordCmd := &cobra.Command{
Use: "change-password (--email email | --user-id id) (--password password | --password-file path)",
Aliases: []string{"set-password"},
Short: "Change a local user's password",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
if err := passwordSelector.validate(); err != nil {
return err
}
newPassword, err := resolvePasswordInput(cmd, password, passwordFile)
if err != nil {
return err
}
return openers.IDP(cmd, func(ctx context.Context, idpStorage storage.Storage, storageFile string) error {
return runChangePassword(ctx, idpStorage, cmd.OutOrStdout(), passwordSelector, newPassword, storageFile)
})
},
}
addUserSelectorFlags(passwordCmd, &passwordSelector)
passwordCmd.Flags().StringVar(&password, "password", "", "New password for the user")
passwordCmd.Flags().StringVar(&passwordFile, "password-file", "", "Read new password from file ('-' for stdin)")
var resetSelector userSelector
resetMFACmd := &cobra.Command{
Use: "reset-mfa (--email email | --user-id id)",
Short: "Reset a local user's MFA enrollment",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
if err := resetSelector.validate(); err != nil {
return err
}
return openers.IDP(cmd, func(ctx context.Context, idpStorage storage.Storage, storageFile string) error {
return runResetMFA(ctx, idpStorage, cmd.OutOrStdout(), resetSelector, storageFile)
})
},
}
addUserSelectorFlags(resetMFACmd, &resetSelector)
userCmd.AddCommand(passwordCmd, resetMFACmd)
mfaCmd := &cobra.Command{
Use: "mfa",
Short: "Manage local MFA for embedded IdP users",
}
enableCmd := &cobra.Command{
Use: "enable",
Short: "Enable MFA for local embedded IdP users",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return openers.Resources(cmd, func(ctx context.Context, resources Resources) error {
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), true)
})
},
}
disableCmd := &cobra.Command{
Use: "disable",
Short: "Disable MFA for local embedded IdP users",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return openers.Resources(cmd, func(ctx context.Context, resources Resources) error {
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), false)
})
},
}
statusCmd := &cobra.Command{
Use: "status",
Short: "Show local MFA status",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return openers.Resources(cmd, func(ctx context.Context, resources Resources) error {
return runMFAStatus(ctx, resources, cmd.OutOrStdout())
})
},
}
mfaCmd.AddCommand(enableCmd, disableCmd, statusCmd)
adminCmd.AddCommand(userCmd, mfaCmd)
if openers.Store != nil {
adminCmd.AddCommand(tokencmd.NewCommands(tokencmd.StoreOpener(openers.Store)))
adminCmd.AddCommand(proxycmd.NewCommands(proxycmd.StoreOpener(openers.Store)))
}
return adminCmd
}
// OpenEmbeddedIDPStorage opens the Dex storage configured for the embedded IdP.
func OpenEmbeddedIDPStorage(cfg *idp.EmbeddedIdPConfig) (storage.Storage, error) {
if cfg == nil || !cfg.Enabled {
return nil, fmt.Errorf("admin commands require the embedded IdP to be enabled")
}
yamlConfig, err := cfg.ToYAMLConfig()
if err != nil {
return nil, fmt.Errorf("build embedded IdP config: %w", err)
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
st, err := yamlConfig.Storage.OpenStorage(logger)
if err != nil {
return nil, fmt.Errorf("open embedded IdP storage: %w", err)
}
return st, nil
}
// CloseStore closes the management store and logs cleanup errors at debug level.
func CloseStore(ctx context.Context, s store.Store) {
if s == nil {
return
}
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}
// OpenIDPStorage opens embedded IdP storage and returns its sqlite file path when applicable.
func OpenIDPStorage(config *nbconfig.Config) (storage.Storage, string, error) {
if config == nil {
return nil, "", fmt.Errorf("management config is required")
}
idpStorage, err := OpenEmbeddedIDPStorage(config.EmbeddedIdP)
if err != nil {
return nil, "", err
}
return idpStorage, embeddedIDPStorageFile(config), nil
}
func embeddedIDPStorageFile(config *nbconfig.Config) string {
if config.EmbeddedIdP == nil || config.EmbeddedIdP.Storage.Type != "sqlite3" {
return ""
}
return config.EmbeddedIdP.Storage.Config.File
}
// CloseIDPStorage closes embedded IdP storage and logs cleanup errors at debug level.
func CloseIDPStorage(s storage.Storage) {
if s == nil {
return
}
if err := s.Close(); err != nil {
log.Debugf("close embedded IdP storage: %v", err)
}
}
func addUserSelectorFlags(cmd *cobra.Command, selector *userSelector) {
cmd.Flags().StringVar(&selector.email, "email", "", "User email")
cmd.Flags().StringVar(&selector.userID, "user-id", "", "User ID")
}
func resolvePasswordInput(cmd *cobra.Command, password, passwordFile string) (string, error) {
if password != "" && passwordFile != "" {
return "", fmt.Errorf("provide only one of --password or --password-file")
}
if passwordFile == "" {
return password, nil
}
var data []byte
var err error
if passwordFile == "-" {
data, err = io.ReadAll(cmd.InOrStdin())
} else {
data, err = os.ReadFile(passwordFile)
}
if err != nil {
return "", fmt.Errorf("read password: %w", err)
}
return strings.TrimRight(string(data), "\r\n"), nil
}
func runChangePassword(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector, password string, idpStorageFile string) error {
if idpStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
selector = selector.normalized()
if err := selector.validate(); err != nil {
return err
}
if password == "" {
return fmt.Errorf("password is required")
}
if err := server.ValidatePassword(password); err != nil {
return fmt.Errorf("invalid password: %w", err)
}
user, err := findLocalUser(ctx, idpStorage, selector, idpStorageFile)
if err != nil {
return err
}
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
if err := idpStorage.UpdatePassword(ctx, user.Email, func(old storage.Password) (storage.Password, error) {
old.Hash = hash
return old, nil
}); err != nil {
return fmt.Errorf("update password for %s: %w", user.Email, err)
}
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
return err
}
_, _ = fmt.Fprintf(w, "Password updated for %s.\n", user.Email)
return nil
}
func runResetMFA(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector, idpStorageFile string) error {
if idpStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
selector = selector.normalized()
if err := selector.validate(); err != nil {
return err
}
user, err := findLocalUser(ctx, idpStorage, selector, idpStorageFile)
if err != nil {
return err
}
reset := false
err = idpStorage.UpdateUserIdentity(ctx, user.UserID, idp.LocalConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
reset = reset || len(old.MFASecrets) > 0 || len(old.WebAuthnCredentials) > 0
old.MFASecrets = map[string]*storage.MFASecret{}
old.WebAuthnCredentials = map[string][]storage.WebAuthnCredential{}
return old, nil
})
if errors.Is(err, storage.ErrNotFound) {
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
return err
}
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
return nil
}
if err != nil {
return fmt.Errorf("reset MFA for %s: %w", user.Email, err)
}
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
return err
}
if reset {
_, _ = fmt.Fprintf(w, "MFA reset for %s. The user will re-enroll at next login.\n", user.Email)
} else {
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
}
return nil
}
func runSetMFAEnabled(ctx context.Context, resources Resources, w io.Writer, enabled bool) error {
if resources.Store == nil {
return fmt.Errorf("management store is required")
}
if resources.IDPStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
accountID, settings, err := getSingleAccountSettings(ctx, resources.Store)
if err != nil {
return err
}
oldEnabled := settings.LocalMfaEnabled
newSettings := settings.Copy()
newSettings.LocalMfaEnabled = enabled
if err := setIDPClientsMFA(ctx, resources.IDPStorage, enabled); err != nil {
return err
}
if err := resources.Store.SaveAccountSettings(ctx, accountID, newSettings); err != nil {
if rollbackErr := setIDPClientsMFA(ctx, resources.IDPStorage, oldEnabled); rollbackErr != nil {
return fmt.Errorf("save local MFA account setting: %w (also failed to roll back embedded IdP MFA state: %v)", err, rollbackErr)
}
return fmt.Errorf("save local MFA account setting: %w", err)
}
if err := storeMFAActivity(ctx, resources.EventStore, accountID, enabled); err != nil {
_, _ = fmt.Fprintf(w, "Warning: failed to record audit event: %v\n", err)
}
state := "disabled"
if enabled {
state = "enabled"
}
_, _ = fmt.Fprintf(w, "Local MFA %s.\n", state)
return nil
}
func runMFAStatus(ctx context.Context, resources Resources, w io.Writer) error {
if resources.Store == nil {
return fmt.Errorf("management store is required")
}
if resources.IDPStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
_, settings, err := getSingleAccountSettings(ctx, resources.Store)
if err != nil {
return err
}
accountStatus := "disabled"
if settings.LocalMfaEnabled {
accountStatus = "enabled"
}
clientStatus, err := idpClientsMFAStatus(ctx, resources.IDPStorage)
if err != nil {
return err
}
_, _ = fmt.Fprintf(w, "Account setting: %s\n", accountStatus)
_, _ = fmt.Fprintf(w, "Embedded IdP clients: %s\n", clientStatus)
return nil
}
func getSingleAccountSettings(ctx context.Context, s store.Store) (string, *types.Settings, error) {
count, err := s.GetAccountsCounter(ctx)
if err != nil {
return "", nil, fmt.Errorf("count accounts: %w", err)
}
if count != 1 {
return "", nil, fmt.Errorf("expected exactly one account, got %d; local MFA is supported only in single-account embedded IdP deployments", count)
}
accountID, err := s.GetAnyAccountID(ctx)
if err != nil {
return "", nil, fmt.Errorf("get account ID: %w", err)
}
settings, err := s.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return "", nil, fmt.Errorf("get account settings: %w", err)
}
if settings == nil {
settings = &types.Settings{}
}
return accountID, settings, nil
}
func storeMFAActivity(ctx context.Context, eventStore activity.Store, accountID string, enabled bool) error {
if eventStore == nil {
return nil
}
event := activity.AccountLocalMfaDisabled
if enabled {
event = activity.AccountLocalMfaEnabled
}
_, err := eventStore.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: event,
InitiatorID: string(hook.SystemSource),
TargetID: accountID,
AccountID: accountID,
})
if err != nil {
return fmt.Errorf("save local MFA audit event: %w", err)
}
return nil
}
func findLocalUser(ctx context.Context, idpStorage storage.Storage, selector userSelector, idpStorageFile string) (storage.Password, error) {
selector = selector.normalized()
if err := selector.validate(); err != nil {
return storage.Password{}, err
}
if selector.email != "" {
user, err := idpStorage.GetPassword(ctx, selector.email)
if errors.Is(err, storage.ErrNotFound) {
if empty, listErr := localUsersEmpty(ctx, idpStorage); listErr != nil {
return storage.Password{}, listErr
} else if empty {
return storage.Password{}, noLocalUsersError(idpStorageFile)
}
return storage.Password{}, fmt.Errorf("local user with email %q not found", selector.email)
}
if err != nil {
return storage.Password{}, fmt.Errorf("get local user by email %q: %w", selector.email, err)
}
return user, nil
}
rawUserID := selector.userID
if decodedUserID, _, err := nbdex.DecodeDexUserID(selector.userID); err == nil && decodedUserID != "" {
rawUserID = decodedUserID
}
users, err := idpStorage.ListPasswords(ctx)
if err != nil {
return storage.Password{}, fmt.Errorf("list local users: %w", err)
}
for _, user := range users {
if user.UserID == rawUserID || user.UserID == selector.userID {
return user, nil
}
}
if len(users) == 0 {
return storage.Password{}, noLocalUsersError(idpStorageFile)
}
return storage.Password{}, fmt.Errorf("local user with ID %q not found", selector.userID)
}
func localUsersEmpty(ctx context.Context, idpStorage storage.Storage) (bool, error) {
users, err := idpStorage.ListPasswords(ctx)
if err != nil {
return false, fmt.Errorf("list local users: %w", err)
}
return len(users) == 0, nil
}
func noLocalUsersError(idpStorageFile string) error {
location := ""
if idpStorageFile != "" {
location = fmt.Sprintf(" (%s)", idpStorageFile)
}
return fmt.Errorf("no local users exist in the embedded IdP storage%s; the management server may never have started with this config, or --datadir points at the wrong location", location)
}
func deleteLocalAuthSession(ctx context.Context, idpStorage storage.Storage, userID string) error {
err := idpStorage.DeleteAuthSession(ctx, userID, idp.LocalConnectorID)
if err == nil || errors.Is(err, storage.ErrNotFound) {
return nil
}
return fmt.Errorf("delete local auth session for user %s: %w", userID, err)
}
func setIDPClientsMFA(ctx context.Context, idpStorage storage.Storage, enabled bool) error {
var mfaChain []string
if enabled {
mfaChain = []string{idp.DefaultTOTPAuthenticatorID}
}
clientIDs := []string{idp.StaticClientCLI, idp.StaticClientDashboard}
if err := nbdex.SetClientsMFAChain(ctx, idpStorage, clientIDs, mfaChain); err != nil {
if errors.Is(err, storage.ErrNotFound) {
return fmt.Errorf("embedded IdP client not found; start the management server once before toggling MFA: %w", err)
}
return fmt.Errorf("update MFA chain on embedded IdP clients: %w", err)
}
return nil
}
func idpClientsMFAStatus(ctx context.Context, idpStorage storage.Storage) (string, error) {
clientIDs := []string{idp.StaticClientCLI, idp.StaticClientDashboard}
enabledCount := 0
for _, clientID := range clientIDs {
client, err := idpStorage.GetClient(ctx, clientID)
if errors.Is(err, storage.ErrNotFound) {
return "unknown", fmt.Errorf("embedded IdP client %q not found", clientID)
}
if err != nil {
return "unknown", fmt.Errorf("get embedded IdP client %q: %w", clientID, err)
}
if hasAuthenticator(client.MFAChain, idp.DefaultTOTPAuthenticatorID) {
enabledCount++
}
}
switch enabledCount {
case 0:
return "disabled", nil
case len(clientIDs):
return "enabled", nil
default:
return "partially enabled", nil
}
}
func hasAuthenticator(chain []string, authenticatorID string) bool {
for _, id := range chain {
if id == authenticatorID {
return true
}
}
return false
}

View File

@@ -0,0 +1,250 @@
package admincmd
import (
"bytes"
"context"
"io"
"log/slog"
"strings"
"testing"
"time"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/memory"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
nbdex "github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/idp"
mgmtstore "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
func newTestIDPStorage(t *testing.T) storage.Storage {
t.Helper()
st := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
hash, err := bcrypt.GenerateFromPassword([]byte("OldPass1!"), bcrypt.DefaultCost)
require.NoError(t, err)
require.NoError(t, st.CreatePassword(context.Background(), storage.Password{
Email: "user@example.com",
Username: "User",
UserID: "user-1",
Hash: hash,
}))
require.NoError(t, st.CreateUserIdentity(context.Background(), storage.UserIdentity{
UserID: "user-1",
ConnectorID: idp.LocalConnectorID,
MFASecrets: map[string]*storage.MFASecret{
idp.DefaultTOTPAuthenticatorID: {
AuthenticatorID: idp.DefaultTOTPAuthenticatorID,
Type: "TOTP",
Secret: "otpauth://totp/NetBird:user@example.com?secret=ABC",
Confirmed: true,
CreatedAt: time.Now(),
},
},
WebAuthnCredentials: map[string][]storage.WebAuthnCredential{
"webauthn": {{CredentialID: []byte("credential")}},
},
}))
require.NoError(t, st.CreateAuthSession(context.Background(), storage.AuthSession{
UserID: "user-1",
ConnectorID: idp.LocalConnectorID,
Nonce: "nonce",
}))
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: idp.StaticClientCLI, Name: "CLI"}))
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: idp.StaticClientDashboard, Name: "Dashboard"}))
return st
}
func TestRunChangePassword(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
var out bytes.Buffer
err := runChangePassword(ctx, st, &out, userSelector{email: "user@example.com"}, "NewPass1!", "")
require.NoError(t, err)
require.Contains(t, out.String(), "Password updated")
user, err := st.GetPassword(ctx, "user@example.com")
require.NoError(t, err)
require.NoError(t, bcrypt.CompareHashAndPassword(user.Hash, []byte("NewPass1!")))
_, err = st.GetAuthSession(ctx, "user-1", idp.LocalConnectorID)
require.ErrorIs(t, err, storage.ErrNotFound)
}
func TestRunChangePasswordValidatesPassword(t *testing.T) {
st := newTestIDPStorage(t)
err := runChangePassword(context.Background(), st, io.Discard, userSelector{email: "user@example.com"}, "short", "")
require.Error(t, err)
require.Contains(t, err.Error(), "invalid password")
}
func TestRunResetMFA(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
var out bytes.Buffer
encodedUserID := nbdex.EncodeDexUserID("user-1", idp.LocalConnectorID)
err := runResetMFA(ctx, st, &out, userSelector{userID: encodedUserID}, "")
require.NoError(t, err)
require.Contains(t, out.String(), "MFA reset")
identity, err := st.GetUserIdentity(ctx, "user-1", idp.LocalConnectorID)
require.NoError(t, err)
require.Empty(t, identity.MFASecrets)
require.Empty(t, identity.WebAuthnCredentials)
_, err = st.GetAuthSession(ctx, "user-1", idp.LocalConnectorID)
require.ErrorIs(t, err, storage.ErrNotFound)
}
func TestRunResetMFAWithoutEnrollment(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
require.NoError(t, st.UpdateUserIdentity(ctx, "user-1", idp.LocalConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
old.MFASecrets = nil
old.WebAuthnCredentials = nil
return old, nil
}))
var out bytes.Buffer
err := runResetMFA(ctx, st, &out, userSelector{email: "user@example.com"}, "")
require.NoError(t, err)
require.Contains(t, out.String(), "No MFA enrollment found")
}
func TestSetIDPClientsMFA(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
require.NoError(t, setIDPClientsMFA(ctx, st, true))
status, err := idpClientsMFAStatus(ctx, st)
require.NoError(t, err)
require.Equal(t, "enabled", status)
require.NoError(t, setIDPClientsMFA(ctx, st, false))
status, err = idpClientsMFAStatus(ctx, st)
require.NoError(t, err)
require.Equal(t, "disabled", status)
}
func newTestManagementStore(t *testing.T, localMFAEnabled bool) mgmtstore.Store {
t.Helper()
ctx := context.Background()
st, err := mgmtstore.NewStore(ctx, types.SqliteStoreEngine, t.TempDir(), nil, false)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, st.Close(ctx)) })
require.NoError(t, st.SaveAccount(ctx, &types.Account{
Id: "account-1",
Settings: &types.Settings{LocalMfaEnabled: localMFAEnabled},
}))
return st
}
func TestRunSetMFAEnabledDoesNotSaveWhenIDPUpdateFails(t *testing.T) {
ctx := context.Background()
managementStore := newTestManagementStore(t, false)
idpStorage := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
err := runSetMFAEnabled(ctx, Resources{Store: managementStore, IDPStorage: idpStorage}, io.Discard, true)
require.Error(t, err)
require.Contains(t, err.Error(), "embedded IdP client")
settings, err := managementStore.GetAccountSettings(ctx, mgmtstore.LockingStrengthNone, "account-1")
require.NoError(t, err)
require.False(t, settings.LocalMfaEnabled)
}
func TestRunSetMFAEnabledUpdatesSettingsAfterIDP(t *testing.T) {
ctx := context.Background()
managementStore := newTestManagementStore(t, false)
idpStorage := newTestIDPStorage(t)
err := runSetMFAEnabled(ctx, Resources{Store: managementStore, IDPStorage: idpStorage}, io.Discard, true)
require.NoError(t, err)
settings, err := managementStore.GetAccountSettings(ctx, mgmtstore.LockingStrengthNone, "account-1")
require.NoError(t, err)
require.True(t, settings.LocalMfaEnabled)
clientStatus, err := idpClientsMFAStatus(ctx, idpStorage)
require.NoError(t, err)
require.Equal(t, "enabled", clientStatus)
}
func TestRunSetMFAEnabledSucceedsWithNilEventStore(t *testing.T) {
ctx := context.Background()
managementStore := newTestManagementStore(t, false)
idpStorage := newTestIDPStorage(t)
var out bytes.Buffer
var err error
require.NotPanics(t, func() {
err = runSetMFAEnabled(ctx, Resources{Store: managementStore, IDPStorage: idpStorage, EventStore: nil}, &out, true)
})
require.NoError(t, err)
require.Contains(t, out.String(), "Local MFA enabled")
settings, err := managementStore.GetAccountSettings(ctx, mgmtstore.LockingStrengthNone, "account-1")
require.NoError(t, err)
require.True(t, settings.LocalMfaEnabled)
}
func TestUserSelectorValidate(t *testing.T) {
require.NoError(t, userSelector{email: " user@example.com "}.validate())
require.NoError(t, userSelector{userID: "user-1"}.validate())
require.Error(t, userSelector{}.validate())
require.Error(t, userSelector{email: "user@example.com", userID: "user-1"}.validate())
}
func TestFindLocalUserNotFound(t *testing.T) {
st := newTestIDPStorage(t)
_, err := findLocalUser(context.Background(), st, userSelector{email: "missing@example.com"}, "")
require.Error(t, err)
require.True(t, strings.Contains(err.Error(), "not found"))
}
func TestFindLocalUserZeroUsersIncludesStoragePath(t *testing.T) {
st := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
_, err := findLocalUser(context.Background(), st, userSelector{email: "missing@example.com"}, "/var/lib/netbird/idp.db")
require.Error(t, err)
require.Contains(t, err.Error(), "no local users exist")
require.Contains(t, err.Error(), "/var/lib/netbird/idp.db")
}
func TestUserCommandValidatesSelectorBeforeOpeningStorage(t *testing.T) {
opened := false
cmd := NewCommands(Openers{
IDP: func(cmd *cobra.Command, fn func(ctx context.Context, idpStorage storage.Storage, storageFile string) error) error {
opened = true
return nil
},
})
cmd.SetArgs([]string{"user", "change-password", "--password", "NewPass1!"})
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
err := cmd.Execute()
require.Error(t, err)
require.Contains(t, err.Error(), "provide exactly one")
require.False(t, opened)
}
func TestResolvePasswordInputFromStdin(t *testing.T) {
cmd := &cobra.Command{}
cmd.SetIn(strings.NewReader("NewPass1!\n"))
password, err := resolvePasswordInput(cmd, "", "-")
require.NoError(t, err)
require.Equal(t, "NewPass1!", password)
}
func TestResolvePasswordInputRejectsMultipleSources(t *testing.T) {
_, err := resolvePasswordInput(&cobra.Command{}, "NewPass1!", "-")
require.Error(t, err)
}

View File

@@ -0,0 +1,80 @@
package cmd
import (
"context"
"path"
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/idp"
)
func TestApplyAdminDatadirOverrideRelocatesDefaultIDPStorage(t *testing.T) {
oldDatadir := filepath.Join(t.TempDir(), "old")
newDatadir := filepath.Join(t.TempDir(), "new")
for _, defaultFile := range []string{
"",
filepath.Join(oldDatadir, "idp.db"),
path.Join(oldDatadir, "idp.db"),
} {
t.Run(defaultFile, func(t *testing.T) {
cfg := &nbconfig.Config{
EmbeddedIdP: &idp.EmbeddedIdPConfig{
Enabled: true,
Storage: idp.EmbeddedStorageConfig{
Type: "sqlite3",
Config: idp.EmbeddedStorageTypeConfig{
File: defaultFile,
},
},
},
}
datadir := oldDatadir
oldAdminDatadir := adminDatadir
adminDatadir = newDatadir
t.Cleanup(func() { adminDatadir = oldAdminDatadir })
applyAdminDatadirOverride(cfg, &datadir)
require.Equal(t, newDatadir, datadir)
require.Equal(t, filepath.Join(newDatadir, "idp.db"), cfg.EmbeddedIdP.Storage.Config.File)
})
}
}
func TestOpenAdminEventStoreMissingEncryptionKeyReturnsNilInterface(t *testing.T) {
eventStore, err := openAdminEventStore(context.Background(), &nbconfig.Config{}, t.TempDir())
require.Error(t, err)
require.Contains(t, err.Error(), "encryption key")
require.Nil(t, eventStore)
}
func TestApplyAdminDatadirOverrideKeepsExplicitIDPStorage(t *testing.T) {
oldDatadir := filepath.Join(t.TempDir(), "old")
newDatadir := filepath.Join(t.TempDir(), "new")
explicitFile := filepath.Join(t.TempDir(), "custom-idp.db")
cfg := &nbconfig.Config{
EmbeddedIdP: &idp.EmbeddedIdPConfig{
Enabled: true,
Storage: idp.EmbeddedStorageConfig{
Type: "sqlite3",
Config: idp.EmbeddedStorageTypeConfig{
File: explicitFile,
},
},
},
}
datadir := oldDatadir
oldAdminDatadir := adminDatadir
adminDatadir = newDatadir
t.Cleanup(func() { adminDatadir = oldAdminDatadir })
applyAdminDatadirOverride(cfg, &datadir)
require.Equal(t, newDatadir, datadir)
require.Equal(t, explicitFile, cfg.EmbeddedIdP.Storage.Config.File)
}

View File

@@ -13,6 +13,7 @@ import (
"os"
"os/signal"
"path"
"path/filepath"
"strings"
"syscall"
@@ -209,7 +210,7 @@ func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
cfg.EmbeddedIdP.Storage.Type = "sqlite3"
}
if cfg.EmbeddedIdP.Storage.Config.File == "" && cfg.Datadir != "" {
cfg.EmbeddedIdP.Storage.Config.File = path.Join(cfg.Datadir, "idp.db")
cfg.EmbeddedIdP.Storage.Config.File = filepath.Join(cfg.Datadir, "idp.db")
}
issuer := cfg.EmbeddedIdP.Issuer

View File

@@ -0,0 +1,141 @@
// Package proxycmd provides reusable cobra commands for managing reverse proxy instances.
// Both the management and combined binaries use these commands, each providing
// their own StoreOpener to handle config loading and store initialization.
package proxycmd
import (
"bufio"
"context"
"fmt"
"io"
"strings"
"text/tabwriter"
"github.com/spf13/cobra"
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/server/store"
)
// StoreOpener initializes a store from the command context and calls fn.
type StoreOpener func(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error
const disconnectAllConfirmation = "disconnect all proxies"
// NewCommands creates the proxy command tree with the given store opener.
// Returns the parent "proxy" command with the disconnect-all subcommand.
func NewCommands(opener StoreOpener) *cobra.Command {
var dryRun bool
var force bool
proxyCmd := &cobra.Command{
Use: "proxy",
Short: "Manage reverse proxy instances",
Long: "Commands for inspecting and repairing the reverse proxy instances registered with the management server.",
}
disconnectAllCmd := &cobra.Command{
Use: "disconnect-all",
Short: "Force-mark all reverse proxy instances as disconnected",
Long: "Lists all reverse proxy instances and force-marks them as disconnected, regardless of their session state. " +
"Use this to repair stale connection state, e.g. after an unclean management server shutdown. " +
"By default, it asks for manual confirmation before changing state. Use --dry-run to preview without changing state, or --force to skip confirmation. " +
"Run during a maintenance window; affected live proxies may stay hidden until their next heartbeat or reconnect/re-register.",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return opener(cmd, func(ctx context.Context, s store.Store) error {
return runDisconnectAll(ctx, s, cmd.OutOrStdout(), cmd.InOrStdin(), dryRun, force)
})
},
}
disconnectAllCmd.Flags().BoolVar(&dryRun, "dry-run", false, "List reverse proxy instances that would be disconnected without changing state")
disconnectAllCmd.Flags().BoolVar(&force, "force", false, "Skip the confirmation prompt and apply the repair")
proxyCmd.AddCommand(disconnectAllCmd)
return proxyCmd
}
func runDisconnectAll(ctx context.Context, s store.Store, out io.Writer, in io.Reader, dryRun, force bool) error {
proxies, err := s.GetAllProxies(ctx)
if err != nil {
return fmt.Errorf("list proxies: %w", err)
}
if len(proxies) == 0 {
_, _ = fmt.Fprintln(out, "No reverse proxy instances found.")
return nil
}
toDisconnect := 0
w := tabwriter.NewWriter(out, 0, 0, 2, ' ', 0)
_, _ = fmt.Fprintln(w, "ID\tCLUSTER\tIP\tACCOUNT\tSTATUS\tLAST SEEN")
_, _ = fmt.Fprintln(w, "--\t-------\t--\t-------\t------\t---------")
for _, p := range proxies {
if p.Status != rpproxy.StatusDisconnected {
toDisconnect++
}
account := "-"
if p.AccountID != nil {
account = *p.AccountID
}
_, _ = fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
p.ID,
p.ClusterAddress,
p.IPAddress,
account,
p.Status,
p.LastSeen.Format("2006-01-02 15:04:05"),
)
}
if err := w.Flush(); err != nil {
return fmt.Errorf("write proxy list: %w", err)
}
if dryRun {
_, _ = fmt.Fprintf(out, "\nDry run: would force-mark %d of %d reverse proxy instance(s) as disconnected.\n", toDisconnect, len(proxies))
return nil
}
if !force {
confirmed, err := confirmDisconnectAll(out, in)
if err != nil {
return err
}
if !confirmed {
_, _ = fmt.Fprintln(out, "Aborted. No reverse proxy instances were changed.")
return nil
}
}
disconnected, err := s.DisconnectAllProxies(ctx)
if err != nil {
return fmt.Errorf("disconnect proxies: %w", err)
}
_, _ = fmt.Fprintf(out, "\nForce-marked %d of %d reverse proxy instance(s) as disconnected.\n", disconnected, len(proxies))
return nil
}
func confirmDisconnectAll(out io.Writer, in io.Reader) (bool, error) {
if in == nil {
in = strings.NewReader("")
}
_, _ = fmt.Fprintln(out, "\nWARNING: This command changes stored reverse proxy state for every non-disconnected instance.")
_, _ = fmt.Fprintln(out, "Run it during a maintenance window; affected live proxies may stay hidden until "+
"their next heartbeat or reconnect/re-register.")
_, _ = fmt.Fprintf(out, "Type %q to continue: ", disconnectAllConfirmation)
scanner := bufio.NewScanner(in)
if !scanner.Scan() {
if err := scanner.Err(); err != nil {
return false, fmt.Errorf("read confirmation: %w", err)
}
return false, nil
}
return strings.EqualFold(strings.TrimSpace(scanner.Text()), disconnectAllConfirmation), nil
}

View File

@@ -0,0 +1,180 @@
package proxycmd
import (
"bytes"
"context"
"strings"
"testing"
"time"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/server/store"
)
func newTestStore(t *testing.T) store.Store {
t.Helper()
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
require.NoError(t, err)
t.Cleanup(cleanup)
return s
}
func seedProxies(t *testing.T, ctx context.Context, s store.Store) {
t.Helper()
accountID := "account-1"
alreadyDisconnectedAt := time.Now().Add(-time.Hour)
seed := []*rpproxy.Proxy{
{
ID: "proxy-1",
SessionID: "session-1",
ClusterAddress: "cluster-a.example.com",
IPAddress: "10.0.0.1",
LastSeen: time.Now(),
Status: rpproxy.StatusConnected,
},
{
ID: "proxy-2",
SessionID: "session-2",
ClusterAddress: "cluster-b.example.com",
IPAddress: "10.0.0.2",
AccountID: &accountID,
LastSeen: time.Now(),
Status: rpproxy.StatusConnected,
},
{
ID: "proxy-3",
SessionID: "session-3",
ClusterAddress: "cluster-a.example.com",
IPAddress: "10.0.0.3",
LastSeen: time.Now().Add(-time.Hour),
Status: rpproxy.StatusDisconnected,
DisconnectedAt: &alreadyDisconnectedAt,
},
}
for _, p := range seed {
require.NoError(t, s.SaveProxy(ctx, p))
}
}
func proxiesByID(t *testing.T, ctx context.Context, s store.Store) map[string]*rpproxy.Proxy {
t.Helper()
proxies, err := s.GetAllProxies(ctx)
require.NoError(t, err)
require.Len(t, proxies, 3)
byID := make(map[string]*rpproxy.Proxy, len(proxies))
for _, p := range proxies {
byID[p.ID] = p
}
return byID
}
func TestRunDisconnectAllWithConfirmation(t *testing.T) {
ctx := context.Background()
s := newTestStore(t)
seedProxies(t, ctx, s)
var out bytes.Buffer
require.NoError(t, runDisconnectAll(ctx, s, &out, strings.NewReader(disconnectAllConfirmation+"\n"), false, false))
output := out.String()
require.Contains(t, output, "proxy-1")
require.Contains(t, output, "proxy-2")
require.Contains(t, output, "proxy-3")
require.Contains(t, output, "cluster-a.example.com")
require.Contains(t, output, "account-1")
require.Contains(t, output, "Type \"disconnect all proxies\" to continue")
require.Contains(t, output, "Force-marked 2 of 3 reverse proxy instance(s) as disconnected.")
for _, p := range proxiesByID(t, ctx, s) {
require.Equal(t, rpproxy.StatusDisconnected, p.Status, "proxy %s should be disconnected", p.ID)
require.NotNil(t, p.DisconnectedAt, "proxy %s should have a disconnected timestamp", p.ID)
}
}
func TestRunDisconnectAllForceSkipsConfirmation(t *testing.T) {
ctx := context.Background()
s := newTestStore(t)
seedProxies(t, ctx, s)
var out bytes.Buffer
require.NoError(t, runDisconnectAll(ctx, s, &out, strings.NewReader(""), false, true))
output := out.String()
require.NotContains(t, output, "Type \"disconnect all proxies\" to continue")
require.Contains(t, output, "Force-marked 2 of 3 reverse proxy instance(s) as disconnected.")
}
func TestRunDisconnectAllAbortLeavesProxiesUnchanged(t *testing.T) {
ctx := context.Background()
s := newTestStore(t)
seedProxies(t, ctx, s)
var out bytes.Buffer
require.NoError(t, runDisconnectAll(ctx, s, &out, strings.NewReader("no\n"), false, false))
output := out.String()
require.Contains(t, output, "Type \"disconnect all proxies\" to continue")
require.Contains(t, output, "Aborted. No reverse proxy instances were changed.")
byID := proxiesByID(t, ctx, s)
require.Equal(t, rpproxy.StatusConnected, byID["proxy-1"].Status)
require.Equal(t, rpproxy.StatusConnected, byID["proxy-2"].Status)
require.Equal(t, rpproxy.StatusDisconnected, byID["proxy-3"].Status)
}
func TestRunDisconnectAllDryRunLeavesProxiesUnchanged(t *testing.T) {
ctx := context.Background()
s := newTestStore(t)
seedProxies(t, ctx, s)
var out bytes.Buffer
require.NoError(t, runDisconnectAll(ctx, s, &out, strings.NewReader(""), true, false))
output := out.String()
require.Contains(t, output, "Dry run: would force-mark 2 of 3 reverse proxy instance(s) as disconnected.")
require.NotContains(t, output, "Type \"disconnect all proxies\" to continue")
byID := proxiesByID(t, ctx, s)
require.Equal(t, rpproxy.StatusConnected, byID["proxy-1"].Status)
require.Equal(t, rpproxy.StatusConnected, byID["proxy-2"].Status)
require.Equal(t, rpproxy.StatusDisconnected, byID["proxy-3"].Status)
}
func TestNewCommandsDisconnectAllDryRun(t *testing.T) {
ctx := context.Background()
s := newTestStore(t)
seedProxies(t, ctx, s)
opened := false
cmd := NewCommands(func(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
opened = true
return fn(cmd.Context(), s)
})
var out bytes.Buffer
cmd.SetOut(&out)
cmd.SetErr(&out)
cmd.SetIn(strings.NewReader(""))
cmd.SetArgs([]string{"disconnect-all", "--dry-run"})
require.NoError(t, cmd.ExecuteContext(ctx))
require.True(t, opened)
require.Contains(t, out.String(), "Dry run: would force-mark 2 of 3 reverse proxy instance(s) as disconnected.")
}
func TestRunDisconnectAllEmpty(t *testing.T) {
ctx := context.Background()
s := newTestStore(t)
var out bytes.Buffer
require.NoError(t, runDisconnectAll(ctx, s, &out, strings.NewReader(""), false, false))
require.Contains(t, out.String(), "No reverse proxy instances found.")
}

View File

@@ -83,7 +83,8 @@ func init() {
rootCmd.AddCommand(migrationCmd)
tc := newTokenCommands()
tc.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
rootCmd.AddCommand(tc)
ac := newAdminCommands()
ac.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
rootCmd.AddCommand(ac)
rootCmd.AddCommand(newLegacyTokenCommand())
}

View File

@@ -1,55 +0,0 @@
package cmd
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
)
var tokenDatadir string
// newTokenCommands creates the token command tree with management-specific store opener.
func newTokenCommands() *cobra.Command {
cmd := tokencmd.NewCommands(withTokenStore)
cmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
return cmd
}
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
datadir := config.Datadir
if tokenDatadir != "" {
datadir = tokenDatadir
}
s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, s)
}

View File

@@ -608,11 +608,11 @@ func (s *ProxyServiceServer) disconnectProxy(conn *proxyConnection) {
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
log.Warnf("Failed to unregister proxy %s from cluster: %v", conn.proxyID, err)
}
conn.cancel()
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
log.Warnf("Failed to mark proxy %s as disconnected: %v", conn.proxyID, err)
}
conn.cancel()
log.Infof("Proxy %s session %s disconnected", conn.proxyID, conn.sessionID)
}

View File

@@ -21,8 +21,11 @@ import (
)
const (
staticClientDashboard = "netbird-dashboard"
staticClientCLI = "netbird-cli"
StaticClientDashboard = "netbird-dashboard"
StaticClientCLI = "netbird-cli"
DefaultTOTPAuthenticatorID = "default-totp"
LocalConnectorID = dex.LocalConnectorID
defaultCLIRedirectURL1 = "http://localhost:53000/"
defaultCLIRedirectURL2 = "http://localhost:54000/"
defaultScopes = "openid profile email groups"
@@ -185,14 +188,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
EnablePasswordDB: true,
StaticClients: []storage.Client{
{
ID: staticClientDashboard,
ID: StaticClientDashboard,
Name: "NetBird Dashboard",
Public: true,
RedirectURIs: redirectURIs,
PostLogoutRedirectURIs: sanitizePostLogoutRedirectURIs(dashboardPostLogoutRedirectURIs),
},
{
ID: staticClientCLI,
ID: StaticClientCLI,
Name: "NetBird CLI",
Public: true,
RedirectURIs: redirectURIs,
@@ -254,13 +257,13 @@ func sanitizePostLogoutRedirectURIs(uris []string) []string {
func configureMFA(cfg *dex.YAMLConfig, sessionMaxLifetime, sessionIdleTimeout string, rememberMe bool, sessionCookieEncryptionKey string) error {
cfg.MFA.Authenticators = []dex.MFAAuthenticator{{
ID: "default-totp",
ID: DefaultTOTPAuthenticatorID,
// Has to be caps otherwise it will fail
Type: "TOTP",
Config: map[string]interface{}{
"issuer": "NetBird",
},
ConnectorTypes: []string{"local"},
ConnectorTypes: []string{LocalConnectorID},
}}
if sessionMaxLifetime == "" {
@@ -736,7 +739,7 @@ func (m *EmbeddedIdPManager) GetDefaultScopes() string {
// GetCLIClientID returns the client ID for CLI authentication.
func (m *EmbeddedIdPManager) GetCLIClientID() string {
return staticClientCLI
return StaticClientCLI
}
// GetCLIRedirectURLs returns the redirect URLs configured for the CLI client.
@@ -775,7 +778,7 @@ func (m *EmbeddedIdPManager) GetLocalKeysLocation() string {
// GetClientIDs returns the OAuth2 client IDs configured for this provider.
func (m *EmbeddedIdPManager) GetClientIDs() []string {
return []string{staticClientDashboard, staticClientCLI}
return []string{StaticClientDashboard, StaticClientCLI}
}
// GetUserIDClaim returns the JWT claim name used for user identification.
@@ -792,11 +795,11 @@ func (m *EmbeddedIdPManager) IsLocalAuthDisabled() bool {
func (m *EmbeddedIdPManager) SetMFAEnabled(ctx context.Context, enabled bool) error {
var mfaChain []string
if enabled {
mfaChain = []string{"default-totp"}
mfaChain = []string{DefaultTOTPAuthenticatorID}
}
if err := m.provider.SetClientsMFAChain(ctx, []string{
staticClientCLI,
staticClientDashboard,
StaticClientCLI,
StaticClientDashboard,
}, mfaChain); err != nil {
return fmt.Errorf("failed to set MFA enabled=%v: %w", enabled, err)
}

View File

@@ -331,7 +331,7 @@ func TestEmbeddedIdPConfig_ToYAMLConfig_IncludesDeviceCallbackRedirectURI(t *tes
var cliRedirectURIs []string
for _, client := range yamlConfig.StaticClients {
if client.ID == staticClientCLI {
if client.ID == StaticClientCLI {
cliRedirectURIs = client.RedirectURIs
break
}

View File

@@ -6088,6 +6088,37 @@ func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID, sessionID strin
return nil
}
// GetAllProxies returns all reverse proxy instance rows.
func (s *SqlStore) GetAllProxies(ctx context.Context) ([]*proxy.Proxy, error) {
var proxies []*proxy.Proxy
result := s.db.Order("cluster_address, id").Find(&proxies)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get proxies: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get proxies")
}
return proxies, nil
}
// DisconnectAllProxies force-marks every proxy that is not already disconnected
// as disconnected, regardless of session ID. Unlike DisconnectProxy it is not
// session-guarded: it is an administrative repair helper, not part of the
// connection lifecycle. last_seen is left untouched so the stale-proxy reaper
// keeps working off the real last heartbeat. Returns the number of proxies updated.
func (s *SqlStore) DisconnectAllProxies(ctx context.Context) (int64, error) {
result := s.db.
Model(&proxy.Proxy{}).
Where("status != ?", proxy.StatusDisconnected).
Updates(map[string]any{
"status": proxy.StatusDisconnected,
"disconnected_at": time.Now(),
})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to disconnect all proxies: %v", result.Error)
return 0, status.Errorf(status.Internal, "failed to disconnect all proxies")
}
return result.RowsAffected, nil
}
// UpdateProxyHeartbeat updates the last_seen timestamp for the proxy's current session.
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
now := time.Now()
@@ -6095,7 +6126,11 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) err
result := s.db.
Model(&proxy.Proxy{}).
Where("id = ? AND session_id = ?", p.ID, p.SessionID).
Update("last_seen", now)
Updates(map[string]any{
"last_seen": now,
"status": proxy.StatusConnected,
"disconnected_at": nil,
})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update proxy heartbeat: %v", result.Error)

View File

@@ -0,0 +1,156 @@
package store
import (
"context"
"os"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
// TestSqlStore_DisconnectAllProxies guards the administrative
// force-disconnect helper:
//
// 1. Every proxy that is not already disconnected is marked
// disconnected regardless of its session ID (unlike
// DisconnectProxy, which is session-guarded).
// 2. Rows that are already disconnected are left untouched, so their
// original disconnected_at is preserved and the returned count
// reflects only the rows that actually changed.
// 3. last_seen is not modified — the stale-proxy reaper keeps working
// off the real last heartbeat.
func TestSqlStore_DisconnectAllProxies(t *testing.T) {
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
ctx := context.Background()
lastSeenFresh := time.Now().Add(-30 * time.Second)
lastSeenStale := time.Now().Add(-30 * time.Minute)
oldDisconnectedAt := time.Now().Add(-time.Hour)
accountID := "acct-disconnect"
proxies := []*rpproxy.Proxy{
{
ID: "p-connected-fresh",
SessionID: "sess-1",
ClusterAddress: "cluster-a.example.com",
IPAddress: "10.0.0.1",
LastSeen: lastSeenFresh,
Status: rpproxy.StatusConnected,
},
{
ID: "p-connected-stale",
SessionID: "sess-2",
ClusterAddress: "cluster-b.example.com",
IPAddress: "10.0.0.2",
AccountID: &accountID,
LastSeen: lastSeenStale,
Status: rpproxy.StatusConnected,
},
{
ID: "p-already-disconnected",
SessionID: "sess-3",
ClusterAddress: "cluster-a.example.com",
IPAddress: "10.0.0.3",
LastSeen: lastSeenStale,
Status: rpproxy.StatusDisconnected,
DisconnectedAt: &oldDisconnectedAt,
},
}
for _, p := range proxies {
require.NoError(t, store.SaveProxy(ctx, p))
}
all, err := store.GetAllProxies(ctx)
require.NoError(t, err)
require.Len(t, all, 3)
disconnected, err := store.DisconnectAllProxies(ctx)
require.NoError(t, err)
assert.Equal(t, int64(2), disconnected)
all, err = store.GetAllProxies(ctx)
require.NoError(t, err)
require.Len(t, all, 3)
byID := make(map[string]*rpproxy.Proxy, len(all))
for _, p := range all {
byID[p.ID] = p
}
for id, p := range byID {
assert.Equal(t, rpproxy.StatusDisconnected, p.Status, "proxy %s should be disconnected", id)
require.NotNil(t, p.DisconnectedAt, "proxy %s should have disconnected_at set", id)
}
// force-marked rows carry a fresh disconnected_at; the untouched row keeps its original one
assert.WithinDuration(t, time.Now(), *byID["p-connected-fresh"].DisconnectedAt, 10*time.Second)
assert.WithinDuration(t, time.Now(), *byID["p-connected-stale"].DisconnectedAt, 10*time.Second)
assert.WithinDuration(t, oldDisconnectedAt, *byID["p-already-disconnected"].DisconnectedAt, time.Second)
// last_seen is preserved so the stale reaper schedule is unaffected
assert.WithinDuration(t, lastSeenFresh, byID["p-connected-fresh"].LastSeen, time.Second)
assert.WithinDuration(t, lastSeenStale, byID["p-connected-stale"].LastSeen, time.Second)
// idempotent: a second run has nothing left to update
disconnected, err = store.DisconnectAllProxies(ctx)
require.NoError(t, err)
assert.Equal(t, int64(0), disconnected)
})
}
func TestSqlStore_UpdateProxyHeartbeatRestoresDisconnectedCurrentSession(t *testing.T) {
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
ctx := context.Background()
proxy := &rpproxy.Proxy{
ID: "p-heartbeat",
SessionID: "sess-heartbeat",
ClusterAddress: "cluster-heartbeat.example.com",
IPAddress: "10.0.0.10",
LastSeen: time.Now().Add(-30 * time.Second),
Status: rpproxy.StatusConnected,
}
require.NoError(t, store.SaveProxy(ctx, proxy))
disconnected, err := store.DisconnectAllProxies(ctx)
require.NoError(t, err)
require.Equal(t, int64(1), disconnected)
require.NoError(t, store.UpdateProxyHeartbeat(ctx, &rpproxy.Proxy{ID: proxy.ID, SessionID: proxy.SessionID}))
all, err := store.GetAllProxies(ctx)
require.NoError(t, err)
require.Len(t, all, 1)
assert.Equal(t, rpproxy.StatusConnected, all[0].Status)
assert.Nil(t, all[0].DisconnectedAt)
assert.WithinDuration(t, time.Now(), all[0].LastSeen, 10*time.Second)
addresses, err := store.GetActiveProxyClusterAddresses(ctx)
require.NoError(t, err)
assert.Contains(t, addresses, proxy.ClusterAddress)
})
}
func TestSqlStore_GetAllProxies_Empty(t *testing.T) {
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
all, err := store.GetAllProxies(context.Background())
require.NoError(t, err)
assert.Empty(t, all)
})
}

View File

@@ -323,6 +323,8 @@ type Store interface {
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetAllProxies(ctx context.Context) ([]*proxy.Proxy, error)
DisconnectAllProxies(ctx context.Context) (int64, error)
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)

View File

@@ -745,6 +745,21 @@ func (mr *MockStoreMockRecorder) DeleteZoneDNSRecords(ctx, accountID, zoneID int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID)
}
// DisconnectAllProxies mocks base method.
func (m *MockStore) DisconnectAllProxies(ctx context.Context) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DisconnectAllProxies", ctx)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// DisconnectAllProxies indicates an expected call of DisconnectAllProxies.
func (mr *MockStoreMockRecorder) DisconnectAllProxies(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectAllProxies", reflect.TypeOf((*MockStore)(nil).DisconnectAllProxies), ctx)
}
// DisconnectProxy mocks base method.
func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
m.ctrl.T.Helper()
@@ -1389,6 +1404,21 @@ func (mr *MockStoreMockRecorder) GetAllEphemeralPeers(ctx, lockStrength interfac
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllEphemeralPeers", reflect.TypeOf((*MockStore)(nil).GetAllEphemeralPeers), ctx, lockStrength)
}
// GetAllProxies mocks base method.
func (m *MockStore) GetAllProxies(ctx context.Context) ([]*proxy.Proxy, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAllProxies", ctx)
ret0, _ := ret[0].([]*proxy.Proxy)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAllProxies indicates an expected call of GetAllProxies.
func (mr *MockStoreMockRecorder) GetAllProxies(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllProxies", reflect.TypeOf((*MockStore)(nil).GetAllProxies), ctx)
}
// GetAllProxyAccessTokens mocks base method.
func (m *MockStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types2.ProxyAccessToken, error) {
m.ctrl.T.Helper()
@@ -1521,6 +1551,21 @@ func (mr *MockStoreMockRecorder) GetDNSRecordByID(ctx, lockStrength, accountID,
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSRecordByID", reflect.TypeOf((*MockStore)(nil).GetDNSRecordByID), ctx, lockStrength, accountID, zoneID, recordID)
}
// GetEmbeddedProxyPeerIDsByCluster mocks base method.
func (m *MockStore) GetEmbeddedProxyPeerIDsByCluster(ctx context.Context, accountID string) (map[string][]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetEmbeddedProxyPeerIDsByCluster", ctx, accountID)
ret0, _ := ret[0].(map[string][]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetEmbeddedProxyPeerIDsByCluster indicates an expected call of GetEmbeddedProxyPeerIDsByCluster.
func (mr *MockStoreMockRecorder) GetEmbeddedProxyPeerIDsByCluster(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEmbeddedProxyPeerIDsByCluster", reflect.TypeOf((*MockStore)(nil).GetEmbeddedProxyPeerIDsByCluster), ctx, accountID)
}
// GetExpiredEphemeralServices mocks base method.
func (m *MockStore) GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*service.Service, error) {
m.ctrl.T.Helper()
@@ -1566,6 +1611,21 @@ func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, accountID, gr
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, accountID, groupName)
}
// GetGroupIDsByPeerIDs mocks base method.
func (m *MockStore) GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupIDsByPeerIDs", ctx, accountID, peerIDs)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupIDsByPeerIDs indicates an expected call of GetGroupIDsByPeerIDs.
func (mr *MockStoreMockRecorder) GetGroupIDsByPeerIDs(ctx, accountID, peerIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupIDsByPeerIDs", reflect.TypeOf((*MockStore)(nil).GetGroupIDsByPeerIDs), ctx, accountID, peerIDs)
}
// GetGroupsByIDs mocks base method.
func (m *MockStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types2.Group, error) {
m.ctrl.T.Helper()
@@ -1850,6 +1910,21 @@ func (mr *MockStoreMockRecorder) GetPeerIDByKey(ctx, lockStrength, key interface
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerIDByKey", reflect.TypeOf((*MockStore)(nil).GetPeerIDByKey), ctx, lockStrength, key)
}
// GetPeerIDsByGroups mocks base method.
func (m *MockStore) GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeerIDsByGroups", ctx, accountID, groupIDs)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPeerIDsByGroups indicates an expected call of GetPeerIDsByGroups.
func (mr *MockStoreMockRecorder) GetPeerIDsByGroups(ctx, accountID, groupIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerIDsByGroups", reflect.TypeOf((*MockStore)(nil).GetPeerIDsByGroups), ctx, accountID, groupIDs)
}
// GetPeerIdByLabel mocks base method.
func (m *MockStore) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID, hostname string) (string, error) {
m.ctrl.T.Helper()
@@ -1925,51 +2000,6 @@ func (mr *MockStoreMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupIDs int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockStore)(nil).GetPeersByGroupIDs), ctx, accountID, groupIDs)
}
// GetPeerIDsByGroups mocks base method.
func (m *MockStore) GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeerIDsByGroups", ctx, accountID, groupIDs)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPeerIDsByGroups indicates an expected call of GetPeerIDsByGroups.
func (mr *MockStoreMockRecorder) GetPeerIDsByGroups(ctx, accountID, groupIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerIDsByGroups", reflect.TypeOf((*MockStore)(nil).GetPeerIDsByGroups), ctx, accountID, groupIDs)
}
// GetGroupIDsByPeerIDs mocks base method.
func (m *MockStore) GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupIDsByPeerIDs", ctx, accountID, peerIDs)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupIDsByPeerIDs indicates an expected call of GetGroupIDsByPeerIDs.
func (mr *MockStoreMockRecorder) GetGroupIDsByPeerIDs(ctx, accountID, peerIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupIDsByPeerIDs", reflect.TypeOf((*MockStore)(nil).GetGroupIDsByPeerIDs), ctx, accountID, peerIDs)
}
// GetEmbeddedProxyPeerIDsByCluster mocks base method.
func (m *MockStore) GetEmbeddedProxyPeerIDsByCluster(ctx context.Context, accountID string) (map[string][]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetEmbeddedProxyPeerIDsByCluster", ctx, accountID)
ret0, _ := ret[0].(map[string][]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetEmbeddedProxyPeerIDsByCluster indicates an expected call of GetEmbeddedProxyPeerIDsByCluster.
func (mr *MockStoreMockRecorder) GetEmbeddedProxyPeerIDsByCluster(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEmbeddedProxyPeerIDsByCluster", reflect.TypeOf((*MockStore)(nil).GetEmbeddedProxyPeerIDsByCluster), ctx, accountID)
}
// GetPeersByIDs mocks base method.
func (m *MockStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*peer.Peer, error) {
m.ctrl.T.Helper()

View File

@@ -1849,12 +1849,17 @@ func (am *DefaultAccountManager) DeleteUserInvite(ctx context.Context, accountID
const minPasswordLength = 8
// validatePassword checks password strength requirements:
// validatePassword checks password strength requirements.
func validatePassword(password string) error {
return ValidatePassword(password)
}
// ValidatePassword checks password strength requirements:
// - Minimum 8 characters
// - At least 1 digit
// - At least 1 uppercase letter
// - At least 1 special character
func validatePassword(password string) error {
func ValidatePassword(password string) error {
if len(password) < minPasswordLength {
return errors.New("password must be at least 8 characters long")
}