mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[management, infrastructure, idp] Simplified IdP Management - Embedded IdP (#5008)
Embed Dex as a built-in IdP to simplify self-hosting setup. Adds an embedded OIDC Identity Provider (Dex) with local user management and optional external IdP connectors (Google/GitHub/OIDC/SAML), plus device-auth flow for CLI login. Introduces instance onboarding/setup endpoints (including owner creation), field-level encryption for sensitive user data, a streamlined self-hosting provisioning script, and expanded APIs + test coverage for IdP management. more at https://github.com/netbirdio/netbird/pull/5008#issuecomment-3718987393
This commit is contained in:
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/server"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
var newServer = func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) server.Server {
|
||||
@@ -135,76 +136,208 @@ var (
|
||||
|
||||
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Config, error) {
|
||||
loadedConfig := &nbconfig.Config{}
|
||||
_, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig)
|
||||
if _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
applyCommandLineOverrides(loadedConfig)
|
||||
|
||||
// Apply EmbeddedIdP config to HttpConfig if embedded IdP is enabled
|
||||
err := applyEmbeddedIdPConfig(loadedConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := applyOIDCConfig(ctx, loadedConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logConfigInfo(loadedConfig)
|
||||
|
||||
if err := ensureEncryptionKey(ctx, mgmtConfigPath, loadedConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return loadedConfig, nil
|
||||
}
|
||||
|
||||
// applyCommandLineOverrides applies command-line flag overrides to the config
|
||||
func applyCommandLineOverrides(cfg *nbconfig.Config) {
|
||||
if mgmtLetsencryptDomain != "" {
|
||||
loadedConfig.HttpConfig.LetsEncryptDomain = mgmtLetsencryptDomain
|
||||
cfg.HttpConfig.LetsEncryptDomain = mgmtLetsencryptDomain
|
||||
}
|
||||
if mgmtDataDir != "" {
|
||||
loadedConfig.Datadir = mgmtDataDir
|
||||
cfg.Datadir = mgmtDataDir
|
||||
}
|
||||
|
||||
if certKey != "" && certFile != "" {
|
||||
loadedConfig.HttpConfig.CertFile = certFile
|
||||
loadedConfig.HttpConfig.CertKey = certKey
|
||||
cfg.HttpConfig.CertFile = certFile
|
||||
cfg.HttpConfig.CertKey = certKey
|
||||
}
|
||||
}
|
||||
|
||||
// applyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
|
||||
// This allows users to only specify EmbeddedIdP config without duplicating values in HttpConfig.
|
||||
func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error {
|
||||
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
oidcEndpoint := loadedConfig.HttpConfig.OIDCConfigEndpoint
|
||||
if oidcEndpoint != "" {
|
||||
// if OIDCConfigEndpoint is specified, we can load DeviceAuthEndpoint and TokenEndpoint automatically
|
||||
log.WithContext(ctx).Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint)
|
||||
oidcConfig, err := fetchOIDCConfig(ctx, oidcEndpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.WithContext(ctx).Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
|
||||
// apply some defaults based on the EmbeddedIdP config
|
||||
if disableSingleAccMode {
|
||||
// Embedded IdP requires single account mode - multiple account mode is not supported
|
||||
return fmt.Errorf("embedded IdP requires single account mode; multiple account mode is not supported with embedded IdP. Please remove --disable-single-account-mode flag")
|
||||
}
|
||||
// Enable user deletion from IDP by default if EmbeddedIdP is enabled
|
||||
userDeleteFromIDPEnabled = true
|
||||
|
||||
log.WithContext(ctx).Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
|
||||
oidcConfig.Issuer, loadedConfig.HttpConfig.AuthIssuer)
|
||||
loadedConfig.HttpConfig.AuthIssuer = oidcConfig.Issuer
|
||||
|
||||
log.WithContext(ctx).Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s",
|
||||
oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation)
|
||||
loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
|
||||
|
||||
if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(nbconfig.NONE)) {
|
||||
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
|
||||
oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint)
|
||||
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
|
||||
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s",
|
||||
oidcConfig.DeviceAuthEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint)
|
||||
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint = oidcConfig.DeviceAuthEndpoint
|
||||
|
||||
u, err := url.Parse(oidcEndpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s",
|
||||
u.Host, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain)
|
||||
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host
|
||||
|
||||
if loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope == "" {
|
||||
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope = nbconfig.DefaultDeviceAuthFlowScope
|
||||
}
|
||||
}
|
||||
|
||||
if loadedConfig.PKCEAuthorizationFlow != nil {
|
||||
log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
|
||||
oidcConfig.TokenEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint)
|
||||
loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
|
||||
log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s",
|
||||
oidcConfig.AuthorizationEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint)
|
||||
loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint
|
||||
}
|
||||
// Ensure HttpConfig exists
|
||||
if cfg.HttpConfig == nil {
|
||||
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
|
||||
}
|
||||
|
||||
if loadedConfig.Relay != nil {
|
||||
log.Infof("Relay addresses: %v", loadedConfig.Relay.Addresses)
|
||||
// Set storage defaults based on Datadir
|
||||
if cfg.EmbeddedIdP.Storage.Type == "" {
|
||||
cfg.EmbeddedIdP.Storage.Type = "sqlite3"
|
||||
}
|
||||
if cfg.EmbeddedIdP.Storage.Config.File == "" && cfg.Datadir != "" {
|
||||
cfg.EmbeddedIdP.Storage.Config.File = path.Join(cfg.Datadir, "idp.db")
|
||||
}
|
||||
|
||||
return loadedConfig, err
|
||||
issuer := cfg.EmbeddedIdP.Issuer
|
||||
|
||||
// Set AuthIssuer from EmbeddedIdP issuer
|
||||
if cfg.HttpConfig.AuthIssuer == "" {
|
||||
cfg.HttpConfig.AuthIssuer = issuer
|
||||
}
|
||||
|
||||
// Set AuthAudience to the dashboard client ID
|
||||
if cfg.HttpConfig.AuthAudience == "" {
|
||||
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
|
||||
}
|
||||
|
||||
// Set AuthUserIDClaim to "sub" (standard OIDC claim)
|
||||
if cfg.HttpConfig.AuthUserIDClaim == "" {
|
||||
cfg.HttpConfig.AuthUserIDClaim = "sub"
|
||||
}
|
||||
|
||||
// Set AuthKeysLocation to the JWKS endpoint
|
||||
if cfg.HttpConfig.AuthKeysLocation == "" {
|
||||
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
|
||||
}
|
||||
|
||||
// Set OIDCConfigEndpoint to the discovery endpoint
|
||||
if cfg.HttpConfig.OIDCConfigEndpoint == "" {
|
||||
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
|
||||
}
|
||||
|
||||
// Copy SignKeyRefreshEnabled from EmbeddedIdP config
|
||||
if cfg.EmbeddedIdP.SignKeyRefreshEnabled {
|
||||
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
|
||||
func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||
oidcEndpoint := cfg.HttpConfig.OIDCConfigEndpoint
|
||||
if oidcEndpoint == "" || cfg.EmbeddedIdP != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint)
|
||||
oidcConfig, err := fetchOIDCConfig(ctx, oidcEndpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.WithContext(ctx).Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
|
||||
|
||||
log.WithContext(ctx).Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
|
||||
oidcConfig.Issuer, cfg.HttpConfig.AuthIssuer)
|
||||
cfg.HttpConfig.AuthIssuer = oidcConfig.Issuer
|
||||
|
||||
log.WithContext(ctx).Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s",
|
||||
oidcConfig.JwksURI, cfg.HttpConfig.AuthKeysLocation)
|
||||
cfg.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
|
||||
|
||||
if err := applyDeviceAuthFlowConfig(ctx, cfg, &oidcConfig, oidcEndpoint); err != nil {
|
||||
return err
|
||||
}
|
||||
applyPKCEFlowConfig(ctx, cfg, &oidcConfig)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyDeviceAuthFlowConfig applies OIDC config to DeviceAuthorizationFlow if enabled
|
||||
func applyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse, oidcEndpoint string) error {
|
||||
if cfg.DeviceAuthorizationFlow == nil || strings.ToLower(cfg.DeviceAuthorizationFlow.Provider) == string(nbconfig.NONE) {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
|
||||
oidcConfig.TokenEndpoint, cfg.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint)
|
||||
cfg.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
|
||||
|
||||
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s",
|
||||
oidcConfig.DeviceAuthEndpoint, cfg.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint)
|
||||
cfg.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint = oidcConfig.DeviceAuthEndpoint
|
||||
|
||||
u, err := url.Parse(oidcEndpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s",
|
||||
u.Host, cfg.DeviceAuthorizationFlow.ProviderConfig.Domain)
|
||||
cfg.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host
|
||||
|
||||
if cfg.DeviceAuthorizationFlow.ProviderConfig.Scope == "" {
|
||||
cfg.DeviceAuthorizationFlow.ProviderConfig.Scope = nbconfig.DefaultDeviceAuthFlowScope
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyPKCEFlowConfig applies OIDC config to PKCEAuthorizationFlow if configured
|
||||
func applyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse) {
|
||||
if cfg.PKCEAuthorizationFlow == nil {
|
||||
return
|
||||
}
|
||||
log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
|
||||
oidcConfig.TokenEndpoint, cfg.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint)
|
||||
cfg.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
|
||||
|
||||
log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s",
|
||||
oidcConfig.AuthorizationEndpoint, cfg.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint)
|
||||
cfg.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint
|
||||
}
|
||||
|
||||
// logConfigInfo logs informational messages about the loaded configuration
|
||||
func logConfigInfo(cfg *nbconfig.Config) {
|
||||
if cfg.EmbeddedIdP != nil {
|
||||
log.Infof("running with the embedded IdP: %v", cfg.EmbeddedIdP.Issuer)
|
||||
}
|
||||
if cfg.Relay != nil {
|
||||
log.Infof("Relay addresses: %v", cfg.Relay.Addresses)
|
||||
}
|
||||
}
|
||||
|
||||
// ensureEncryptionKey generates and saves a DataStoreEncryptionKey if not set
|
||||
func ensureEncryptionKey(ctx context.Context, configPath string, cfg *nbconfig.Config) error {
|
||||
if cfg.DataStoreEncryptionKey != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("DataStoreEncryptionKey is not set, generating a new key")
|
||||
key, err := crypt.GenerateKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate datastore encryption key: %v", err)
|
||||
}
|
||||
cfg.DataStoreEncryptionKey = key
|
||||
|
||||
if err := util.DirectWriteJson(ctx, configPath, cfg); err != nil {
|
||||
return fmt.Errorf("failed to save config with new encryption key: %v", err)
|
||||
}
|
||||
log.WithContext(ctx).Infof("DataStoreEncryptionKey generated and saved to config")
|
||||
return nil
|
||||
}
|
||||
|
||||
// OIDCConfigResponse used for parsing OIDC config response
|
||||
|
||||
@@ -309,7 +309,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis
|
||||
setupKeys := map[string]*types.SetupKey{}
|
||||
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
||||
|
||||
owner := types.NewOwnerUser(userID)
|
||||
owner := types.NewOwnerUser(userID, "", "")
|
||||
owner.AccountID = accountID
|
||||
users[userID] = owner
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
@@ -29,6 +28,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -62,6 +62,14 @@ func (s *BaseServer) Store() store.Store {
|
||||
log.Fatalf("failed to create store: %v", err)
|
||||
}
|
||||
|
||||
if s.Config.DataStoreEncryptionKey != "" {
|
||||
fieldEncrypt, err := crypt.NewFieldEncrypt(s.Config.DataStoreEncryptionKey)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create field encryptor: %v", err)
|
||||
}
|
||||
store.SetFieldEncrypt(fieldEncrypt)
|
||||
}
|
||||
|
||||
return store
|
||||
})
|
||||
}
|
||||
@@ -73,27 +81,18 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
log.Fatalf("failed to initialize integration metrics: %v", err)
|
||||
}
|
||||
|
||||
eventStore, key, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics)
|
||||
eventStore, _, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize event store: %v", err)
|
||||
}
|
||||
|
||||
if s.Config.DataStoreEncryptionKey != key {
|
||||
log.WithContext(context.Background()).Infof("update Config with activity store key")
|
||||
s.Config.DataStoreEncryptionKey = key
|
||||
err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.Config)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to update Config with activity store: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return eventStore
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController())
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController(), s.IdpManager())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -145,7 +144,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
}
|
||||
|
||||
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
||||
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController())
|
||||
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create management server: %v", err)
|
||||
}
|
||||
|
||||
@@ -57,6 +57,10 @@ type Config struct {
|
||||
|
||||
// disable default all-to-all policy
|
||||
DisableDefaultPolicy bool
|
||||
|
||||
// EmbeddedIdP contains configuration for the embedded Dex OIDC provider.
|
||||
// When set, Dex will be embedded in the management server and serve requests at /oauth2/
|
||||
EmbeddedIdP *idp.EmbeddedIdPConfig
|
||||
}
|
||||
|
||||
// GetAuthAudiences returns the audience from the http config and device authorization flow config
|
||||
|
||||
@@ -44,6 +44,9 @@ func maybeCreateNamed[T any](s Server, name string, createFunc func() T) (result
|
||||
|
||||
func maybeCreateKeyed[T any](s Server, key string, createFunc func() T) (result T, isNew bool) {
|
||||
if t, ok := s.GetContainer(key); ok {
|
||||
if t == nil {
|
||||
return result, false
|
||||
}
|
||||
return t.(T), false
|
||||
}
|
||||
|
||||
|
||||
@@ -55,14 +55,33 @@ func (s *BaseServer) SecretsManager() grpc.SecretsManager {
|
||||
}
|
||||
|
||||
func (s *BaseServer) AuthManager() auth.Manager {
|
||||
audiences := s.Config.GetAuthAudiences()
|
||||
audience := s.Config.HttpConfig.AuthAudience
|
||||
keysLocation := s.Config.HttpConfig.AuthKeysLocation
|
||||
signingKeyRefreshEnabled := s.Config.HttpConfig.IdpSignKeyRefreshEnabled
|
||||
issuer := s.Config.HttpConfig.AuthIssuer
|
||||
userIDClaim := s.Config.HttpConfig.AuthUserIDClaim
|
||||
|
||||
// Use embedded IdP configuration if available
|
||||
if oauthProvider := s.OAuthConfigProvider(); oauthProvider != nil {
|
||||
audiences = oauthProvider.GetClientIDs()
|
||||
if len(audiences) > 0 {
|
||||
audience = audiences[0] // Use the first client ID as the primary audience
|
||||
}
|
||||
keysLocation = oauthProvider.GetKeysLocation()
|
||||
signingKeyRefreshEnabled = true
|
||||
issuer = oauthProvider.GetIssuer()
|
||||
userIDClaim = oauthProvider.GetUserIDClaim()
|
||||
}
|
||||
|
||||
return Create(s, func() auth.Manager {
|
||||
return auth.NewManager(s.Store(),
|
||||
s.Config.HttpConfig.AuthIssuer,
|
||||
s.Config.HttpConfig.AuthAudience,
|
||||
s.Config.HttpConfig.AuthKeysLocation,
|
||||
s.Config.HttpConfig.AuthUserIDClaim,
|
||||
s.Config.GetAuthAudiences(),
|
||||
s.Config.HttpConfig.IdpSignKeyRefreshEnabled)
|
||||
issuer,
|
||||
audience,
|
||||
keysLocation,
|
||||
userIDClaim,
|
||||
audiences,
|
||||
signingKeyRefreshEnabled)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -95,6 +95,17 @@ func (s *BaseServer) IdpManager() idp.Manager {
|
||||
return Create(s, func() idp.Manager {
|
||||
var idpManager idp.Manager
|
||||
var err error
|
||||
// Use embedded IdP manager if embedded Dex is configured and enabled.
|
||||
// Legacy IdpManager won't be used anymore even if configured.
|
||||
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
|
||||
idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create embedded IDP manager: %v", err)
|
||||
}
|
||||
return idpManager
|
||||
}
|
||||
|
||||
// Fall back to external IdP manager
|
||||
if s.Config.IdpManagerConfig != nil {
|
||||
idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
|
||||
if err != nil {
|
||||
@@ -105,6 +116,25 @@ func (s *BaseServer) IdpManager() idp.Manager {
|
||||
})
|
||||
}
|
||||
|
||||
// OAuthConfigProvider is only relevant when we have an embedded IdP manager. Otherwise must be nil
|
||||
func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
|
||||
if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
idpManager := s.IdpManager()
|
||||
if idpManager == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reuse the EmbeddedIdPManager instance from IdpManager
|
||||
// EmbeddedIdPManager implements both idp.Manager and idp.OAuthConfigProvider
|
||||
if provider, ok := idpManager.(idp.OAuthConfigProvider); ok {
|
||||
return provider
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BaseServer) GroupsManager() groups.Manager {
|
||||
return Create(s, func() groups.Manager {
|
||||
return groups.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager())
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
@@ -22,7 +23,6 @@ import (
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/metrics"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/util/wsproxy"
|
||||
wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
@@ -40,7 +40,7 @@ type Server interface {
|
||||
SetContainer(key string, container any)
|
||||
}
|
||||
|
||||
// Server holds the HTTP BaseServer instance.
|
||||
// BaseServer holds the HTTP server instance.
|
||||
// Add any additional fields you need, such as database connections, Config, etc.
|
||||
type BaseServer struct {
|
||||
// Config holds the server configuration
|
||||
@@ -144,7 +144,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
|
||||
}
|
||||
|
||||
rootHandler := s.handlerFunc(s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter())
|
||||
rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter())
|
||||
switch {
|
||||
case s.certManager != nil:
|
||||
// a call to certManager.Listener() always creates a new listener so we do it once
|
||||
@@ -215,6 +215,10 @@ func (s *BaseServer) Stop() error {
|
||||
if s.update != nil {
|
||||
s.update.StopWatch()
|
||||
}
|
||||
// Stop embedded IdP if configured
|
||||
if embeddedIdP, ok := s.IdpManager().(*idp.EmbeddedIdPManager); ok {
|
||||
_ = embeddedIdP.Stop(ctx)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-s.Errors():
|
||||
@@ -246,11 +250,7 @@ func (s *BaseServer) SetContainer(key string, container any) {
|
||||
log.Tracef("container with key %s set successfully", key)
|
||||
}
|
||||
|
||||
func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config) error {
|
||||
return util.DirectWriteJson(ctx, path, config)
|
||||
}
|
||||
|
||||
func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
|
||||
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
|
||||
wsProxy := wsproxyserver.New(gRPCHandler, wsproxyserver.WithOTelMeter(meter))
|
||||
|
||||
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
pb "github.com/golang/protobuf/proto" // nolint
|
||||
"github.com/golang/protobuf/ptypes/timestamp"
|
||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc/codes"
|
||||
@@ -24,6 +25,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -69,6 +71,8 @@ type Server struct {
|
||||
|
||||
networkMapController network_map.Controller
|
||||
|
||||
oAuthConfigProvider idp.OAuthConfigProvider
|
||||
|
||||
syncSem atomic.Int32
|
||||
syncLim int32
|
||||
}
|
||||
@@ -83,6 +87,7 @@ func NewServer(
|
||||
authManager auth.Manager,
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator,
|
||||
networkMapController network_map.Controller,
|
||||
oAuthConfigProvider idp.OAuthConfigProvider,
|
||||
) (*Server, error) {
|
||||
if appMetrics != nil {
|
||||
// update gauge based on number of connected peers which is equal to open gRPC streams
|
||||
@@ -119,6 +124,7 @@ func NewServer(
|
||||
blockPeersWithSameConfig: blockPeersWithSameConfig,
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
networkMapController: networkMapController,
|
||||
oAuthConfigProvider: oAuthConfigProvider,
|
||||
|
||||
loginFilter: newLoginFilter(),
|
||||
|
||||
@@ -761,32 +767,48 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr
|
||||
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||
}
|
||||
|
||||
if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(nbconfig.NONE) {
|
||||
return nil, status.Error(codes.NotFound, "no device authorization flow information available")
|
||||
}
|
||||
var flowInfoResp *proto.DeviceAuthorizationFlow
|
||||
|
||||
provider, ok := proto.DeviceAuthorizationFlowProvider_value[strings.ToUpper(s.config.DeviceAuthorizationFlow.Provider)]
|
||||
if !ok {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "no provider found in the protocol for %s", s.config.DeviceAuthorizationFlow.Provider)
|
||||
}
|
||||
// Use embedded IdP configuration if available
|
||||
if s.oAuthConfigProvider != nil {
|
||||
flowInfoResp = &proto.DeviceAuthorizationFlow{
|
||||
Provider: proto.DeviceAuthorizationFlow_HOSTED,
|
||||
ProviderConfig: &proto.ProviderConfig{
|
||||
ClientID: s.oAuthConfigProvider.GetCLIClientID(),
|
||||
Audience: s.oAuthConfigProvider.GetCLIClientID(),
|
||||
DeviceAuthEndpoint: s.oAuthConfigProvider.GetDeviceAuthEndpoint(),
|
||||
TokenEndpoint: s.oAuthConfigProvider.GetTokenEndpoint(),
|
||||
Scope: s.oAuthConfigProvider.GetDefaultScopes(),
|
||||
},
|
||||
}
|
||||
} else {
|
||||
if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(nbconfig.NONE) {
|
||||
return nil, status.Error(codes.NotFound, "no device authorization flow information available")
|
||||
}
|
||||
|
||||
flowInfoResp := &proto.DeviceAuthorizationFlow{
|
||||
Provider: proto.DeviceAuthorizationFlowProvider(provider),
|
||||
ProviderConfig: &proto.ProviderConfig{
|
||||
ClientID: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientID,
|
||||
ClientSecret: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientSecret,
|
||||
Domain: s.config.DeviceAuthorizationFlow.ProviderConfig.Domain,
|
||||
Audience: s.config.DeviceAuthorizationFlow.ProviderConfig.Audience,
|
||||
DeviceAuthEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint,
|
||||
TokenEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint,
|
||||
Scope: s.config.DeviceAuthorizationFlow.ProviderConfig.Scope,
|
||||
UseIDToken: s.config.DeviceAuthorizationFlow.ProviderConfig.UseIDToken,
|
||||
},
|
||||
provider, ok := proto.DeviceAuthorizationFlowProvider_value[strings.ToUpper(s.config.DeviceAuthorizationFlow.Provider)]
|
||||
if !ok {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "no provider found in the protocol for %s", s.config.DeviceAuthorizationFlow.Provider)
|
||||
}
|
||||
|
||||
flowInfoResp = &proto.DeviceAuthorizationFlow{
|
||||
Provider: proto.DeviceAuthorizationFlowProvider(provider),
|
||||
ProviderConfig: &proto.ProviderConfig{
|
||||
ClientID: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientID,
|
||||
ClientSecret: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientSecret,
|
||||
Domain: s.config.DeviceAuthorizationFlow.ProviderConfig.Domain,
|
||||
Audience: s.config.DeviceAuthorizationFlow.ProviderConfig.Audience,
|
||||
DeviceAuthEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint,
|
||||
TokenEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint,
|
||||
Scope: s.config.DeviceAuthorizationFlow.ProviderConfig.Scope,
|
||||
UseIDToken: s.config.DeviceAuthorizationFlow.ProviderConfig.UseIDToken,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "failed to encrypt no device authorization flow information")
|
||||
return nil, status.Error(codes.Internal, "failed to encrypt device authorization flow information")
|
||||
}
|
||||
|
||||
return &proto.EncryptedMessage{
|
||||
@@ -820,30 +842,47 @@ func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.Encryp
|
||||
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||
}
|
||||
|
||||
if s.config.PKCEAuthorizationFlow == nil {
|
||||
return nil, status.Error(codes.NotFound, "no pkce authorization flow information available")
|
||||
}
|
||||
var initInfoFlow *proto.PKCEAuthorizationFlow
|
||||
|
||||
initInfoFlow := &proto.PKCEAuthorizationFlow{
|
||||
ProviderConfig: &proto.ProviderConfig{
|
||||
Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience,
|
||||
ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID,
|
||||
ClientSecret: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientSecret,
|
||||
TokenEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint,
|
||||
AuthorizationEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint,
|
||||
Scope: s.config.PKCEAuthorizationFlow.ProviderConfig.Scope,
|
||||
RedirectURLs: s.config.PKCEAuthorizationFlow.ProviderConfig.RedirectURLs,
|
||||
UseIDToken: s.config.PKCEAuthorizationFlow.ProviderConfig.UseIDToken,
|
||||
DisablePromptLogin: s.config.PKCEAuthorizationFlow.ProviderConfig.DisablePromptLogin,
|
||||
LoginFlag: uint32(s.config.PKCEAuthorizationFlow.ProviderConfig.LoginFlag),
|
||||
},
|
||||
// Use embedded IdP configuration if available
|
||||
if s.oAuthConfigProvider != nil {
|
||||
initInfoFlow = &proto.PKCEAuthorizationFlow{
|
||||
ProviderConfig: &proto.ProviderConfig{
|
||||
Audience: s.oAuthConfigProvider.GetCLIClientID(),
|
||||
ClientID: s.oAuthConfigProvider.GetCLIClientID(),
|
||||
TokenEndpoint: s.oAuthConfigProvider.GetTokenEndpoint(),
|
||||
AuthorizationEndpoint: s.oAuthConfigProvider.GetAuthorizationEndpoint(),
|
||||
Scope: s.oAuthConfigProvider.GetDefaultScopes(),
|
||||
RedirectURLs: s.oAuthConfigProvider.GetCLIRedirectURLs(),
|
||||
LoginFlag: uint32(common.LoginFlagPromptLogin),
|
||||
},
|
||||
}
|
||||
} else {
|
||||
if s.config.PKCEAuthorizationFlow == nil {
|
||||
return nil, status.Error(codes.NotFound, "no pkce authorization flow information available")
|
||||
}
|
||||
|
||||
initInfoFlow = &proto.PKCEAuthorizationFlow{
|
||||
ProviderConfig: &proto.ProviderConfig{
|
||||
Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience,
|
||||
ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID,
|
||||
ClientSecret: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientSecret,
|
||||
TokenEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint,
|
||||
AuthorizationEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint,
|
||||
Scope: s.config.PKCEAuthorizationFlow.ProviderConfig.Scope,
|
||||
RedirectURLs: s.config.PKCEAuthorizationFlow.ProviderConfig.RedirectURLs,
|
||||
UseIDToken: s.config.PKCEAuthorizationFlow.ProviderConfig.UseIDToken,
|
||||
DisablePromptLogin: s.config.PKCEAuthorizationFlow.ProviderConfig.DisablePromptLogin,
|
||||
LoginFlag: uint32(s.config.PKCEAuthorizationFlow.ProviderConfig.LoginFlag),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
flowInfoResp := s.integratedPeerValidator.ValidateFlowResponse(ctx, peerKey.String(), initInfoFlow)
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information")
|
||||
return nil, status.Error(codes.Internal, "failed to encrypt pkce authorization flow information")
|
||||
}
|
||||
|
||||
return &proto.EncryptedMessage{
|
||||
|
||||
@@ -243,7 +243,7 @@ func BuildManager(
|
||||
am.externalCacheManager = nbcache.NewUserDataCache(cacheStore)
|
||||
am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cacheStore)
|
||||
|
||||
if !isNil(am.idpManager) {
|
||||
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
|
||||
go func() {
|
||||
err := am.warmupIDPCache(ctx, cacheStore)
|
||||
if err != nil {
|
||||
@@ -557,7 +557,7 @@ func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx co
|
||||
|
||||
// newAccount creates a new Account with a generated ID and generated default setup keys.
|
||||
// If ID is already in use (due to collision) we try one more time before returning error
|
||||
func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*types.Account, error) {
|
||||
func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain, email, name string) (*types.Account, error) {
|
||||
for i := 0; i < 2; i++ {
|
||||
accountId := xid.New().String()
|
||||
|
||||
@@ -568,7 +568,7 @@ func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain
|
||||
log.WithContext(ctx).Warnf("an account with ID already exists, retrying...")
|
||||
continue
|
||||
case statusErr.Type() == status.NotFound:
|
||||
newAccount := newAccountWithId(ctx, accountId, userID, domain, am.disableDefaultPolicy)
|
||||
newAccount := newAccountWithId(ctx, accountId, userID, domain, email, name, am.disableDefaultPolicy)
|
||||
am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil)
|
||||
return newAccount, nil
|
||||
default:
|
||||
@@ -741,23 +741,23 @@ func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID st
|
||||
// If user does have an account, it returns the user's account ID.
|
||||
// If the user doesn't have an account, it creates one using the provided domain.
|
||||
// Returns the account ID or an error if none is found or created.
|
||||
func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) {
|
||||
if userID == "" {
|
||||
func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth auth.UserAuth) (string, error) {
|
||||
if userAuth.UserId == "" {
|
||||
return "", status.Errorf(status.NotFound, "no valid userID provided")
|
||||
}
|
||||
|
||||
accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
|
||||
acc, err := am.GetOrCreateAccountByUser(ctx, userAuth)
|
||||
if err != nil {
|
||||
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
||||
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userAuth.UserId)
|
||||
}
|
||||
|
||||
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil {
|
||||
if err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, acc.Id); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return account.Id, nil
|
||||
return acc.Id, nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
@@ -768,9 +768,19 @@ func isNil(i idp.Manager) bool {
|
||||
return i == nil || reflect.ValueOf(i).IsNil()
|
||||
}
|
||||
|
||||
// IsEmbeddedIdp checks if the IDP manager is an embedded IDP (data stored locally in DB).
|
||||
// When true, user cache should be skipped and data fetched directly from the IDP manager.
|
||||
func IsEmbeddedIdp(i idp.Manager) bool {
|
||||
if isNil(i) {
|
||||
return false
|
||||
}
|
||||
_, ok := i.(*idp.EmbeddedIdPManager)
|
||||
return ok
|
||||
}
|
||||
|
||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||
if !isNil(am.idpManager) {
|
||||
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
|
||||
// user can be nil if it wasn't found (e.g., just created)
|
||||
user, err := am.lookupUserInCache(ctx, userID, accountID)
|
||||
if err != nil {
|
||||
@@ -1016,6 +1026,9 @@ func (am *DefaultAccountManager) isCacheFresh(ctx context.Context, accountUsers
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accountID, userID string) error {
|
||||
if IsEmbeddedIdp(am.idpManager) {
|
||||
return nil
|
||||
}
|
||||
data, err := am.getAccountFromCache(ctx, accountID, false)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1107,7 +1120,7 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
|
||||
|
||||
lowerDomain := strings.ToLower(userAuth.Domain)
|
||||
|
||||
newAccount, err := am.newAccount(ctx, userAuth.UserId, lowerDomain)
|
||||
newAccount, err := am.newAccount(ctx, userAuth.UserId, lowerDomain, userAuth.Email, userAuth.Name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -1132,7 +1145,7 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth auth.UserAuth) (string, error) {
|
||||
newUser := types.NewRegularUser(userAuth.UserId)
|
||||
newUser := types.NewRegularUser(userAuth.UserId, userAuth.Email, userAuth.Name)
|
||||
newUser.AccountID = domainAccountID
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, domainAccountID)
|
||||
@@ -1315,6 +1328,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
|
||||
if err != nil {
|
||||
// this is not really possible because we got an account by user ID
|
||||
log.Errorf("failed to get user by ID %s: %v", userAuth.UserId, err)
|
||||
return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId)
|
||||
}
|
||||
|
||||
@@ -1512,7 +1526,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
|
||||
}
|
||||
|
||||
if userAuth.DomainCategory != types.PrivateCategory || !isDomainValid(userAuth.Domain) {
|
||||
return am.GetAccountIDByUserID(ctx, userAuth.UserId, userAuth.Domain)
|
||||
return am.GetAccountIDByUserID(ctx, userAuth)
|
||||
}
|
||||
|
||||
if userAuth.AccountId != "" {
|
||||
@@ -1734,7 +1748,7 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
|
||||
}
|
||||
|
||||
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
|
||||
func newAccountWithId(ctx context.Context, accountID, userID, domain string, disableDefaultPolicy bool) *types.Account {
|
||||
func newAccountWithId(ctx context.Context, accountID, userID, domain, email, name string, disableDefaultPolicy bool) *types.Account {
|
||||
log.WithContext(ctx).Debugf("creating new account")
|
||||
|
||||
network := types.NewNetwork()
|
||||
@@ -1744,7 +1758,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis
|
||||
setupKeys := map[string]*types.SetupKey{}
|
||||
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
||||
|
||||
owner := types.NewOwnerUser(userID)
|
||||
owner := types.NewOwnerUser(userID, email, name)
|
||||
owner.AccountID = accountID
|
||||
users[userID] = owner
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ import (
|
||||
type ExternalCacheManager nbcache.UserDataCache
|
||||
|
||||
type Manager interface {
|
||||
GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*types.Account, error)
|
||||
GetOrCreateAccountByUser(ctx context.Context, userAuth auth.UserAuth) (*types.Account, error)
|
||||
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
|
||||
CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType, expiresIn time.Duration,
|
||||
autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error)
|
||||
@@ -44,7 +44,7 @@ type Manager interface {
|
||||
GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error)
|
||||
GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error)
|
||||
AccountExists(ctx context.Context, accountID string) (bool, error)
|
||||
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
||||
GetAccountIDByUserID(ctx context.Context, userAuth auth.UserAuth) (string, error)
|
||||
GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error)
|
||||
DeleteAccount(ctx context.Context, accountID, userID string) error
|
||||
GetUserByID(ctx context.Context, id string) (*types.User, error)
|
||||
@@ -124,4 +124,9 @@ type Manager interface {
|
||||
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
||||
GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetIdentityProvider(ctx context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error)
|
||||
GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error)
|
||||
CreateIdentityProvider(ctx context.Context, accountID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error)
|
||||
UpdateIdentityProvider(ctx context.Context, accountID, idpID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error)
|
||||
DeleteIdentityProvider(ctx context.Context, accountID, idpID, userID string) error
|
||||
}
|
||||
|
||||
@@ -382,7 +382,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, testCase := range tt {
|
||||
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io", false)
|
||||
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io", "", "", false)
|
||||
account.UpdateSettings(&testCase.accountSettings)
|
||||
account.Network = network
|
||||
account.Peers = testCase.peers
|
||||
@@ -407,7 +407,7 @@ func TestNewAccount(t *testing.T) {
|
||||
domain := "netbird.io"
|
||||
userId := "account_creator"
|
||||
accountID := "account_id"
|
||||
account := newAccountWithId(context.Background(), accountID, userId, domain, false)
|
||||
account := newAccountWithId(context.Background(), accountID, userId, domain, "", "", false)
|
||||
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
|
||||
}
|
||||
|
||||
@@ -418,7 +418,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID, Domain: ""})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -612,7 +612,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain)
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain})
|
||||
require.NoError(t, err, "create init user failed")
|
||||
|
||||
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
@@ -649,10 +649,10 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
|
||||
func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
|
||||
userId := "user-id"
|
||||
domain := "test.domain"
|
||||
_ = newAccountWithId(context.Background(), "", userId, domain, false)
|
||||
_ = newAccountWithId(context.Background(), "", userId, domain, "", "", false)
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userId, Domain: domain})
|
||||
require.NoError(t, err, "create init user failed")
|
||||
// as initAccount was created without account id we have to take the id after account initialization
|
||||
// that happens inside the GetAccountIDByUserID where the id is getting generated
|
||||
@@ -718,7 +718,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
|
||||
}
|
||||
|
||||
userId := "test_user"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, "")
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userId, Domain: ""})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -745,7 +745,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
|
||||
|
||||
userId := "test_user"
|
||||
domain := "hotmail.com"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userId, Domain: domain})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -759,7 +759,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
|
||||
|
||||
domain = "gmail.com"
|
||||
|
||||
account, err = manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
|
||||
account, err = manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userId, Domain: domain})
|
||||
if err != nil {
|
||||
t.Fatalf("got the following error while retrieving existing acc: %v", err)
|
||||
}
|
||||
@@ -782,7 +782,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||
|
||||
userId := "test_user"
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, "")
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userId, Domain: ""})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -795,14 +795,14 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists, "expected to get existing account after creation using userid")
|
||||
|
||||
_, err = manager.GetAccountIDByUserID(context.Background(), "", "")
|
||||
_, err = manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: "", Domain: ""})
|
||||
if err == nil {
|
||||
t.Errorf("expected an error when user ID is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*types.Account, error) {
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain, "", "", false)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1098,7 +1098,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "netbird.cloud")
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID, Domain: "netbird.cloud"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -1849,7 +1849,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
|
||||
@@ -1864,7 +1864,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
_, err = manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
@@ -1876,7 +1876,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
|
||||
@@ -1920,7 +1920,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
@@ -1946,7 +1946,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
},
|
||||
}
|
||||
|
||||
accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
accountID, err = manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||
@@ -1963,7 +1963,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
_, err = manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
@@ -1975,7 +1975,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
@@ -2025,7 +2025,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
@@ -3434,7 +3434,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
|
||||
assert.True(t, cold)
|
||||
})
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("should return true when account is not found in cache", func(t *testing.T) {
|
||||
@@ -3462,7 +3462,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
|
||||
initiatorId := "test-user"
|
||||
domain := "example.com"
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(ctx, initiatorId, domain)
|
||||
account, err := manager.GetOrCreateAccountByUser(ctx, auth.UserAuth{UserId: initiatorId, Domain: domain})
|
||||
require.NoError(t, err)
|
||||
|
||||
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
|
||||
@@ -3575,7 +3575,7 @@ func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("should return account onboarding when onboarding exist", func(t *testing.T) {
|
||||
@@ -3607,7 +3607,7 @@ func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err)
|
||||
|
||||
onboarding := &types.AccountOnboarding{
|
||||
@@ -3646,7 +3646,7 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key1, err := wgtypes.GenerateKey()
|
||||
@@ -3717,7 +3717,7 @@ func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
|
||||
|
||||
// Create a domain-based account with user approval enabled
|
||||
existingAccountID := "existing-account"
|
||||
account := newAccountWithId(context.Background(), existingAccountID, "owner-user", "example.com", false)
|
||||
account := newAccountWithId(context.Background(), existingAccountID, "owner-user", "example.com", "", "", false)
|
||||
account.Settings.Extra = &types.ExtraSettings{
|
||||
UserApprovalRequired: true,
|
||||
}
|
||||
|
||||
@@ -183,6 +183,10 @@ const (
|
||||
|
||||
AccountAutoUpdateVersionUpdated Activity = 92
|
||||
|
||||
IdentityProviderCreated Activity = 93
|
||||
IdentityProviderUpdated Activity = 94
|
||||
IdentityProviderDeleted Activity = 95
|
||||
|
||||
AccountDeleted Activity = 99999
|
||||
)
|
||||
|
||||
@@ -295,6 +299,10 @@ var activityMap = map[Activity]Code{
|
||||
UserCreated: {"User created", "user.create"},
|
||||
|
||||
AccountAutoUpdateVersionUpdated: {"Account AutoUpdate Version updated", "account.settings.auto.version.update"},
|
||||
|
||||
IdentityProviderCreated: {"Identity provider created", "identityprovider.create"},
|
||||
IdentityProviderUpdated: {"Identity provider updated", "identityprovider.update"},
|
||||
IdentityProviderDeleted: {"Identity provider deleted", "identityprovider.delete"},
|
||||
}
|
||||
|
||||
// StringCode returns a string code of the activity
|
||||
|
||||
@@ -49,8 +49,7 @@ func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim s
|
||||
)
|
||||
|
||||
return &manager{
|
||||
store: store,
|
||||
|
||||
store: store,
|
||||
validator: jwtValidator,
|
||||
extractor: claimsExtractor,
|
||||
}
|
||||
|
||||
@@ -277,7 +277,7 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
|
||||
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain, false)
|
||||
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain, "", "", false)
|
||||
|
||||
account.Users[dnsRegularUserID] = &types.User{
|
||||
Id: dnsRegularUserID,
|
||||
|
||||
@@ -379,7 +379,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
|
||||
Id: "example user",
|
||||
AutoGroups: []string{groupForUsers.ID},
|
||||
}
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false)
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, "", "", false)
|
||||
account.Routes[routeResource.ID] = routeResource
|
||||
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
|
||||
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/rs/cors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -29,6 +30,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/dns"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/events"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/groups"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/idp"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/instance"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/networks"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/peers"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/policies"
|
||||
@@ -36,6 +39,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/users"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
@@ -51,23 +56,15 @@ const (
|
||||
)
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(
|
||||
ctx context.Context,
|
||||
accountManager account.Manager,
|
||||
networksManager nbnetworks.Manager,
|
||||
resourceManager resources.Manager,
|
||||
routerManager routers.Manager,
|
||||
groupsManager nbgroups.Manager,
|
||||
LocationManager geolocation.Geolocation,
|
||||
authManager auth.Manager,
|
||||
appMetrics telemetry.AppMetrics,
|
||||
integratedValidator integrated_validator.IntegratedValidator,
|
||||
proxyController port_forwarding.Controller,
|
||||
permissionsManager permissions.Manager,
|
||||
peersManager nbpeers.Manager,
|
||||
settingsManager settings.Manager,
|
||||
networkMapController network_map.Controller,
|
||||
) (http.Handler, error) {
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager) (http.Handler, error) {
|
||||
|
||||
// Register bypass paths for unauthenticated endpoints
|
||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||
}
|
||||
if err := bypass.AddBypassPath("/api/setup"); err != nil {
|
||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||
}
|
||||
|
||||
var rateLimitingConfig *middleware.RateLimiterConfig
|
||||
if os.Getenv(rateLimitingEnabledKey) == "true" {
|
||||
@@ -122,7 +119,14 @@ func NewAPIHandler(
|
||||
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
||||
}
|
||||
|
||||
accounts.AddEndpoints(accountManager, settingsManager, router)
|
||||
// Check if embedded IdP is enabled
|
||||
embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager)
|
||||
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create instance manager: %w", err)
|
||||
}
|
||||
|
||||
accounts.AddEndpoints(accountManager, settingsManager, embeddedIdpEnabled, router)
|
||||
peers.AddEndpoints(accountManager, router, networkMapController)
|
||||
users.AddEndpoints(accountManager, router)
|
||||
setup_keys.AddEndpoints(accountManager, router)
|
||||
@@ -134,6 +138,13 @@ func NewAPIHandler(
|
||||
dns.AddEndpoints(accountManager, router)
|
||||
events.AddEndpoints(accountManager, router)
|
||||
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
|
||||
idp.AddEndpoints(accountManager, router)
|
||||
instance.AddEndpoints(instanceManager, router)
|
||||
|
||||
// Mount embedded IdP handler at /oauth2 path if configured
|
||||
if embeddedIdpEnabled {
|
||||
rootRouter.PathPrefix("/oauth2").Handler(corsMiddleware.Handler(embeddedIdP.Handler()))
|
||||
}
|
||||
|
||||
return rootRouter, nil
|
||||
}
|
||||
|
||||
@@ -36,22 +36,24 @@ const (
|
||||
|
||||
// handler is a handler that handles the server.Account HTTP endpoints
|
||||
type handler struct {
|
||||
accountManager account.Manager
|
||||
settingsManager settings.Manager
|
||||
accountManager account.Manager
|
||||
settingsManager settings.Manager
|
||||
embeddedIdpEnabled bool
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) {
|
||||
accountsHandler := newHandler(accountManager, settingsManager)
|
||||
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool, router *mux.Router) {
|
||||
accountsHandler := newHandler(accountManager, settingsManager, embeddedIdpEnabled)
|
||||
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler creates a new handler HTTP handler
|
||||
func newHandler(accountManager account.Manager, settingsManager settings.Manager) *handler {
|
||||
func newHandler(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool) *handler {
|
||||
return &handler{
|
||||
accountManager: accountManager,
|
||||
settingsManager: settingsManager,
|
||||
accountManager: accountManager,
|
||||
settingsManager: settingsManager,
|
||||
embeddedIdpEnabled: embeddedIdpEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,7 +165,7 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(accountID, settings, meta, onboarding)
|
||||
resp := toAccountResponse(accountID, settings, meta, onboarding, h.embeddedIdpEnabled)
|
||||
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
||||
}
|
||||
|
||||
@@ -290,7 +292,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding)
|
||||
resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding, h.embeddedIdpEnabled)
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, &resp)
|
||||
}
|
||||
@@ -319,7 +321,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding) *api.Account {
|
||||
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding, embeddedIdpEnabled bool) *api.Account {
|
||||
jwtAllowGroups := settings.JWTAllowGroups
|
||||
if jwtAllowGroups == nil {
|
||||
jwtAllowGroups = []string{}
|
||||
@@ -339,6 +341,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
||||
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
|
||||
DnsDomain: &settings.DNSDomain,
|
||||
AutoUpdateVersion: &settings.AutoUpdateVersion,
|
||||
EmbeddedIdpEnabled: &embeddedIdpEnabled,
|
||||
}
|
||||
|
||||
if settings.NetworkRange.IsValid() {
|
||||
|
||||
@@ -33,6 +33,7 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
|
||||
AnyTimes()
|
||||
|
||||
return &handler{
|
||||
embeddedIdpEnabled: false,
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
|
||||
return account.Settings, nil
|
||||
@@ -122,6 +123,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
},
|
||||
expectedArray: true,
|
||||
expectedID: accountID,
|
||||
@@ -145,6 +147,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -168,6 +171,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr("latest"),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -191,6 +195,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -214,6 +219,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
@@ -237,6 +243,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateVersion: sr(""),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
|
||||
196
management/server/http/handlers/idp/idp_handler.go
Normal file
196
management/server/http/handlers/idp/idp_handler.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// handler handles identity provider HTTP endpoints
|
||||
type handler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
// AddEndpoints registers identity provider endpoints
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
h := newHandler(accountManager)
|
||||
router.HandleFunc("/identity-providers", h.getAllIdentityProviders).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers", h.createIdentityProvider).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func newHandler(accountManager account.Manager) *handler {
|
||||
return &handler{
|
||||
accountManager: accountManager,
|
||||
}
|
||||
}
|
||||
|
||||
// getAllIdentityProviders returns all identity providers for the account
|
||||
func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
providers, err := h.accountManager.GetIdentityProviders(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]api.IdentityProvider, 0, len(providers))
|
||||
for _, p := range providers {
|
||||
response = append(response, toAPIResponse(p))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, response)
|
||||
}
|
||||
|
||||
// getIdentityProvider returns a specific identity provider
|
||||
func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
idpID := vars["idpId"]
|
||||
if idpID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "identity provider ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := h.accountManager.GetIdentityProvider(r.Context(), accountID, idpID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toAPIResponse(provider))
|
||||
}
|
||||
|
||||
// createIdentityProvider creates a new identity provider
|
||||
func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
var req api.IdentityProviderRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
idp := fromAPIRequest(&req)
|
||||
|
||||
created, err := h.accountManager.CreateIdentityProvider(r.Context(), accountID, userID, idp)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toAPIResponse(created))
|
||||
}
|
||||
|
||||
// updateIdentityProvider updates an existing identity provider
|
||||
func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
idpID := vars["idpId"]
|
||||
if idpID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "identity provider ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.IdentityProviderRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
idp := fromAPIRequest(&req)
|
||||
|
||||
updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), accountID, idpID, userID, idp)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toAPIResponse(updated))
|
||||
}
|
||||
|
||||
// deleteIdentityProvider deletes an identity provider
|
||||
func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
idpID := vars["idpId"]
|
||||
if idpID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "identity provider ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.accountManager.DeleteIdentityProvider(r.Context(), accountID, idpID, userID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func toAPIResponse(idp *types.IdentityProvider) api.IdentityProvider {
|
||||
resp := api.IdentityProvider{
|
||||
Type: api.IdentityProviderType(idp.Type),
|
||||
Name: idp.Name,
|
||||
Issuer: idp.Issuer,
|
||||
ClientId: idp.ClientID,
|
||||
}
|
||||
if idp.ID != "" {
|
||||
resp.Id = &idp.ID
|
||||
}
|
||||
// Note: ClientSecret is never returned in responses for security
|
||||
return resp
|
||||
}
|
||||
|
||||
func fromAPIRequest(req *api.IdentityProviderRequest) *types.IdentityProvider {
|
||||
return &types.IdentityProvider{
|
||||
Type: types.IdentityProviderType(req.Type),
|
||||
Name: req.Name,
|
||||
Issuer: req.Issuer,
|
||||
ClientID: req.ClientId,
|
||||
ClientSecret: req.ClientSecret,
|
||||
}
|
||||
}
|
||||
438
management/server/http/handlers/idp/idp_handler_test.go
Normal file
438
management/server/http/handlers/idp/idp_handler_test.go
Normal file
@@ -0,0 +1,438 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
testAccountID = "test-account-id"
|
||||
testUserID = "test-user-id"
|
||||
existingIDPID = "existing-idp-id"
|
||||
newIDPID = "new-idp-id"
|
||||
)
|
||||
|
||||
func initIDPTestData(existingIDP *types.IdentityProvider) *handler {
|
||||
return &handler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetIdentityProvidersFunc: func(_ context.Context, accountID, userID string) ([]*types.IdentityProvider, error) {
|
||||
if accountID != testAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
if existingIDP != nil {
|
||||
return []*types.IdentityProvider{existingIDP}, nil
|
||||
}
|
||||
return []*types.IdentityProvider{}, nil
|
||||
},
|
||||
GetIdentityProviderFunc: func(_ context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error) {
|
||||
if accountID != testAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
if existingIDP != nil && idpID == existingIDP.ID {
|
||||
return existingIDP, nil
|
||||
}
|
||||
return nil, status.Errorf(status.NotFound, "identity provider not found")
|
||||
},
|
||||
CreateIdentityProviderFunc: func(_ context.Context, accountID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error) {
|
||||
if accountID != testAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
if idp.Name == "" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "name is required")
|
||||
}
|
||||
created := idp.Copy()
|
||||
created.ID = newIDPID
|
||||
created.AccountID = accountID
|
||||
return created, nil
|
||||
},
|
||||
UpdateIdentityProviderFunc: func(_ context.Context, accountID, idpID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error) {
|
||||
if accountID != testAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
if existingIDP == nil || idpID != existingIDP.ID {
|
||||
return nil, status.Errorf(status.NotFound, "identity provider not found")
|
||||
}
|
||||
updated := idp.Copy()
|
||||
updated.ID = idpID
|
||||
updated.AccountID = accountID
|
||||
return updated, nil
|
||||
},
|
||||
DeleteIdentityProviderFunc: func(_ context.Context, accountID, idpID, userID string) error {
|
||||
if accountID != testAccountID {
|
||||
return status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
if existingIDP == nil || idpID != existingIDP.ID {
|
||||
return status.Errorf(status.NotFound, "identity provider not found")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllIdentityProviders(t *testing.T) {
|
||||
existingIDP := &types.IdentityProvider{
|
||||
ID: existingIDPID,
|
||||
Name: "Test IDP",
|
||||
Type: types.IdentityProviderTypeOIDC,
|
||||
Issuer: "https://issuer.example.com",
|
||||
ClientID: "client-id",
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "Get All Identity Providers",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
h := initIDPTestData(existingIDP)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/identity-providers", nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/identity-providers", h.getAllIdentityProviders).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, recorder.Code)
|
||||
|
||||
content, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
var idps []api.IdentityProvider
|
||||
err = json.Unmarshal(content, &idps)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, idps, tc.expectedCount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetIdentityProvider(t *testing.T) {
|
||||
existingIDP := &types.IdentityProvider{
|
||||
ID: existingIDPID,
|
||||
Name: "Test IDP",
|
||||
Type: types.IdentityProviderTypeOIDC,
|
||||
Issuer: "https://issuer.example.com",
|
||||
ClientID: "client-id",
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
idpID string
|
||||
expectedStatus int
|
||||
expectedBody bool
|
||||
}{
|
||||
{
|
||||
name: "Get Existing Identity Provider",
|
||||
idpID: existingIDPID,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
},
|
||||
{
|
||||
name: "Get Non-Existing Identity Provider",
|
||||
idpID: "non-existing-id",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
expectedBody: false,
|
||||
},
|
||||
}
|
||||
|
||||
h := initIDPTestData(existingIDP)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/identity-providers/%s", tc.idpID), nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, recorder.Code)
|
||||
|
||||
if tc.expectedBody {
|
||||
content, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
var idp api.IdentityProvider
|
||||
err = json.Unmarshal(content, &idp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, existingIDPID, *idp.Id)
|
||||
assert.Equal(t, existingIDP.Name, idp.Name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateIdentityProvider(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
requestBody string
|
||||
expectedStatus int
|
||||
expectedBody bool
|
||||
}{
|
||||
{
|
||||
name: "Create Identity Provider",
|
||||
requestBody: `{
|
||||
"name": "New IDP",
|
||||
"type": "oidc",
|
||||
"issuer": "https://new-issuer.example.com",
|
||||
"client_id": "new-client-id",
|
||||
"client_secret": "new-client-secret"
|
||||
}`,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
},
|
||||
{
|
||||
name: "Create Identity Provider with Invalid JSON",
|
||||
requestBody: `{invalid json`,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: false,
|
||||
},
|
||||
}
|
||||
|
||||
h := initIDPTestData(nil)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/identity-providers", bytes.NewBufferString(tc.requestBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/identity-providers", h.createIdentityProvider).Methods("POST")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, recorder.Code)
|
||||
|
||||
if tc.expectedBody {
|
||||
content, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
var idp api.IdentityProvider
|
||||
err = json.Unmarshal(content, &idp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, newIDPID, *idp.Id)
|
||||
assert.Equal(t, "New IDP", idp.Name)
|
||||
assert.Equal(t, api.IdentityProviderTypeOidc, idp.Type)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateIdentityProvider(t *testing.T) {
|
||||
existingIDP := &types.IdentityProvider{
|
||||
ID: existingIDPID,
|
||||
Name: "Test IDP",
|
||||
Type: types.IdentityProviderTypeOIDC,
|
||||
Issuer: "https://issuer.example.com",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
idpID string
|
||||
requestBody string
|
||||
expectedStatus int
|
||||
expectedBody bool
|
||||
}{
|
||||
{
|
||||
name: "Update Existing Identity Provider",
|
||||
idpID: existingIDPID,
|
||||
requestBody: `{
|
||||
"name": "Updated IDP",
|
||||
"type": "oidc",
|
||||
"issuer": "https://updated-issuer.example.com",
|
||||
"client_id": "updated-client-id"
|
||||
}`,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
},
|
||||
{
|
||||
name: "Update Non-Existing Identity Provider",
|
||||
idpID: "non-existing-id",
|
||||
requestBody: `{
|
||||
"name": "Updated IDP",
|
||||
"type": "oidc",
|
||||
"issuer": "https://updated-issuer.example.com",
|
||||
"client_id": "updated-client-id"
|
||||
}`,
|
||||
expectedStatus: http.StatusNotFound,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
name: "Update Identity Provider with Invalid JSON",
|
||||
idpID: existingIDPID,
|
||||
requestBody: `{invalid json`,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: false,
|
||||
},
|
||||
}
|
||||
|
||||
h := initIDPTestData(existingIDP)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/api/identity-providers/%s", tc.idpID), bytes.NewBufferString(tc.requestBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, recorder.Code)
|
||||
|
||||
if tc.expectedBody {
|
||||
content, err := io.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
var idp api.IdentityProvider
|
||||
err = json.Unmarshal(content, &idp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, existingIDPID, *idp.Id)
|
||||
assert.Equal(t, "Updated IDP", idp.Name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteIdentityProvider(t *testing.T) {
|
||||
existingIDP := &types.IdentityProvider{
|
||||
ID: existingIDPID,
|
||||
Name: "Test IDP",
|
||||
Type: types.IdentityProviderTypeOIDC,
|
||||
Issuer: "https://issuer.example.com",
|
||||
ClientID: "client-id",
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
idpID string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "Delete Existing Identity Provider",
|
||||
idpID: existingIDPID,
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Delete Non-Existing Identity Provider",
|
||||
idpID: "non-existing-id",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
h := initIDPTestData(existingIDP)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/api/identity-providers/%s", tc.idpID), nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, recorder.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToAPIResponse(t *testing.T) {
|
||||
idp := &types.IdentityProvider{
|
||||
ID: "test-id",
|
||||
Name: "Test IDP",
|
||||
Type: types.IdentityProviderTypeGoogle,
|
||||
Issuer: "https://accounts.google.com",
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "should-not-be-returned",
|
||||
}
|
||||
|
||||
response := toAPIResponse(idp)
|
||||
|
||||
assert.Equal(t, "test-id", *response.Id)
|
||||
assert.Equal(t, "Test IDP", response.Name)
|
||||
assert.Equal(t, api.IdentityProviderTypeGoogle, response.Type)
|
||||
assert.Equal(t, "https://accounts.google.com", response.Issuer)
|
||||
assert.Equal(t, "client-id", response.ClientId)
|
||||
// Note: ClientSecret is not included in response type by design
|
||||
}
|
||||
|
||||
func TestFromAPIRequest(t *testing.T) {
|
||||
req := &api.IdentityProviderRequest{
|
||||
Name: "New IDP",
|
||||
Type: api.IdentityProviderTypeOkta,
|
||||
Issuer: "https://dev-123456.okta.com",
|
||||
ClientId: "okta-client-id",
|
||||
ClientSecret: "okta-client-secret",
|
||||
}
|
||||
|
||||
idp := fromAPIRequest(req)
|
||||
|
||||
assert.Equal(t, "New IDP", idp.Name)
|
||||
assert.Equal(t, types.IdentityProviderTypeOkta, idp.Type)
|
||||
assert.Equal(t, "https://dev-123456.okta.com", idp.Issuer)
|
||||
assert.Equal(t, "okta-client-id", idp.ClientID)
|
||||
assert.Equal(t, "okta-client-secret", idp.ClientSecret)
|
||||
}
|
||||
67
management/server/http/handlers/instance/instance_handler.go
Normal file
67
management/server/http/handlers/instance/instance_handler.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package instance
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
// handler handles the instance setup HTTP endpoints
|
||||
type handler struct {
|
||||
instanceManager nbinstance.Manager
|
||||
}
|
||||
|
||||
// AddEndpoints registers the instance setup endpoints.
|
||||
// These endpoints bypass authentication for initial setup.
|
||||
func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
|
||||
h := &handler{
|
||||
instanceManager: instanceManager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/instance", h.getInstanceStatus).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/setup", h.setup).Methods("POST", "OPTIONS")
|
||||
}
|
||||
|
||||
// getInstanceStatus returns the instance status including whether setup is required.
|
||||
// This endpoint is unauthenticated.
|
||||
func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) {
|
||||
setupRequired, err := h.instanceManager.IsSetupRequired(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to check setup status: %v", err)
|
||||
util.WriteErrorResponse("failed to check instance status", http.StatusInternalServerError, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, api.InstanceStatus{
|
||||
SetupRequired: setupRequired,
|
||||
})
|
||||
}
|
||||
|
||||
// setup creates the initial admin user for the instance.
|
||||
// This endpoint is unauthenticated but only works when setup is required.
|
||||
func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
|
||||
var req api.SetupRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("invalid request body", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
userData, err := h.instanceManager.CreateOwnerUser(r.Context(), req.Email, req.Password, req.Name)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
log.WithContext(r.Context()).Infof("instance setup completed: created user %s", req.Email)
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, api.SetupResponse{
|
||||
UserId: userData.ID,
|
||||
Email: userData.Email,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,281 @@
|
||||
package instance
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/mail"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// mockInstanceManager implements instance.Manager for testing
|
||||
type mockInstanceManager struct {
|
||||
isSetupRequired bool
|
||||
isSetupRequiredFn func(ctx context.Context) (bool, error)
|
||||
createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error)
|
||||
}
|
||||
|
||||
func (m *mockInstanceManager) IsSetupRequired(ctx context.Context) (bool, error) {
|
||||
if m.isSetupRequiredFn != nil {
|
||||
return m.isSetupRequiredFn(ctx)
|
||||
}
|
||||
return m.isSetupRequired, nil
|
||||
}
|
||||
|
||||
func (m *mockInstanceManager) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) {
|
||||
if m.createOwnerUserFn != nil {
|
||||
return m.createOwnerUserFn(ctx, email, password, name)
|
||||
}
|
||||
|
||||
// Default mock includes validation like the real manager
|
||||
if !m.isSetupRequired {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "setup already completed")
|
||||
}
|
||||
if email == "" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "email is required")
|
||||
}
|
||||
if _, err := mail.ParseAddress(email); err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "invalid email format")
|
||||
}
|
||||
if name == "" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "name is required")
|
||||
}
|
||||
if password == "" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "password is required")
|
||||
}
|
||||
if len(password) < 8 {
|
||||
return nil, status.Errorf(status.InvalidArgument, "password must be at least 8 characters")
|
||||
}
|
||||
|
||||
return &idp.UserData{
|
||||
ID: "test-user-id",
|
||||
Email: email,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var _ nbinstance.Manager = (*mockInstanceManager)(nil)
|
||||
|
||||
func setupTestRouter(manager nbinstance.Manager) *mux.Router {
|
||||
router := mux.NewRouter()
|
||||
AddEndpoints(manager, router)
|
||||
return router
|
||||
}
|
||||
|
||||
func TestGetInstanceStatus_SetupRequired(t *testing.T) {
|
||||
manager := &mockInstanceManager{isSetupRequired: true}
|
||||
router := setupTestRouter(manager)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/instance", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response api.InstanceStatus
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, response.SetupRequired)
|
||||
}
|
||||
|
||||
func TestGetInstanceStatus_SetupNotRequired(t *testing.T) {
|
||||
manager := &mockInstanceManager{isSetupRequired: false}
|
||||
router := setupTestRouter(manager)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/instance", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response api.InstanceStatus
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, response.SetupRequired)
|
||||
}
|
||||
|
||||
func TestGetInstanceStatus_Error(t *testing.T) {
|
||||
manager := &mockInstanceManager{
|
||||
isSetupRequiredFn: func(ctx context.Context) (bool, error) {
|
||||
return false, errors.New("database error")
|
||||
},
|
||||
}
|
||||
router := setupTestRouter(manager)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/instance", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
}
|
||||
|
||||
func TestSetup_Success(t *testing.T) {
|
||||
manager := &mockInstanceManager{
|
||||
isSetupRequired: true,
|
||||
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
|
||||
assert.Equal(t, "admin@example.com", email)
|
||||
assert.Equal(t, "securepassword123", password)
|
||||
assert.Equal(t, "Admin User", name)
|
||||
return &idp.UserData{
|
||||
ID: "created-user-id",
|
||||
Email: email,
|
||||
Name: name,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
router := setupTestRouter(manager)
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin User"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var response api.SetupResponse
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "created-user-id", response.UserId)
|
||||
assert.Equal(t, "admin@example.com", response.Email)
|
||||
}
|
||||
|
||||
func TestSetup_AlreadyCompleted(t *testing.T) {
|
||||
manager := &mockInstanceManager{isSetupRequired: false}
|
||||
router := setupTestRouter(manager)
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusPreconditionFailed, rec.Code)
|
||||
}
|
||||
|
||||
func TestSetup_MissingEmail(t *testing.T) {
|
||||
manager := &mockInstanceManager{isSetupRequired: true}
|
||||
router := setupTestRouter(manager)
|
||||
|
||||
body := `{"password": "securepassword123"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code)
|
||||
}
|
||||
|
||||
func TestSetup_InvalidEmail(t *testing.T) {
|
||||
manager := &mockInstanceManager{isSetupRequired: true}
|
||||
router := setupTestRouter(manager)
|
||||
|
||||
body := `{"email": "not-an-email", "password": "securepassword123", "name": "User"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
// Note: Invalid email format uses mail.ParseAddress which is treated differently
|
||||
// and returns 400 Bad Request instead of 422 Unprocessable Entity
|
||||
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code)
|
||||
}
|
||||
|
||||
func TestSetup_MissingPassword(t *testing.T) {
|
||||
manager := &mockInstanceManager{isSetupRequired: true}
|
||||
router := setupTestRouter(manager)
|
||||
|
||||
body := `{"email": "admin@example.com", "name": "User"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code)
|
||||
}
|
||||
|
||||
func TestSetup_PasswordTooShort(t *testing.T) {
|
||||
manager := &mockInstanceManager{isSetupRequired: true}
|
||||
router := setupTestRouter(manager)
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "short", "name": "User"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code)
|
||||
}
|
||||
|
||||
func TestSetup_InvalidJSON(t *testing.T) {
|
||||
manager := &mockInstanceManager{isSetupRequired: true}
|
||||
router := setupTestRouter(manager)
|
||||
|
||||
body := `{invalid json}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestSetup_CreateUserError(t *testing.T) {
|
||||
manager := &mockInstanceManager{
|
||||
isSetupRequired: true,
|
||||
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
|
||||
return nil, errors.New("user creation failed")
|
||||
},
|
||||
}
|
||||
router := setupTestRouter(manager)
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "User"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
}
|
||||
|
||||
func TestSetup_ManagerError(t *testing.T) {
|
||||
manager := &mockInstanceManager{
|
||||
isSetupRequired: true,
|
||||
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
|
||||
return nil, status.Errorf(status.Internal, "database error")
|
||||
},
|
||||
}
|
||||
router := setupTestRouter(manager)
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "User"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
}
|
||||
@@ -66,7 +66,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
|
||||
},
|
||||
}
|
||||
|
||||
srvUser := types.NewRegularUser(serviceUser)
|
||||
srvUser := types.NewRegularUser(serviceUser, "", "")
|
||||
srvUser.IsServiceUser = true
|
||||
|
||||
account := &types.Account{
|
||||
@@ -75,7 +75,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
|
||||
Peers: peersMap,
|
||||
Users: map[string]*types.User{
|
||||
adminUser: types.NewAdminUser(adminUser),
|
||||
regularUser: types.NewRegularUser(regularUser),
|
||||
regularUser: types.NewRegularUser(regularUser, "", ""),
|
||||
serviceUser: srvUser,
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
|
||||
@@ -326,6 +326,16 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
|
||||
|
||||
isCurrent := user.ID == currenUserID
|
||||
|
||||
var password *string
|
||||
if user.Password != "" {
|
||||
password = &user.Password
|
||||
}
|
||||
|
||||
var idpID *string
|
||||
if user.IdPID != "" {
|
||||
idpID = &user.IdPID
|
||||
}
|
||||
|
||||
return &api.User{
|
||||
Id: user.ID,
|
||||
Name: user.Name,
|
||||
@@ -339,6 +349,8 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
|
||||
LastLogin: &user.LastLogin,
|
||||
Issued: &user.Issued,
|
||||
PendingApproval: user.PendingApproval,
|
||||
Password: password,
|
||||
IdpId: idpID,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -134,6 +134,9 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
||||
userAuth.IsChild = ok
|
||||
}
|
||||
|
||||
// Email is now extracted in ToUserAuth (from claims or userinfo endpoint)
|
||||
// Available as userAuth.Email
|
||||
|
||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||
accountId, _, err := m.ensureAccount(ctx, userAuth)
|
||||
if err != nil {
|
||||
|
||||
@@ -94,7 +94,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
groupsManagerMock := groups.NewManagerMock()
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
234
management/server/identity_provider.go
Normal file
234
management/server/identity_provider.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// GetIdentityProviders returns all identity providers for an account
|
||||
func (am *DefaultAccountManager) GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error) {
|
||||
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
|
||||
if !ok {
|
||||
log.Warn("identity provider management requires embedded IdP")
|
||||
return []*types.IdentityProvider{}, nil
|
||||
}
|
||||
|
||||
connectors, err := embeddedManager.ListConnectors(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to list identity providers: %v", err)
|
||||
}
|
||||
|
||||
result := make([]*types.IdentityProvider, 0, len(connectors))
|
||||
for _, conn := range connectors {
|
||||
result = append(result, connectorConfigToIdentityProvider(conn, accountID))
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetIdentityProvider returns a specific identity provider by ID
|
||||
func (am *DefaultAccountManager) GetIdentityProvider(ctx context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error) {
|
||||
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.Internal, "identity provider management requires embedded IdP")
|
||||
}
|
||||
|
||||
conn, err := embeddedManager.GetConnector(ctx, idpID)
|
||||
if err != nil {
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "identity provider not found")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get identity provider: %v", err)
|
||||
}
|
||||
|
||||
return connectorConfigToIdentityProvider(conn, accountID), nil
|
||||
}
|
||||
|
||||
// CreateIdentityProvider creates a new identity provider
|
||||
func (am *DefaultAccountManager) CreateIdentityProvider(ctx context.Context, accountID, userID string, idpConfig *types.IdentityProvider) (*types.IdentityProvider, error) {
|
||||
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := idpConfig.Validate(); err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
|
||||
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.Internal, "identity provider management requires embedded IdP")
|
||||
}
|
||||
|
||||
// Generate ID if not provided
|
||||
if idpConfig.ID == "" {
|
||||
idpConfig.ID = generateIdentityProviderID(idpConfig.Type)
|
||||
}
|
||||
idpConfig.AccountID = accountID
|
||||
|
||||
connCfg := identityProviderToConnectorConfig(idpConfig)
|
||||
|
||||
_, err = embeddedManager.CreateConnector(ctx, connCfg)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to create identity provider: %v", err)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, idpConfig.ID, accountID, activity.IdentityProviderCreated, idpConfig.EventMeta())
|
||||
|
||||
return idpConfig, nil
|
||||
}
|
||||
|
||||
// UpdateIdentityProvider updates an existing identity provider
|
||||
func (am *DefaultAccountManager) UpdateIdentityProvider(ctx context.Context, accountID, idpID, userID string, idpConfig *types.IdentityProvider) (*types.IdentityProvider, error) {
|
||||
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := idpConfig.Validate(); err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
|
||||
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.Internal, "identity provider management requires embedded IdP")
|
||||
}
|
||||
|
||||
idpConfig.ID = idpID
|
||||
idpConfig.AccountID = accountID
|
||||
|
||||
connCfg := identityProviderToConnectorConfig(idpConfig)
|
||||
|
||||
if err := embeddedManager.UpdateConnector(ctx, connCfg); err != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to update identity provider: %v", err)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, idpConfig.ID, accountID, activity.IdentityProviderUpdated, idpConfig.EventMeta())
|
||||
|
||||
return idpConfig, nil
|
||||
}
|
||||
|
||||
// DeleteIdentityProvider deletes an identity provider
|
||||
func (am *DefaultAccountManager) DeleteIdentityProvider(ctx context.Context, accountID, idpID, userID string) error {
|
||||
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
|
||||
if !ok {
|
||||
return status.Errorf(status.Internal, "identity provider management requires embedded IdP")
|
||||
}
|
||||
|
||||
// Get the IDP info before deleting for the activity event
|
||||
conn, err := embeddedManager.GetConnector(ctx, idpID)
|
||||
if err != nil {
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return status.Errorf(status.NotFound, "identity provider not found")
|
||||
}
|
||||
return status.Errorf(status.Internal, "failed to get identity provider: %v", err)
|
||||
}
|
||||
idpConfig := connectorConfigToIdentityProvider(conn, accountID)
|
||||
|
||||
if err := embeddedManager.DeleteConnector(ctx, idpID); err != nil {
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return status.Errorf(status.NotFound, "identity provider not found")
|
||||
}
|
||||
return status.Errorf(status.Internal, "failed to delete identity provider: %v", err)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, idpID, accountID, activity.IdentityProviderDeleted, idpConfig.EventMeta())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// connectorConfigToIdentityProvider converts a dex.ConnectorConfig to types.IdentityProvider
|
||||
func connectorConfigToIdentityProvider(conn *dex.ConnectorConfig, accountID string) *types.IdentityProvider {
|
||||
return &types.IdentityProvider{
|
||||
ID: conn.ID,
|
||||
AccountID: accountID,
|
||||
Type: types.IdentityProviderType(conn.Type),
|
||||
Name: conn.Name,
|
||||
Issuer: conn.Issuer,
|
||||
ClientID: conn.ClientID,
|
||||
ClientSecret: conn.ClientSecret,
|
||||
}
|
||||
}
|
||||
|
||||
// identityProviderToConnectorConfig converts a types.IdentityProvider to dex.ConnectorConfig
|
||||
func identityProviderToConnectorConfig(idpConfig *types.IdentityProvider) *dex.ConnectorConfig {
|
||||
return &dex.ConnectorConfig{
|
||||
ID: idpConfig.ID,
|
||||
Name: idpConfig.Name,
|
||||
Type: string(idpConfig.Type),
|
||||
Issuer: idpConfig.Issuer,
|
||||
ClientID: idpConfig.ClientID,
|
||||
ClientSecret: idpConfig.ClientSecret,
|
||||
}
|
||||
}
|
||||
|
||||
// generateIdentityProviderID generates a unique ID for an identity provider.
|
||||
// For specific provider types (okta, zitadel, entra, google, pocketid, microsoft),
|
||||
// the ID is prefixed with the type name. Generic OIDC providers get no prefix.
|
||||
func generateIdentityProviderID(idpType types.IdentityProviderType) string {
|
||||
id := xid.New().String()
|
||||
|
||||
switch idpType {
|
||||
case types.IdentityProviderTypeOkta:
|
||||
return "okta-" + id
|
||||
case types.IdentityProviderTypeZitadel:
|
||||
return "zitadel-" + id
|
||||
case types.IdentityProviderTypeEntra:
|
||||
return "entra-" + id
|
||||
case types.IdentityProviderTypeGoogle:
|
||||
return "google-" + id
|
||||
case types.IdentityProviderTypePocketID:
|
||||
return "pocketid-" + id
|
||||
case types.IdentityProviderTypeMicrosoft:
|
||||
return "microsoft-" + id
|
||||
case types.IdentityProviderTypeAuthentik:
|
||||
return "authentik-" + id
|
||||
case types.IdentityProviderTypeKeycloak:
|
||||
return "keycloak-" + id
|
||||
default:
|
||||
// Generic OIDC - no prefix
|
||||
return id
|
||||
}
|
||||
}
|
||||
202
management/server/identity_provider_test.go
Normal file
202
management/server/identity_provider_test.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
)
|
||||
|
||||
func createManagerWithEmbeddedIdP(t testing.TB) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
dataDir := t.TempDir()
|
||||
testStore, cleanUp, err := store.NewTestStoreFromSQL(ctx, "", dataDir)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
// Create embedded IdP manager
|
||||
embeddedConfig := &idp.EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
Storage: idp.EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: idp.EmbeddedStorageTypeConfig{
|
||||
File: filepath.Join(dataDir, "dex.db"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
idpManager, err := idp.NewEmbeddedIdPManager(ctx, embeddedConfig, nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
t.Cleanup(func() { _ = idpManager.Stop(ctx) })
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
settingsMockManager.EXPECT().
|
||||
UpdateExtraSettings(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(false, nil).
|
||||
AnyTimes()
|
||||
|
||||
permissionsManager := permissions.NewManager(testStore)
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, testStore)
|
||||
networkMapController := controller.NewController(ctx, testStore, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(testStore, peers.NewManager(testStore, permissionsManager)), &config.Config{})
|
||||
manager, err := BuildManager(ctx, &config.Config{}, testStore, networkMapController, idpManager, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return manager, updateManager, nil
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_CreateIdentityProvider_Validation(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := "testingUser"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
idp *types.IdentityProvider
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "Missing Name",
|
||||
idp: &types.IdentityProvider{
|
||||
Type: types.IdentityProviderTypeOIDC,
|
||||
Issuer: "https://issuer.example.com",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "name is required",
|
||||
},
|
||||
{
|
||||
name: "Missing Type",
|
||||
idp: &types.IdentityProvider{
|
||||
Name: "Test IDP",
|
||||
Issuer: "https://issuer.example.com",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "type is required",
|
||||
},
|
||||
{
|
||||
name: "Missing Issuer",
|
||||
idp: &types.IdentityProvider{
|
||||
Name: "Test IDP",
|
||||
Type: types.IdentityProviderTypeOIDC,
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "issuer is required",
|
||||
},
|
||||
{
|
||||
name: "Missing ClientID",
|
||||
idp: &types.IdentityProvider{
|
||||
Name: "Test IDP",
|
||||
Type: types.IdentityProviderTypeOIDC,
|
||||
Issuer: "https://issuer.example.com",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "client ID is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := manager.CreateIdentityProvider(context.Background(), account.Id, userID, tc.idp)
|
||||
if tc.expectError {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.errorMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetIdentityProviders(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := "testingUser"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should return empty list (stub implementation)
|
||||
providers, err := manager.GetIdentityProviders(context.Background(), account.Id, userID)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, providers)
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetIdentityProvider_NotFound(t *testing.T) {
|
||||
manager, _, err := createManagerWithEmbeddedIdP(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := "testingUser"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should return not found error when identity provider doesn't exist
|
||||
_, err = manager.GetIdentityProvider(context.Background(), account.Id, "any-id", userID)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateIdentityProvider_Validation(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
userID := "testingUser"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should fail validation before reaching "not implemented" error
|
||||
invalidIDP := &types.IdentityProvider{
|
||||
Name: "", // Empty name should fail validation
|
||||
}
|
||||
|
||||
_, err = manager.UpdateIdentityProvider(context.Background(), account.Id, "some-id", userID, invalidIDP)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "name is required")
|
||||
}
|
||||
511
management/server/idp/embedded.go
Normal file
511
management/server/idp/embedded.go
Normal file
@@ -0,0 +1,511 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const (
|
||||
staticClientDashboard = "netbird-dashboard"
|
||||
staticClientCLI = "netbird-cli"
|
||||
defaultCLIRedirectURL1 = "http://localhost:53000/"
|
||||
defaultCLIRedirectURL2 = "http://localhost:54000/"
|
||||
defaultScopes = "openid profile email offline_access"
|
||||
defaultUserIDClaim = "sub"
|
||||
)
|
||||
|
||||
// EmbeddedIdPConfig contains configuration for the embedded Dex OIDC identity provider
|
||||
type EmbeddedIdPConfig struct {
|
||||
// Enabled indicates whether the embedded IDP is enabled
|
||||
Enabled bool
|
||||
// Issuer is the OIDC issuer URL (e.g., "http://localhost:3002/oauth2")
|
||||
Issuer string
|
||||
// Storage configuration for the IdP database
|
||||
Storage EmbeddedStorageConfig
|
||||
// DashboardRedirectURIs are the OAuth2 redirect URIs for the dashboard client
|
||||
DashboardRedirectURIs []string
|
||||
// DashboardRedirectURIs are the OAuth2 redirect URIs for the dashboard client
|
||||
CLIRedirectURIs []string
|
||||
// Owner is the initial owner/admin user (optional, can be nil)
|
||||
Owner *OwnerConfig
|
||||
// SignKeyRefreshEnabled enables automatic key rotation for signing keys
|
||||
SignKeyRefreshEnabled bool
|
||||
}
|
||||
|
||||
// EmbeddedStorageConfig holds storage configuration for the embedded IdP.
|
||||
type EmbeddedStorageConfig struct {
|
||||
// Type is the storage type (currently only "sqlite3" is supported)
|
||||
Type string
|
||||
// Config contains type-specific configuration
|
||||
Config EmbeddedStorageTypeConfig
|
||||
}
|
||||
|
||||
// EmbeddedStorageTypeConfig contains type-specific storage configuration.
|
||||
type EmbeddedStorageTypeConfig struct {
|
||||
// File is the path to the SQLite database file (for sqlite3 type)
|
||||
File string
|
||||
}
|
||||
|
||||
// OwnerConfig represents the initial owner/admin user for the embedded IdP.
|
||||
type OwnerConfig struct {
|
||||
// Email is the user's email address (required)
|
||||
Email string
|
||||
// Hash is the bcrypt hash of the user's password (required)
|
||||
Hash string
|
||||
// Username is the display name for the user (optional, defaults to email)
|
||||
Username string
|
||||
}
|
||||
|
||||
// ToYAMLConfig converts EmbeddedIdPConfig to dex.YAMLConfig.
|
||||
func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
if c.Issuer == "" {
|
||||
return nil, fmt.Errorf("issuer is required")
|
||||
}
|
||||
if c.Storage.Type == "" {
|
||||
c.Storage.Type = "sqlite3"
|
||||
}
|
||||
if c.Storage.Type == "sqlite3" && c.Storage.Config.File == "" {
|
||||
return nil, fmt.Errorf("storage file is required for sqlite3")
|
||||
}
|
||||
|
||||
// Build CLI redirect URIs including the device callback (both relative and absolute)
|
||||
cliRedirectURIs := c.CLIRedirectURIs
|
||||
cliRedirectURIs = append(cliRedirectURIs, "/device/callback")
|
||||
cliRedirectURIs = append(cliRedirectURIs, c.Issuer+"/device/callback")
|
||||
|
||||
cfg := &dex.YAMLConfig{
|
||||
Issuer: c.Issuer,
|
||||
Storage: dex.Storage{
|
||||
Type: c.Storage.Type,
|
||||
Config: map[string]interface{}{
|
||||
"file": c.Storage.Config.File,
|
||||
},
|
||||
},
|
||||
Web: dex.Web{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedHeaders: []string{"Authorization", "Content-Type"},
|
||||
},
|
||||
OAuth2: dex.OAuth2{
|
||||
SkipApprovalScreen: true,
|
||||
},
|
||||
Frontend: dex.Frontend{
|
||||
Issuer: "NetBird",
|
||||
Theme: "light",
|
||||
},
|
||||
EnablePasswordDB: true,
|
||||
StaticClients: []storage.Client{
|
||||
{
|
||||
ID: staticClientDashboard,
|
||||
Name: "NetBird Dashboard",
|
||||
Public: true,
|
||||
RedirectURIs: c.DashboardRedirectURIs,
|
||||
},
|
||||
{
|
||||
ID: staticClientCLI,
|
||||
Name: "NetBird CLI",
|
||||
Public: true,
|
||||
RedirectURIs: cliRedirectURIs,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add owner user if provided
|
||||
if c.Owner != nil && c.Owner.Email != "" && c.Owner.Hash != "" {
|
||||
username := c.Owner.Username
|
||||
if username == "" {
|
||||
username = c.Owner.Email
|
||||
}
|
||||
cfg.StaticPasswords = []dex.Password{
|
||||
{
|
||||
Email: c.Owner.Email,
|
||||
Hash: []byte(c.Owner.Hash),
|
||||
Username: username,
|
||||
UserID: uuid.New().String(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// Compile-time check that EmbeddedIdPManager implements Manager interface
|
||||
var _ Manager = (*EmbeddedIdPManager)(nil)
|
||||
|
||||
// Compile-time check that EmbeddedIdPManager implements OAuthConfigProvider interface
|
||||
var _ OAuthConfigProvider = (*EmbeddedIdPManager)(nil)
|
||||
|
||||
// OAuthConfigProvider defines the interface for OAuth configuration needed by auth flows.
|
||||
type OAuthConfigProvider interface {
|
||||
GetIssuer() string
|
||||
GetKeysLocation() string
|
||||
GetClientIDs() []string
|
||||
GetUserIDClaim() string
|
||||
GetTokenEndpoint() string
|
||||
GetDeviceAuthEndpoint() string
|
||||
GetAuthorizationEndpoint() string
|
||||
GetDefaultScopes() string
|
||||
GetCLIClientID() string
|
||||
GetCLIRedirectURLs() []string
|
||||
}
|
||||
|
||||
// EmbeddedIdPManager implements the Manager interface using the embedded Dex IdP.
|
||||
type EmbeddedIdPManager struct {
|
||||
provider *dex.Provider
|
||||
appMetrics telemetry.AppMetrics
|
||||
config EmbeddedIdPConfig
|
||||
}
|
||||
|
||||
// NewEmbeddedIdPManager creates a new instance of EmbeddedIdPManager from a configuration.
|
||||
// It instantiates the underlying Dex provider internally.
|
||||
// Note: Storage defaults are applied in config loading (applyEmbeddedIdPConfig) based on Datadir.
|
||||
func NewEmbeddedIdPManager(ctx context.Context, config *EmbeddedIdPConfig, appMetrics telemetry.AppMetrics) (*EmbeddedIdPManager, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("embedded IdP config is required")
|
||||
}
|
||||
|
||||
// Apply defaults for CLI redirect URIs
|
||||
if len(config.CLIRedirectURIs) == 0 {
|
||||
config.CLIRedirectURIs = []string{defaultCLIRedirectURL1, defaultCLIRedirectURL2}
|
||||
}
|
||||
|
||||
// there are some properties create when creating YAML config (e.g., auth clients)
|
||||
yamlConfig, err := config.ToYAMLConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
provider, err := dex.NewProviderFromYAML(ctx, yamlConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create embedded IdP provider: %w", err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("embedded Dex IDP initialized with issuer: %s", yamlConfig.Issuer)
|
||||
|
||||
return &EmbeddedIdPManager{
|
||||
provider: provider,
|
||||
appMetrics: appMetrics,
|
||||
config: *config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Handler returns the HTTP handler for serving OIDC requests.
|
||||
func (m *EmbeddedIdPManager) Handler() http.Handler {
|
||||
return m.provider.Handler()
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the embedded IdP provider.
|
||||
func (m *EmbeddedIdPManager) Stop(ctx context.Context) error {
|
||||
return m.provider.Stop(ctx)
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (m *EmbeddedIdPManager) UpdateUserAppMetadata(ctx context.Context, userID string, appMetadata AppMetadata) error {
|
||||
// TODO: implement
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from the embedded IdP via user ID.
|
||||
func (m *EmbeddedIdPManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
user, err := m.provider.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
if m.appMetrics != nil {
|
||||
m.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get user by ID: %w", err)
|
||||
}
|
||||
|
||||
return &UserData{
|
||||
Email: user.Email,
|
||||
Name: user.Username,
|
||||
ID: user.UserID,
|
||||
AppMetadata: appMetadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given account.
|
||||
// Note: Embedded dex doesn't store account metadata, so this returns all users.
|
||||
func (m *EmbeddedIdPManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
users, err := m.provider.ListUsers(ctx)
|
||||
if err != nil {
|
||||
if m.appMetrics != nil {
|
||||
m.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, fmt.Errorf("failed to list users: %w", err)
|
||||
}
|
||||
|
||||
result := make([]*UserData, 0, len(users))
|
||||
for _, user := range users {
|
||||
result = append(result, &UserData{
|
||||
Email: user.Email,
|
||||
Name: user.Username,
|
||||
ID: user.UserID,
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// Note: Embedded dex doesn't store account metadata, so all users are indexed under UnsetAccountID.
|
||||
func (m *EmbeddedIdPManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
if m.appMetrics != nil {
|
||||
m.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
users, err := m.provider.ListUsers(ctx)
|
||||
if err != nil {
|
||||
if m.appMetrics != nil {
|
||||
m.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, fmt.Errorf("failed to list users: %w", err)
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, user := range users {
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], &UserData{
|
||||
Email: user.Email,
|
||||
Name: user.Username,
|
||||
ID: user.UserID,
|
||||
})
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in the embedded IdP.
|
||||
func (m *EmbeddedIdPManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
if m.appMetrics != nil {
|
||||
m.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
|
||||
// Check if user already exists
|
||||
_, err := m.provider.GetUser(ctx, email)
|
||||
if err == nil {
|
||||
return nil, fmt.Errorf("user with email %s already exists", email)
|
||||
}
|
||||
if !errors.Is(err, storage.ErrNotFound) {
|
||||
if m.appMetrics != nil {
|
||||
m.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, fmt.Errorf("failed to check existing user: %w", err)
|
||||
}
|
||||
|
||||
// Generate a random password for the new user
|
||||
password := GeneratePassword(16, 2, 2, 2)
|
||||
|
||||
// Create the user via provider (handles hashing and ID generation)
|
||||
// The provider returns an encoded user ID in Dex's format (base64 protobuf with connector ID)
|
||||
userID, err := m.provider.CreateUser(ctx, email, name, password)
|
||||
if err != nil {
|
||||
if m.appMetrics != nil {
|
||||
m.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, fmt.Errorf("failed to create user in embedded IdP: %w", err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("created user %s in embedded IdP", email)
|
||||
|
||||
return &UserData{
|
||||
Email: email,
|
||||
Name: name,
|
||||
ID: userID,
|
||||
Password: password,
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
WTInvitedBy: invitedByEmail,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
func (m *EmbeddedIdPManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
user, err := m.provider.GetUser(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return nil, nil // Return empty slice for not found
|
||||
}
|
||||
if m.appMetrics != nil {
|
||||
m.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get user by email: %w", err)
|
||||
}
|
||||
|
||||
return []*UserData{
|
||||
{
|
||||
Email: user.Email,
|
||||
Name: user.Username,
|
||||
ID: user.UserID,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateUserWithPassword creates a new user in the embedded IdP with a provided password.
|
||||
// Unlike CreateUser which auto-generates a password, this method uses the provided password.
|
||||
// This is useful for instance setup where the user provides their own password.
|
||||
func (m *EmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email, password, name string) (*UserData, error) {
|
||||
if m.appMetrics != nil {
|
||||
m.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
|
||||
// Check if user already exists
|
||||
_, err := m.provider.GetUser(ctx, email)
|
||||
if err == nil {
|
||||
return nil, fmt.Errorf("user with email %s already exists", email)
|
||||
}
|
||||
if !errors.Is(err, storage.ErrNotFound) {
|
||||
if m.appMetrics != nil {
|
||||
m.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, fmt.Errorf("failed to check existing user: %w", err)
|
||||
}
|
||||
|
||||
// Create the user via provider with the provided password
|
||||
userID, err := m.provider.CreateUser(ctx, email, name, password)
|
||||
if err != nil {
|
||||
if m.appMetrics != nil {
|
||||
m.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, fmt.Errorf("failed to create user in embedded IdP: %w", err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("created user %s in embedded IdP with provided password", email)
|
||||
|
||||
return &UserData{
|
||||
Email: email,
|
||||
Name: name,
|
||||
ID: userID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// InviteUserByID resends an invitation to a user.
|
||||
func (m *EmbeddedIdPManager) InviteUserByID(ctx context.Context, userID string) error {
|
||||
// TODO: implement
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user from the embedded IdP by user ID.
|
||||
func (m *EmbeddedIdPManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
if m.appMetrics != nil {
|
||||
m.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
// Get user by ID to retrieve email (provider.DeleteUser requires email)
|
||||
user, err := m.provider.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
if m.appMetrics != nil {
|
||||
m.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return fmt.Errorf("failed to get user for deletion: %w", err)
|
||||
}
|
||||
|
||||
err = m.provider.DeleteUser(ctx, user.Email)
|
||||
if err != nil {
|
||||
if m.appMetrics != nil {
|
||||
m.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return fmt.Errorf("failed to delete user from embedded IdP: %w", err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("deleted user %s from embedded IdP", user.Email)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateConnector creates a new identity provider connector in Dex.
|
||||
// Returns the created connector config with the redirect URL populated.
|
||||
func (m *EmbeddedIdPManager) CreateConnector(ctx context.Context, cfg *dex.ConnectorConfig) (*dex.ConnectorConfig, error) {
|
||||
return m.provider.CreateConnector(ctx, cfg)
|
||||
}
|
||||
|
||||
// GetConnector retrieves an identity provider connector by ID.
|
||||
func (m *EmbeddedIdPManager) GetConnector(ctx context.Context, id string) (*dex.ConnectorConfig, error) {
|
||||
return m.provider.GetConnector(ctx, id)
|
||||
}
|
||||
|
||||
// ListConnectors returns all identity provider connectors.
|
||||
func (m *EmbeddedIdPManager) ListConnectors(ctx context.Context) ([]*dex.ConnectorConfig, error) {
|
||||
return m.provider.ListConnectors(ctx)
|
||||
}
|
||||
|
||||
// UpdateConnector updates an existing identity provider connector.
|
||||
func (m *EmbeddedIdPManager) UpdateConnector(ctx context.Context, cfg *dex.ConnectorConfig) error {
|
||||
// Preserve existing secret if not provided in update
|
||||
if cfg.ClientSecret == "" {
|
||||
existing, err := m.provider.GetConnector(ctx, cfg.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get existing connector: %w", err)
|
||||
}
|
||||
cfg.ClientSecret = existing.ClientSecret
|
||||
}
|
||||
return m.provider.UpdateConnector(ctx, cfg)
|
||||
}
|
||||
|
||||
// DeleteConnector removes an identity provider connector.
|
||||
func (m *EmbeddedIdPManager) DeleteConnector(ctx context.Context, id string) error {
|
||||
return m.provider.DeleteConnector(ctx, id)
|
||||
}
|
||||
|
||||
// GetIssuer returns the OIDC issuer URL.
|
||||
func (m *EmbeddedIdPManager) GetIssuer() string {
|
||||
return m.provider.GetIssuer()
|
||||
}
|
||||
|
||||
// GetTokenEndpoint returns the OAuth2 token endpoint URL.
|
||||
func (m *EmbeddedIdPManager) GetTokenEndpoint() string {
|
||||
return m.provider.GetTokenEndpoint()
|
||||
}
|
||||
|
||||
// GetDeviceAuthEndpoint returns the OAuth2 device authorization endpoint URL.
|
||||
func (m *EmbeddedIdPManager) GetDeviceAuthEndpoint() string {
|
||||
return m.provider.GetDeviceAuthEndpoint()
|
||||
}
|
||||
|
||||
// GetAuthorizationEndpoint returns the OAuth2 authorization endpoint URL.
|
||||
func (m *EmbeddedIdPManager) GetAuthorizationEndpoint() string {
|
||||
return m.provider.GetAuthorizationEndpoint()
|
||||
}
|
||||
|
||||
// GetDefaultScopes returns the default OAuth2 scopes for authentication.
|
||||
func (m *EmbeddedIdPManager) GetDefaultScopes() string {
|
||||
return defaultScopes
|
||||
}
|
||||
|
||||
// GetCLIClientID returns the client ID for CLI authentication.
|
||||
func (m *EmbeddedIdPManager) GetCLIClientID() string {
|
||||
return staticClientCLI
|
||||
}
|
||||
|
||||
// GetCLIRedirectURLs returns the redirect URLs configured for the CLI client.
|
||||
func (m *EmbeddedIdPManager) GetCLIRedirectURLs() []string {
|
||||
if len(m.config.CLIRedirectURIs) == 0 {
|
||||
return []string{defaultCLIRedirectURL1, defaultCLIRedirectURL2}
|
||||
}
|
||||
return m.config.CLIRedirectURIs
|
||||
}
|
||||
|
||||
// GetKeysLocation returns the JWKS endpoint URL for token validation.
|
||||
func (m *EmbeddedIdPManager) GetKeysLocation() string {
|
||||
return m.provider.GetKeysLocation()
|
||||
}
|
||||
|
||||
// GetClientIDs returns the OAuth2 client IDs configured for this provider.
|
||||
func (m *EmbeddedIdPManager) GetClientIDs() []string {
|
||||
return []string{staticClientDashboard, staticClientCLI}
|
||||
}
|
||||
|
||||
// GetUserIDClaim returns the JWT claim name used for user identification.
|
||||
func (m *EmbeddedIdPManager) GetUserIDClaim() string {
|
||||
return defaultUserIDClaim
|
||||
}
|
||||
249
management/server/idp/embedded_test.go
Normal file
249
management/server/idp/embedded_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
)
|
||||
|
||||
func TestEmbeddedIdPManager_CreateUser_EndToEnd(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a temporary directory for the test
|
||||
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create the embedded IDP config
|
||||
config := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: filepath.Join(tmpDir, "dex.db"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create the embedded IDP manager
|
||||
manager, err := NewEmbeddedIdPManager(ctx, config, nil)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = manager.Stop(ctx) }()
|
||||
|
||||
// Test data
|
||||
email := "newuser@example.com"
|
||||
name := "New User"
|
||||
accountID := "test-account-id"
|
||||
invitedByEmail := "admin@example.com"
|
||||
|
||||
// Create the user
|
||||
userData, err := manager.CreateUser(ctx, email, name, accountID, invitedByEmail)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, userData)
|
||||
|
||||
t.Logf("Created user: ID=%s, Email=%s, Name=%s, Password=%s",
|
||||
userData.ID, userData.Email, userData.Name, userData.Password)
|
||||
|
||||
// Verify user data
|
||||
assert.Equal(t, email, userData.Email)
|
||||
assert.Equal(t, name, userData.Name)
|
||||
assert.NotEmpty(t, userData.ID)
|
||||
assert.NotEmpty(t, userData.Password)
|
||||
assert.Equal(t, accountID, userData.AppMetadata.WTAccountID)
|
||||
assert.Equal(t, invitedByEmail, userData.AppMetadata.WTInvitedBy)
|
||||
|
||||
// Verify the user ID is in Dex's encoded format (base64 protobuf)
|
||||
rawUserID, connectorID, err := dex.DecodeDexUserID(userData.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, rawUserID)
|
||||
assert.Equal(t, "local", connectorID)
|
||||
|
||||
t.Logf("Decoded user ID: rawUserID=%s, connectorID=%s", rawUserID, connectorID)
|
||||
|
||||
// Verify we can look up the user by the encoded ID
|
||||
lookedUpUser, err := manager.GetUserDataByID(ctx, userData.ID, AppMetadata{WTAccountID: accountID})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, email, lookedUpUser.Email)
|
||||
|
||||
// Verify we can look up by email
|
||||
users, err := manager.GetUserByEmail(ctx, email)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, users, 1)
|
||||
assert.Equal(t, email, users[0].Email)
|
||||
|
||||
// Verify creating duplicate user fails
|
||||
_, err = manager.CreateUser(ctx, email, name, accountID, invitedByEmail)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already exists")
|
||||
}
|
||||
|
||||
func TestEmbeddedIdPManager_GetUserDataByID_WithEncodedID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
config := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: filepath.Join(tmpDir, "dex.db"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := NewEmbeddedIdPManager(ctx, config, nil)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = manager.Stop(ctx) }()
|
||||
|
||||
// Create a user first
|
||||
userData, err := manager.CreateUser(ctx, "test@example.com", "Test User", "account1", "admin@example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
// The returned ID should be encoded
|
||||
encodedID := userData.ID
|
||||
|
||||
// Lookup should work with the encoded ID
|
||||
lookedUp, err := manager.GetUserDataByID(ctx, encodedID, AppMetadata{WTAccountID: "account1"})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test@example.com", lookedUp.Email)
|
||||
assert.Equal(t, "Test User", lookedUp.Name)
|
||||
}
|
||||
|
||||
func TestEmbeddedIdPManager_DeleteUser(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
config := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: filepath.Join(tmpDir, "dex.db"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := NewEmbeddedIdPManager(ctx, config, nil)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = manager.Stop(ctx) }()
|
||||
|
||||
// Create a user
|
||||
userData, err := manager.CreateUser(ctx, "delete-me@example.com", "Delete Me", "account1", "admin@example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete the user using the encoded ID
|
||||
err = manager.DeleteUser(ctx, userData.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify user no longer exists
|
||||
_, err = manager.GetUserDataByID(ctx, userData.ID, AppMetadata{})
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestEmbeddedIdPManager_GetAccount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
config := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: filepath.Join(tmpDir, "dex.db"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := NewEmbeddedIdPManager(ctx, config, nil)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = manager.Stop(ctx) }()
|
||||
|
||||
// Create multiple users
|
||||
_, err = manager.CreateUser(ctx, "user1@example.com", "User 1", "account1", "admin@example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = manager.CreateUser(ctx, "user2@example.com", "User 2", "account1", "admin@example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get all users for the account
|
||||
users, err := manager.GetAccount(ctx, "account1")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, users, 2)
|
||||
|
||||
emails := make([]string, len(users))
|
||||
for i, u := range users {
|
||||
emails[i] = u.Email
|
||||
}
|
||||
assert.Contains(t, emails, "user1@example.com")
|
||||
assert.Contains(t, emails, "user2@example.com")
|
||||
}
|
||||
|
||||
func TestEmbeddedIdPManager_UserIDFormat_MatchesJWT(t *testing.T) {
|
||||
// This test verifies that the user ID returned by CreateUser
|
||||
// matches the format that Dex uses in JWT tokens (the 'sub' claim)
|
||||
ctx := context.Background()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
config := &EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
Storage: EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: EmbeddedStorageTypeConfig{
|
||||
File: filepath.Join(tmpDir, "dex.db"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := NewEmbeddedIdPManager(ctx, config, nil)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = manager.Stop(ctx) }()
|
||||
|
||||
// Create a user
|
||||
userData, err := manager.CreateUser(ctx, "jwt-test@example.com", "JWT Test", "account1", "admin@example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
// The ID should be in the format: base64(protobuf{user_id, connector_id})
|
||||
// Example: CiQ3YWFkOGMwNS0zMjg3LTQ3M2YtYjQyYS0zNjU1MDRiZjI1ZTcSBWxvY2Fs
|
||||
|
||||
// Verify it can be decoded
|
||||
rawUserID, connectorID, err := dex.DecodeDexUserID(userData.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Raw user ID should be a UUID
|
||||
assert.Regexp(t, `^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`, rawUserID)
|
||||
|
||||
// Connector ID should be "local" for password-based auth
|
||||
assert.Equal(t, "local", connectorID)
|
||||
|
||||
// Re-encoding should produce the same result
|
||||
reEncoded := dex.EncodeDexUserID(rawUserID, connectorID)
|
||||
assert.Equal(t, userData.ID, reEncoded)
|
||||
|
||||
t.Logf("User ID format verified:")
|
||||
t.Logf(" Encoded ID: %s", userData.ID)
|
||||
t.Logf(" Raw UUID: %s", rawUserID)
|
||||
t.Logf(" Connector: %s", connectorID)
|
||||
}
|
||||
@@ -72,6 +72,7 @@ type UserData struct {
|
||||
Name string `json:"name"`
|
||||
ID string `json:"user_id"`
|
||||
AppMetadata AppMetadata `json:"app_metadata"`
|
||||
Password string `json:"-"` // Plain password, only set on user creation, excluded from JSON
|
||||
}
|
||||
|
||||
func (u *UserData) MarshalBinary() (data []byte, err error) {
|
||||
|
||||
136
management/server/instance/manager.go
Normal file
136
management/server/instance/manager.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package instance
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// Manager handles instance-level operations like initial setup.
|
||||
type Manager interface {
|
||||
// IsSetupRequired checks if instance setup is required.
|
||||
// Returns true if embedded IDP is enabled and no accounts exist.
|
||||
IsSetupRequired(ctx context.Context) (bool, error)
|
||||
|
||||
// CreateOwnerUser creates the initial owner user in the embedded IDP.
|
||||
// This should only be called when IsSetupRequired returns true.
|
||||
CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error)
|
||||
}
|
||||
|
||||
// DefaultManager is the default implementation of Manager.
|
||||
type DefaultManager struct {
|
||||
store store.Store
|
||||
embeddedIdpManager *idp.EmbeddedIdPManager
|
||||
|
||||
setupRequired bool
|
||||
setupMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewManager creates a new instance manager.
|
||||
// If idpManager is not an EmbeddedIdPManager, setup-related operations will return appropriate defaults.
|
||||
func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager) (Manager, error) {
|
||||
embeddedIdp, _ := idpManager.(*idp.EmbeddedIdPManager)
|
||||
|
||||
m := &DefaultManager{
|
||||
store: store,
|
||||
embeddedIdpManager: embeddedIdp,
|
||||
setupRequired: false,
|
||||
}
|
||||
|
||||
if embeddedIdp != nil {
|
||||
err := m.loadSetupRequired(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *DefaultManager) loadSetupRequired(ctx context.Context) error {
|
||||
users, err := m.embeddedIdpManager.GetAllAccounts(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.setupMu.Lock()
|
||||
m.setupRequired = len(users) == 0
|
||||
m.setupMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsSetupRequired checks if instance setup is required.
|
||||
// Setup is required when:
|
||||
// 1. Embedded IDP is enabled
|
||||
// 2. No accounts exist in the store
|
||||
func (m *DefaultManager) IsSetupRequired(_ context.Context) (bool, error) {
|
||||
if m.embeddedIdpManager == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
m.setupMu.RLock()
|
||||
defer m.setupMu.RUnlock()
|
||||
|
||||
return m.setupRequired, nil
|
||||
}
|
||||
|
||||
// CreateOwnerUser creates the initial owner user in the embedded IDP.
|
||||
func (m *DefaultManager) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) {
|
||||
|
||||
if err := m.validateSetupInfo(email, password, name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if m.embeddedIdpManager == nil {
|
||||
return nil, errors.New("embedded IDP is not enabled")
|
||||
}
|
||||
|
||||
m.setupMu.RLock()
|
||||
setupRequired := m.setupRequired
|
||||
m.setupMu.RUnlock()
|
||||
|
||||
if !setupRequired {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "setup already completed")
|
||||
}
|
||||
|
||||
userData, err := m.embeddedIdpManager.CreateUserWithPassword(ctx, email, password, name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user in embedded IdP: %w", err)
|
||||
}
|
||||
|
||||
m.setupMu.Lock()
|
||||
m.setupRequired = false
|
||||
m.setupMu.Unlock()
|
||||
|
||||
log.WithContext(ctx).Infof("created owner user %s in embedded IdP", email)
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
func (m *DefaultManager) validateSetupInfo(email, password, name string) error {
|
||||
if email == "" {
|
||||
return status.Errorf(status.InvalidArgument, "email is required")
|
||||
}
|
||||
if _, err := mail.ParseAddress(email); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "invalid email format")
|
||||
}
|
||||
if name == "" {
|
||||
return status.Errorf(status.InvalidArgument, "name is required")
|
||||
}
|
||||
if password == "" {
|
||||
return status.Errorf(status.InvalidArgument, "password is required")
|
||||
}
|
||||
if len(password) < 8 {
|
||||
return status.Errorf(status.InvalidArgument, "password must be at least 8 characters")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
268
management/server/instance/manager_test.go
Normal file
268
management/server/instance/manager_test.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package instance
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
)
|
||||
|
||||
// mockStore implements a minimal store.Store for testing
|
||||
type mockStore struct {
|
||||
accountsCount int64
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockStore) GetAccountsCounter(ctx context.Context) (int64, error) {
|
||||
if m.err != nil {
|
||||
return 0, m.err
|
||||
}
|
||||
return m.accountsCount, nil
|
||||
}
|
||||
|
||||
// mockEmbeddedIdPManager wraps the real EmbeddedIdPManager for testing
|
||||
type mockEmbeddedIdPManager struct {
|
||||
createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error)
|
||||
}
|
||||
|
||||
func (m *mockEmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error) {
|
||||
if m.createUserFunc != nil {
|
||||
return m.createUserFunc(ctx, email, password, name)
|
||||
}
|
||||
return &idp.UserData{
|
||||
ID: "test-user-id",
|
||||
Email: email,
|
||||
Name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// testManager is a test implementation that accepts our mock types
|
||||
type testManager struct {
|
||||
store *mockStore
|
||||
embeddedIdpManager *mockEmbeddedIdPManager
|
||||
}
|
||||
|
||||
func (m *testManager) IsSetupRequired(ctx context.Context) (bool, error) {
|
||||
if m.embeddedIdpManager == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
count, err := m.store.GetAccountsCounter(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count == 0, nil
|
||||
}
|
||||
|
||||
func (m *testManager) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) {
|
||||
if m.embeddedIdpManager == nil {
|
||||
return nil, errors.New("embedded IDP is not enabled")
|
||||
}
|
||||
|
||||
return m.embeddedIdpManager.CreateUserWithPassword(ctx, email, password, name)
|
||||
}
|
||||
|
||||
func TestIsSetupRequired_EmbeddedIdPDisabled(t *testing.T) {
|
||||
manager := &testManager{
|
||||
store: &mockStore{accountsCount: 0},
|
||||
embeddedIdpManager: nil, // No embedded IDP
|
||||
}
|
||||
|
||||
required, err := manager.IsSetupRequired(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.False(t, required, "setup should not be required when embedded IDP is disabled")
|
||||
}
|
||||
|
||||
func TestIsSetupRequired_NoAccounts(t *testing.T) {
|
||||
manager := &testManager{
|
||||
store: &mockStore{accountsCount: 0},
|
||||
embeddedIdpManager: &mockEmbeddedIdPManager{},
|
||||
}
|
||||
|
||||
required, err := manager.IsSetupRequired(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.True(t, required, "setup should be required when no accounts exist")
|
||||
}
|
||||
|
||||
func TestIsSetupRequired_AccountsExist(t *testing.T) {
|
||||
manager := &testManager{
|
||||
store: &mockStore{accountsCount: 1},
|
||||
embeddedIdpManager: &mockEmbeddedIdPManager{},
|
||||
}
|
||||
|
||||
required, err := manager.IsSetupRequired(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.False(t, required, "setup should not be required when accounts exist")
|
||||
}
|
||||
|
||||
func TestIsSetupRequired_MultipleAccounts(t *testing.T) {
|
||||
manager := &testManager{
|
||||
store: &mockStore{accountsCount: 5},
|
||||
embeddedIdpManager: &mockEmbeddedIdPManager{},
|
||||
}
|
||||
|
||||
required, err := manager.IsSetupRequired(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.False(t, required, "setup should not be required when multiple accounts exist")
|
||||
}
|
||||
|
||||
func TestIsSetupRequired_StoreError(t *testing.T) {
|
||||
manager := &testManager{
|
||||
store: &mockStore{err: errors.New("database error")},
|
||||
embeddedIdpManager: &mockEmbeddedIdPManager{},
|
||||
}
|
||||
|
||||
_, err := manager.IsSetupRequired(context.Background())
|
||||
assert.Error(t, err, "should return error when store fails")
|
||||
}
|
||||
|
||||
func TestCreateOwnerUser_Success(t *testing.T) {
|
||||
expectedEmail := "admin@example.com"
|
||||
expectedName := "Admin User"
|
||||
expectedPassword := "securepassword123"
|
||||
|
||||
manager := &testManager{
|
||||
store: &mockStore{accountsCount: 0},
|
||||
embeddedIdpManager: &mockEmbeddedIdPManager{
|
||||
createUserFunc: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
|
||||
assert.Equal(t, expectedEmail, email)
|
||||
assert.Equal(t, expectedPassword, password)
|
||||
assert.Equal(t, expectedName, name)
|
||||
return &idp.UserData{
|
||||
ID: "created-user-id",
|
||||
Email: email,
|
||||
Name: name,
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
userData, err := manager.CreateOwnerUser(context.Background(), expectedEmail, expectedPassword, expectedName)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "created-user-id", userData.ID)
|
||||
assert.Equal(t, expectedEmail, userData.Email)
|
||||
assert.Equal(t, expectedName, userData.Name)
|
||||
}
|
||||
|
||||
func TestCreateOwnerUser_EmbeddedIdPDisabled(t *testing.T) {
|
||||
manager := &testManager{
|
||||
store: &mockStore{accountsCount: 0},
|
||||
embeddedIdpManager: nil,
|
||||
}
|
||||
|
||||
_, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
|
||||
assert.Error(t, err, "should return error when embedded IDP is disabled")
|
||||
assert.Contains(t, err.Error(), "embedded IDP is not enabled")
|
||||
}
|
||||
|
||||
func TestCreateOwnerUser_IdPError(t *testing.T) {
|
||||
manager := &testManager{
|
||||
store: &mockStore{accountsCount: 0},
|
||||
embeddedIdpManager: &mockEmbeddedIdPManager{
|
||||
createUserFunc: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
|
||||
return nil, errors.New("user already exists")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
|
||||
assert.Error(t, err, "should return error when IDP fails")
|
||||
}
|
||||
|
||||
func TestDefaultManager_ValidateSetupRequest(t *testing.T) {
|
||||
manager := &DefaultManager{
|
||||
setupRequired: true,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
password string
|
||||
userName string
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid request",
|
||||
email: "admin@example.com",
|
||||
password: "password123",
|
||||
userName: "Admin User",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty email",
|
||||
email: "",
|
||||
password: "password123",
|
||||
userName: "Admin User",
|
||||
expectError: true,
|
||||
errorMsg: "email is required",
|
||||
},
|
||||
{
|
||||
name: "invalid email format",
|
||||
email: "not-an-email",
|
||||
password: "password123",
|
||||
userName: "Admin User",
|
||||
expectError: true,
|
||||
errorMsg: "invalid email format",
|
||||
},
|
||||
{
|
||||
name: "empty name",
|
||||
email: "admin@example.com",
|
||||
password: "password123",
|
||||
userName: "",
|
||||
expectError: true,
|
||||
errorMsg: "name is required",
|
||||
},
|
||||
{
|
||||
name: "empty password",
|
||||
email: "admin@example.com",
|
||||
password: "",
|
||||
userName: "Admin User",
|
||||
expectError: true,
|
||||
errorMsg: "password is required",
|
||||
},
|
||||
{
|
||||
name: "password too short",
|
||||
email: "admin@example.com",
|
||||
password: "short",
|
||||
userName: "Admin User",
|
||||
expectError: true,
|
||||
errorMsg: "password must be at least 8 characters",
|
||||
},
|
||||
{
|
||||
name: "password exactly 8 characters",
|
||||
email: "admin@example.com",
|
||||
password: "12345678",
|
||||
userName: "Admin User",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := manager.validateSetupInfo(tt.email, tt.password, tt.userName)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultManager_CreateOwnerUser_SetupAlreadyCompleted(t *testing.T) {
|
||||
manager := &DefaultManager{
|
||||
setupRequired: false,
|
||||
embeddedIdpManager: &idp.EmbeddedIdPManager{},
|
||||
}
|
||||
|
||||
_, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "setup already completed")
|
||||
}
|
||||
@@ -381,7 +381,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
|
||||
return nil, nil, "", cleanup, err
|
||||
}
|
||||
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil)
|
||||
if err != nil {
|
||||
return nil, nil, "", cleanup, err
|
||||
}
|
||||
|
||||
@@ -242,6 +242,7 @@ func startServer(
|
||||
nil,
|
||||
server.MockIntegratedValidator{},
|
||||
networkMapController,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed creating management server: %v", err)
|
||||
|
||||
@@ -27,13 +27,13 @@ import (
|
||||
var _ account.Manager = (*MockAccountManager)(nil)
|
||||
|
||||
type MockAccountManager struct {
|
||||
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*types.Account, error)
|
||||
GetOrCreateAccountByUserFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.Account, error)
|
||||
GetAccountFunc func(ctx context.Context, accountID string) (*types.Account, error)
|
||||
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType types.SetupKeyType,
|
||||
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error)
|
||||
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
|
||||
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
|
||||
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
|
||||
GetAccountIDByUserIdFunc func(ctx context.Context, userAuth auth.UserAuth) (string, error)
|
||||
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
@@ -129,6 +129,12 @@ type MockAccountManager struct {
|
||||
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error
|
||||
|
||||
GetIdentityProviderFunc func(ctx context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error)
|
||||
GetIdentityProvidersFunc func(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error)
|
||||
CreateIdentityProviderFunc func(ctx context.Context, accountID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error)
|
||||
UpdateIdentityProviderFunc func(ctx context.Context, accountID, idpID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error)
|
||||
DeleteIdentityProviderFunc func(ctx context.Context, accountID, idpID, userID string) error
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
|
||||
@@ -237,10 +243,10 @@ func (am *MockAccountManager) DeletePeer(ctx context.Context, accountID, peerID,
|
||||
|
||||
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetOrCreateAccountByUser(
|
||||
ctx context.Context, userId, domain string,
|
||||
ctx context.Context, userAuth auth.UserAuth,
|
||||
) (*types.Account, error) {
|
||||
if am.GetOrCreateAccountByUserFunc != nil {
|
||||
return am.GetOrCreateAccountByUserFunc(ctx, userId, domain)
|
||||
return am.GetOrCreateAccountByUserFunc(ctx, userAuth)
|
||||
}
|
||||
return nil, status.Errorf(
|
||||
codes.Unimplemented,
|
||||
@@ -276,9 +282,9 @@ func (am *MockAccountManager) AccountExists(ctx context.Context, accountID strin
|
||||
}
|
||||
|
||||
// GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) {
|
||||
func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth auth.UserAuth) (string, error) {
|
||||
if am.GetAccountIDByUserIdFunc != nil {
|
||||
return am.GetAccountIDByUserIdFunc(ctx, userId, domain)
|
||||
return am.GetAccountIDByUserIdFunc(ctx, userAuth)
|
||||
}
|
||||
return "", status.Errorf(
|
||||
codes.Unimplemented,
|
||||
@@ -993,3 +999,43 @@ func (am *MockAccountManager) RecalculateNetworkMapCache(ctx context.Context, ac
|
||||
func (am *MockAccountManager) GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error) {
|
||||
return "something", nil
|
||||
}
|
||||
|
||||
// GetIdentityProvider mocks GetIdentityProvider of the AccountManager interface
|
||||
func (am *MockAccountManager) GetIdentityProvider(ctx context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error) {
|
||||
if am.GetIdentityProviderFunc != nil {
|
||||
return am.GetIdentityProviderFunc(ctx, accountID, idpID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetIdentityProvider is not implemented")
|
||||
}
|
||||
|
||||
// GetIdentityProviders mocks GetIdentityProviders of the AccountManager interface
|
||||
func (am *MockAccountManager) GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error) {
|
||||
if am.GetIdentityProvidersFunc != nil {
|
||||
return am.GetIdentityProvidersFunc(ctx, accountID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetIdentityProviders is not implemented")
|
||||
}
|
||||
|
||||
// CreateIdentityProvider mocks CreateIdentityProvider of the AccountManager interface
|
||||
func (am *MockAccountManager) CreateIdentityProvider(ctx context.Context, accountID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error) {
|
||||
if am.CreateIdentityProviderFunc != nil {
|
||||
return am.CreateIdentityProviderFunc(ctx, accountID, userID, idp)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateIdentityProvider is not implemented")
|
||||
}
|
||||
|
||||
// UpdateIdentityProvider mocks UpdateIdentityProvider of the AccountManager interface
|
||||
func (am *MockAccountManager) UpdateIdentityProvider(ctx context.Context, accountID, idpID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error) {
|
||||
if am.UpdateIdentityProviderFunc != nil {
|
||||
return am.UpdateIdentityProviderFunc(ctx, accountID, idpID, userID, idp)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method UpdateIdentityProvider is not implemented")
|
||||
}
|
||||
|
||||
// DeleteIdentityProvider mocks DeleteIdentityProvider of the AccountManager interface
|
||||
func (am *MockAccountManager) DeleteIdentityProvider(ctx context.Context, accountID, idpID, userID string) error {
|
||||
if am.DeleteIdentityProviderFunc != nil {
|
||||
return am.DeleteIdentityProviderFunc(ctx, accountID, idpID, userID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteIdentityProvider is not implemented")
|
||||
}
|
||||
|
||||
@@ -865,7 +865,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account,
|
||||
userID := testUserID
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain, "", "", false)
|
||||
|
||||
account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup
|
||||
|
||||
|
||||
@@ -502,7 +502,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
someUser := "some_user"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", "", "", false)
|
||||
account.Users[someUser] = &types.User{
|
||||
Id: someUser,
|
||||
Role: types.UserRoleUser,
|
||||
@@ -689,7 +689,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
someUser := "some_user"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", "", "", false)
|
||||
account.Users[someUser] = &types.User{
|
||||
Id: someUser,
|
||||
Role: testCase.role,
|
||||
@@ -759,7 +759,7 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou
|
||||
adminUser := "account_creator"
|
||||
regularUser := "regular_user"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", "", "", false)
|
||||
account.Users[regularUser] = &types.User{
|
||||
Id: regularUser,
|
||||
Role: types.UserRoleUser,
|
||||
@@ -2124,7 +2124,7 @@ func Test_DeletePeer(t *testing.T) {
|
||||
// account with an admin and a regular user
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "", "", "", false)
|
||||
account.Peers = map[string]*nbpeer.Peer{
|
||||
"peer1": {
|
||||
ID: "peer1",
|
||||
@@ -2307,12 +2307,12 @@ func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create account
|
||||
account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
|
||||
account := newAccountWithId(context.Background(), "test-account", "owner", "", "", "", false)
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create user pending approval
|
||||
pendingUser := types.NewRegularUser("pending-user")
|
||||
pendingUser := types.NewRegularUser("pending-user", "", "")
|
||||
pendingUser.AccountID = account.Id
|
||||
pendingUser.Blocked = true
|
||||
pendingUser.PendingApproval = true
|
||||
@@ -2344,12 +2344,12 @@ func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create account
|
||||
account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
|
||||
account := newAccountWithId(context.Background(), "test-account", "owner", "", "", "", false)
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create regular user (not pending approval)
|
||||
regularUser := types.NewRegularUser("regular-user")
|
||||
regularUser := types.NewRegularUser("regular-user", "", "")
|
||||
regularUser.AccountID = account.Id
|
||||
err = manager.Store.SaveUser(context.Background(), regularUser)
|
||||
require.NoError(t, err)
|
||||
@@ -2378,12 +2378,12 @@ func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create account
|
||||
account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
|
||||
account := newAccountWithId(context.Background(), "test-account", "owner", "", "", "", false)
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create user pending approval
|
||||
pendingUser := types.NewRegularUser("pending-user")
|
||||
pendingUser := types.NewRegularUser("pending-user", "", "")
|
||||
pendingUser.AccountID = account.Id
|
||||
pendingUser.Blocked = true
|
||||
pendingUser.PendingApproval = true
|
||||
@@ -2443,12 +2443,12 @@ func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create account
|
||||
account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
|
||||
account := newAccountWithId(context.Background(), "test-account", "owner", "", "", "", false)
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create regular user (not pending approval)
|
||||
regularUser := types.NewRegularUser("regular-user")
|
||||
regularUser := types.NewRegularUser("regular-user", "", "")
|
||||
regularUser.AccountID = account.Id
|
||||
err = manager.Store.SaveUser(context.Background(), regularUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -3,33 +3,35 @@ package modules
|
||||
type Module string
|
||||
|
||||
const (
|
||||
Networks Module = "networks"
|
||||
Peers Module = "peers"
|
||||
Groups Module = "groups"
|
||||
Settings Module = "settings"
|
||||
Accounts Module = "accounts"
|
||||
Dns Module = "dns"
|
||||
Nameservers Module = "nameservers"
|
||||
Events Module = "events"
|
||||
Policies Module = "policies"
|
||||
Routes Module = "routes"
|
||||
Users Module = "users"
|
||||
SetupKeys Module = "setup_keys"
|
||||
Pats Module = "pats"
|
||||
Networks Module = "networks"
|
||||
Peers Module = "peers"
|
||||
Groups Module = "groups"
|
||||
Settings Module = "settings"
|
||||
Accounts Module = "accounts"
|
||||
Dns Module = "dns"
|
||||
Nameservers Module = "nameservers"
|
||||
Events Module = "events"
|
||||
Policies Module = "policies"
|
||||
Routes Module = "routes"
|
||||
Users Module = "users"
|
||||
SetupKeys Module = "setup_keys"
|
||||
Pats Module = "pats"
|
||||
IdentityProviders Module = "identity_providers"
|
||||
)
|
||||
|
||||
var All = map[Module]struct{}{
|
||||
Networks: {},
|
||||
Peers: {},
|
||||
Groups: {},
|
||||
Settings: {},
|
||||
Accounts: {},
|
||||
Dns: {},
|
||||
Nameservers: {},
|
||||
Events: {},
|
||||
Policies: {},
|
||||
Routes: {},
|
||||
Users: {},
|
||||
SetupKeys: {},
|
||||
Pats: {},
|
||||
Networks: {},
|
||||
Peers: {},
|
||||
Groups: {},
|
||||
Settings: {},
|
||||
Accounts: {},
|
||||
Dns: {},
|
||||
Nameservers: {},
|
||||
Events: {},
|
||||
Policies: {},
|
||||
Routes: {},
|
||||
Users: {},
|
||||
SetupKeys: {},
|
||||
Pats: {},
|
||||
IdentityProviders: {},
|
||||
}
|
||||
|
||||
@@ -93,5 +93,11 @@ var NetworkAdmin = RolePermissions{
|
||||
operations.Update: false,
|
||||
operations.Delete: false,
|
||||
},
|
||||
modules.IdentityProviders: {
|
||||
operations.Read: true,
|
||||
operations.Create: false,
|
||||
operations.Update: false,
|
||||
operations.Delete: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -109,7 +109,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er
|
||||
ID: "peer1",
|
||||
}
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false)
|
||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, "", "", false)
|
||||
account.Users[admin.Id] = admin
|
||||
account.Users[user.Id] = user
|
||||
account.Peers["peer1"] = peer1
|
||||
|
||||
@@ -1320,7 +1320,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou
|
||||
accountID := "testingAcc"
|
||||
domain := "example.com"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
|
||||
account := newAccountWithId(context.Background(), accountID, userID, domain, "", "", false)
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
)
|
||||
|
||||
func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
@@ -24,7 +25,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
}
|
||||
|
||||
userID := "testingUser"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -99,7 +100,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
}
|
||||
|
||||
userID := "testingUser"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -204,7 +205,7 @@ func TestGetSetupKeys(t *testing.T) {
|
||||
}
|
||||
|
||||
userID := "testingUser"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -471,7 +472,7 @@ func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t
|
||||
}
|
||||
|
||||
userID := "testingUser"
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbutil "github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
// storeFileName Store file name. Stored in the datadir
|
||||
@@ -263,3 +264,8 @@ func (s *FileStore) Close(ctx context.Context) error {
|
||||
func (s *FileStore) GetStoreEngine() types.Engine {
|
||||
return types.FileStoreEngine
|
||||
}
|
||||
|
||||
// SetFieldEncrypt is a no-op for FileStore as it doesn't support field encryption.
|
||||
func (s *FileStore) SetFieldEncrypt(_ *crypt.FieldEncrypt) {
|
||||
// no-op: FileStore stores data in plaintext JSON; encryption is not supported
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -57,13 +58,13 @@ const (
|
||||
|
||||
// SqlStore represents an account storage backed by a Sql DB persisted to disk
|
||||
type SqlStore struct {
|
||||
db *gorm.DB
|
||||
globalAccountLock sync.Mutex
|
||||
metrics telemetry.AppMetrics
|
||||
installationPK int
|
||||
storeEngine types.Engine
|
||||
pool *pgxpool.Pool
|
||||
|
||||
db *gorm.DB
|
||||
globalAccountLock sync.Mutex
|
||||
metrics telemetry.AppMetrics
|
||||
installationPK int
|
||||
storeEngine types.Engine
|
||||
pool *pgxpool.Pool
|
||||
fieldEncrypt *crypt.FieldEncrypt
|
||||
transactionTimeout time.Duration
|
||||
}
|
||||
|
||||
@@ -175,6 +176,13 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
|
||||
|
||||
generateAccountSQLTypes(account)
|
||||
|
||||
// Encrypt sensitive user data before saving
|
||||
for i := range account.UsersG {
|
||||
if err := account.UsersG[i].EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return fmt.Errorf("encrypt user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, group := range account.GroupsG {
|
||||
group.StoreGroupPeers()
|
||||
}
|
||||
@@ -440,7 +448,18 @@ func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&users)
|
||||
usersCopy := make([]*types.User, len(users))
|
||||
for i, user := range users {
|
||||
userCopy := user.Copy()
|
||||
userCopy.Email = user.Email
|
||||
userCopy.Name = user.Name
|
||||
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return fmt.Errorf("encrypt user: %w", err)
|
||||
}
|
||||
usersCopy[i] = userCopy
|
||||
}
|
||||
|
||||
result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&usersCopy)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save users to store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save users to store")
|
||||
@@ -450,7 +469,15 @@ func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
|
||||
|
||||
// SaveUser saves the given user to the database.
|
||||
func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error {
|
||||
result := s.db.Save(user)
|
||||
userCopy := user.Copy()
|
||||
userCopy.Email = user.Email
|
||||
userCopy.Name = user.Name
|
||||
|
||||
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return fmt.Errorf("encrypt user: %w", err)
|
||||
}
|
||||
|
||||
result := s.db.Save(userCopy)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save user to store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save user to store")
|
||||
@@ -600,6 +627,10 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
|
||||
return nil, status.NewGetUserFromStoreError()
|
||||
}
|
||||
|
||||
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
@@ -618,6 +649,10 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
return nil, status.NewGetUserFromStoreError()
|
||||
}
|
||||
|
||||
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
@@ -654,6 +689,12 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
|
||||
return nil, status.Errorf(status.Internal, "issue getting users from store")
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
@@ -672,6 +713,10 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
|
||||
return nil, status.Errorf(status.Internal, "failed to get account owner from the store")
|
||||
}
|
||||
|
||||
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
@@ -866,6 +911,9 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
|
||||
if user.AutoGroups == nil {
|
||||
user.AutoGroups = []string{}
|
||||
}
|
||||
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||
}
|
||||
account.Users[user.Id] = &user
|
||||
user.PATsG = nil
|
||||
}
|
||||
@@ -1141,6 +1189,9 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
||||
account.Users = make(map[string]*types.User, len(account.UsersG))
|
||||
for i := range account.UsersG {
|
||||
user := &account.UsersG[i]
|
||||
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||
}
|
||||
user.PATs = make(map[string]*types.PersonalAccessToken)
|
||||
if userPats, ok := patsByUserID[user.Id]; ok {
|
||||
for j := range userPats {
|
||||
@@ -1545,7 +1596,7 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
|
||||
}
|
||||
|
||||
func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) {
|
||||
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1555,7 +1606,7 @@ func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User
|
||||
var autoGroups []byte
|
||||
var lastLogin, createdAt sql.NullTime
|
||||
var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool
|
||||
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType)
|
||||
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name)
|
||||
if err == nil {
|
||||
if lastLogin.Valid {
|
||||
u.LastLogin = &lastLogin.Time
|
||||
@@ -3012,6 +3063,11 @@ func (s *SqlStore) GetDB() *gorm.DB {
|
||||
return s.db
|
||||
}
|
||||
|
||||
// SetFieldEncrypt sets the field encryptor for encrypting sensitive user data.
|
||||
func (s *SqlStore) SetFieldEncrypt(enc *crypt.FieldEncrypt) {
|
||||
s.fieldEncrypt = enc
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
|
||||
@@ -32,6 +32,7 @@ import (
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
route2 "github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) {
|
||||
@@ -2090,7 +2091,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty
|
||||
setupKeys := map[string]*types.SetupKey{}
|
||||
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
||||
|
||||
owner := types.NewOwnerUser(userID)
|
||||
owner := types.NewOwnerUser(userID, "", "")
|
||||
owner.AccountID = accountID
|
||||
users[userID] = owner
|
||||
|
||||
@@ -3114,6 +3115,138 @@ func TestSqlStore_SaveUsers(t *testing.T) {
|
||||
require.Equal(t, users[1].AutoGroups, user.AutoGroups)
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveUserWithEncryption(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Enable encryption
|
||||
key, err := crypt.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
|
||||
require.NoError(t, err)
|
||||
store.SetFieldEncrypt(fieldEncrypt)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
// rawUser is used to read raw (potentially encrypted) data from the database
|
||||
// without any gorm hooks or automatic decryption
|
||||
type rawUser struct {
|
||||
Id string
|
||||
Email string
|
||||
Name string
|
||||
}
|
||||
|
||||
t.Run("save user with empty email and name", func(t *testing.T) {
|
||||
user := &types.User{
|
||||
Id: "user-empty-fields",
|
||||
AccountID: accountID,
|
||||
Role: types.UserRoleUser,
|
||||
Email: "",
|
||||
Name: "",
|
||||
AutoGroups: []string{"groupA"},
|
||||
}
|
||||
err = store.SaveUser(context.Background(), user)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify using direct database query that empty strings remain empty (not encrypted)
|
||||
var raw rawUser
|
||||
err = store.(*SqlStore).db.Table("users").Select("id, email, name").Where("id = ?", user.Id).First(&raw).Error
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", raw.Email, "empty email should remain empty in database")
|
||||
require.Equal(t, "", raw.Name, "empty name should remain empty in database")
|
||||
|
||||
// Verify manual decryption returns empty strings
|
||||
decryptedEmail, err := fieldEncrypt.Decrypt(raw.Email)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", decryptedEmail)
|
||||
|
||||
decryptedName, err := fieldEncrypt.Decrypt(raw.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", decryptedName)
|
||||
})
|
||||
|
||||
t.Run("save user with email and name", func(t *testing.T) {
|
||||
user := &types.User{
|
||||
Id: "user-with-fields",
|
||||
AccountID: accountID,
|
||||
Role: types.UserRoleAdmin,
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
AutoGroups: []string{"groupB"},
|
||||
}
|
||||
err = store.SaveUser(context.Background(), user)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify using direct database query that the data is encrypted (not plaintext)
|
||||
var raw rawUser
|
||||
err = store.(*SqlStore).db.Table("users").Select("id, email, name").Where("id = ?", user.Id).First(&raw).Error
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, "test@example.com", raw.Email, "email should be encrypted in database")
|
||||
require.NotEqual(t, "Test User", raw.Name, "name should be encrypted in database")
|
||||
|
||||
// Verify manual decryption returns correct values
|
||||
decryptedEmail, err := fieldEncrypt.Decrypt(raw.Email)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test@example.com", decryptedEmail)
|
||||
|
||||
decryptedName, err := fieldEncrypt.Decrypt(raw.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Test User", decryptedName)
|
||||
})
|
||||
|
||||
t.Run("save multiple users with mixed fields", func(t *testing.T) {
|
||||
users := []*types.User{
|
||||
{
|
||||
Id: "batch-user-1",
|
||||
AccountID: accountID,
|
||||
Email: "",
|
||||
Name: "",
|
||||
},
|
||||
{
|
||||
Id: "batch-user-2",
|
||||
AccountID: accountID,
|
||||
Email: "batch@example.com",
|
||||
Name: "Batch User",
|
||||
},
|
||||
}
|
||||
err = store.SaveUsers(context.Background(), users)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify first user (empty fields) using direct database query
|
||||
var raw1 rawUser
|
||||
err = store.(*SqlStore).db.Table("users").Select("id, email, name").Where("id = ?", "batch-user-1").First(&raw1).Error
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", raw1.Email, "empty email should remain empty in database")
|
||||
require.Equal(t, "", raw1.Name, "empty name should remain empty in database")
|
||||
|
||||
// Verify second user (with fields) using direct database query
|
||||
var raw2 rawUser
|
||||
err = store.(*SqlStore).db.Table("users").Select("id, email, name").Where("id = ?", "batch-user-2").First(&raw2).Error
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, "batch@example.com", raw2.Email, "email should be encrypted in database")
|
||||
require.NotEqual(t, "Batch User", raw2.Name, "name should be encrypted in database")
|
||||
|
||||
// Verify manual decryption returns empty strings for first user
|
||||
decryptedEmail1, err := fieldEncrypt.Decrypt(raw1.Email)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", decryptedEmail1)
|
||||
|
||||
decryptedName1, err := fieldEncrypt.Decrypt(raw1.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "", decryptedName1)
|
||||
|
||||
// Verify manual decryption returns correct values for second user
|
||||
decryptedEmail2, err := fieldEncrypt.Decrypt(raw2.Email)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "batch@example.com", decryptedEmail2)
|
||||
|
||||
decryptedName2, err := fieldEncrypt.Decrypt(raw2.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Batch User", decryptedName2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteUser(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/testutil"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/migration"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -204,6 +205,9 @@ type Store interface {
|
||||
MarkAccountPrimary(ctx context.Context, accountID string) error
|
||||
UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error
|
||||
GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) ([]*types.PolicyRule, error)
|
||||
|
||||
// SetFieldEncrypt sets the field encryptor for encrypting sensitive user data.
|
||||
SetFieldEncrypt(enc *crypt.FieldEncrypt)
|
||||
GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error)
|
||||
}
|
||||
|
||||
@@ -340,6 +344,12 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc {
|
||||
func(db *gorm.DB) error {
|
||||
return migration.DropIndex[routerTypes.NetworkRouter](ctx, db, "idx_network_routers_id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNewField[types.User](ctx, db, "name", "")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNewField[types.User](ctx, db, "email", "")
|
||||
},
|
||||
}
|
||||
} // migratePostAuto migrates the SQLite database to the latest schema
|
||||
func migratePostAuto(ctx context.Context, db *gorm.DB) error {
|
||||
|
||||
122
management/server/types/identity_provider.go
Normal file
122
management/server/types/identity_provider.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// Identity provider validation errors
|
||||
var (
|
||||
ErrIdentityProviderNameRequired = errors.New("identity provider name is required")
|
||||
ErrIdentityProviderTypeRequired = errors.New("identity provider type is required")
|
||||
ErrIdentityProviderTypeUnsupported = errors.New("unsupported identity provider type")
|
||||
ErrIdentityProviderIssuerRequired = errors.New("identity provider issuer is required")
|
||||
ErrIdentityProviderIssuerInvalid = errors.New("identity provider issuer must be a valid URL")
|
||||
ErrIdentityProviderClientIDRequired = errors.New("identity provider client ID is required")
|
||||
)
|
||||
|
||||
// IdentityProviderType is the type of identity provider
|
||||
type IdentityProviderType string
|
||||
|
||||
const (
|
||||
// IdentityProviderTypeOIDC is a generic OIDC identity provider
|
||||
IdentityProviderTypeOIDC IdentityProviderType = "oidc"
|
||||
// IdentityProviderTypeZitadel is the Zitadel identity provider
|
||||
IdentityProviderTypeZitadel IdentityProviderType = "zitadel"
|
||||
// IdentityProviderTypeEntra is the Microsoft Entra (Azure AD) identity provider
|
||||
IdentityProviderTypeEntra IdentityProviderType = "entra"
|
||||
// IdentityProviderTypeGoogle is the Google identity provider
|
||||
IdentityProviderTypeGoogle IdentityProviderType = "google"
|
||||
// IdentityProviderTypeOkta is the Okta identity provider
|
||||
IdentityProviderTypeOkta IdentityProviderType = "okta"
|
||||
// IdentityProviderTypePocketID is the PocketID identity provider
|
||||
IdentityProviderTypePocketID IdentityProviderType = "pocketid"
|
||||
// IdentityProviderTypeMicrosoft is the Microsoft identity provider
|
||||
IdentityProviderTypeMicrosoft IdentityProviderType = "microsoft"
|
||||
// IdentityProviderTypeAuthentik is the Authentik identity provider
|
||||
IdentityProviderTypeAuthentik IdentityProviderType = "authentik"
|
||||
// IdentityProviderTypeKeycloak is the Keycloak identity provider
|
||||
IdentityProviderTypeKeycloak IdentityProviderType = "keycloak"
|
||||
)
|
||||
|
||||
// IdentityProvider represents an identity provider configuration
|
||||
type IdentityProvider struct {
|
||||
// ID is the unique identifier of the identity provider
|
||||
ID string `gorm:"primaryKey"`
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
// Type is the type of identity provider
|
||||
Type IdentityProviderType
|
||||
// Name is a human-readable name for the identity provider
|
||||
Name string
|
||||
// Issuer is the OIDC issuer URL
|
||||
Issuer string
|
||||
// ClientID is the OAuth2 client ID
|
||||
ClientID string
|
||||
// ClientSecret is the OAuth2 client secret
|
||||
ClientSecret string
|
||||
}
|
||||
|
||||
// Copy returns a copy of the IdentityProvider
|
||||
func (idp *IdentityProvider) Copy() *IdentityProvider {
|
||||
return &IdentityProvider{
|
||||
ID: idp.ID,
|
||||
AccountID: idp.AccountID,
|
||||
Type: idp.Type,
|
||||
Name: idp.Name,
|
||||
Issuer: idp.Issuer,
|
||||
ClientID: idp.ClientID,
|
||||
ClientSecret: idp.ClientSecret,
|
||||
}
|
||||
}
|
||||
|
||||
// EventMeta returns a map of metadata for activity events
|
||||
func (idp *IdentityProvider) EventMeta() map[string]any {
|
||||
return map[string]any{
|
||||
"name": idp.Name,
|
||||
"type": string(idp.Type),
|
||||
"issuer": idp.Issuer,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate validates the identity provider configuration
|
||||
func (idp *IdentityProvider) Validate() error {
|
||||
if idp.Name == "" {
|
||||
return ErrIdentityProviderNameRequired
|
||||
}
|
||||
if idp.Type == "" {
|
||||
return ErrIdentityProviderTypeRequired
|
||||
}
|
||||
if !idp.Type.IsValid() {
|
||||
return ErrIdentityProviderTypeUnsupported
|
||||
}
|
||||
if !idp.Type.HasBuiltInIssuer() && idp.Issuer == "" {
|
||||
return ErrIdentityProviderIssuerRequired
|
||||
}
|
||||
if idp.Issuer != "" {
|
||||
parsedURL, err := url.Parse(idp.Issuer)
|
||||
if err != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
|
||||
return ErrIdentityProviderIssuerInvalid
|
||||
}
|
||||
}
|
||||
if idp.ClientID == "" {
|
||||
return ErrIdentityProviderClientIDRequired
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsValid checks if the given type is a supported identity provider type
|
||||
func (t IdentityProviderType) IsValid() bool {
|
||||
switch t {
|
||||
case IdentityProviderTypeOIDC, IdentityProviderTypeZitadel, IdentityProviderTypeEntra,
|
||||
IdentityProviderTypeGoogle, IdentityProviderTypeOkta, IdentityProviderTypePocketID,
|
||||
IdentityProviderTypeMicrosoft, IdentityProviderTypeAuthentik, IdentityProviderTypeKeycloak:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasBuiltInIssuer returns true for types that don't require an issuer URL
|
||||
func (t IdentityProviderType) HasBuiltInIssuer() bool {
|
||||
return t == IdentityProviderTypeGoogle || t == IdentityProviderTypeMicrosoft
|
||||
}
|
||||
137
management/server/types/identity_provider_test.go
Normal file
137
management/server/types/identity_provider_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIdentityProvider_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
idp *IdentityProvider
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "valid OIDC provider",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
Issuer: "https://example.com",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "valid OIDC provider with path",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
Issuer: "https://example.com/oauth2/issuer",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "missing name",
|
||||
idp: &IdentityProvider{
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
Issuer: "https://example.com",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderNameRequired,
|
||||
},
|
||||
{
|
||||
name: "missing type",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Issuer: "https://example.com",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderTypeRequired,
|
||||
},
|
||||
{
|
||||
name: "invalid type",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: "invalid",
|
||||
Issuer: "https://example.com",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderTypeUnsupported,
|
||||
},
|
||||
{
|
||||
name: "missing issuer for OIDC",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderIssuerRequired,
|
||||
},
|
||||
{
|
||||
name: "invalid issuer URL - no scheme",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
Issuer: "example.com",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderIssuerInvalid,
|
||||
},
|
||||
{
|
||||
name: "invalid issuer URL - no host",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
Issuer: "https://",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderIssuerInvalid,
|
||||
},
|
||||
{
|
||||
name: "invalid issuer URL - just path",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
Issuer: "/oauth2/issuer",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderIssuerInvalid,
|
||||
},
|
||||
{
|
||||
name: "missing client ID",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
Issuer: "https://example.com",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderClientIDRequired,
|
||||
},
|
||||
{
|
||||
name: "Google provider without issuer is valid",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Google SSO",
|
||||
Type: IdentityProviderTypeGoogle,
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "Microsoft provider without issuer is valid",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Microsoft SSO",
|
||||
Type: IdentityProviderTypeMicrosoft,
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.idp.Validate()
|
||||
assert.Equal(t, tt.expectedErr, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integration_reference"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -65,7 +66,11 @@ type UserInfo struct {
|
||||
LastLogin time.Time `json:"last_login"`
|
||||
Issued string `json:"issued"`
|
||||
PendingApproval bool `json:"pending_approval"`
|
||||
Password string `json:"password"`
|
||||
IntegrationReference integration_reference.IntegrationReference `json:"-"`
|
||||
// IdPID is the identity provider ID (connector ID) extracted from the Dex-encoded user ID.
|
||||
// This field is only populated when the user ID can be decoded from Dex's format.
|
||||
IdPID string `json:"idp_id,omitempty"`
|
||||
}
|
||||
|
||||
// User represents a user of the system
|
||||
@@ -96,6 +101,9 @@ type User struct {
|
||||
Issued string `gorm:"default:api"`
|
||||
|
||||
IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
|
||||
|
||||
Name string `gorm:"default:''"`
|
||||
Email string `gorm:"default:''"`
|
||||
}
|
||||
|
||||
// IsBlocked returns true if the user is blocked, false otherwise
|
||||
@@ -143,10 +151,16 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
}
|
||||
|
||||
if userData == nil {
|
||||
|
||||
name := u.Name
|
||||
if u.IsServiceUser {
|
||||
name = u.ServiceUserName
|
||||
}
|
||||
|
||||
return &UserInfo{
|
||||
ID: u.Id,
|
||||
Email: "",
|
||||
Name: u.ServiceUserName,
|
||||
Email: u.Email,
|
||||
Name: name,
|
||||
Role: string(u.Role),
|
||||
AutoGroups: u.AutoGroups,
|
||||
Status: string(UserStatusActive),
|
||||
@@ -178,6 +192,7 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
LastLogin: u.GetLastLogin(),
|
||||
Issued: u.Issued,
|
||||
PendingApproval: u.PendingApproval,
|
||||
Password: userData.Password,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -204,11 +219,13 @@ func (u *User) Copy() *User {
|
||||
CreatedAt: u.CreatedAt,
|
||||
Issued: u.Issued,
|
||||
IntegrationReference: u.IntegrationReference,
|
||||
Email: u.Email,
|
||||
Name: u.Name,
|
||||
}
|
||||
}
|
||||
|
||||
// NewUser creates a new user
|
||||
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User {
|
||||
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string, email string, name string) *User {
|
||||
return &User{
|
||||
Id: id,
|
||||
Role: role,
|
||||
@@ -218,20 +235,70 @@ func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, se
|
||||
AutoGroups: autoGroups,
|
||||
Issued: issued,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Name: name,
|
||||
Email: email,
|
||||
}
|
||||
}
|
||||
|
||||
// NewRegularUser creates a new user with role UserRoleUser
|
||||
func NewRegularUser(id string) *User {
|
||||
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI)
|
||||
func NewRegularUser(id, email, name string) *User {
|
||||
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI, email, name)
|
||||
}
|
||||
|
||||
// NewAdminUser creates a new user with role UserRoleAdmin
|
||||
func NewAdminUser(id string) *User {
|
||||
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI)
|
||||
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI, "", "")
|
||||
}
|
||||
|
||||
// NewOwnerUser creates a new user with role UserRoleOwner
|
||||
func NewOwnerUser(id string) *User {
|
||||
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI)
|
||||
func NewOwnerUser(id string, email string, name string) *User {
|
||||
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI, email, name)
|
||||
}
|
||||
|
||||
// EncryptSensitiveData encrypts the user's sensitive fields (Email and Name) in place.
|
||||
func (u *User) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||
if enc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
if u.Email != "" {
|
||||
u.Email, err = enc.Encrypt(u.Email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt email: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if u.Name != "" {
|
||||
u.Name, err = enc.Encrypt(u.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecryptSensitiveData decrypts the user's sensitive fields (Email and Name) in place.
|
||||
func (u *User) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||
if enc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
if u.Email != "" {
|
||||
u.Email, err = enc.Decrypt(u.Email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decrypt email: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if u.Name != "" {
|
||||
u.Name, err = enc.Decrypt(u.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decrypt name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
298
management/server/types/user_test.go
Normal file
298
management/server/types/user_test.go
Normal file
@@ -0,0 +1,298 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
func TestUser_EncryptSensitiveData(t *testing.T) {
|
||||
key, err := crypt.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("encrypt email and name", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-1",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, "test@example.com", user.Email, "email should be encrypted")
|
||||
assert.NotEqual(t, "Test User", user.Name, "name should be encrypted")
|
||||
assert.NotEmpty(t, user.Email, "encrypted email should not be empty")
|
||||
assert.NotEmpty(t, user.Name, "encrypted name should not be empty")
|
||||
})
|
||||
|
||||
t.Run("encrypt empty email and name", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-2",
|
||||
Email: "",
|
||||
Name: "",
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "", user.Email, "empty email should remain empty")
|
||||
assert.Equal(t, "", user.Name, "empty name should remain empty")
|
||||
})
|
||||
|
||||
t.Run("encrypt only email", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-3",
|
||||
Email: "test@example.com",
|
||||
Name: "",
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, "test@example.com", user.Email, "email should be encrypted")
|
||||
assert.NotEmpty(t, user.Email, "encrypted email should not be empty")
|
||||
assert.Equal(t, "", user.Name, "empty name should remain empty")
|
||||
})
|
||||
|
||||
t.Run("encrypt only name", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-4",
|
||||
Email: "",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "", user.Email, "empty email should remain empty")
|
||||
assert.NotEqual(t, "Test User", user.Name, "name should be encrypted")
|
||||
assert.NotEmpty(t, user.Name, "encrypted name should not be empty")
|
||||
})
|
||||
|
||||
t.Run("nil encryptor returns no error", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-5",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test@example.com", user.Email, "email should remain unchanged with nil encryptor")
|
||||
assert.Equal(t, "Test User", user.Name, "name should remain unchanged with nil encryptor")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_DecryptSensitiveData(t *testing.T) {
|
||||
key, err := crypt.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("decrypt email and name", func(t *testing.T) {
|
||||
originalEmail := "test@example.com"
|
||||
originalName := "Test User"
|
||||
|
||||
user := &User{
|
||||
Id: "user-1",
|
||||
Email: originalEmail,
|
||||
Name: originalName,
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = user.DecryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, originalEmail, user.Email, "decrypted email should match original")
|
||||
assert.Equal(t, originalName, user.Name, "decrypted name should match original")
|
||||
})
|
||||
|
||||
t.Run("decrypt empty email and name", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-2",
|
||||
Email: "",
|
||||
Name: "",
|
||||
}
|
||||
|
||||
err := user.DecryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "", user.Email, "empty email should remain empty")
|
||||
assert.Equal(t, "", user.Name, "empty name should remain empty")
|
||||
})
|
||||
|
||||
t.Run("decrypt only email", func(t *testing.T) {
|
||||
originalEmail := "test@example.com"
|
||||
|
||||
user := &User{
|
||||
Id: "user-3",
|
||||
Email: originalEmail,
|
||||
Name: "",
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = user.DecryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, originalEmail, user.Email, "decrypted email should match original")
|
||||
assert.Equal(t, "", user.Name, "empty name should remain empty")
|
||||
})
|
||||
|
||||
t.Run("decrypt only name", func(t *testing.T) {
|
||||
originalName := "Test User"
|
||||
|
||||
user := &User{
|
||||
Id: "user-4",
|
||||
Email: "",
|
||||
Name: originalName,
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = user.DecryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "", user.Email, "empty email should remain empty")
|
||||
assert.Equal(t, originalName, user.Name, "decrypted name should match original")
|
||||
})
|
||||
|
||||
t.Run("nil encryptor returns no error", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-5",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
err := user.DecryptSensitiveData(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test@example.com", user.Email, "email should remain unchanged with nil encryptor")
|
||||
assert.Equal(t, "Test User", user.Name, "name should remain unchanged with nil encryptor")
|
||||
})
|
||||
|
||||
t.Run("decrypt with invalid ciphertext returns error", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-6",
|
||||
Email: "not-valid-base64-ciphertext!!!",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
err := user.DecryptSensitiveData(fieldEncrypt)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "decrypt email")
|
||||
})
|
||||
|
||||
t.Run("decrypt with wrong key returns error", func(t *testing.T) {
|
||||
originalEmail := "test@example.com"
|
||||
originalName := "Test User"
|
||||
|
||||
user := &User{
|
||||
Id: "user-7",
|
||||
Email: originalEmail,
|
||||
Name: originalName,
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
differentKey, err := crypt.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
differentEncrypt, err := crypt.NewFieldEncrypt(differentKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = user.DecryptSensitiveData(differentEncrypt)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "decrypt email")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_EncryptDecryptRoundTrip(t *testing.T) {
|
||||
key, err := crypt.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
email string
|
||||
uname string
|
||||
}{
|
||||
{
|
||||
name: "standard email and name",
|
||||
email: "user@example.com",
|
||||
uname: "John Doe",
|
||||
},
|
||||
{
|
||||
name: "email with special characters",
|
||||
email: "user+tag@sub.example.com",
|
||||
uname: "O'Brien, Mary-Jane",
|
||||
},
|
||||
{
|
||||
name: "unicode characters",
|
||||
email: "user@example.com",
|
||||
uname: "Jean-Pierre Müller 日本語",
|
||||
},
|
||||
{
|
||||
name: "long values",
|
||||
email: "very.long.email.address.that.is.quite.extended@subdomain.example.organization.com",
|
||||
uname: "A Very Long Name That Contains Many Words And Is Quite Extended For Testing Purposes",
|
||||
},
|
||||
{
|
||||
name: "empty email only",
|
||||
email: "",
|
||||
uname: "Name Only",
|
||||
},
|
||||
{
|
||||
name: "empty name only",
|
||||
email: "email@only.com",
|
||||
uname: "",
|
||||
},
|
||||
{
|
||||
name: "both empty",
|
||||
email: "",
|
||||
uname: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "test-user",
|
||||
Email: tc.email,
|
||||
Name: tc.uname,
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tc.email != "" {
|
||||
assert.NotEqual(t, tc.email, user.Email, "email should be encrypted")
|
||||
}
|
||||
if tc.uname != "" {
|
||||
assert.NotEqual(t, tc.uname, user.Name, "name should be encrypted")
|
||||
}
|
||||
|
||||
err = user.DecryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.email, user.Email, "decrypted email should match original")
|
||||
assert.Equal(t, tc.uname, user.Name, "decrypted name should match original")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@@ -40,7 +41,7 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
newUserID := uuid.New().String()
|
||||
newUser := types.NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, types.UserIssuedAPI)
|
||||
newUser := types.NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, types.UserIssuedAPI, "", "")
|
||||
newUser.AccountID = accountID
|
||||
log.WithContext(ctx).Debugf("New User: %v", newUser)
|
||||
|
||||
@@ -104,7 +105,12 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
||||
inviterID = createdBy
|
||||
}
|
||||
|
||||
idpUser, err := am.createNewIdpUser(ctx, accountID, inviterID, invite)
|
||||
var idpUser *idp.UserData
|
||||
if IsEmbeddedIdp(am.idpManager) {
|
||||
idpUser, err = am.createEmbeddedIdpUser(ctx, accountID, inviterID, invite)
|
||||
} else {
|
||||
idpUser, err = am.createNewIdpUser(ctx, accountID, inviterID, invite)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -117,18 +123,26 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
||||
Issued: invite.Issued,
|
||||
IntegrationReference: invite.IntegrationReference,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Email: invite.Email,
|
||||
Name: invite.Name,
|
||||
}
|
||||
|
||||
if err = am.Store.SaveUser(ctx, newUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = am.refreshCache(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if !IsEmbeddedIdp(am.idpManager) {
|
||||
_, err = am.refreshCache(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, newUser.Id, accountID, activity.UserInvited, nil)
|
||||
eventType := activity.UserInvited
|
||||
if IsEmbeddedIdp(am.idpManager) {
|
||||
eventType = activity.UserCreated
|
||||
}
|
||||
am.StoreEvent(ctx, userID, newUser.Id, accountID, eventType, nil)
|
||||
|
||||
return newUser.ToUserInfo(idpUser)
|
||||
}
|
||||
@@ -172,6 +186,34 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID
|
||||
return am.idpManager.CreateUser(ctx, invite.Email, invite.Name, accountID, inviterUser.Email)
|
||||
}
|
||||
|
||||
// createEmbeddedIdpUser validates the invite and creates a new user in the embedded IdP.
|
||||
// Unlike createNewIdpUser, this method fetches user data directly from the database
|
||||
// since the embedded IdP usage ensures the username and email are stored locally in the User table.
|
||||
func (am *DefaultAccountManager) createEmbeddedIdpUser(ctx context.Context, accountID string, inviterID string, invite *types.UserInfo) (*idp.UserData, error) {
|
||||
inviter, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, inviterID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get inviter user: %w", err)
|
||||
}
|
||||
|
||||
if inviter == nil {
|
||||
return nil, status.Errorf(status.NotFound, "inviter user with ID %s doesn't exist", inviterID)
|
||||
}
|
||||
|
||||
// check if the user is already registered with this email => reject
|
||||
existingUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, user := range existingUsers {
|
||||
if strings.EqualFold(user.Email, invite.Email) {
|
||||
return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account")
|
||||
}
|
||||
}
|
||||
|
||||
return am.idpManager.CreateUser(ctx, invite.Email, invite.Name, accountID, inviter.Email)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
|
||||
return am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
|
||||
}
|
||||
@@ -757,7 +799,7 @@ func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initi
|
||||
// If the AccountManager has a non-nil idpManager and the User is not a service user,
|
||||
// it will attempt to look up the UserData from the cache.
|
||||
func (am *DefaultAccountManager) getUserInfo(ctx context.Context, user *types.User, accountID string) (*types.UserInfo, error) {
|
||||
if !isNil(am.idpManager) && !user.IsServiceUser {
|
||||
if !isNil(am.idpManager) && !user.IsServiceUser && !IsEmbeddedIdp(am.idpManager) {
|
||||
userData, err := am.lookupUserInCache(ctx, user.Id, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -808,7 +850,10 @@ func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUse
|
||||
}
|
||||
|
||||
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
|
||||
func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userID, domain string) (*types.Account, error) {
|
||||
func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userAuth auth.UserAuth) (*types.Account, error) {
|
||||
userID := userAuth.UserId
|
||||
domain := userAuth.Domain
|
||||
|
||||
start := time.Now()
|
||||
unlock := am.Store.AcquireGlobalLock(ctx)
|
||||
defer unlock()
|
||||
@@ -819,7 +864,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u
|
||||
account, err := am.Store.GetAccountByUser(ctx, userID)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
account, err = am.newAccount(ctx, userID, lowerDomain)
|
||||
account, err = am.newAccount(ctx, userID, lowerDomain, userAuth.Email, userAuth.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -884,7 +929,8 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
|
||||
var queriedUsers []*idp.UserData
|
||||
var err error
|
||||
|
||||
if !isNil(am.idpManager) {
|
||||
// embedded IdP ensures that we have user data (email and name) stored in the database.
|
||||
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
|
||||
users := make(map[string]userLoggedInOnce, len(accountUsers))
|
||||
usersFromIntegration := make([]*idp.UserData, 0)
|
||||
for _, user := range accountUsers {
|
||||
@@ -921,6 +967,10 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Try to decode Dex user ID to extract the IdP ID (connector ID)
|
||||
if _, connectorID, decodeErr := dex.DecodeDexUserID(accountUser.Id); decodeErr == nil && connectorID != "" {
|
||||
info.IdPID = connectorID
|
||||
}
|
||||
userInfosMap[accountUser.Id] = info
|
||||
}
|
||||
|
||||
@@ -942,7 +992,7 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
|
||||
|
||||
info = &types.UserInfo{
|
||||
ID: localUser.Id,
|
||||
Email: "",
|
||||
Email: localUser.Email,
|
||||
Name: name,
|
||||
Role: string(localUser.Role),
|
||||
AutoGroups: localUser.AutoGroups,
|
||||
@@ -951,6 +1001,10 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
|
||||
NonDeletable: localUser.NonDeletable,
|
||||
}
|
||||
}
|
||||
// Try to decode Dex user ID to extract the IdP ID (connector ID)
|
||||
if _, connectorID, decodeErr := dex.DecodeDexUserID(localUser.Id); decodeErr == nil && connectorID != "" {
|
||||
info.IdPID = connectorID
|
||||
}
|
||||
userInfosMap[info.ID] = info
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -29,6 +30,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integration_reference"
|
||||
@@ -58,7 +60,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||
|
||||
err = s.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -105,7 +107,7 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||
account.Users[mockTargetUserId] = &types.User{
|
||||
Id: mockTargetUserId,
|
||||
IsServiceUser: false,
|
||||
@@ -133,7 +135,7 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||
account.Users[mockTargetUserId] = &types.User{
|
||||
Id: mockTargetUserId,
|
||||
IsServiceUser: true,
|
||||
@@ -165,7 +167,7 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -190,7 +192,7 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -215,7 +217,7 @@ func TestUser_DeletePAT(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||
account.Users[mockUserID] = &types.User{
|
||||
Id: mockUserID,
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
@@ -258,7 +260,7 @@ func TestUser_GetPAT(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||
account.Users[mockUserID] = &types.User{
|
||||
Id: mockUserID,
|
||||
AccountID: mockAccountID,
|
||||
@@ -298,7 +300,7 @@ func TestUser_GetAllPATs(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||
account.Users[mockUserID] = &types.User{
|
||||
Id: mockUserID,
|
||||
AccountID: mockAccountID,
|
||||
@@ -362,6 +364,8 @@ func TestUser_Copy(t *testing.T) {
|
||||
ID: 0,
|
||||
IntegrationType: "test",
|
||||
},
|
||||
Email: "whatever@gmail.com",
|
||||
Name: "John Doe",
|
||||
}
|
||||
|
||||
err := validateStruct(user)
|
||||
@@ -408,7 +412,7 @@ func TestUser_CreateServiceUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -455,7 +459,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -503,7 +507,7 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -534,7 +538,7 @@ func TestUser_InviteNewUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -641,7 +645,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||
account.Users[mockServiceUserID] = tt.serviceUser
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
@@ -680,7 +684,7 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -707,7 +711,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||
|
||||
targetId := "user2"
|
||||
account.Users[targetId] = &types.User{
|
||||
@@ -801,7 +805,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||
|
||||
targetId := "user2"
|
||||
account.Users[targetId] = &types.User{
|
||||
@@ -969,7 +973,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -1005,9 +1009,9 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account.Users["normal_user1"] = types.NewRegularUser("normal_user1")
|
||||
account.Users["normal_user2"] = types.NewRegularUser("normal_user2")
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||
account.Users["normal_user1"] = types.NewRegularUser("normal_user1", "", "")
|
||||
account.Users["normal_user2"] = types.NewRegularUser("normal_user2", "", "")
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
@@ -1047,7 +1051,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||
externalUser := &types.User{
|
||||
Id: "externalUser",
|
||||
Role: types.UserRoleUser,
|
||||
@@ -1104,7 +1108,7 @@ func TestUser_IsAdmin(t *testing.T) {
|
||||
user := types.NewAdminUser(mockUserID)
|
||||
assert.True(t, user.HasAdminPower())
|
||||
|
||||
user = types.NewRegularUser(mockUserID)
|
||||
user = types.NewRegularUser(mockUserID, "", "")
|
||||
assert.False(t, user.HasAdminPower())
|
||||
}
|
||||
|
||||
@@ -1115,7 +1119,7 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||
account.Users[mockServiceUserID] = &types.User{
|
||||
Id: mockServiceUserID,
|
||||
Role: "user",
|
||||
@@ -1149,7 +1153,7 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||
account.Users[mockServiceUserID] = &types.User{
|
||||
Id: mockServiceUserID,
|
||||
Role: "user",
|
||||
@@ -1320,13 +1324,13 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
// create an account and an admin user
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), ownerUserID, "netbird.io")
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: ownerUserID, Domain: "netbird.io"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// create other users
|
||||
account.Users[regularUserID] = types.NewRegularUser(regularUserID)
|
||||
account.Users[regularUserID] = types.NewRegularUser(regularUserID, "", "")
|
||||
account.Users[adminUserID] = types.NewAdminUser(adminUserID)
|
||||
account.Users[serviceUserID] = &types.User{IsServiceUser: true, Id: serviceUserID, Role: types.UserRoleAdmin, ServiceUserName: "service"}
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
@@ -1516,7 +1520,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account1 := newAccountWithId(context.Background(), "account1", "ownerAccount1", "", false)
|
||||
account1 := newAccountWithId(context.Background(), "account1", "ownerAccount1", "", "", "", false)
|
||||
targetId := "user2"
|
||||
account1.Users[targetId] = &types.User{
|
||||
Id: targetId,
|
||||
@@ -1525,7 +1529,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
|
||||
}
|
||||
require.NoError(t, s.SaveAccount(context.Background(), account1))
|
||||
|
||||
account2 := newAccountWithId(context.Background(), "account2", "ownerAccount2", "", false)
|
||||
account2 := newAccountWithId(context.Background(), "account2", "ownerAccount2", "", "", "", false)
|
||||
require.NoError(t, s.SaveAccount(context.Background(), account2))
|
||||
|
||||
permissionsManager := permissions.NewManager(s)
|
||||
@@ -1552,7 +1556,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
account1 := newAccountWithId(context.Background(), "account1", "account1Owner", "", false)
|
||||
account1 := newAccountWithId(context.Background(), "account1", "account1Owner", "", "", "", false)
|
||||
account1.Settings.RegularUsersViewBlocked = false
|
||||
account1.Users["blocked-user"] = &types.User{
|
||||
Id: "blocked-user",
|
||||
@@ -1574,7 +1578,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
}
|
||||
require.NoError(t, store.SaveAccount(context.Background(), account1))
|
||||
|
||||
account2 := newAccountWithId(context.Background(), "account2", "account2Owner", "", false)
|
||||
account2 := newAccountWithId(context.Background(), "account2", "account2Owner", "", "", "", false)
|
||||
account2.Users["settings-blocked-user"] = &types.User{
|
||||
Id: "settings-blocked-user",
|
||||
Role: types.UserRoleUser,
|
||||
@@ -1771,7 +1775,7 @@ func TestApproveUser(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create account with admin and pending approval user
|
||||
account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", false)
|
||||
account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", "", "", false)
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1782,7 +1786,7 @@ func TestApproveUser(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create user pending approval
|
||||
pendingUser := types.NewRegularUser("pending-user")
|
||||
pendingUser := types.NewRegularUser("pending-user", "", "")
|
||||
pendingUser.AccountID = account.Id
|
||||
pendingUser.Blocked = true
|
||||
pendingUser.PendingApproval = true
|
||||
@@ -1807,12 +1811,12 @@ func TestApproveUser(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "not pending approval")
|
||||
|
||||
// Test approval by non-admin should fail
|
||||
regularUser := types.NewRegularUser("regular-user")
|
||||
regularUser := types.NewRegularUser("regular-user", "", "")
|
||||
regularUser.AccountID = account.Id
|
||||
err = manager.Store.SaveUser(context.Background(), regularUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
pendingUser2 := types.NewRegularUser("pending-user-2")
|
||||
pendingUser2 := types.NewRegularUser("pending-user-2", "", "")
|
||||
pendingUser2.AccountID = account.Id
|
||||
pendingUser2.Blocked = true
|
||||
pendingUser2.PendingApproval = true
|
||||
@@ -1830,7 +1834,7 @@ func TestRejectUser(t *testing.T) {
|
||||
}
|
||||
|
||||
// Create account with admin and pending approval user
|
||||
account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", false)
|
||||
account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", "", "", false)
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1841,7 +1845,7 @@ func TestRejectUser(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create user pending approval
|
||||
pendingUser := types.NewRegularUser("pending-user")
|
||||
pendingUser := types.NewRegularUser("pending-user", "", "")
|
||||
pendingUser.AccountID = account.Id
|
||||
pendingUser.Blocked = true
|
||||
pendingUser.PendingApproval = true
|
||||
@@ -1857,7 +1861,7 @@ func TestRejectUser(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
|
||||
// Test rejection of non-pending user should fail
|
||||
regularUser := types.NewRegularUser("regular-user")
|
||||
regularUser := types.NewRegularUser("regular-user", "", "")
|
||||
regularUser.AccountID = account.Id
|
||||
err = manager.Store.SaveUser(context.Background(), regularUser)
|
||||
require.NoError(t, err)
|
||||
@@ -1867,7 +1871,7 @@ func TestRejectUser(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "not pending approval")
|
||||
|
||||
// Test rejection by non-admin should fail
|
||||
pendingUser2 := types.NewRegularUser("pending-user-2")
|
||||
pendingUser2 := types.NewRegularUser("pending-user-2", "", "")
|
||||
pendingUser2.AccountID = account.Id
|
||||
pendingUser2.Blocked = true
|
||||
pendingUser2.PendingApproval = true
|
||||
@@ -1877,3 +1881,149 @@ func TestRejectUser(t *testing.T) {
|
||||
err = manager.RejectUser(context.Background(), account.Id, regularUser.Id, pendingUser2.Id)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUser_Operations_WithEmbeddedIDP(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create temporary directory for Dex
|
||||
tmpDir := t.TempDir()
|
||||
dexDataDir := tmpDir + "/dex"
|
||||
require.NoError(t, os.MkdirAll(dexDataDir, 0700))
|
||||
|
||||
// Create embedded IDP config
|
||||
embeddedIdPConfig := &idp.EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Issuer: "http://localhost:5556/dex",
|
||||
Storage: idp.EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: idp.EmbeddedStorageTypeConfig{
|
||||
File: dexDataDir + "/dex.db",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create embedded IDP manager
|
||||
embeddedIdp, err := idp.NewEmbeddedIdPManager(ctx, embeddedIdPConfig, nil)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = embeddedIdp.Stop(ctx) }()
|
||||
|
||||
// Create test store
|
||||
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", tmpDir)
|
||||
require.NoError(t, err)
|
||||
defer cleanup()
|
||||
|
||||
// Create account with owner user
|
||||
account := newAccountWithId(ctx, mockAccountID, mockUserID, "", "owner@test.com", "Owner User", false)
|
||||
require.NoError(t, testStore.SaveAccount(ctx, account))
|
||||
|
||||
// Create mock network map controller
|
||||
ctrl := gomock.NewController(t)
|
||||
networkMapControllerMock := network_map.NewMockController(ctrl)
|
||||
networkMapControllerMock.EXPECT().
|
||||
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(nil).
|
||||
AnyTimes()
|
||||
|
||||
// Create account manager with embedded IDP
|
||||
permissionsManager := permissions.NewManager(testStore)
|
||||
am := DefaultAccountManager{
|
||||
Store: testStore,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
permissionsManager: permissionsManager,
|
||||
idpManager: embeddedIdp,
|
||||
cacheLoading: map[string]chan struct{}{},
|
||||
networkMapController: networkMapControllerMock,
|
||||
}
|
||||
|
||||
// Initialize cache manager
|
||||
cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn)
|
||||
require.NoError(t, err)
|
||||
am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cacheStore)
|
||||
am.externalCacheManager = nbcache.NewUserDataCache(cacheStore)
|
||||
|
||||
t.Run("create regular user returns password", func(t *testing.T) {
|
||||
userInfo, err := am.CreateUser(ctx, mockAccountID, mockUserID, &types.UserInfo{
|
||||
Email: "newuser@test.com",
|
||||
Name: "New User",
|
||||
Role: "user",
|
||||
AutoGroups: []string{},
|
||||
IsServiceUser: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, userInfo)
|
||||
|
||||
// Verify user data
|
||||
assert.Equal(t, "newuser@test.com", userInfo.Email)
|
||||
assert.Equal(t, "New User", userInfo.Name)
|
||||
assert.Equal(t, "user", userInfo.Role)
|
||||
assert.NotEmpty(t, userInfo.ID)
|
||||
|
||||
// IMPORTANT: Password should be returned for embedded IDP
|
||||
assert.NotEmpty(t, userInfo.Password, "Password should be returned for embedded IDP user")
|
||||
t.Logf("Created user: ID=%s, Email=%s, Password=%s", userInfo.ID, userInfo.Email, userInfo.Password)
|
||||
|
||||
// Verify user ID is in Dex encoded format
|
||||
rawUserID, connectorID, err := dex.DecodeDexUserID(userInfo.ID)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, rawUserID)
|
||||
assert.Equal(t, "local", connectorID)
|
||||
t.Logf("Decoded user ID: rawUserID=%s, connectorID=%s", rawUserID, connectorID)
|
||||
|
||||
// Verify user exists in database with correct data
|
||||
dbUser, err := testStore.GetUserByUserID(ctx, store.LockingStrengthNone, userInfo.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "newuser@test.com", dbUser.Email)
|
||||
assert.Equal(t, "New User", dbUser.Name)
|
||||
|
||||
// Store user ID for delete test
|
||||
createdUserID := userInfo.ID
|
||||
|
||||
t.Run("delete user works", func(t *testing.T) {
|
||||
err := am.DeleteUser(ctx, mockAccountID, mockUserID, createdUserID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify user is deleted from database
|
||||
_, err = testStore.GetUserByUserID(ctx, store.LockingStrengthNone, createdUserID)
|
||||
assert.Error(t, err, "User should be deleted from database")
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("create service user does not return password", func(t *testing.T) {
|
||||
userInfo, err := am.CreateUser(ctx, mockAccountID, mockUserID, &types.UserInfo{
|
||||
Name: "Service User",
|
||||
Role: "user",
|
||||
AutoGroups: []string{},
|
||||
IsServiceUser: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, userInfo)
|
||||
|
||||
assert.True(t, userInfo.IsServiceUser)
|
||||
assert.Equal(t, "Service User", userInfo.Name)
|
||||
// Service users don't have passwords
|
||||
assert.Empty(t, userInfo.Password, "Service users should not have passwords")
|
||||
})
|
||||
|
||||
t.Run("duplicate email fails", func(t *testing.T) {
|
||||
// Create first user
|
||||
_, err := am.CreateUser(ctx, mockAccountID, mockUserID, &types.UserInfo{
|
||||
Email: "duplicate@test.com",
|
||||
Name: "First User",
|
||||
Role: "user",
|
||||
AutoGroups: []string{},
|
||||
IsServiceUser: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to create second user with same email
|
||||
_, err = am.CreateUser(ctx, mockAccountID, mockUserID, &types.UserInfo{
|
||||
Email: "duplicate@test.com",
|
||||
Name: "Second User",
|
||||
Role: "user",
|
||||
AutoGroups: []string{},
|
||||
IsServiceUser: false,
|
||||
})
|
||||
assert.Error(t, err, "Creating user with duplicate email should fail")
|
||||
t.Logf("Duplicate email error: %v", err)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user