mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-23 08:19:56 +00:00
Compare commits
3 Commits
v0.73.2
...
feat/admin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
520370a8b0 | ||
|
|
b5a16a1898 | ||
|
|
449b5cbb80 |
91
combined/cmd/admin.go
Normal file
91
combined/cmd/admin.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
|
||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// newAdminCommands creates the admin command tree with combined-specific resource openers.
|
||||
func newAdminCommands() *cobra.Command {
|
||||
cmd := admincmd.NewCommands(withAdminResources)
|
||||
cmd.AddCommand(tokencmd.NewCommands(withAdminTokenStore))
|
||||
return cmd
|
||||
}
|
||||
|
||||
// withAdminResources loads the combined YAML config, initializes stores, and calls fn.
|
||||
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
|
||||
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, cfg *CombinedConfig) error {
|
||||
mgmtConfig, err := cfg.ToManagementConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create management config: %w", err)
|
||||
}
|
||||
|
||||
idpStorage, err := admincmd.OpenEmbeddedIDPStorage(mgmtConfig.EmbeddedIdP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := idpStorage.Close(); err != nil {
|
||||
log.Debugf("close embedded IdP storage: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage})
|
||||
})
|
||||
}
|
||||
|
||||
// withAdminTokenStore opens only the management store for admin token commands.
|
||||
func withAdminTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, _ *CombinedConfig) error {
|
||||
return fn(ctx, managementStore)
|
||||
})
|
||||
}
|
||||
|
||||
func withAdminStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store, cfg *CombinedConfig) error) error {
|
||||
if err := util.InitLog("error", "console"); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
if dsn := cfg.Server.Store.DSN; dsn != "" {
|
||||
switch strings.ToLower(cfg.Server.Store.Engine) {
|
||||
case "postgres":
|
||||
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
|
||||
case "mysql":
|
||||
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||
}
|
||||
}
|
||||
if file := cfg.Server.Store.File; file != "" {
|
||||
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
|
||||
}
|
||||
|
||||
managementStore, err := store.NewStore(ctx, types.Engine(cfg.Management.Store.Engine), cfg.Management.DataDir, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create store: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := managementStore.Close(ctx); err != nil {
|
||||
log.Debugf("close store: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, managementStore, cfg)
|
||||
}
|
||||
@@ -64,7 +64,7 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
|
||||
_ = rootCmd.MarkPersistentFlagRequired("config")
|
||||
|
||||
rootCmd.AddCommand(newTokenCommands())
|
||||
rootCmd.AddCommand(newAdminCommands())
|
||||
}
|
||||
|
||||
func RootCmd() *cobra.Command {
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// newTokenCommands creates the token command tree with combined-specific store opener.
|
||||
func newTokenCommands() *cobra.Command {
|
||||
return tokencmd.NewCommands(withTokenStore)
|
||||
}
|
||||
|
||||
// withTokenStore loads the combined YAML config, initializes the store, and calls fn.
|
||||
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||
if err := util.InitLog("error", "console"); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
if dsn := cfg.Server.Store.DSN; dsn != "" {
|
||||
switch strings.ToLower(cfg.Server.Store.Engine) {
|
||||
case "postgres":
|
||||
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
|
||||
case "mysql":
|
||||
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||
}
|
||||
}
|
||||
if file := cfg.Server.Store.File; file != "" {
|
||||
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
|
||||
}
|
||||
|
||||
datadir := cfg.Management.DataDir
|
||||
engine := types.Engine(cfg.Management.Store.Engine)
|
||||
|
||||
s, err := store.NewStore(ctx, engine, datadir, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create store: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := s.Close(ctx); err != nil {
|
||||
log.Debugf("close store: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, s)
|
||||
}
|
||||
89
management/cmd/admin.go
Normal file
89
management/cmd/admin.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
|
||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var adminDatadir string
|
||||
|
||||
// newAdminCommands creates the admin command tree with management-specific resource openers.
|
||||
func newAdminCommands() *cobra.Command {
|
||||
cmd := admincmd.NewCommands(withAdminResources)
|
||||
cmd.PersistentFlags().StringVar(&adminDatadir, "datadir", "", "Override the data directory from config (used for store.db and the default idp.db)")
|
||||
cmd.AddCommand(tokencmd.NewCommands(withAdminTokenStore))
|
||||
return cmd
|
||||
}
|
||||
|
||||
// withAdminResources initializes logging, loads config, opens the management store
|
||||
// and embedded IdP storage, and calls fn.
|
||||
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
|
||||
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, config *nbconfig.Config) error {
|
||||
idpStorage, err := admincmd.OpenEmbeddedIDPStorage(config.EmbeddedIdP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := idpStorage.Close(); err != nil {
|
||||
log.Debugf("close embedded IdP storage: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage})
|
||||
})
|
||||
}
|
||||
|
||||
// withAdminTokenStore opens only the management store for admin token commands.
|
||||
func withAdminTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, _ *nbconfig.Config) error {
|
||||
return fn(ctx, managementStore)
|
||||
})
|
||||
}
|
||||
|
||||
func withAdminStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store, config *nbconfig.Config) error) error {
|
||||
if err := util.InitLog("error", "console"); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||
|
||||
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
datadir := config.Datadir
|
||||
if adminDatadir != "" {
|
||||
oldDatadir := datadir
|
||||
datadir = adminDatadir
|
||||
if config.EmbeddedIdP != nil && config.EmbeddedIdP.Storage.Type == "sqlite3" {
|
||||
defaultIDPFile := filepath.Join(oldDatadir, "idp.db")
|
||||
if config.EmbeddedIdP.Storage.Config.File == "" || config.EmbeddedIdP.Storage.Config.File == defaultIDPFile {
|
||||
config.EmbeddedIdP.Storage.Config.File = filepath.Join(datadir, "idp.db")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
managementStore, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create store: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := managementStore.Close(ctx); err != nil {
|
||||
log.Debugf("close store: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, managementStore, config)
|
||||
}
|
||||
441
management/cmd/admin/admin.go
Normal file
441
management/cmd/admin/admin.go
Normal file
@@ -0,0 +1,441 @@
|
||||
// Package admincmd provides reusable cobra commands for self-hosted administrator helpers.
|
||||
// Both the management and combined binaries use these commands, each providing
|
||||
// their own opener to handle config loading and storage initialization.
|
||||
package admincmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
nbdex "github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
localConnectorID = "local"
|
||||
dashboardClientID = "netbird-dashboard"
|
||||
cliClientID = "netbird-cli"
|
||||
defaultTOTPAuthenticatorID = "default-totp"
|
||||
)
|
||||
|
||||
// Resources contains the storages required by the admin commands.
|
||||
type Resources struct {
|
||||
Store store.Store
|
||||
IDPStorage storage.Storage
|
||||
}
|
||||
|
||||
// Opener initializes command resources from the command context and calls fn.
|
||||
type Opener func(cmd *cobra.Command, fn func(ctx context.Context, resources Resources) error) error
|
||||
|
||||
type userSelector struct {
|
||||
email string
|
||||
userID string
|
||||
}
|
||||
|
||||
func (s userSelector) normalized() userSelector {
|
||||
return userSelector{
|
||||
email: strings.TrimSpace(s.email),
|
||||
userID: strings.TrimSpace(s.userID),
|
||||
}
|
||||
}
|
||||
|
||||
func (s userSelector) validate() error {
|
||||
s = s.normalized()
|
||||
if (s.email == "") == (s.userID == "") {
|
||||
return fmt.Errorf("provide exactly one of --email or --user-id")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewCommands creates the admin command tree with the given resource opener.
|
||||
func NewCommands(opener Opener) *cobra.Command {
|
||||
adminCmd := &cobra.Command{
|
||||
Use: "admin",
|
||||
Short: "Self-hosted administrator helpers",
|
||||
Long: "Administrative helpers for self-hosted deployments using the embedded identity provider.",
|
||||
}
|
||||
|
||||
userCmd := &cobra.Command{
|
||||
Use: "user",
|
||||
Short: "Manage local embedded IdP users",
|
||||
}
|
||||
|
||||
var passwordSelector userSelector
|
||||
var password string
|
||||
var passwordFile string
|
||||
passwordCmd := &cobra.Command{
|
||||
Use: "change-password (--email email | --user-id id) (--password password | --password-file path)",
|
||||
Aliases: []string{"set-password"},
|
||||
Short: "Change a local user's password",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
newPassword, err := resolvePasswordInput(cmd, password, passwordFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return opener(cmd, func(ctx context.Context, resources Resources) error {
|
||||
return runChangePassword(ctx, resources.IDPStorage, cmd.OutOrStdout(), passwordSelector, newPassword)
|
||||
})
|
||||
},
|
||||
}
|
||||
addUserSelectorFlags(passwordCmd, &passwordSelector)
|
||||
passwordCmd.Flags().StringVar(&password, "password", "", "New password for the user")
|
||||
passwordCmd.Flags().StringVar(&passwordFile, "password-file", "", "Read new password from file ('-' for stdin)")
|
||||
|
||||
var resetSelector userSelector
|
||||
resetMFACmd := &cobra.Command{
|
||||
Use: "reset-mfa (--email email | --user-id id)",
|
||||
Short: "Reset a local user's MFA enrollment",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return opener(cmd, func(ctx context.Context, resources Resources) error {
|
||||
return runResetMFA(ctx, resources.IDPStorage, cmd.OutOrStdout(), resetSelector)
|
||||
})
|
||||
},
|
||||
}
|
||||
addUserSelectorFlags(resetMFACmd, &resetSelector)
|
||||
|
||||
userCmd.AddCommand(passwordCmd, resetMFACmd)
|
||||
|
||||
mfaCmd := &cobra.Command{
|
||||
Use: "mfa",
|
||||
Short: "Manage local MFA for embedded IdP users",
|
||||
}
|
||||
|
||||
enableCmd := &cobra.Command{
|
||||
Use: "enable",
|
||||
Short: "Enable MFA for local embedded IdP users",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return opener(cmd, func(ctx context.Context, resources Resources) error {
|
||||
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), true)
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
disableCmd := &cobra.Command{
|
||||
Use: "disable",
|
||||
Short: "Disable MFA for local embedded IdP users",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return opener(cmd, func(ctx context.Context, resources Resources) error {
|
||||
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), false)
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
statusCmd := &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "Show local MFA status",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return opener(cmd, func(ctx context.Context, resources Resources) error {
|
||||
return runMFAStatus(ctx, resources, cmd.OutOrStdout())
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
mfaCmd.AddCommand(enableCmd, disableCmd, statusCmd)
|
||||
adminCmd.AddCommand(userCmd, mfaCmd)
|
||||
return adminCmd
|
||||
}
|
||||
|
||||
// OpenEmbeddedIDPStorage opens the Dex storage configured for the embedded IdP.
|
||||
func OpenEmbeddedIDPStorage(cfg *idp.EmbeddedIdPConfig) (storage.Storage, error) {
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
return nil, fmt.Errorf("admin commands require the embedded IdP to be enabled")
|
||||
}
|
||||
|
||||
yamlConfig, err := cfg.ToYAMLConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build embedded IdP config: %w", err)
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
st, err := yamlConfig.Storage.OpenStorage(logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open embedded IdP storage: %w", err)
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
func addUserSelectorFlags(cmd *cobra.Command, selector *userSelector) {
|
||||
cmd.Flags().StringVar(&selector.email, "email", "", "User email")
|
||||
cmd.Flags().StringVar(&selector.userID, "user-id", "", "User ID")
|
||||
}
|
||||
|
||||
func resolvePasswordInput(cmd *cobra.Command, password, passwordFile string) (string, error) {
|
||||
if password != "" && passwordFile != "" {
|
||||
return "", fmt.Errorf("provide only one of --password or --password-file")
|
||||
}
|
||||
if passwordFile == "" {
|
||||
return password, nil
|
||||
}
|
||||
|
||||
var data []byte
|
||||
var err error
|
||||
if passwordFile == "-" {
|
||||
data, err = io.ReadAll(cmd.InOrStdin())
|
||||
} else {
|
||||
data, err = os.ReadFile(passwordFile)
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read password: %w", err)
|
||||
}
|
||||
return strings.TrimRight(string(data), "\r\n"), nil
|
||||
}
|
||||
|
||||
func runChangePassword(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector, password string) error {
|
||||
if idpStorage == nil {
|
||||
return fmt.Errorf("embedded IdP storage is required")
|
||||
}
|
||||
selector = selector.normalized()
|
||||
if err := selector.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
if password == "" {
|
||||
return fmt.Errorf("password is required")
|
||||
}
|
||||
if err := server.ValidatePassword(password); err != nil {
|
||||
return fmt.Errorf("invalid password: %w", err)
|
||||
}
|
||||
|
||||
user, err := findLocalUser(ctx, idpStorage, selector)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
if err := idpStorage.UpdatePassword(ctx, user.Email, func(old storage.Password) (storage.Password, error) {
|
||||
old.Hash = hash
|
||||
return old, nil
|
||||
}); err != nil {
|
||||
return fmt.Errorf("update password for %s: %w", user.Email, err)
|
||||
}
|
||||
|
||||
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(w, "Password updated for %s.\n", user.Email)
|
||||
return nil
|
||||
}
|
||||
|
||||
func runResetMFA(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector) error {
|
||||
if idpStorage == nil {
|
||||
return fmt.Errorf("embedded IdP storage is required")
|
||||
}
|
||||
selector = selector.normalized()
|
||||
if err := selector.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := findLocalUser(ctx, idpStorage, selector)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reset := false
|
||||
err = idpStorage.UpdateUserIdentity(ctx, user.UserID, localConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
|
||||
reset = reset || len(old.MFASecrets) > 0 || len(old.WebAuthnCredentials) > 0
|
||||
old.MFASecrets = map[string]*storage.MFASecret{}
|
||||
old.WebAuthnCredentials = map[string][]storage.WebAuthnCredential{}
|
||||
return old, nil
|
||||
})
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
|
||||
return err
|
||||
}
|
||||
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("reset MFA for %s: %w", user.Email, err)
|
||||
}
|
||||
|
||||
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reset {
|
||||
_, _ = fmt.Fprintf(w, "MFA reset for %s. The user will re-enroll at next login.\n", user.Email)
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSetMFAEnabled(ctx context.Context, resources Resources, w io.Writer, enabled bool) error {
|
||||
if resources.Store == nil {
|
||||
return fmt.Errorf("management store is required")
|
||||
}
|
||||
if resources.IDPStorage == nil {
|
||||
return fmt.Errorf("embedded IdP storage is required")
|
||||
}
|
||||
|
||||
accounts := resources.Store.GetAllAccounts(ctx)
|
||||
if len(accounts) != 1 {
|
||||
return fmt.Errorf("expected exactly one account, got %d; local MFA is supported only in single-account embedded IdP deployments", len(accounts))
|
||||
}
|
||||
|
||||
settings := &types.Settings{}
|
||||
if accounts[0].Settings != nil {
|
||||
settings = accounts[0].Settings.Copy()
|
||||
}
|
||||
settings.LocalMfaEnabled = enabled
|
||||
if err := resources.Store.SaveAccountSettings(ctx, accounts[0].Id, settings); err != nil {
|
||||
return fmt.Errorf("save local MFA account setting: %w", err)
|
||||
}
|
||||
|
||||
if err := setIDPClientsMFA(ctx, resources.IDPStorage, enabled); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
state := "disabled"
|
||||
if enabled {
|
||||
state = "enabled"
|
||||
}
|
||||
_, _ = fmt.Fprintf(w, "Local MFA %s.\n", state)
|
||||
return nil
|
||||
}
|
||||
|
||||
func runMFAStatus(ctx context.Context, resources Resources, w io.Writer) error {
|
||||
if resources.Store == nil {
|
||||
return fmt.Errorf("management store is required")
|
||||
}
|
||||
if resources.IDPStorage == nil {
|
||||
return fmt.Errorf("embedded IdP storage is required")
|
||||
}
|
||||
|
||||
accounts := resources.Store.GetAllAccounts(ctx)
|
||||
accountStatus := "unknown"
|
||||
if len(accounts) == 1 && accounts[0].Settings != nil {
|
||||
accountStatus = "disabled"
|
||||
if accounts[0].Settings.LocalMfaEnabled {
|
||||
accountStatus = "enabled"
|
||||
}
|
||||
}
|
||||
|
||||
clientStatus, err := idpClientsMFAStatus(ctx, resources.IDPStorage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(w, "Account setting: %s\n", accountStatus)
|
||||
_, _ = fmt.Fprintf(w, "Embedded IdP clients: %s\n", clientStatus)
|
||||
return nil
|
||||
}
|
||||
|
||||
func findLocalUser(ctx context.Context, idpStorage storage.Storage, selector userSelector) (storage.Password, error) {
|
||||
selector = selector.normalized()
|
||||
if err := selector.validate(); err != nil {
|
||||
return storage.Password{}, err
|
||||
}
|
||||
|
||||
if selector.email != "" {
|
||||
user, err := idpStorage.GetPassword(ctx, selector.email)
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return storage.Password{}, fmt.Errorf("local user with email %q not found", selector.email)
|
||||
}
|
||||
if err != nil {
|
||||
return storage.Password{}, fmt.Errorf("get local user by email %q: %w", selector.email, err)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
rawUserID := selector.userID
|
||||
if decodedUserID, _, err := nbdex.DecodeDexUserID(selector.userID); err == nil && decodedUserID != "" {
|
||||
rawUserID = decodedUserID
|
||||
}
|
||||
|
||||
users, err := idpStorage.ListPasswords(ctx)
|
||||
if err != nil {
|
||||
return storage.Password{}, fmt.Errorf("list local users: %w", err)
|
||||
}
|
||||
for _, user := range users {
|
||||
if user.UserID == rawUserID || user.UserID == selector.userID {
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
|
||||
return storage.Password{}, fmt.Errorf("local user with ID %q not found", selector.userID)
|
||||
}
|
||||
|
||||
func deleteLocalAuthSession(ctx context.Context, idpStorage storage.Storage, userID string) error {
|
||||
err := idpStorage.DeleteAuthSession(ctx, userID, localConnectorID)
|
||||
if err == nil || errors.Is(err, storage.ErrNotFound) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("delete local auth session for user %s: %w", userID, err)
|
||||
}
|
||||
|
||||
func setIDPClientsMFA(ctx context.Context, idpStorage storage.Storage, enabled bool) error {
|
||||
var mfaChain []string
|
||||
if enabled {
|
||||
mfaChain = []string{defaultTOTPAuthenticatorID}
|
||||
}
|
||||
|
||||
for _, clientID := range []string{cliClientID, dashboardClientID} {
|
||||
if err := idpStorage.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
|
||||
old.MFAChain = mfaChain
|
||||
return old, nil
|
||||
}); err != nil {
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return fmt.Errorf("embedded IdP client %q not found; start the management server once before toggling MFA", clientID)
|
||||
}
|
||||
return fmt.Errorf("update MFA chain on embedded IdP client %q: %w", clientID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func idpClientsMFAStatus(ctx context.Context, idpStorage storage.Storage) (string, error) {
|
||||
clientIDs := []string{cliClientID, dashboardClientID}
|
||||
enabledCount := 0
|
||||
for _, clientID := range clientIDs {
|
||||
client, err := idpStorage.GetClient(ctx, clientID)
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return "unknown", fmt.Errorf("embedded IdP client %q not found", clientID)
|
||||
}
|
||||
if err != nil {
|
||||
return "unknown", fmt.Errorf("get embedded IdP client %q: %w", clientID, err)
|
||||
}
|
||||
if hasAuthenticator(client.MFAChain, defaultTOTPAuthenticatorID) {
|
||||
enabledCount++
|
||||
}
|
||||
}
|
||||
|
||||
switch enabledCount {
|
||||
case 0:
|
||||
return "disabled", nil
|
||||
case len(clientIDs):
|
||||
return "enabled", nil
|
||||
default:
|
||||
return "partially enabled", nil
|
||||
}
|
||||
}
|
||||
|
||||
func hasAuthenticator(chain []string, authenticatorID string) bool {
|
||||
for _, id := range chain {
|
||||
if id == authenticatorID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
160
management/cmd/admin/admin_test.go
Normal file
160
management/cmd/admin/admin_test.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package admincmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/memory"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
nbdex "github.com/netbirdio/netbird/idp/dex"
|
||||
)
|
||||
|
||||
func newTestIDPStorage(t *testing.T) storage.Storage {
|
||||
t.Helper()
|
||||
|
||||
st := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte("OldPass1!"), bcrypt.DefaultCost)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, st.CreatePassword(context.Background(), storage.Password{
|
||||
Email: "user@example.com",
|
||||
Username: "User",
|
||||
UserID: "user-1",
|
||||
Hash: hash,
|
||||
}))
|
||||
require.NoError(t, st.CreateUserIdentity(context.Background(), storage.UserIdentity{
|
||||
UserID: "user-1",
|
||||
ConnectorID: localConnectorID,
|
||||
MFASecrets: map[string]*storage.MFASecret{
|
||||
defaultTOTPAuthenticatorID: {
|
||||
AuthenticatorID: defaultTOTPAuthenticatorID,
|
||||
Type: "TOTP",
|
||||
Secret: "otpauth://totp/NetBird:user@example.com?secret=ABC",
|
||||
Confirmed: true,
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
},
|
||||
WebAuthnCredentials: map[string][]storage.WebAuthnCredential{
|
||||
"webauthn": {{CredentialID: []byte("credential")}},
|
||||
},
|
||||
}))
|
||||
require.NoError(t, st.CreateAuthSession(context.Background(), storage.AuthSession{
|
||||
UserID: "user-1",
|
||||
ConnectorID: localConnectorID,
|
||||
Nonce: "nonce",
|
||||
}))
|
||||
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: cliClientID, Name: "CLI"}))
|
||||
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: dashboardClientID, Name: "Dashboard"}))
|
||||
|
||||
return st
|
||||
}
|
||||
|
||||
func TestRunChangePassword(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
st := newTestIDPStorage(t)
|
||||
var out bytes.Buffer
|
||||
|
||||
err := runChangePassword(ctx, st, &out, userSelector{email: "user@example.com"}, "NewPass1!")
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, out.String(), "Password updated")
|
||||
|
||||
user, err := st.GetPassword(ctx, "user@example.com")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, bcrypt.CompareHashAndPassword(user.Hash, []byte("NewPass1!")))
|
||||
|
||||
_, err = st.GetAuthSession(ctx, "user-1", localConnectorID)
|
||||
require.ErrorIs(t, err, storage.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestRunChangePasswordValidatesPassword(t *testing.T) {
|
||||
st := newTestIDPStorage(t)
|
||||
err := runChangePassword(context.Background(), st, io.Discard, userSelector{email: "user@example.com"}, "short")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid password")
|
||||
}
|
||||
|
||||
func TestRunResetMFA(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
st := newTestIDPStorage(t)
|
||||
var out bytes.Buffer
|
||||
|
||||
encodedUserID := nbdex.EncodeDexUserID("user-1", localConnectorID)
|
||||
err := runResetMFA(ctx, st, &out, userSelector{userID: encodedUserID})
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, out.String(), "MFA reset")
|
||||
|
||||
identity, err := st.GetUserIdentity(ctx, "user-1", localConnectorID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, identity.MFASecrets)
|
||||
require.Empty(t, identity.WebAuthnCredentials)
|
||||
|
||||
_, err = st.GetAuthSession(ctx, "user-1", localConnectorID)
|
||||
require.ErrorIs(t, err, storage.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestRunResetMFAWithoutEnrollment(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
st := newTestIDPStorage(t)
|
||||
require.NoError(t, st.UpdateUserIdentity(ctx, "user-1", localConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
|
||||
old.MFASecrets = nil
|
||||
old.WebAuthnCredentials = nil
|
||||
return old, nil
|
||||
}))
|
||||
|
||||
var out bytes.Buffer
|
||||
err := runResetMFA(ctx, st, &out, userSelector{email: "user@example.com"})
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, out.String(), "No MFA enrollment found")
|
||||
}
|
||||
|
||||
func TestSetIDPClientsMFA(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
st := newTestIDPStorage(t)
|
||||
|
||||
require.NoError(t, setIDPClientsMFA(ctx, st, true))
|
||||
status, err := idpClientsMFAStatus(ctx, st)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "enabled", status)
|
||||
|
||||
require.NoError(t, setIDPClientsMFA(ctx, st, false))
|
||||
status, err = idpClientsMFAStatus(ctx, st)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "disabled", status)
|
||||
}
|
||||
|
||||
func TestUserSelectorValidate(t *testing.T) {
|
||||
require.NoError(t, userSelector{email: " user@example.com "}.validate())
|
||||
require.NoError(t, userSelector{userID: "user-1"}.validate())
|
||||
require.Error(t, userSelector{}.validate())
|
||||
require.Error(t, userSelector{email: "user@example.com", userID: "user-1"}.validate())
|
||||
}
|
||||
|
||||
func TestFindLocalUserNotFound(t *testing.T) {
|
||||
st := newTestIDPStorage(t)
|
||||
_, err := findLocalUser(context.Background(), st, userSelector{email: "missing@example.com"})
|
||||
require.Error(t, err)
|
||||
require.True(t, strings.Contains(err.Error(), "not found"))
|
||||
}
|
||||
|
||||
func TestResolvePasswordInputFromStdin(t *testing.T) {
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetIn(strings.NewReader("NewPass1!\n"))
|
||||
|
||||
password, err := resolvePasswordInput(cmd, "", "-")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "NewPass1!", password)
|
||||
}
|
||||
|
||||
func TestResolvePasswordInputRejectsMultipleSources(t *testing.T) {
|
||||
_, err := resolvePasswordInput(&cobra.Command{}, "NewPass1!", "-")
|
||||
require.Error(t, err)
|
||||
}
|
||||
@@ -83,7 +83,7 @@ func init() {
|
||||
|
||||
rootCmd.AddCommand(migrationCmd)
|
||||
|
||||
tc := newTokenCommands()
|
||||
tc.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
|
||||
rootCmd.AddCommand(tc)
|
||||
ac := newAdminCommands()
|
||||
ac.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
|
||||
rootCmd.AddCommand(ac)
|
||||
}
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var tokenDatadir string
|
||||
|
||||
// newTokenCommands creates the token command tree with management-specific store opener.
|
||||
func newTokenCommands() *cobra.Command {
|
||||
cmd := tokencmd.NewCommands(withTokenStore)
|
||||
cmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
|
||||
return cmd
|
||||
}
|
||||
|
||||
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
|
||||
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||
if err := util.InitLog("error", "console"); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||
|
||||
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
datadir := config.Datadir
|
||||
if tokenDatadir != "" {
|
||||
datadir = tokenDatadir
|
||||
}
|
||||
|
||||
s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create store: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := s.Close(ctx); err != nil {
|
||||
log.Debugf("close store: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, s)
|
||||
}
|
||||
@@ -1205,7 +1205,7 @@ func (s *Server) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*pr
|
||||
return nil, msg
|
||||
}
|
||||
|
||||
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()), realIP)
|
||||
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()))
|
||||
if err != nil {
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
@@ -1254,10 +1254,7 @@ func (s *Server) Logout(ctx context.Context, req *proto.EncryptedMessage) (*prot
|
||||
func toProtocolChecks(ctx context.Context, postureChecks []*posture.Checks) []*proto.Checks {
|
||||
protoChecks := make([]*proto.Checks, 0, len(postureChecks))
|
||||
for _, postureCheck := range postureChecks {
|
||||
check := toProtocolCheck(postureCheck)
|
||||
if check != nil {
|
||||
protoChecks = append(protoChecks, check)
|
||||
}
|
||||
protoChecks = append(protoChecks, toProtocolCheck(postureCheck))
|
||||
}
|
||||
|
||||
return protoChecks
|
||||
@@ -1281,9 +1278,5 @@ func toProtocolCheck(postureCheck *posture.Checks) *proto.Checks {
|
||||
}
|
||||
}
|
||||
|
||||
if len(protoCheck.Files) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return protoCheck
|
||||
}
|
||||
|
||||
@@ -1889,12 +1889,12 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
|
||||
// concurrent stream that started earlier loses the optimistic-lock race
|
||||
// in MarkPeerConnected and bails without writing.
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, RealIP: realIP}, accountID)
|
||||
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
||||
}
|
||||
|
||||
if err := am.MarkPeerConnected(ctx, peerPubKey, accountID, syncTime.UnixNano(), netMap); err != nil {
|
||||
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano(), netMap); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
@@ -1914,13 +1914,13 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error {
|
||||
func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, RealIP: realIP, UpdateAccountPeers: true}, accountID)
|
||||
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ type Manager interface {
|
||||
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
|
||||
MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
|
||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
@@ -123,7 +123,7 @@ type Manager interface {
|
||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
|
||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error
|
||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error
|
||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||
|
||||
@@ -1323,17 +1323,17 @@ func (mr *MockManagerMockRecorder) ExtendPeerSession(ctx, peerPubKey, userID int
|
||||
}
|
||||
|
||||
// MarkPeerConnected mocks base method.
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, accountID, sessionStartedAt, nmap)
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// MarkPeerConnected indicates an expected call of MarkPeerConnected.
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, accountID, sessionStartedAt, nmap)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
|
||||
}
|
||||
|
||||
// MarkPeerDisconnected mocks base method.
|
||||
@@ -1586,17 +1586,17 @@ func (mr *MockManagerMockRecorder) SyncPeer(ctx, sync, accountID interface{}) *g
|
||||
}
|
||||
|
||||
// SyncPeerMeta mocks base method.
|
||||
func (m *MockManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta peer.PeerSystemMeta, realIP net.IP) error {
|
||||
func (m *MockManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta peer.PeerSystemMeta) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SyncPeerMeta", ctx, peerPubKey, meta, realIP)
|
||||
ret := m.ctrl.Call(m, "SyncPeerMeta", ctx, peerPubKey, meta)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SyncPeerMeta indicates an expected call of SyncPeerMeta.
|
||||
func (mr *MockManagerMockRecorder) SyncPeerMeta(ctx, peerPubKey, meta, realIP interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) SyncPeerMeta(ctx, peerPubKey, meta interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncPeerMeta", reflect.TypeOf((*MockManager)(nil).SyncPeerMeta), ctx, peerPubKey, meta, realIP)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncPeerMeta", reflect.TypeOf((*MockManager)(nil).SyncPeerMeta), ctx, peerPubKey, meta)
|
||||
}
|
||||
|
||||
// SyncUserJWTGroups mocks base method.
|
||||
|
||||
@@ -1836,7 +1836,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
@@ -1907,7 +1907,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1935,7 +1935,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
|
||||
t.Run("disconnect peer when session token matches", func(t *testing.T) {
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -1956,7 +1956,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) {
|
||||
// Newer stream wins on connect (sets SessionStartedAt = now ns).
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -1980,7 +1980,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
|
||||
t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) {
|
||||
node2SyncTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, node2SyncTime.UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "node 2 should connect peer")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -1990,7 +1990,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
"SessionStartedAt should equal node2SyncTime token")
|
||||
|
||||
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, node1StaleSyncTime.UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "stale connect should not return error")
|
||||
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -2052,7 +2052,7 @@ func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
|
||||
defer done.Done()
|
||||
ready.Done()
|
||||
start.Wait()
|
||||
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, token, nil)
|
||||
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token, nil)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -2093,7 +2093,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
|
||||
@@ -39,7 +39,7 @@ type MockAccountManager struct {
|
||||
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
|
||||
MarkPeerDisconnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
@@ -114,7 +114,7 @@ type MockAccountManager struct {
|
||||
GetIdpManagerFunc func() idp.Manager
|
||||
UpdateIntegratedValidatorFunc func(ctx context.Context, accountID, userID, validator string, groups []string) error
|
||||
GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error
|
||||
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error)
|
||||
@@ -345,9 +345,9 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
|
||||
}
|
||||
|
||||
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
if am.MarkPeerConnectedFunc != nil {
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, accountID, sessionStartedAt, nmap)
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
@@ -975,9 +975,9 @@ func (am *MockAccountManager) GroupValidation(ctx context.Context, accountId str
|
||||
}
|
||||
|
||||
// SyncPeerMeta mocks SyncPeerMeta of the AccountManager interface
|
||||
func (am *MockAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error {
|
||||
func (am *MockAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error {
|
||||
if am.SyncPeerMetaFunc != nil {
|
||||
return am.SyncPeerMetaFunc(ctx, peerPubKey, meta, realIP)
|
||||
return am.SyncPeerMetaFunc(ctx, peerPubKey, meta)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method SyncPeerMeta is not implemented")
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
//
|
||||
// Disconnects use MarkPeerDisconnected and require the session to match
|
||||
// exactly; see PeerStatus.SessionStartedAt for the protocol.
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusConnect, time.Since(start))
|
||||
@@ -102,6 +102,10 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
}
|
||||
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusApplied)
|
||||
|
||||
if am.geo != nil && realIP != nil {
|
||||
am.updatePeerLocationIfChanged(ctx, accountID, peer, realIP)
|
||||
}
|
||||
|
||||
if err = am.schedulePeerExpirations(ctx, accountID, peer); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -191,28 +195,24 @@ func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerP
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolvePeerLocation looks up the geo location for realIP, returning nil when
|
||||
// there is nothing to apply: geo disabled, no real IP, the IP is unchanged from
|
||||
// what the peer already has, or the lookup failed. Geo lookups are skipped on
|
||||
// same-IP reconnects since they are comparatively expensive. The returned value
|
||||
// is applied by Peer.UpdateMetaIfNew so the change is persisted by its peer save.
|
||||
func (am *DefaultAccountManager) resolvePeerLocation(ctx context.Context, peer *nbpeer.Peer, realIP net.IP) *nbpeer.Location {
|
||||
if am.geo == nil || realIP == nil {
|
||||
return nil
|
||||
}
|
||||
// updatePeerLocationIfChanged refreshes the geolocation on a separate
|
||||
// row update, only when the connection IP actually changed. Geo lookups
|
||||
// are expensive so we skip same-IP reconnects.
|
||||
func (am *DefaultAccountManager) updatePeerLocationIfChanged(ctx context.Context, accountID string, peer *nbpeer.Peer, realIP net.IP) {
|
||||
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) {
|
||||
return nil
|
||||
return
|
||||
}
|
||||
location, err := am.geo.Lookup(realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
|
||||
return nil
|
||||
return
|
||||
}
|
||||
return &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: location.Country.ISOCode,
|
||||
CityName: location.City.Names.En,
|
||||
GeoNameID: location.City.GeonameID,
|
||||
peer.Location.ConnectionIP = realIP
|
||||
peer.Location.CountryCode = location.Country.ISOCode
|
||||
peer.Location.CityName = location.City.Names.En
|
||||
peer.Location.GeoNameID = location.City.GeonameID
|
||||
if err := am.Store.SavePeerLocation(ctx, accountID, peer); err != nil {
|
||||
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -980,8 +980,7 @@ func getPeerIPDNSLabel(ip netip.Addr, peerHostName string) (string, error) {
|
||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
var peer *nbpeer.Peer
|
||||
var ipv6CapabilityChanged bool
|
||||
var metaDiff nbpeer.MetaDiff
|
||||
var updated, versionChanged, ipv6CapabilityChanged bool
|
||||
var err error
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
@@ -1011,10 +1010,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
}
|
||||
|
||||
oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
|
||||
newLocation := am.resolvePeerLocation(ctx, peer, sync.RealIP)
|
||||
metaDiff = peer.UpdateMetaIfNew(ctx, sync.Meta, newLocation)
|
||||
updated, versionChanged = peer.UpdateMetaIfNew(ctx, sync.Meta)
|
||||
ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
|
||||
if metaDiff.Updated() {
|
||||
if updated {
|
||||
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
|
||||
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
|
||||
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
|
||||
@@ -1042,10 +1040,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
metaDiffAffectsPosture := posture.AffectsPosture(&metaDiff, resPostureChecks)
|
||||
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || metaDiffAffectsPosture || metaDiff.VersionChanged || metaDiff.Hostname {
|
||||
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(resPostureChecks) > 0 || versionChanged)) {
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, metaDiffAffectsPosture)
|
||||
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(resPostureChecks) > 0)
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
@@ -1062,8 +1059,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
// metadata change that flips a posture result removes this peer from others'
|
||||
// maps asymmetrically; that case (and an invalid peer, whose map is empty) falls
|
||||
// back to the resolver.
|
||||
func (am *DefaultAccountManager) syncPeerAffectedPeers(ctx context.Context, accountID, peerID string, nmap *types.NetworkMap, peerNotValid, metaChangeAffectedPosture bool) []string {
|
||||
if peerNotValid || metaChangeAffectedPosture {
|
||||
func (am *DefaultAccountManager) syncPeerAffectedPeers(ctx context.Context, accountID, peerID string, nmap *types.NetworkMap, peerNotValid, metaUpdated, hasPostureChecks bool) []string {
|
||||
if peerNotValid || (metaUpdated && hasPostureChecks) {
|
||||
return am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, []string{peerID})
|
||||
}
|
||||
return affectedPeerIDsFromNetworkMap(nmap, peerID)
|
||||
@@ -1173,7 +1170,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
}
|
||||
|
||||
// This is needed to keep in memory for the peer config. Otherwise browser client will end in a retry loop
|
||||
peer.Meta = login.Meta
|
||||
peer.UpdateMetaIfNew(ctx, login.Meta)
|
||||
|
||||
peerGroupIDs, err = getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
|
||||
if err != nil {
|
||||
|
||||
@@ -256,18 +256,14 @@ func (p *Peer) Copy() *Peer {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateMetaIfNew updates peer's system metadata and connection geo location if
|
||||
// new information is provided. newLocation is the geo location resolved from the
|
||||
// peer's current connection IP, or nil when there is nothing to apply (geo
|
||||
// disabled, no real IP, or the IP is unchanged); the caller owns the expensive
|
||||
// lookup and the same-IP guard. It returns a MetaDiff describing what changed;
|
||||
// diff.Updated() reports whether the peer needs to be persisted.
|
||||
func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta, newLocation *Location) MetaDiff {
|
||||
// UpdateMetaIfNew updates peer's system metadata if new information is provided
|
||||
// returns true if meta was updated, false otherwise
|
||||
func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta) (updated, versionChanged bool) {
|
||||
if meta.isEmpty() {
|
||||
return MetaDiff{}
|
||||
return updated, versionChanged
|
||||
}
|
||||
|
||||
versionChanged := p.Meta.WtVersion != meta.WtVersion
|
||||
versionChanged = p.Meta.WtVersion != meta.WtVersion
|
||||
|
||||
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
|
||||
if meta.UIVersion == "" {
|
||||
@@ -276,177 +272,97 @@ func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta, newLoca
|
||||
|
||||
oldVersion := p.Meta.WtVersion
|
||||
|
||||
diff := diffMeta(p.Meta, meta)
|
||||
if diff.Any() {
|
||||
diff := metaDiff(p.Meta, meta)
|
||||
if len(diff) != 0 {
|
||||
p.Meta = meta
|
||||
}
|
||||
diff.VersionChanged = versionChanged
|
||||
|
||||
locationInfo := ""
|
||||
if newLocation != nil {
|
||||
p.Location = *newLocation
|
||||
diff.LocationChanged = true
|
||||
locationInfo = fmt.Sprintf("location changed to %s, ", newLocation.ConnectionIP)
|
||||
updated = true
|
||||
}
|
||||
|
||||
versionInfo := ""
|
||||
if diff.VersionChanged {
|
||||
if versionChanged {
|
||||
versionInfo = fmt.Sprintf("version changed: %s -> %s, ", oldVersion, meta.WtVersion)
|
||||
}
|
||||
|
||||
if diff.Any() || diff.VersionChanged || diff.LocationChanged {
|
||||
if len(diff) > 0 || versionChanged {
|
||||
log.WithContext(ctx).
|
||||
Debugf("peer meta updated, %s%s%d field(s) changed: %s", versionInfo, locationInfo, len(diff.Changed), strings.Join(diff.Changed, ", "))
|
||||
Debugf("peer meta updated, %s%d field(s) changed: %s", versionInfo, len(diff), strings.Join(diff, ", "))
|
||||
}
|
||||
|
||||
return diff
|
||||
}
|
||||
|
||||
// MetaDiff records which PeerSystemMeta fields differ between two metas. Each bool
|
||||
// maps to a single struct field, except Environment, which is split into Cloud and
|
||||
// Platform. Changed holds the human-readable `field: <old> -> <new>` entries so the
|
||||
// existing log line and isEqual can be derived from the same comparison.
|
||||
//
|
||||
// VersionChanged and LocationChanged sit outside the per-meta-field set:
|
||||
// VersionChanged tracks the WireGuard client version specifically (compared before
|
||||
// the UIVersion fixup, to signal client upgrades) and LocationChanged tracks the
|
||||
// peer's connection geo location, which lives on Peer rather than PeerSystemMeta.
|
||||
// Neither contributes an entry to Changed, so the field-coverage accounting stays
|
||||
// driven purely by the PeerSystemMeta comparison.
|
||||
type MetaDiff struct {
|
||||
Hostname bool
|
||||
GoOS bool
|
||||
Kernel bool
|
||||
KernelVersion bool
|
||||
Core bool
|
||||
Platform bool
|
||||
OS bool
|
||||
OSVersion bool
|
||||
WtVersion bool
|
||||
UIVersion bool
|
||||
SystemSerialNumber bool
|
||||
SystemProductName bool
|
||||
SystemManufacturer bool
|
||||
EnvironmentCloud bool
|
||||
EnvironmentPlatform bool
|
||||
Flags bool
|
||||
Capabilities bool
|
||||
NetworkAddresses bool
|
||||
Files bool
|
||||
|
||||
VersionChanged bool
|
||||
LocationChanged bool
|
||||
|
||||
Changed []string
|
||||
}
|
||||
|
||||
// Any reports whether any PeerSystemMeta field changed.
|
||||
func (d MetaDiff) Any() bool {
|
||||
return len(d.Changed) != 0
|
||||
}
|
||||
|
||||
// Updated reports whether the peer needs to be persisted: any meta field changed
|
||||
// or the geo location changed. The version flag alone does not imply a write,
|
||||
// since a version change is also reflected in the WtVersion meta field.
|
||||
func (d MetaDiff) Updated() bool {
|
||||
return d.Any() || d.LocationChanged || d.VersionChanged
|
||||
return updated, versionChanged
|
||||
}
|
||||
|
||||
// metaDiff returns a human-readable list of the fields that differ between the
|
||||
// old and new meta, each formatted as `field: <old> -> <new>`. It is the single
|
||||
// source of truth for meta comparison: isEqual reports equality as an empty
|
||||
// diff, so the log line can never disagree with the change decision. Slices are
|
||||
// cloned before sorting, so callers' meta is not mutated.
|
||||
func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
|
||||
return diffMeta(oldMeta, newMeta).Changed
|
||||
}
|
||||
|
||||
// diffMeta compares two metas field by field, returning both a per-field flag set
|
||||
// (for callers that need to know exactly what changed, e.g. matching against
|
||||
// posture checks) and the human-readable Changed list. It is the single source of
|
||||
// truth for meta comparison: isEqual reports equality as an empty diff, so the log
|
||||
// line, the change decision, and the flags can never disagree.
|
||||
func diffMeta(oldMeta, newMeta PeerSystemMeta) MetaDiff {
|
||||
var d MetaDiff
|
||||
var diff []string
|
||||
add := func(field string, oldVal, newVal any) {
|
||||
d.Changed = append(d.Changed, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
|
||||
diff = append(diff, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
|
||||
}
|
||||
|
||||
if oldMeta.Hostname != newMeta.Hostname {
|
||||
d.Hostname = true
|
||||
add("hostname", oldMeta.Hostname, newMeta.Hostname)
|
||||
}
|
||||
if oldMeta.GoOS != newMeta.GoOS {
|
||||
d.GoOS = true
|
||||
add("goos", oldMeta.GoOS, newMeta.GoOS)
|
||||
}
|
||||
if oldMeta.Kernel != newMeta.Kernel {
|
||||
d.Kernel = true
|
||||
add("kernel", oldMeta.Kernel, newMeta.Kernel)
|
||||
}
|
||||
if oldMeta.KernelVersion != newMeta.KernelVersion {
|
||||
d.KernelVersion = true
|
||||
add("kernel_version", oldMeta.KernelVersion, newMeta.KernelVersion)
|
||||
}
|
||||
if oldMeta.Core != newMeta.Core {
|
||||
d.Core = true
|
||||
add("core", oldMeta.Core, newMeta.Core)
|
||||
}
|
||||
if oldMeta.Platform != newMeta.Platform {
|
||||
d.Platform = true
|
||||
add("platform", oldMeta.Platform, newMeta.Platform)
|
||||
}
|
||||
if oldMeta.OS != newMeta.OS {
|
||||
d.OS = true
|
||||
add("os", oldMeta.OS, newMeta.OS)
|
||||
}
|
||||
if oldMeta.OSVersion != newMeta.OSVersion {
|
||||
d.OSVersion = true
|
||||
add("os_version", oldMeta.OSVersion, newMeta.OSVersion)
|
||||
}
|
||||
if oldMeta.WtVersion != newMeta.WtVersion {
|
||||
d.WtVersion = true
|
||||
add("wt_version", oldMeta.WtVersion, newMeta.WtVersion)
|
||||
}
|
||||
if oldMeta.UIVersion != newMeta.UIVersion {
|
||||
d.UIVersion = true
|
||||
add("ui_version", oldMeta.UIVersion, newMeta.UIVersion)
|
||||
}
|
||||
if oldMeta.SystemSerialNumber != newMeta.SystemSerialNumber {
|
||||
d.SystemSerialNumber = true
|
||||
add("system_serial_number", oldMeta.SystemSerialNumber, newMeta.SystemSerialNumber)
|
||||
}
|
||||
if oldMeta.SystemProductName != newMeta.SystemProductName {
|
||||
d.SystemProductName = true
|
||||
add("system_product_name", oldMeta.SystemProductName, newMeta.SystemProductName)
|
||||
}
|
||||
if oldMeta.SystemManufacturer != newMeta.SystemManufacturer {
|
||||
d.SystemManufacturer = true
|
||||
add("system_manufacturer", oldMeta.SystemManufacturer, newMeta.SystemManufacturer)
|
||||
}
|
||||
if oldMeta.Environment.Cloud != newMeta.Environment.Cloud {
|
||||
d.EnvironmentCloud = true
|
||||
add("environment_cloud", oldMeta.Environment.Cloud, newMeta.Environment.Cloud)
|
||||
}
|
||||
if oldMeta.Environment.Platform != newMeta.Environment.Platform {
|
||||
d.EnvironmentPlatform = true
|
||||
add("environment_platform", oldMeta.Environment.Platform, newMeta.Environment.Platform)
|
||||
}
|
||||
if !oldMeta.Flags.isEqual(newMeta.Flags) {
|
||||
d.Flags = true
|
||||
add("flags", fmt.Sprintf("%+v", oldMeta.Flags), fmt.Sprintf("%+v", newMeta.Flags))
|
||||
}
|
||||
if !capabilitiesEqual(oldMeta.Capabilities, newMeta.Capabilities) {
|
||||
d.Capabilities = true
|
||||
add("capabilities", oldMeta.Capabilities, newMeta.Capabilities)
|
||||
}
|
||||
|
||||
if !sameMultiset(oldMeta.NetworkAddresses, newMeta.NetworkAddresses) {
|
||||
d.NetworkAddresses = true
|
||||
add("network_addresses", fmt.Sprintf("%v", oldMeta.NetworkAddresses), fmt.Sprintf("%v", newMeta.NetworkAddresses))
|
||||
}
|
||||
|
||||
if !sameMultiset(oldMeta.Files, newMeta.Files) {
|
||||
d.Files = true
|
||||
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
|
||||
}
|
||||
|
||||
return d
|
||||
return diff
|
||||
}
|
||||
|
||||
// sameMultiset reports whether two slices contain the same elements with the
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"regexp"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -52,34 +51,6 @@ type Checks struct {
|
||||
Checks ChecksDefinition `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// AffectsPosture reports whether the peer metadata changes described by diff can
|
||||
// alter the outcome of any of the given posture checks. It maps each check kind to
|
||||
// the metadata fields it inspects, so an unrelated change (e.g. a hostname update)
|
||||
// does not force a posture re-evaluation.
|
||||
func AffectsPosture(diff *nbpeer.MetaDiff, checks []*Checks) bool {
|
||||
if diff == nil {
|
||||
return false
|
||||
}
|
||||
for _, c := range checks {
|
||||
if c.Checks.ProcessCheck != nil && diff.Files {
|
||||
return true
|
||||
}
|
||||
if c.Checks.OSVersionCheck != nil && (diff.OSVersion || diff.OS || diff.KernelVersion) {
|
||||
return true
|
||||
}
|
||||
if c.Checks.NBVersionCheck != nil && diff.WtVersion {
|
||||
return true
|
||||
}
|
||||
if c.Checks.GeoLocationCheck != nil && diff.LocationChanged {
|
||||
return true
|
||||
}
|
||||
if c.Checks.PeerNetworkRangeCheck != nil && diff.NetworkAddresses {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ChecksDefinition contains definition of actual check
|
||||
type ChecksDefinition struct {
|
||||
NBVersionCheck *NBVersionCheck `json:",omitempty"`
|
||||
|
||||
@@ -581,6 +581,28 @@ func (s *SqlStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accoun
|
||||
return result.RowsAffected > 0, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerWithLocation *nbpeer.Peer) error {
|
||||
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
|
||||
var peerCopy nbpeer.Peer
|
||||
// Since the location field has been migrated to JSON serialization,
|
||||
// updating the struct ensures the correct data format is inserted into the database.
|
||||
peerCopy.Location = peerWithLocation.Location
|
||||
|
||||
result := s.db.Model(&nbpeer.Peer{}).
|
||||
Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID).
|
||||
Updates(peerCopy)
|
||||
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save peer locations to store: %v", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApproveAccountPeers marks all peers that currently require approval in the given account as approved.
|
||||
func (s *SqlStore) ApproveAccountPeers(ctx context.Context, accountID string) (int, error) {
|
||||
result := s.db.Model(&nbpeer.Peer{}).
|
||||
|
||||
@@ -618,6 +618,56 @@ func TestSqlStore_SavePeerStatus(t *testing.T) {
|
||||
assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal")
|
||||
}
|
||||
|
||||
func TestSqlStore_SavePeerLocation(t *testing.T) {
|
||||
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
|
||||
require.NoError(t, err)
|
||||
|
||||
peer := &nbpeer.Peer{
|
||||
AccountID: account.Id,
|
||||
ID: "testpeer",
|
||||
Location: nbpeer.Location{
|
||||
ConnectionIP: net.ParseIP("0.0.0.0"),
|
||||
CountryCode: "YY",
|
||||
CityName: "City",
|
||||
GeoNameID: 1,
|
||||
},
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
}
|
||||
// error is expected as peer is not in store yet
|
||||
err = store.SavePeerLocation(context.Background(), account.Id, peer)
|
||||
assert.Error(t, err)
|
||||
|
||||
account.Peers[peer.ID] = peer
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
peer.Location.ConnectionIP = net.ParseIP("35.1.1.1")
|
||||
peer.Location.CountryCode = "DE"
|
||||
peer.Location.CityName = "Berlin"
|
||||
peer.Location.GeoNameID = 2950159
|
||||
|
||||
err = store.SavePeerLocation(context.Background(), account.Id, account.Peers[peer.ID])
|
||||
assert.NoError(t, err)
|
||||
|
||||
account, err = store.GetAccount(context.Background(), account.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
actual := account.Peers[peer.ID].Location
|
||||
assert.Equal(t, peer.Location, actual)
|
||||
|
||||
peer.ID = "non-existing-peer"
|
||||
err = store.SavePeerLocation(context.Background(), account.Id, peer)
|
||||
assert.Error(t, err)
|
||||
parsedErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
|
||||
}
|
||||
|
||||
func Test_TestGetAccountByPrivateDomain(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
|
||||
@@ -185,6 +185,7 @@ type Store interface {
|
||||
// recorded by the database. Returns true when the update happened,
|
||||
// false when a newer session has taken over.
|
||||
MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error)
|
||||
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
ApproveAccountPeers(ctx context.Context, accountID string) (int, error)
|
||||
DeletePeer(ctx context.Context, accountID string, peerID string) error
|
||||
|
||||
|
||||
@@ -2968,6 +2968,20 @@ func (mr *MockStoreMockRecorder) SavePeer(ctx, accountID, peer interface{}) *gom
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeer", reflect.TypeOf((*MockStore)(nil).SavePeer), ctx, accountID, peer)
|
||||
}
|
||||
|
||||
// SavePeerLocation mocks base method.
|
||||
func (m *MockStore) SavePeerLocation(ctx context.Context, accountID string, peer *peer.Peer) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SavePeerLocation", ctx, accountID, peer)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SavePeerLocation indicates an expected call of SavePeerLocation.
|
||||
func (mr *MockStoreMockRecorder) SavePeerLocation(ctx, accountID, peer interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeerLocation", reflect.TypeOf((*MockStore)(nil).SavePeerLocation), ctx, accountID, peer)
|
||||
}
|
||||
|
||||
// SavePeerStatus mocks base method.
|
||||
func (m *MockStore) SavePeerStatus(ctx context.Context, accountID, peerID string, status peer.PeerStatus) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -12,9 +12,6 @@ type PeerSync struct {
|
||||
WireGuardPubKey string
|
||||
// Meta is the system information passed by peer, must be always present
|
||||
Meta nbpeer.PeerSystemMeta
|
||||
// RealIP is the peer's connection IP, used to refresh its geo location.
|
||||
// May be nil when the request has no associated connection IP.
|
||||
RealIP net.IP
|
||||
// UpdateAccountPeers indicate updating account peers,
|
||||
// which occurs when the peer's metadata is updated
|
||||
UpdateAccountPeers bool
|
||||
|
||||
@@ -1847,12 +1847,17 @@ func (am *DefaultAccountManager) DeleteUserInvite(ctx context.Context, accountID
|
||||
|
||||
const minPasswordLength = 8
|
||||
|
||||
// validatePassword checks password strength requirements:
|
||||
// validatePassword checks password strength requirements.
|
||||
func validatePassword(password string) error {
|
||||
return ValidatePassword(password)
|
||||
}
|
||||
|
||||
// ValidatePassword checks password strength requirements:
|
||||
// - Minimum 8 characters
|
||||
// - At least 1 digit
|
||||
// - At least 1 uppercase letter
|
||||
// - At least 1 special character
|
||||
func validatePassword(password string) error {
|
||||
func ValidatePassword(password string) error {
|
||||
if len(password) < minPasswordLength {
|
||||
return errors.New("password must be at least 8 characters long")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user