[management, infrastructure, idp] Simplified IdP Management - Embedded IdP (#5008)

Embed Dex as a built-in IdP to simplify self-hosting setup.
Adds an embedded OIDC Identity Provider (Dex) with local user management and optional external IdP connectors (Google/GitHub/OIDC/SAML), plus device-auth flow for CLI login. Introduces instance onboarding/setup endpoints (including owner creation), field-level encryption for sensitive user data, a streamlined self-hosting provisioning script, and expanded APIs + test coverage for IdP management.

more at https://github.com/netbirdio/netbird/pull/5008#issuecomment-3718987393
This commit is contained in:
Misha Bragin
2026-01-07 08:52:32 -05:00
committed by GitHub
parent 5393ad948f
commit e586c20e36
90 changed files with 7702 additions and 517 deletions

View File

@@ -24,6 +24,7 @@ import (
"github.com/netbirdio/netbird/management/internals/server"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/crypt"
)
var newServer = func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) server.Server {
@@ -135,76 +136,208 @@ var (
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Config, error) {
loadedConfig := &nbconfig.Config{}
_, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig)
if _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig); err != nil {
return nil, err
}
applyCommandLineOverrides(loadedConfig)
// Apply EmbeddedIdP config to HttpConfig if embedded IdP is enabled
err := applyEmbeddedIdPConfig(loadedConfig)
if err != nil {
return nil, err
}
if err := applyOIDCConfig(ctx, loadedConfig); err != nil {
return nil, err
}
logConfigInfo(loadedConfig)
if err := ensureEncryptionKey(ctx, mgmtConfigPath, loadedConfig); err != nil {
return nil, err
}
return loadedConfig, nil
}
// applyCommandLineOverrides applies command-line flag overrides to the config
func applyCommandLineOverrides(cfg *nbconfig.Config) {
if mgmtLetsencryptDomain != "" {
loadedConfig.HttpConfig.LetsEncryptDomain = mgmtLetsencryptDomain
cfg.HttpConfig.LetsEncryptDomain = mgmtLetsencryptDomain
}
if mgmtDataDir != "" {
loadedConfig.Datadir = mgmtDataDir
cfg.Datadir = mgmtDataDir
}
if certKey != "" && certFile != "" {
loadedConfig.HttpConfig.CertFile = certFile
loadedConfig.HttpConfig.CertKey = certKey
cfg.HttpConfig.CertFile = certFile
cfg.HttpConfig.CertKey = certKey
}
}
// applyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
// This allows users to only specify EmbeddedIdP config without duplicating values in HttpConfig.
func applyEmbeddedIdPConfig(cfg *nbconfig.Config) error {
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
return nil
}
oidcEndpoint := loadedConfig.HttpConfig.OIDCConfigEndpoint
if oidcEndpoint != "" {
// if OIDCConfigEndpoint is specified, we can load DeviceAuthEndpoint and TokenEndpoint automatically
log.WithContext(ctx).Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint)
oidcConfig, err := fetchOIDCConfig(ctx, oidcEndpoint)
if err != nil {
return nil, err
}
log.WithContext(ctx).Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
// apply some defaults based on the EmbeddedIdP config
if disableSingleAccMode {
// Embedded IdP requires single account mode - multiple account mode is not supported
return fmt.Errorf("embedded IdP requires single account mode; multiple account mode is not supported with embedded IdP. Please remove --disable-single-account-mode flag")
}
// Enable user deletion from IDP by default if EmbeddedIdP is enabled
userDeleteFromIDPEnabled = true
log.WithContext(ctx).Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
oidcConfig.Issuer, loadedConfig.HttpConfig.AuthIssuer)
loadedConfig.HttpConfig.AuthIssuer = oidcConfig.Issuer
log.WithContext(ctx).Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s",
oidcConfig.JwksURI, loadedConfig.HttpConfig.AuthKeysLocation)
loadedConfig.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
if !(loadedConfig.DeviceAuthorizationFlow == nil || strings.ToLower(loadedConfig.DeviceAuthorizationFlow.Provider) == string(nbconfig.NONE)) {
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.TokenEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint)
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.DeviceAuthEndpoint, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint)
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint = oidcConfig.DeviceAuthEndpoint
u, err := url.Parse(oidcEndpoint)
if err != nil {
return nil, err
}
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s",
u.Host, loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain)
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host
if loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope == "" {
loadedConfig.DeviceAuthorizationFlow.ProviderConfig.Scope = nbconfig.DefaultDeviceAuthFlowScope
}
}
if loadedConfig.PKCEAuthorizationFlow != nil {
log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.TokenEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint)
loadedConfig.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.AuthorizationEndpoint, loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint)
loadedConfig.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint
}
// Ensure HttpConfig exists
if cfg.HttpConfig == nil {
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
}
if loadedConfig.Relay != nil {
log.Infof("Relay addresses: %v", loadedConfig.Relay.Addresses)
// Set storage defaults based on Datadir
if cfg.EmbeddedIdP.Storage.Type == "" {
cfg.EmbeddedIdP.Storage.Type = "sqlite3"
}
if cfg.EmbeddedIdP.Storage.Config.File == "" && cfg.Datadir != "" {
cfg.EmbeddedIdP.Storage.Config.File = path.Join(cfg.Datadir, "idp.db")
}
return loadedConfig, err
issuer := cfg.EmbeddedIdP.Issuer
// Set AuthIssuer from EmbeddedIdP issuer
if cfg.HttpConfig.AuthIssuer == "" {
cfg.HttpConfig.AuthIssuer = issuer
}
// Set AuthAudience to the dashboard client ID
if cfg.HttpConfig.AuthAudience == "" {
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
}
// Set AuthUserIDClaim to "sub" (standard OIDC claim)
if cfg.HttpConfig.AuthUserIDClaim == "" {
cfg.HttpConfig.AuthUserIDClaim = "sub"
}
// Set AuthKeysLocation to the JWKS endpoint
if cfg.HttpConfig.AuthKeysLocation == "" {
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
}
// Set OIDCConfigEndpoint to the discovery endpoint
if cfg.HttpConfig.OIDCConfigEndpoint == "" {
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
}
// Copy SignKeyRefreshEnabled from EmbeddedIdP config
if cfg.EmbeddedIdP.SignKeyRefreshEnabled {
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
}
return nil
}
// applyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
oidcEndpoint := cfg.HttpConfig.OIDCConfigEndpoint
if oidcEndpoint == "" || cfg.EmbeddedIdP != nil {
return nil
}
log.WithContext(ctx).Infof("loading OIDC configuration from the provided IDP configuration endpoint %s", oidcEndpoint)
oidcConfig, err := fetchOIDCConfig(ctx, oidcEndpoint)
if err != nil {
return err
}
log.WithContext(ctx).Infof("loaded OIDC configuration from the provided IDP configuration endpoint: %s", oidcEndpoint)
log.WithContext(ctx).Infof("overriding HttpConfig.AuthIssuer with a new value %s, previously configured value: %s",
oidcConfig.Issuer, cfg.HttpConfig.AuthIssuer)
cfg.HttpConfig.AuthIssuer = oidcConfig.Issuer
log.WithContext(ctx).Infof("overriding HttpConfig.AuthKeysLocation (JWT certs) with a new value %s, previously configured value: %s",
oidcConfig.JwksURI, cfg.HttpConfig.AuthKeysLocation)
cfg.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
if err := applyDeviceAuthFlowConfig(ctx, cfg, &oidcConfig, oidcEndpoint); err != nil {
return err
}
applyPKCEFlowConfig(ctx, cfg, &oidcConfig)
return nil
}
// applyDeviceAuthFlowConfig applies OIDC config to DeviceAuthorizationFlow if enabled
func applyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse, oidcEndpoint string) error {
if cfg.DeviceAuthorizationFlow == nil || strings.ToLower(cfg.DeviceAuthorizationFlow.Provider) == string(nbconfig.NONE) {
return nil
}
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.TokenEndpoint, cfg.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint)
cfg.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.DeviceAuthEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.DeviceAuthEndpoint, cfg.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint)
cfg.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint = oidcConfig.DeviceAuthEndpoint
u, err := url.Parse(oidcEndpoint)
if err != nil {
return err
}
log.WithContext(ctx).Infof("overriding DeviceAuthorizationFlow.ProviderConfig.Domain with a new value: %s, previously configured value: %s",
u.Host, cfg.DeviceAuthorizationFlow.ProviderConfig.Domain)
cfg.DeviceAuthorizationFlow.ProviderConfig.Domain = u.Host
if cfg.DeviceAuthorizationFlow.ProviderConfig.Scope == "" {
cfg.DeviceAuthorizationFlow.ProviderConfig.Scope = nbconfig.DefaultDeviceAuthFlowScope
}
return nil
}
// applyPKCEFlowConfig applies OIDC config to PKCEAuthorizationFlow if configured
func applyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse) {
if cfg.PKCEAuthorizationFlow == nil {
return
}
log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.TokenEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.TokenEndpoint, cfg.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint)
cfg.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint = oidcConfig.TokenEndpoint
log.WithContext(ctx).Infof("overriding PKCEAuthorizationFlow.AuthorizationEndpoint with a new value: %s, previously configured value: %s",
oidcConfig.AuthorizationEndpoint, cfg.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint)
cfg.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint
}
// logConfigInfo logs informational messages about the loaded configuration
func logConfigInfo(cfg *nbconfig.Config) {
if cfg.EmbeddedIdP != nil {
log.Infof("running with the embedded IdP: %v", cfg.EmbeddedIdP.Issuer)
}
if cfg.Relay != nil {
log.Infof("Relay addresses: %v", cfg.Relay.Addresses)
}
}
// ensureEncryptionKey generates and saves a DataStoreEncryptionKey if not set
func ensureEncryptionKey(ctx context.Context, configPath string, cfg *nbconfig.Config) error {
if cfg.DataStoreEncryptionKey != "" {
return nil
}
log.WithContext(ctx).Infof("DataStoreEncryptionKey is not set, generating a new key")
key, err := crypt.GenerateKey()
if err != nil {
return fmt.Errorf("failed to generate datastore encryption key: %v", err)
}
cfg.DataStoreEncryptionKey = key
if err := util.DirectWriteJson(ctx, configPath, cfg); err != nil {
return fmt.Errorf("failed to save config with new encryption key: %v", err)
}
log.WithContext(ctx).Infof("DataStoreEncryptionKey generated and saved to config")
return nil
}
// OIDCConfigResponse used for parsing OIDC config response

View File

@@ -309,7 +309,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis
setupKeys := map[string]*types.SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
owner := types.NewOwnerUser(userID)
owner := types.NewOwnerUser(userID, "", "")
owner.AccountID = accountID
users[userID] = owner

View File

@@ -21,7 +21,6 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/activity"
nbContext "github.com/netbirdio/netbird/management/server/context"
@@ -29,6 +28,7 @@ import (
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util/crypt"
)
var (
@@ -62,6 +62,14 @@ func (s *BaseServer) Store() store.Store {
log.Fatalf("failed to create store: %v", err)
}
if s.Config.DataStoreEncryptionKey != "" {
fieldEncrypt, err := crypt.NewFieldEncrypt(s.Config.DataStoreEncryptionKey)
if err != nil {
log.Fatalf("failed to create field encryptor: %v", err)
}
store.SetFieldEncrypt(fieldEncrypt)
}
return store
})
}
@@ -73,27 +81,18 @@ func (s *BaseServer) EventStore() activity.Store {
log.Fatalf("failed to initialize integration metrics: %v", err)
}
eventStore, key, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics)
eventStore, _, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics)
if err != nil {
log.Fatalf("failed to initialize event store: %v", err)
}
if s.Config.DataStoreEncryptionKey != key {
log.WithContext(context.Background()).Infof("update Config with activity store key")
s.Config.DataStoreEncryptionKey = key
err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.Config)
if err != nil {
log.Fatalf("failed to update Config with activity store: %v", err)
}
}
return eventStore
})
}
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController())
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController(), s.IdpManager())
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}
@@ -145,7 +144,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
}
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController())
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider())
if err != nil {
log.Fatalf("failed to create management server: %v", err)
}

View File

@@ -57,6 +57,10 @@ type Config struct {
// disable default all-to-all policy
DisableDefaultPolicy bool
// EmbeddedIdP contains configuration for the embedded Dex OIDC provider.
// When set, Dex will be embedded in the management server and serve requests at /oauth2/
EmbeddedIdP *idp.EmbeddedIdPConfig
}
// GetAuthAudiences returns the audience from the http config and device authorization flow config

View File

@@ -44,6 +44,9 @@ func maybeCreateNamed[T any](s Server, name string, createFunc func() T) (result
func maybeCreateKeyed[T any](s Server, key string, createFunc func() T) (result T, isNew bool) {
if t, ok := s.GetContainer(key); ok {
if t == nil {
return result, false
}
return t.(T), false
}

View File

@@ -55,14 +55,33 @@ func (s *BaseServer) SecretsManager() grpc.SecretsManager {
}
func (s *BaseServer) AuthManager() auth.Manager {
audiences := s.Config.GetAuthAudiences()
audience := s.Config.HttpConfig.AuthAudience
keysLocation := s.Config.HttpConfig.AuthKeysLocation
signingKeyRefreshEnabled := s.Config.HttpConfig.IdpSignKeyRefreshEnabled
issuer := s.Config.HttpConfig.AuthIssuer
userIDClaim := s.Config.HttpConfig.AuthUserIDClaim
// Use embedded IdP configuration if available
if oauthProvider := s.OAuthConfigProvider(); oauthProvider != nil {
audiences = oauthProvider.GetClientIDs()
if len(audiences) > 0 {
audience = audiences[0] // Use the first client ID as the primary audience
}
keysLocation = oauthProvider.GetKeysLocation()
signingKeyRefreshEnabled = true
issuer = oauthProvider.GetIssuer()
userIDClaim = oauthProvider.GetUserIDClaim()
}
return Create(s, func() auth.Manager {
return auth.NewManager(s.Store(),
s.Config.HttpConfig.AuthIssuer,
s.Config.HttpConfig.AuthAudience,
s.Config.HttpConfig.AuthKeysLocation,
s.Config.HttpConfig.AuthUserIDClaim,
s.Config.GetAuthAudiences(),
s.Config.HttpConfig.IdpSignKeyRefreshEnabled)
issuer,
audience,
keysLocation,
userIDClaim,
audiences,
signingKeyRefreshEnabled)
})
}

View File

@@ -95,6 +95,17 @@ func (s *BaseServer) IdpManager() idp.Manager {
return Create(s, func() idp.Manager {
var idpManager idp.Manager
var err error
// Use embedded IdP manager if embedded Dex is configured and enabled.
// Legacy IdpManager won't be used anymore even if configured.
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
if err != nil {
log.Fatalf("failed to create embedded IDP manager: %v", err)
}
return idpManager
}
// Fall back to external IdP manager
if s.Config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
if err != nil {
@@ -105,6 +116,25 @@ func (s *BaseServer) IdpManager() idp.Manager {
})
}
// OAuthConfigProvider is only relevant when we have an embedded IdP manager. Otherwise must be nil
func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled {
return nil
}
idpManager := s.IdpManager()
if idpManager == nil {
return nil
}
// Reuse the EmbeddedIdPManager instance from IdpManager
// EmbeddedIdPManager implements both idp.Manager and idp.OAuthConfigProvider
if provider, ok := idpManager.(idp.OAuthConfigProvider); ok {
return provider
}
return nil
}
func (s *BaseServer) GroupsManager() groups.Manager {
return Create(s, func() groups.Manager {
return groups.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager())

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/netbirdio/netbird/management/server/idp"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"golang.org/x/crypto/acme/autocert"
@@ -22,7 +23,6 @@ import (
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/wsproxy"
wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server"
"github.com/netbirdio/netbird/version"
@@ -40,7 +40,7 @@ type Server interface {
SetContainer(key string, container any)
}
// Server holds the HTTP BaseServer instance.
// BaseServer holds the HTTP server instance.
// Add any additional fields you need, such as database connections, Config, etc.
type BaseServer struct {
// Config holds the server configuration
@@ -144,7 +144,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
}
rootHandler := s.handlerFunc(s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter())
rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter())
switch {
case s.certManager != nil:
// a call to certManager.Listener() always creates a new listener so we do it once
@@ -215,6 +215,10 @@ func (s *BaseServer) Stop() error {
if s.update != nil {
s.update.StopWatch()
}
// Stop embedded IdP if configured
if embeddedIdP, ok := s.IdpManager().(*idp.EmbeddedIdPManager); ok {
_ = embeddedIdP.Stop(ctx)
}
select {
case <-s.Errors():
@@ -246,11 +250,7 @@ func (s *BaseServer) SetContainer(key string, container any) {
log.Tracef("container with key %s set successfully", key)
}
func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config) error {
return util.DirectWriteJson(ctx, path, config)
}
func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
wsProxy := wsproxyserver.New(gRPCHandler, wsproxyserver.WithOTelMeter(meter))
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {

View File

@@ -16,6 +16,7 @@ import (
pb "github.com/golang/protobuf/proto" // nolint
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/netbirdio/netbird/shared/management/client/common"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
@@ -24,6 +25,7 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
"github.com/netbirdio/netbird/management/server/store"
@@ -69,6 +71,8 @@ type Server struct {
networkMapController network_map.Controller
oAuthConfigProvider idp.OAuthConfigProvider
syncSem atomic.Int32
syncLim int32
}
@@ -83,6 +87,7 @@ func NewServer(
authManager auth.Manager,
integratedPeerValidator integrated_validator.IntegratedValidator,
networkMapController network_map.Controller,
oAuthConfigProvider idp.OAuthConfigProvider,
) (*Server, error) {
if appMetrics != nil {
// update gauge based on number of connected peers which is equal to open gRPC streams
@@ -119,6 +124,7 @@ func NewServer(
blockPeersWithSameConfig: blockPeersWithSameConfig,
integratedPeerValidator: integratedPeerValidator,
networkMapController: networkMapController,
oAuthConfigProvider: oAuthConfigProvider,
loginFilter: newLoginFilter(),
@@ -761,32 +767,48 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr
return nil, status.Error(codes.InvalidArgument, errMSG)
}
if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(nbconfig.NONE) {
return nil, status.Error(codes.NotFound, "no device authorization flow information available")
}
var flowInfoResp *proto.DeviceAuthorizationFlow
provider, ok := proto.DeviceAuthorizationFlowProvider_value[strings.ToUpper(s.config.DeviceAuthorizationFlow.Provider)]
if !ok {
return nil, status.Errorf(codes.InvalidArgument, "no provider found in the protocol for %s", s.config.DeviceAuthorizationFlow.Provider)
}
// Use embedded IdP configuration if available
if s.oAuthConfigProvider != nil {
flowInfoResp = &proto.DeviceAuthorizationFlow{
Provider: proto.DeviceAuthorizationFlow_HOSTED,
ProviderConfig: &proto.ProviderConfig{
ClientID: s.oAuthConfigProvider.GetCLIClientID(),
Audience: s.oAuthConfigProvider.GetCLIClientID(),
DeviceAuthEndpoint: s.oAuthConfigProvider.GetDeviceAuthEndpoint(),
TokenEndpoint: s.oAuthConfigProvider.GetTokenEndpoint(),
Scope: s.oAuthConfigProvider.GetDefaultScopes(),
},
}
} else {
if s.config.DeviceAuthorizationFlow == nil || s.config.DeviceAuthorizationFlow.Provider == string(nbconfig.NONE) {
return nil, status.Error(codes.NotFound, "no device authorization flow information available")
}
flowInfoResp := &proto.DeviceAuthorizationFlow{
Provider: proto.DeviceAuthorizationFlowProvider(provider),
ProviderConfig: &proto.ProviderConfig{
ClientID: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientID,
ClientSecret: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientSecret,
Domain: s.config.DeviceAuthorizationFlow.ProviderConfig.Domain,
Audience: s.config.DeviceAuthorizationFlow.ProviderConfig.Audience,
DeviceAuthEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint,
TokenEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint,
Scope: s.config.DeviceAuthorizationFlow.ProviderConfig.Scope,
UseIDToken: s.config.DeviceAuthorizationFlow.ProviderConfig.UseIDToken,
},
provider, ok := proto.DeviceAuthorizationFlowProvider_value[strings.ToUpper(s.config.DeviceAuthorizationFlow.Provider)]
if !ok {
return nil, status.Errorf(codes.InvalidArgument, "no provider found in the protocol for %s", s.config.DeviceAuthorizationFlow.Provider)
}
flowInfoResp = &proto.DeviceAuthorizationFlow{
Provider: proto.DeviceAuthorizationFlowProvider(provider),
ProviderConfig: &proto.ProviderConfig{
ClientID: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientID,
ClientSecret: s.config.DeviceAuthorizationFlow.ProviderConfig.ClientSecret,
Domain: s.config.DeviceAuthorizationFlow.ProviderConfig.Domain,
Audience: s.config.DeviceAuthorizationFlow.ProviderConfig.Audience,
DeviceAuthEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.DeviceAuthEndpoint,
TokenEndpoint: s.config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint,
Scope: s.config.DeviceAuthorizationFlow.ProviderConfig.Scope,
UseIDToken: s.config.DeviceAuthorizationFlow.ProviderConfig.UseIDToken,
},
}
}
encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp)
if err != nil {
return nil, status.Error(codes.Internal, "failed to encrypt no device authorization flow information")
return nil, status.Error(codes.Internal, "failed to encrypt device authorization flow information")
}
return &proto.EncryptedMessage{
@@ -820,30 +842,47 @@ func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.Encryp
return nil, status.Error(codes.InvalidArgument, errMSG)
}
if s.config.PKCEAuthorizationFlow == nil {
return nil, status.Error(codes.NotFound, "no pkce authorization flow information available")
}
var initInfoFlow *proto.PKCEAuthorizationFlow
initInfoFlow := &proto.PKCEAuthorizationFlow{
ProviderConfig: &proto.ProviderConfig{
Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience,
ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID,
ClientSecret: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientSecret,
TokenEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint,
AuthorizationEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint,
Scope: s.config.PKCEAuthorizationFlow.ProviderConfig.Scope,
RedirectURLs: s.config.PKCEAuthorizationFlow.ProviderConfig.RedirectURLs,
UseIDToken: s.config.PKCEAuthorizationFlow.ProviderConfig.UseIDToken,
DisablePromptLogin: s.config.PKCEAuthorizationFlow.ProviderConfig.DisablePromptLogin,
LoginFlag: uint32(s.config.PKCEAuthorizationFlow.ProviderConfig.LoginFlag),
},
// Use embedded IdP configuration if available
if s.oAuthConfigProvider != nil {
initInfoFlow = &proto.PKCEAuthorizationFlow{
ProviderConfig: &proto.ProviderConfig{
Audience: s.oAuthConfigProvider.GetCLIClientID(),
ClientID: s.oAuthConfigProvider.GetCLIClientID(),
TokenEndpoint: s.oAuthConfigProvider.GetTokenEndpoint(),
AuthorizationEndpoint: s.oAuthConfigProvider.GetAuthorizationEndpoint(),
Scope: s.oAuthConfigProvider.GetDefaultScopes(),
RedirectURLs: s.oAuthConfigProvider.GetCLIRedirectURLs(),
LoginFlag: uint32(common.LoginFlagPromptLogin),
},
}
} else {
if s.config.PKCEAuthorizationFlow == nil {
return nil, status.Error(codes.NotFound, "no pkce authorization flow information available")
}
initInfoFlow = &proto.PKCEAuthorizationFlow{
ProviderConfig: &proto.ProviderConfig{
Audience: s.config.PKCEAuthorizationFlow.ProviderConfig.Audience,
ClientID: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientID,
ClientSecret: s.config.PKCEAuthorizationFlow.ProviderConfig.ClientSecret,
TokenEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.TokenEndpoint,
AuthorizationEndpoint: s.config.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint,
Scope: s.config.PKCEAuthorizationFlow.ProviderConfig.Scope,
RedirectURLs: s.config.PKCEAuthorizationFlow.ProviderConfig.RedirectURLs,
UseIDToken: s.config.PKCEAuthorizationFlow.ProviderConfig.UseIDToken,
DisablePromptLogin: s.config.PKCEAuthorizationFlow.ProviderConfig.DisablePromptLogin,
LoginFlag: uint32(s.config.PKCEAuthorizationFlow.ProviderConfig.LoginFlag),
},
}
}
flowInfoResp := s.integratedPeerValidator.ValidateFlowResponse(ctx, peerKey.String(), initInfoFlow)
encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp)
if err != nil {
return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information")
return nil, status.Error(codes.Internal, "failed to encrypt pkce authorization flow information")
}
return &proto.EncryptedMessage{

View File

@@ -243,7 +243,7 @@ func BuildManager(
am.externalCacheManager = nbcache.NewUserDataCache(cacheStore)
am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cacheStore)
if !isNil(am.idpManager) {
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
go func() {
err := am.warmupIDPCache(ctx, cacheStore)
if err != nil {
@@ -557,7 +557,7 @@ func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx co
// newAccount creates a new Account with a generated ID and generated default setup keys.
// If ID is already in use (due to collision) we try one more time before returning error
func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*types.Account, error) {
func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain, email, name string) (*types.Account, error) {
for i := 0; i < 2; i++ {
accountId := xid.New().String()
@@ -568,7 +568,7 @@ func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain
log.WithContext(ctx).Warnf("an account with ID already exists, retrying...")
continue
case statusErr.Type() == status.NotFound:
newAccount := newAccountWithId(ctx, accountId, userID, domain, am.disableDefaultPolicy)
newAccount := newAccountWithId(ctx, accountId, userID, domain, email, name, am.disableDefaultPolicy)
am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil)
return newAccount, nil
default:
@@ -741,23 +741,23 @@ func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID st
// If user does have an account, it returns the user's account ID.
// If the user doesn't have an account, it creates one using the provided domain.
// Returns the account ID or an error if none is found or created.
func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) {
if userID == "" {
func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth auth.UserAuth) (string, error) {
if userAuth.UserId == "" {
return "", status.Errorf(status.NotFound, "no valid userID provided")
}
accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID)
accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
acc, err := am.GetOrCreateAccountByUser(ctx, userAuth)
if err != nil {
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userAuth.UserId)
}
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil {
if err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, acc.Id); err != nil {
return "", err
}
return account.Id, nil
return acc.Id, nil
}
return "", err
}
@@ -768,9 +768,19 @@ func isNil(i idp.Manager) bool {
return i == nil || reflect.ValueOf(i).IsNil()
}
// IsEmbeddedIdp checks if the IDP manager is an embedded IDP (data stored locally in DB).
// When true, user cache should be skipped and data fetched directly from the IDP manager.
func IsEmbeddedIdp(i idp.Manager) bool {
if isNil(i) {
return false
}
_, ok := i.(*idp.EmbeddedIdPManager)
return ok
}
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
if !isNil(am.idpManager) {
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
// user can be nil if it wasn't found (e.g., just created)
user, err := am.lookupUserInCache(ctx, userID, accountID)
if err != nil {
@@ -1016,6 +1026,9 @@ func (am *DefaultAccountManager) isCacheFresh(ctx context.Context, accountUsers
}
func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accountID, userID string) error {
if IsEmbeddedIdp(am.idpManager) {
return nil
}
data, err := am.getAccountFromCache(ctx, accountID, false)
if err != nil {
return err
@@ -1107,7 +1120,7 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
lowerDomain := strings.ToLower(userAuth.Domain)
newAccount, err := am.newAccount(ctx, userAuth.UserId, lowerDomain)
newAccount, err := am.newAccount(ctx, userAuth.UserId, lowerDomain, userAuth.Email, userAuth.Name)
if err != nil {
return "", err
}
@@ -1132,7 +1145,7 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
}
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth auth.UserAuth) (string, error) {
newUser := types.NewRegularUser(userAuth.UserId)
newUser := types.NewRegularUser(userAuth.UserId, userAuth.Email, userAuth.Name)
newUser.AccountID = domainAccountID
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, domainAccountID)
@@ -1315,6 +1328,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil {
// this is not really possible because we got an account by user ID
log.Errorf("failed to get user by ID %s: %v", userAuth.UserId, err)
return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId)
}
@@ -1512,7 +1526,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
}
if userAuth.DomainCategory != types.PrivateCategory || !isDomainValid(userAuth.Domain) {
return am.GetAccountIDByUserID(ctx, userAuth.UserId, userAuth.Domain)
return am.GetAccountIDByUserID(ctx, userAuth)
}
if userAuth.AccountId != "" {
@@ -1734,7 +1748,7 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
}
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
func newAccountWithId(ctx context.Context, accountID, userID, domain string, disableDefaultPolicy bool) *types.Account {
func newAccountWithId(ctx context.Context, accountID, userID, domain, email, name string, disableDefaultPolicy bool) *types.Account {
log.WithContext(ctx).Debugf("creating new account")
network := types.NewNetwork()
@@ -1744,7 +1758,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis
setupKeys := map[string]*types.SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
owner := types.NewOwnerUser(userID)
owner := types.NewOwnerUser(userID, email, name)
owner.AccountID = accountID
users[userID] = owner

View File

@@ -24,7 +24,7 @@ import (
type ExternalCacheManager nbcache.UserDataCache
type Manager interface {
GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*types.Account, error)
GetOrCreateAccountByUser(ctx context.Context, userAuth auth.UserAuth) (*types.Account, error)
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType, expiresIn time.Duration,
autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error)
@@ -44,7 +44,7 @@ type Manager interface {
GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error)
GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error)
AccountExists(ctx context.Context, accountID string) (bool, error)
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
GetAccountIDByUserID(ctx context.Context, userAuth auth.UserAuth) (string, error)
GetAccountIDFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (string, string, error)
DeleteAccount(ctx context.Context, accountID, userID string) error
GetUserByID(ctx context.Context, id string) (*types.User, error)
@@ -124,4 +124,9 @@ type Manager interface {
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error)
GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error)
GetIdentityProvider(ctx context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error)
GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error)
CreateIdentityProvider(ctx context.Context, accountID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error)
UpdateIdentityProvider(ctx context.Context, accountID, idpID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error)
DeleteIdentityProvider(ctx context.Context, accountID, idpID, userID string) error
}

View File

@@ -382,7 +382,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
}
for _, testCase := range tt {
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io", false)
account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io", "", "", false)
account.UpdateSettings(&testCase.accountSettings)
account.Network = network
account.Peers = testCase.peers
@@ -407,7 +407,7 @@ func TestNewAccount(t *testing.T) {
domain := "netbird.io"
userId := "account_creator"
accountID := "account_id"
account := newAccountWithId(context.Background(), accountID, userId, domain, false)
account := newAccountWithId(context.Background(), accountID, userId, domain, "", "", false)
verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId})
}
@@ -418,7 +418,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
return
}
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID, Domain: ""})
if err != nil {
t.Fatal(err)
}
@@ -612,7 +612,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain)
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain})
require.NoError(t, err, "create init user failed")
initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
@@ -649,10 +649,10 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
userId := "user-id"
domain := "test.domain"
_ = newAccountWithId(context.Background(), "", userId, domain, false)
_ = newAccountWithId(context.Background(), "", userId, domain, "", "", false)
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain)
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userId, Domain: domain})
require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization
// that happens inside the GetAccountIDByUserID where the id is getting generated
@@ -718,7 +718,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) {
}
userId := "test_user"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, "")
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userId, Domain: ""})
if err != nil {
t.Fatal(err)
}
@@ -745,7 +745,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
userId := "test_user"
domain := "hotmail.com"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userId, Domain: domain})
if err != nil {
t.Fatal(err)
}
@@ -759,7 +759,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) {
domain = "gmail.com"
account, err = manager.GetOrCreateAccountByUser(context.Background(), userId, domain)
account, err = manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userId, Domain: domain})
if err != nil {
t.Fatalf("got the following error while retrieving existing acc: %v", err)
}
@@ -782,7 +782,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
userId := "test_user"
accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, "")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userId, Domain: ""})
if err != nil {
t.Fatal(err)
}
@@ -795,14 +795,14 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
assert.NoError(t, err)
assert.True(t, exists, "expected to get existing account after creation using userid")
_, err = manager.GetAccountIDByUserID(context.Background(), "", "")
_, err = manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: "", Domain: ""})
if err == nil {
t.Errorf("expected an error when user ID is empty")
}
}
func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*types.Account, error) {
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
account := newAccountWithId(context.Background(), accountID, userID, domain, "", "", false)
err := am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, err
@@ -1098,7 +1098,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
return
}
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "netbird.cloud")
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID, Domain: "netbird.cloud"})
if err != nil {
t.Fatal(err)
}
@@ -1849,7 +1849,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
@@ -1864,7 +1864,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
_, err = manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1876,7 +1876,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
}, false)
require.NoError(t, err, "unable to add peer")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
@@ -1920,7 +1920,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1946,7 +1946,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
},
}
accountID, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
accountID, err = manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
@@ -1963,7 +1963,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
_, err = manager.GetAccountIDByUserID(context.Background(), userID, "")
_, err = manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1975,7 +1975,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
}, false)
require.NoError(t, err, "unable to add peer")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to get the account")
account, err := manager.Store.GetAccount(context.Background(), accountID)
@@ -2025,7 +2025,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
@@ -3434,7 +3434,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
assert.True(t, cold)
})
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err)
t.Run("should return true when account is not found in cache", func(t *testing.T) {
@@ -3462,7 +3462,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
initiatorId := "test-user"
domain := "example.com"
account, err := manager.GetOrCreateAccountByUser(ctx, initiatorId, domain)
account, err := manager.GetOrCreateAccountByUser(ctx, auth.UserAuth{UserId: initiatorId, Domain: domain})
require.NoError(t, err)
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
@@ -3575,7 +3575,7 @@ func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err)
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err)
t.Run("should return account onboarding when onboarding exist", func(t *testing.T) {
@@ -3607,7 +3607,7 @@ func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err)
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err)
onboarding := &types.AccountOnboarding{
@@ -3646,7 +3646,7 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
key1, err := wgtypes.GenerateKey()
@@ -3717,7 +3717,7 @@ func TestAddNewUserToDomainAccountWithApproval(t *testing.T) {
// Create a domain-based account with user approval enabled
existingAccountID := "existing-account"
account := newAccountWithId(context.Background(), existingAccountID, "owner-user", "example.com", false)
account := newAccountWithId(context.Background(), existingAccountID, "owner-user", "example.com", "", "", false)
account.Settings.Extra = &types.ExtraSettings{
UserApprovalRequired: true,
}

View File

@@ -183,6 +183,10 @@ const (
AccountAutoUpdateVersionUpdated Activity = 92
IdentityProviderCreated Activity = 93
IdentityProviderUpdated Activity = 94
IdentityProviderDeleted Activity = 95
AccountDeleted Activity = 99999
)
@@ -295,6 +299,10 @@ var activityMap = map[Activity]Code{
UserCreated: {"User created", "user.create"},
AccountAutoUpdateVersionUpdated: {"Account AutoUpdate Version updated", "account.settings.auto.version.update"},
IdentityProviderCreated: {"Identity provider created", "identityprovider.create"},
IdentityProviderUpdated: {"Identity provider updated", "identityprovider.update"},
IdentityProviderDeleted: {"Identity provider deleted", "identityprovider.delete"},
}
// StringCode returns a string code of the activity

View File

@@ -49,8 +49,7 @@ func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim s
)
return &manager{
store: store,
store: store,
validator: jwtValidator,
extractor: claimsExtractor,
}

View File

@@ -277,7 +277,7 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
domain := "example.com"
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain, false)
account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain, "", "", false)
account.Users[dnsRegularUserID] = &types.User{
Id: dnsRegularUserID,

View File

@@ -379,7 +379,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
Id: "example user",
AutoGroups: []string{groupForUsers.ID},
}
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false)
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, "", "", false)
account.Routes[routeResource.ID] = routeResource
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/gorilla/mux"
idpmanager "github.com/netbirdio/netbird/management/server/idp"
"github.com/rs/cors"
log "github.com/sirupsen/logrus"
@@ -29,6 +30,8 @@ import (
"github.com/netbirdio/netbird/management/server/http/handlers/dns"
"github.com/netbirdio/netbird/management/server/http/handlers/events"
"github.com/netbirdio/netbird/management/server/http/handlers/groups"
"github.com/netbirdio/netbird/management/server/http/handlers/idp"
"github.com/netbirdio/netbird/management/server/http/handlers/instance"
"github.com/netbirdio/netbird/management/server/http/handlers/networks"
"github.com/netbirdio/netbird/management/server/http/handlers/peers"
"github.com/netbirdio/netbird/management/server/http/handlers/policies"
@@ -36,6 +39,8 @@ import (
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
"github.com/netbirdio/netbird/management/server/http/handlers/users"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
nbinstance "github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
@@ -51,23 +56,15 @@ const (
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(
ctx context.Context,
accountManager account.Manager,
networksManager nbnetworks.Manager,
resourceManager resources.Manager,
routerManager routers.Manager,
groupsManager nbgroups.Manager,
LocationManager geolocation.Geolocation,
authManager auth.Manager,
appMetrics telemetry.AppMetrics,
integratedValidator integrated_validator.IntegratedValidator,
proxyController port_forwarding.Controller,
permissionsManager permissions.Manager,
peersManager nbpeers.Manager,
settingsManager settings.Manager,
networkMapController network_map.Controller,
) (http.Handler, error) {
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil {
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
if err := bypass.AddBypassPath("/api/setup"); err != nil {
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
var rateLimitingConfig *middleware.RateLimiterConfig
if os.Getenv(rateLimitingEnabledKey) == "true" {
@@ -122,7 +119,14 @@ func NewAPIHandler(
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}
accounts.AddEndpoints(accountManager, settingsManager, router)
// Check if embedded IdP is enabled
embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager)
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP)
if err != nil {
return nil, fmt.Errorf("failed to create instance manager: %w", err)
}
accounts.AddEndpoints(accountManager, settingsManager, embeddedIdpEnabled, router)
peers.AddEndpoints(accountManager, router, networkMapController)
users.AddEndpoints(accountManager, router)
setup_keys.AddEndpoints(accountManager, router)
@@ -134,6 +138,13 @@ func NewAPIHandler(
dns.AddEndpoints(accountManager, router)
events.AddEndpoints(accountManager, router)
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
idp.AddEndpoints(accountManager, router)
instance.AddEndpoints(instanceManager, router)
// Mount embedded IdP handler at /oauth2 path if configured
if embeddedIdpEnabled {
rootRouter.PathPrefix("/oauth2").Handler(corsMiddleware.Handler(embeddedIdP.Handler()))
}
return rootRouter, nil
}

View File

@@ -36,22 +36,24 @@ const (
// handler is a handler that handles the server.Account HTTP endpoints
type handler struct {
accountManager account.Manager
settingsManager settings.Manager
accountManager account.Manager
settingsManager settings.Manager
embeddedIdpEnabled bool
}
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) {
accountsHandler := newHandler(accountManager, settingsManager)
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool, router *mux.Router) {
accountsHandler := newHandler(accountManager, settingsManager, embeddedIdpEnabled)
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
}
// newHandler creates a new handler HTTP handler
func newHandler(accountManager account.Manager, settingsManager settings.Manager) *handler {
func newHandler(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool) *handler {
return &handler{
accountManager: accountManager,
settingsManager: settingsManager,
accountManager: accountManager,
settingsManager: settingsManager,
embeddedIdpEnabled: embeddedIdpEnabled,
}
}
@@ -163,7 +165,7 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
return
}
resp := toAccountResponse(accountID, settings, meta, onboarding)
resp := toAccountResponse(accountID, settings, meta, onboarding, h.embeddedIdpEnabled)
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
}
@@ -290,7 +292,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
return
}
resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding)
resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding, h.embeddedIdpEnabled)
util.WriteJSONObject(r.Context(), w, &resp)
}
@@ -319,7 +321,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding) *api.Account {
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding, embeddedIdpEnabled bool) *api.Account {
jwtAllowGroups := settings.JWTAllowGroups
if jwtAllowGroups == nil {
jwtAllowGroups = []string{}
@@ -339,6 +341,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
DnsDomain: &settings.DNSDomain,
AutoUpdateVersion: &settings.AutoUpdateVersion,
EmbeddedIdpEnabled: &embeddedIdpEnabled,
}
if settings.NetworkRange.IsValid() {

View File

@@ -33,6 +33,7 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
AnyTimes()
return &handler{
embeddedIdpEnabled: false,
accountManager: &mock_server.MockAccountManager{
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
return account.Settings, nil
@@ -122,6 +123,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
},
expectedArray: true,
expectedID: accountID,
@@ -145,6 +147,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
},
expectedArray: false,
expectedID: accountID,
@@ -168,6 +171,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr("latest"),
EmbeddedIdpEnabled: br(false),
},
expectedArray: false,
expectedID: accountID,
@@ -191,6 +195,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
},
expectedArray: false,
expectedID: accountID,
@@ -214,6 +219,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
},
expectedArray: false,
expectedID: accountID,
@@ -237,6 +243,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
},
expectedArray: false,
expectedID: accountID,

View File

@@ -0,0 +1,196 @@
package idp
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
// handler handles identity provider HTTP endpoints
type handler struct {
accountManager account.Manager
}
// AddEndpoints registers identity provider endpoints
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
h := newHandler(accountManager)
router.HandleFunc("/identity-providers", h.getAllIdentityProviders).Methods("GET", "OPTIONS")
router.HandleFunc("/identity-providers", h.createIdentityProvider).Methods("POST", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE", "OPTIONS")
}
func newHandler(accountManager account.Manager) *handler {
return &handler{
accountManager: accountManager,
}
}
// getAllIdentityProviders returns all identity providers for the account
func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
providers, err := h.accountManager.GetIdentityProviders(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
response := make([]api.IdentityProvider, 0, len(providers))
for _, p := range providers {
response = append(response, toAPIResponse(p))
}
util.WriteJSONObject(r.Context(), w, response)
}
// getIdentityProvider returns a specific identity provider
func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
idpID := vars["idpId"]
if idpID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "identity provider ID is required"), w)
return
}
provider, err := h.accountManager.GetIdentityProvider(r.Context(), accountID, idpID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toAPIResponse(provider))
}
// createIdentityProvider creates a new identity provider
func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
var req api.IdentityProviderRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
idp := fromAPIRequest(&req)
created, err := h.accountManager.CreateIdentityProvider(r.Context(), accountID, userID, idp)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toAPIResponse(created))
}
// updateIdentityProvider updates an existing identity provider
func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
idpID := vars["idpId"]
if idpID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "identity provider ID is required"), w)
return
}
var req api.IdentityProviderRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
idp := fromAPIRequest(&req)
updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), accountID, idpID, userID, idp)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toAPIResponse(updated))
}
// deleteIdentityProvider deletes an identity provider
func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
idpID := vars["idpId"]
if idpID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "identity provider ID is required"), w)
return
}
if err := h.accountManager.DeleteIdentityProvider(r.Context(), accountID, idpID, userID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func toAPIResponse(idp *types.IdentityProvider) api.IdentityProvider {
resp := api.IdentityProvider{
Type: api.IdentityProviderType(idp.Type),
Name: idp.Name,
Issuer: idp.Issuer,
ClientId: idp.ClientID,
}
if idp.ID != "" {
resp.Id = &idp.ID
}
// Note: ClientSecret is never returned in responses for security
return resp
}
func fromAPIRequest(req *api.IdentityProviderRequest) *types.IdentityProvider {
return &types.IdentityProvider{
Type: types.IdentityProviderType(req.Type),
Name: req.Name,
Issuer: req.Issuer,
ClientID: req.ClientId,
ClientSecret: req.ClientSecret,
}
}

View File

@@ -0,0 +1,438 @@
package idp
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
testAccountID = "test-account-id"
testUserID = "test-user-id"
existingIDPID = "existing-idp-id"
newIDPID = "new-idp-id"
)
func initIDPTestData(existingIDP *types.IdentityProvider) *handler {
return &handler{
accountManager: &mock_server.MockAccountManager{
GetIdentityProvidersFunc: func(_ context.Context, accountID, userID string) ([]*types.IdentityProvider, error) {
if accountID != testAccountID {
return nil, status.Errorf(status.NotFound, "account not found")
}
if existingIDP != nil {
return []*types.IdentityProvider{existingIDP}, nil
}
return []*types.IdentityProvider{}, nil
},
GetIdentityProviderFunc: func(_ context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error) {
if accountID != testAccountID {
return nil, status.Errorf(status.NotFound, "account not found")
}
if existingIDP != nil && idpID == existingIDP.ID {
return existingIDP, nil
}
return nil, status.Errorf(status.NotFound, "identity provider not found")
},
CreateIdentityProviderFunc: func(_ context.Context, accountID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error) {
if accountID != testAccountID {
return nil, status.Errorf(status.NotFound, "account not found")
}
if idp.Name == "" {
return nil, status.Errorf(status.InvalidArgument, "name is required")
}
created := idp.Copy()
created.ID = newIDPID
created.AccountID = accountID
return created, nil
},
UpdateIdentityProviderFunc: func(_ context.Context, accountID, idpID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error) {
if accountID != testAccountID {
return nil, status.Errorf(status.NotFound, "account not found")
}
if existingIDP == nil || idpID != existingIDP.ID {
return nil, status.Errorf(status.NotFound, "identity provider not found")
}
updated := idp.Copy()
updated.ID = idpID
updated.AccountID = accountID
return updated, nil
},
DeleteIdentityProviderFunc: func(_ context.Context, accountID, idpID, userID string) error {
if accountID != testAccountID {
return status.Errorf(status.NotFound, "account not found")
}
if existingIDP == nil || idpID != existingIDP.ID {
return status.Errorf(status.NotFound, "identity provider not found")
}
return nil
},
},
}
}
func TestGetAllIdentityProviders(t *testing.T) {
existingIDP := &types.IdentityProvider{
ID: existingIDPID,
Name: "Test IDP",
Type: types.IdentityProviderTypeOIDC,
Issuer: "https://issuer.example.com",
ClientID: "client-id",
}
tt := []struct {
name string
expectedStatus int
expectedCount int
}{
{
name: "Get All Identity Providers",
expectedStatus: http.StatusOK,
expectedCount: 1,
},
}
h := initIDPTestData(existingIDP)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/identity-providers", nil)
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers", h.getAllIdentityProviders).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
assert.Equal(t, tc.expectedStatus, recorder.Code)
content, err := io.ReadAll(res.Body)
require.NoError(t, err)
var idps []api.IdentityProvider
err = json.Unmarshal(content, &idps)
require.NoError(t, err)
assert.Len(t, idps, tc.expectedCount)
})
}
}
func TestGetIdentityProvider(t *testing.T) {
existingIDP := &types.IdentityProvider{
ID: existingIDPID,
Name: "Test IDP",
Type: types.IdentityProviderTypeOIDC,
Issuer: "https://issuer.example.com",
ClientID: "client-id",
}
tt := []struct {
name string
idpID string
expectedStatus int
expectedBody bool
}{
{
name: "Get Existing Identity Provider",
idpID: existingIDPID,
expectedStatus: http.StatusOK,
expectedBody: true,
},
{
name: "Get Non-Existing Identity Provider",
idpID: "non-existing-id",
expectedStatus: http.StatusNotFound,
expectedBody: false,
},
}
h := initIDPTestData(existingIDP)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/identity-providers/%s", tc.idpID), nil)
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
assert.Equal(t, tc.expectedStatus, recorder.Code)
if tc.expectedBody {
content, err := io.ReadAll(res.Body)
require.NoError(t, err)
var idp api.IdentityProvider
err = json.Unmarshal(content, &idp)
require.NoError(t, err)
assert.Equal(t, existingIDPID, *idp.Id)
assert.Equal(t, existingIDP.Name, idp.Name)
}
})
}
}
func TestCreateIdentityProvider(t *testing.T) {
tt := []struct {
name string
requestBody string
expectedStatus int
expectedBody bool
}{
{
name: "Create Identity Provider",
requestBody: `{
"name": "New IDP",
"type": "oidc",
"issuer": "https://new-issuer.example.com",
"client_id": "new-client-id",
"client_secret": "new-client-secret"
}`,
expectedStatus: http.StatusOK,
expectedBody: true,
},
{
name: "Create Identity Provider with Invalid JSON",
requestBody: `{invalid json`,
expectedStatus: http.StatusBadRequest,
expectedBody: false,
},
}
h := initIDPTestData(nil)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/identity-providers", bytes.NewBufferString(tc.requestBody))
req.Header.Set("Content-Type", "application/json")
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers", h.createIdentityProvider).Methods("POST")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
assert.Equal(t, tc.expectedStatus, recorder.Code)
if tc.expectedBody {
content, err := io.ReadAll(res.Body)
require.NoError(t, err)
var idp api.IdentityProvider
err = json.Unmarshal(content, &idp)
require.NoError(t, err)
assert.Equal(t, newIDPID, *idp.Id)
assert.Equal(t, "New IDP", idp.Name)
assert.Equal(t, api.IdentityProviderTypeOidc, idp.Type)
}
})
}
}
func TestUpdateIdentityProvider(t *testing.T) {
existingIDP := &types.IdentityProvider{
ID: existingIDPID,
Name: "Test IDP",
Type: types.IdentityProviderTypeOIDC,
Issuer: "https://issuer.example.com",
ClientID: "client-id",
ClientSecret: "client-secret",
}
tt := []struct {
name string
idpID string
requestBody string
expectedStatus int
expectedBody bool
}{
{
name: "Update Existing Identity Provider",
idpID: existingIDPID,
requestBody: `{
"name": "Updated IDP",
"type": "oidc",
"issuer": "https://updated-issuer.example.com",
"client_id": "updated-client-id"
}`,
expectedStatus: http.StatusOK,
expectedBody: true,
},
{
name: "Update Non-Existing Identity Provider",
idpID: "non-existing-id",
requestBody: `{
"name": "Updated IDP",
"type": "oidc",
"issuer": "https://updated-issuer.example.com",
"client_id": "updated-client-id"
}`,
expectedStatus: http.StatusNotFound,
expectedBody: false,
},
{
name: "Update Identity Provider with Invalid JSON",
idpID: existingIDPID,
requestBody: `{invalid json`,
expectedStatus: http.StatusBadRequest,
expectedBody: false,
},
}
h := initIDPTestData(existingIDP)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/api/identity-providers/%s", tc.idpID), bytes.NewBufferString(tc.requestBody))
req.Header.Set("Content-Type", "application/json")
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
assert.Equal(t, tc.expectedStatus, recorder.Code)
if tc.expectedBody {
content, err := io.ReadAll(res.Body)
require.NoError(t, err)
var idp api.IdentityProvider
err = json.Unmarshal(content, &idp)
require.NoError(t, err)
assert.Equal(t, existingIDPID, *idp.Id)
assert.Equal(t, "Updated IDP", idp.Name)
}
})
}
}
func TestDeleteIdentityProvider(t *testing.T) {
existingIDP := &types.IdentityProvider{
ID: existingIDPID,
Name: "Test IDP",
Type: types.IdentityProviderTypeOIDC,
Issuer: "https://issuer.example.com",
ClientID: "client-id",
}
tt := []struct {
name string
idpID string
expectedStatus int
}{
{
name: "Delete Existing Identity Provider",
idpID: existingIDPID,
expectedStatus: http.StatusOK,
},
{
name: "Delete Non-Existing Identity Provider",
idpID: "non-existing-id",
expectedStatus: http.StatusNotFound,
},
}
h := initIDPTestData(existingIDP)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/api/identity-providers/%s", tc.idpID), nil)
req = nbcontext.SetUserAuthInRequest(req, auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
assert.Equal(t, tc.expectedStatus, recorder.Code)
})
}
}
func TestToAPIResponse(t *testing.T) {
idp := &types.IdentityProvider{
ID: "test-id",
Name: "Test IDP",
Type: types.IdentityProviderTypeGoogle,
Issuer: "https://accounts.google.com",
ClientID: "client-id",
ClientSecret: "should-not-be-returned",
}
response := toAPIResponse(idp)
assert.Equal(t, "test-id", *response.Id)
assert.Equal(t, "Test IDP", response.Name)
assert.Equal(t, api.IdentityProviderTypeGoogle, response.Type)
assert.Equal(t, "https://accounts.google.com", response.Issuer)
assert.Equal(t, "client-id", response.ClientId)
// Note: ClientSecret is not included in response type by design
}
func TestFromAPIRequest(t *testing.T) {
req := &api.IdentityProviderRequest{
Name: "New IDP",
Type: api.IdentityProviderTypeOkta,
Issuer: "https://dev-123456.okta.com",
ClientId: "okta-client-id",
ClientSecret: "okta-client-secret",
}
idp := fromAPIRequest(req)
assert.Equal(t, "New IDP", idp.Name)
assert.Equal(t, types.IdentityProviderTypeOkta, idp.Type)
assert.Equal(t, "https://dev-123456.okta.com", idp.Issuer)
assert.Equal(t, "okta-client-id", idp.ClientID)
assert.Equal(t, "okta-client-secret", idp.ClientSecret)
}

View File

@@ -0,0 +1,67 @@
package instance
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
nbinstance "github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
// handler handles the instance setup HTTP endpoints
type handler struct {
instanceManager nbinstance.Manager
}
// AddEndpoints registers the instance setup endpoints.
// These endpoints bypass authentication for initial setup.
func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
h := &handler{
instanceManager: instanceManager,
}
router.HandleFunc("/instance", h.getInstanceStatus).Methods("GET", "OPTIONS")
router.HandleFunc("/setup", h.setup).Methods("POST", "OPTIONS")
}
// getInstanceStatus returns the instance status including whether setup is required.
// This endpoint is unauthenticated.
func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) {
setupRequired, err := h.instanceManager.IsSetupRequired(r.Context())
if err != nil {
log.WithContext(r.Context()).Errorf("failed to check setup status: %v", err)
util.WriteErrorResponse("failed to check instance status", http.StatusInternalServerError, w)
return
}
util.WriteJSONObject(r.Context(), w, api.InstanceStatus{
SetupRequired: setupRequired,
})
}
// setup creates the initial admin user for the instance.
// This endpoint is unauthenticated but only works when setup is required.
func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
var req api.SetupRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("invalid request body", http.StatusBadRequest, w)
return
}
userData, err := h.instanceManager.CreateOwnerUser(r.Context(), req.Email, req.Password, req.Name)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
log.WithContext(r.Context()).Infof("instance setup completed: created user %s", req.Email)
util.WriteJSONObject(r.Context(), w, api.SetupResponse{
UserId: userData.ID,
Email: userData.Email,
})
}

View File

@@ -0,0 +1,281 @@
package instance
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"net/mail"
"testing"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/idp"
nbinstance "github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
// mockInstanceManager implements instance.Manager for testing
type mockInstanceManager struct {
isSetupRequired bool
isSetupRequiredFn func(ctx context.Context) (bool, error)
createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error)
}
func (m *mockInstanceManager) IsSetupRequired(ctx context.Context) (bool, error) {
if m.isSetupRequiredFn != nil {
return m.isSetupRequiredFn(ctx)
}
return m.isSetupRequired, nil
}
func (m *mockInstanceManager) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) {
if m.createOwnerUserFn != nil {
return m.createOwnerUserFn(ctx, email, password, name)
}
// Default mock includes validation like the real manager
if !m.isSetupRequired {
return nil, status.Errorf(status.PreconditionFailed, "setup already completed")
}
if email == "" {
return nil, status.Errorf(status.InvalidArgument, "email is required")
}
if _, err := mail.ParseAddress(email); err != nil {
return nil, status.Errorf(status.InvalidArgument, "invalid email format")
}
if name == "" {
return nil, status.Errorf(status.InvalidArgument, "name is required")
}
if password == "" {
return nil, status.Errorf(status.InvalidArgument, "password is required")
}
if len(password) < 8 {
return nil, status.Errorf(status.InvalidArgument, "password must be at least 8 characters")
}
return &idp.UserData{
ID: "test-user-id",
Email: email,
Name: name,
}, nil
}
var _ nbinstance.Manager = (*mockInstanceManager)(nil)
func setupTestRouter(manager nbinstance.Manager) *mux.Router {
router := mux.NewRouter()
AddEndpoints(manager, router)
return router
}
func TestGetInstanceStatus_SetupRequired(t *testing.T) {
manager := &mockInstanceManager{isSetupRequired: true}
router := setupTestRouter(manager)
req := httptest.NewRequest(http.MethodGet, "/instance", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response api.InstanceStatus
err := json.NewDecoder(rec.Body).Decode(&response)
require.NoError(t, err)
assert.True(t, response.SetupRequired)
}
func TestGetInstanceStatus_SetupNotRequired(t *testing.T) {
manager := &mockInstanceManager{isSetupRequired: false}
router := setupTestRouter(manager)
req := httptest.NewRequest(http.MethodGet, "/instance", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response api.InstanceStatus
err := json.NewDecoder(rec.Body).Decode(&response)
require.NoError(t, err)
assert.False(t, response.SetupRequired)
}
func TestGetInstanceStatus_Error(t *testing.T) {
manager := &mockInstanceManager{
isSetupRequiredFn: func(ctx context.Context) (bool, error) {
return false, errors.New("database error")
},
}
router := setupTestRouter(manager)
req := httptest.NewRequest(http.MethodGet, "/instance", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestSetup_Success(t *testing.T) {
manager := &mockInstanceManager{
isSetupRequired: true,
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
assert.Equal(t, "admin@example.com", email)
assert.Equal(t, "securepassword123", password)
assert.Equal(t, "Admin User", name)
return &idp.UserData{
ID: "created-user-id",
Email: email,
Name: name,
}, nil
},
}
router := setupTestRouter(manager)
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin User"}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response api.SetupResponse
err := json.NewDecoder(rec.Body).Decode(&response)
require.NoError(t, err)
assert.Equal(t, "created-user-id", response.UserId)
assert.Equal(t, "admin@example.com", response.Email)
}
func TestSetup_AlreadyCompleted(t *testing.T) {
manager := &mockInstanceManager{isSetupRequired: false}
router := setupTestRouter(manager)
body := `{"email": "admin@example.com", "password": "securepassword123"}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusPreconditionFailed, rec.Code)
}
func TestSetup_MissingEmail(t *testing.T) {
manager := &mockInstanceManager{isSetupRequired: true}
router := setupTestRouter(manager)
body := `{"password": "securepassword123"}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code)
}
func TestSetup_InvalidEmail(t *testing.T) {
manager := &mockInstanceManager{isSetupRequired: true}
router := setupTestRouter(manager)
body := `{"email": "not-an-email", "password": "securepassword123", "name": "User"}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
// Note: Invalid email format uses mail.ParseAddress which is treated differently
// and returns 400 Bad Request instead of 422 Unprocessable Entity
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code)
}
func TestSetup_MissingPassword(t *testing.T) {
manager := &mockInstanceManager{isSetupRequired: true}
router := setupTestRouter(manager)
body := `{"email": "admin@example.com", "name": "User"}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code)
}
func TestSetup_PasswordTooShort(t *testing.T) {
manager := &mockInstanceManager{isSetupRequired: true}
router := setupTestRouter(manager)
body := `{"email": "admin@example.com", "password": "short", "name": "User"}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code)
}
func TestSetup_InvalidJSON(t *testing.T) {
manager := &mockInstanceManager{isSetupRequired: true}
router := setupTestRouter(manager)
body := `{invalid json}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestSetup_CreateUserError(t *testing.T) {
manager := &mockInstanceManager{
isSetupRequired: true,
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
return nil, errors.New("user creation failed")
},
}
router := setupTestRouter(manager)
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "User"}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestSetup_ManagerError(t *testing.T) {
manager := &mockInstanceManager{
isSetupRequired: true,
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
return nil, status.Errorf(status.Internal, "database error")
},
}
router := setupTestRouter(manager)
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "User"}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}

View File

@@ -66,7 +66,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
},
}
srvUser := types.NewRegularUser(serviceUser)
srvUser := types.NewRegularUser(serviceUser, "", "")
srvUser.IsServiceUser = true
account := &types.Account{
@@ -75,7 +75,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
Peers: peersMap,
Users: map[string]*types.User{
adminUser: types.NewAdminUser(adminUser),
regularUser: types.NewRegularUser(regularUser),
regularUser: types.NewRegularUser(regularUser, "", ""),
serviceUser: srvUser,
},
Groups: map[string]*types.Group{

View File

@@ -326,6 +326,16 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
isCurrent := user.ID == currenUserID
var password *string
if user.Password != "" {
password = &user.Password
}
var idpID *string
if user.IdPID != "" {
idpID = &user.IdPID
}
return &api.User{
Id: user.ID,
Name: user.Name,
@@ -339,6 +349,8 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
LastLogin: &user.LastLogin,
Issued: &user.Issued,
PendingApproval: user.PendingApproval,
Password: password,
IdpId: idpID,
}
}

View File

@@ -134,6 +134,9 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
userAuth.IsChild = ok
}
// Email is now extracted in ToUserAuth (from claims or userinfo endpoint)
// Available as userAuth.Email
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
accountId, _, err := m.ensureAccount(ctx, userAuth)
if err != nil {

View File

@@ -94,7 +94,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
groupsManagerMock := groups.NewManagerMock()
peersManager := peers.NewManager(store, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}

View File

@@ -0,0 +1,234 @@
package server
import (
"context"
"errors"
"github.com/dexidp/dex/storage"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
// GetIdentityProviders returns all identity providers for an account
func (am *DefaultAccountManager) GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error) {
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
if !ok {
log.Warn("identity provider management requires embedded IdP")
return []*types.IdentityProvider{}, nil
}
connectors, err := embeddedManager.ListConnectors(ctx)
if err != nil {
return nil, status.Errorf(status.Internal, "failed to list identity providers: %v", err)
}
result := make([]*types.IdentityProvider, 0, len(connectors))
for _, conn := range connectors {
result = append(result, connectorConfigToIdentityProvider(conn, accountID))
}
return result, nil
}
// GetIdentityProvider returns a specific identity provider by ID
func (am *DefaultAccountManager) GetIdentityProvider(ctx context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error) {
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
if !ok {
return nil, status.Errorf(status.Internal, "identity provider management requires embedded IdP")
}
conn, err := embeddedManager.GetConnector(ctx, idpID)
if err != nil {
if errors.Is(err, storage.ErrNotFound) {
return nil, status.Errorf(status.NotFound, "identity provider not found")
}
return nil, status.Errorf(status.Internal, "failed to get identity provider: %v", err)
}
return connectorConfigToIdentityProvider(conn, accountID), nil
}
// CreateIdentityProvider creates a new identity provider
func (am *DefaultAccountManager) CreateIdentityProvider(ctx context.Context, accountID, userID string, idpConfig *types.IdentityProvider) (*types.IdentityProvider, error) {
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err := idpConfig.Validate(); err != nil {
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error())
}
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
if !ok {
return nil, status.Errorf(status.Internal, "identity provider management requires embedded IdP")
}
// Generate ID if not provided
if idpConfig.ID == "" {
idpConfig.ID = generateIdentityProviderID(idpConfig.Type)
}
idpConfig.AccountID = accountID
connCfg := identityProviderToConnectorConfig(idpConfig)
_, err = embeddedManager.CreateConnector(ctx, connCfg)
if err != nil {
return nil, status.Errorf(status.Internal, "failed to create identity provider: %v", err)
}
am.StoreEvent(ctx, userID, idpConfig.ID, accountID, activity.IdentityProviderCreated, idpConfig.EventMeta())
return idpConfig, nil
}
// UpdateIdentityProvider updates an existing identity provider
func (am *DefaultAccountManager) UpdateIdentityProvider(ctx context.Context, accountID, idpID, userID string, idpConfig *types.IdentityProvider) (*types.IdentityProvider, error) {
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err := idpConfig.Validate(); err != nil {
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error())
}
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
if !ok {
return nil, status.Errorf(status.Internal, "identity provider management requires embedded IdP")
}
idpConfig.ID = idpID
idpConfig.AccountID = accountID
connCfg := identityProviderToConnectorConfig(idpConfig)
if err := embeddedManager.UpdateConnector(ctx, connCfg); err != nil {
return nil, status.Errorf(status.Internal, "failed to update identity provider: %v", err)
}
am.StoreEvent(ctx, userID, idpConfig.ID, accountID, activity.IdentityProviderUpdated, idpConfig.EventMeta())
return idpConfig, nil
}
// DeleteIdentityProvider deletes an identity provider
func (am *DefaultAccountManager) DeleteIdentityProvider(ctx context.Context, accountID, idpID, userID string) error {
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
if !ok {
return status.Errorf(status.Internal, "identity provider management requires embedded IdP")
}
// Get the IDP info before deleting for the activity event
conn, err := embeddedManager.GetConnector(ctx, idpID)
if err != nil {
if errors.Is(err, storage.ErrNotFound) {
return status.Errorf(status.NotFound, "identity provider not found")
}
return status.Errorf(status.Internal, "failed to get identity provider: %v", err)
}
idpConfig := connectorConfigToIdentityProvider(conn, accountID)
if err := embeddedManager.DeleteConnector(ctx, idpID); err != nil {
if errors.Is(err, storage.ErrNotFound) {
return status.Errorf(status.NotFound, "identity provider not found")
}
return status.Errorf(status.Internal, "failed to delete identity provider: %v", err)
}
am.StoreEvent(ctx, userID, idpID, accountID, activity.IdentityProviderDeleted, idpConfig.EventMeta())
return nil
}
// connectorConfigToIdentityProvider converts a dex.ConnectorConfig to types.IdentityProvider
func connectorConfigToIdentityProvider(conn *dex.ConnectorConfig, accountID string) *types.IdentityProvider {
return &types.IdentityProvider{
ID: conn.ID,
AccountID: accountID,
Type: types.IdentityProviderType(conn.Type),
Name: conn.Name,
Issuer: conn.Issuer,
ClientID: conn.ClientID,
ClientSecret: conn.ClientSecret,
}
}
// identityProviderToConnectorConfig converts a types.IdentityProvider to dex.ConnectorConfig
func identityProviderToConnectorConfig(idpConfig *types.IdentityProvider) *dex.ConnectorConfig {
return &dex.ConnectorConfig{
ID: idpConfig.ID,
Name: idpConfig.Name,
Type: string(idpConfig.Type),
Issuer: idpConfig.Issuer,
ClientID: idpConfig.ClientID,
ClientSecret: idpConfig.ClientSecret,
}
}
// generateIdentityProviderID generates a unique ID for an identity provider.
// For specific provider types (okta, zitadel, entra, google, pocketid, microsoft),
// the ID is prefixed with the type name. Generic OIDC providers get no prefix.
func generateIdentityProviderID(idpType types.IdentityProviderType) string {
id := xid.New().String()
switch idpType {
case types.IdentityProviderTypeOkta:
return "okta-" + id
case types.IdentityProviderTypeZitadel:
return "zitadel-" + id
case types.IdentityProviderTypeEntra:
return "entra-" + id
case types.IdentityProviderTypeGoogle:
return "google-" + id
case types.IdentityProviderTypePocketID:
return "pocketid-" + id
case types.IdentityProviderTypeMicrosoft:
return "microsoft-" + id
case types.IdentityProviderTypeAuthentik:
return "authentik-" + id
case types.IdentityProviderTypeKeycloak:
return "keycloak-" + id
default:
// Generic OIDC - no prefix
return id
}
}

View File

@@ -0,0 +1,202 @@
package server
import (
"context"
"path/filepath"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
)
func createManagerWithEmbeddedIdP(t testing.TB) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) {
t.Helper()
ctx := context.Background()
dataDir := t.TempDir()
testStore, cleanUp, err := store.NewTestStoreFromSQL(ctx, "", dataDir)
if err != nil {
return nil, nil, err
}
t.Cleanup(cleanUp)
// Create embedded IdP manager
embeddedConfig := &idp.EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
Storage: idp.EmbeddedStorageConfig{
Type: "sqlite3",
Config: idp.EmbeddedStorageTypeConfig{
File: filepath.Join(dataDir, "dex.db"),
},
},
}
idpManager, err := idp.NewEmbeddedIdPManager(ctx, embeddedConfig, nil)
if err != nil {
return nil, nil, err
}
t.Cleanup(func() { _ = idpManager.Stop(ctx) })
eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
if err != nil {
return nil, nil, err
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
settingsMockManager.EXPECT().
UpdateExtraSettings(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(false, nil).
AnyTimes()
permissionsManager := permissions.NewManager(testStore)
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, testStore)
networkMapController := controller.NewController(ctx, testStore, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(testStore, peers.NewManager(testStore, permissionsManager)), &config.Config{})
manager, err := BuildManager(ctx, &config.Config{}, testStore, networkMapController, idpManager, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, nil, err
}
return manager, updateManager, nil
}
func TestDefaultAccountManager_CreateIdentityProvider_Validation(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err)
userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err)
testCases := []struct {
name string
idp *types.IdentityProvider
expectError bool
errorMsg string
}{
{
name: "Missing Name",
idp: &types.IdentityProvider{
Type: types.IdentityProviderTypeOIDC,
Issuer: "https://issuer.example.com",
ClientID: "client-id",
},
expectError: true,
errorMsg: "name is required",
},
{
name: "Missing Type",
idp: &types.IdentityProvider{
Name: "Test IDP",
Issuer: "https://issuer.example.com",
ClientID: "client-id",
},
expectError: true,
errorMsg: "type is required",
},
{
name: "Missing Issuer",
idp: &types.IdentityProvider{
Name: "Test IDP",
Type: types.IdentityProviderTypeOIDC,
ClientID: "client-id",
},
expectError: true,
errorMsg: "issuer is required",
},
{
name: "Missing ClientID",
idp: &types.IdentityProvider{
Name: "Test IDP",
Type: types.IdentityProviderTypeOIDC,
Issuer: "https://issuer.example.com",
},
expectError: true,
errorMsg: "client ID is required",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := manager.CreateIdentityProvider(context.Background(), account.Id, userID, tc.idp)
if tc.expectError {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.errorMsg)
}
})
}
}
func TestDefaultAccountManager_GetIdentityProviders(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err)
userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err)
// Should return empty list (stub implementation)
providers, err := manager.GetIdentityProviders(context.Background(), account.Id, userID)
require.NoError(t, err)
assert.Empty(t, providers)
}
func TestDefaultAccountManager_GetIdentityProvider_NotFound(t *testing.T) {
manager, _, err := createManagerWithEmbeddedIdP(t)
require.NoError(t, err)
userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err)
// Should return not found error when identity provider doesn't exist
_, err = manager.GetIdentityProvider(context.Background(), account.Id, "any-id", userID)
require.Error(t, err)
assert.Contains(t, err.Error(), "not found")
}
func TestDefaultAccountManager_UpdateIdentityProvider_Validation(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err)
userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err)
// Should fail validation before reaching "not implemented" error
invalidIDP := &types.IdentityProvider{
Name: "", // Empty name should fail validation
}
_, err = manager.UpdateIdentityProvider(context.Background(), account.Id, "some-id", userID, invalidIDP)
require.Error(t, err)
assert.Contains(t, err.Error(), "name is required")
}

View File

@@ -0,0 +1,511 @@
package idp
import (
"context"
"errors"
"fmt"
"net/http"
"github.com/dexidp/dex/storage"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/telemetry"
)
const (
staticClientDashboard = "netbird-dashboard"
staticClientCLI = "netbird-cli"
defaultCLIRedirectURL1 = "http://localhost:53000/"
defaultCLIRedirectURL2 = "http://localhost:54000/"
defaultScopes = "openid profile email offline_access"
defaultUserIDClaim = "sub"
)
// EmbeddedIdPConfig contains configuration for the embedded Dex OIDC identity provider
type EmbeddedIdPConfig struct {
// Enabled indicates whether the embedded IDP is enabled
Enabled bool
// Issuer is the OIDC issuer URL (e.g., "http://localhost:3002/oauth2")
Issuer string
// Storage configuration for the IdP database
Storage EmbeddedStorageConfig
// DashboardRedirectURIs are the OAuth2 redirect URIs for the dashboard client
DashboardRedirectURIs []string
// DashboardRedirectURIs are the OAuth2 redirect URIs for the dashboard client
CLIRedirectURIs []string
// Owner is the initial owner/admin user (optional, can be nil)
Owner *OwnerConfig
// SignKeyRefreshEnabled enables automatic key rotation for signing keys
SignKeyRefreshEnabled bool
}
// EmbeddedStorageConfig holds storage configuration for the embedded IdP.
type EmbeddedStorageConfig struct {
// Type is the storage type (currently only "sqlite3" is supported)
Type string
// Config contains type-specific configuration
Config EmbeddedStorageTypeConfig
}
// EmbeddedStorageTypeConfig contains type-specific storage configuration.
type EmbeddedStorageTypeConfig struct {
// File is the path to the SQLite database file (for sqlite3 type)
File string
}
// OwnerConfig represents the initial owner/admin user for the embedded IdP.
type OwnerConfig struct {
// Email is the user's email address (required)
Email string
// Hash is the bcrypt hash of the user's password (required)
Hash string
// Username is the display name for the user (optional, defaults to email)
Username string
}
// ToYAMLConfig converts EmbeddedIdPConfig to dex.YAMLConfig.
func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
if c.Issuer == "" {
return nil, fmt.Errorf("issuer is required")
}
if c.Storage.Type == "" {
c.Storage.Type = "sqlite3"
}
if c.Storage.Type == "sqlite3" && c.Storage.Config.File == "" {
return nil, fmt.Errorf("storage file is required for sqlite3")
}
// Build CLI redirect URIs including the device callback (both relative and absolute)
cliRedirectURIs := c.CLIRedirectURIs
cliRedirectURIs = append(cliRedirectURIs, "/device/callback")
cliRedirectURIs = append(cliRedirectURIs, c.Issuer+"/device/callback")
cfg := &dex.YAMLConfig{
Issuer: c.Issuer,
Storage: dex.Storage{
Type: c.Storage.Type,
Config: map[string]interface{}{
"file": c.Storage.Config.File,
},
},
Web: dex.Web{
AllowedOrigins: []string{"*"},
AllowedHeaders: []string{"Authorization", "Content-Type"},
},
OAuth2: dex.OAuth2{
SkipApprovalScreen: true,
},
Frontend: dex.Frontend{
Issuer: "NetBird",
Theme: "light",
},
EnablePasswordDB: true,
StaticClients: []storage.Client{
{
ID: staticClientDashboard,
Name: "NetBird Dashboard",
Public: true,
RedirectURIs: c.DashboardRedirectURIs,
},
{
ID: staticClientCLI,
Name: "NetBird CLI",
Public: true,
RedirectURIs: cliRedirectURIs,
},
},
}
// Add owner user if provided
if c.Owner != nil && c.Owner.Email != "" && c.Owner.Hash != "" {
username := c.Owner.Username
if username == "" {
username = c.Owner.Email
}
cfg.StaticPasswords = []dex.Password{
{
Email: c.Owner.Email,
Hash: []byte(c.Owner.Hash),
Username: username,
UserID: uuid.New().String(),
},
}
}
return cfg, nil
}
// Compile-time check that EmbeddedIdPManager implements Manager interface
var _ Manager = (*EmbeddedIdPManager)(nil)
// Compile-time check that EmbeddedIdPManager implements OAuthConfigProvider interface
var _ OAuthConfigProvider = (*EmbeddedIdPManager)(nil)
// OAuthConfigProvider defines the interface for OAuth configuration needed by auth flows.
type OAuthConfigProvider interface {
GetIssuer() string
GetKeysLocation() string
GetClientIDs() []string
GetUserIDClaim() string
GetTokenEndpoint() string
GetDeviceAuthEndpoint() string
GetAuthorizationEndpoint() string
GetDefaultScopes() string
GetCLIClientID() string
GetCLIRedirectURLs() []string
}
// EmbeddedIdPManager implements the Manager interface using the embedded Dex IdP.
type EmbeddedIdPManager struct {
provider *dex.Provider
appMetrics telemetry.AppMetrics
config EmbeddedIdPConfig
}
// NewEmbeddedIdPManager creates a new instance of EmbeddedIdPManager from a configuration.
// It instantiates the underlying Dex provider internally.
// Note: Storage defaults are applied in config loading (applyEmbeddedIdPConfig) based on Datadir.
func NewEmbeddedIdPManager(ctx context.Context, config *EmbeddedIdPConfig, appMetrics telemetry.AppMetrics) (*EmbeddedIdPManager, error) {
if config == nil {
return nil, fmt.Errorf("embedded IdP config is required")
}
// Apply defaults for CLI redirect URIs
if len(config.CLIRedirectURIs) == 0 {
config.CLIRedirectURIs = []string{defaultCLIRedirectURL1, defaultCLIRedirectURL2}
}
// there are some properties create when creating YAML config (e.g., auth clients)
yamlConfig, err := config.ToYAMLConfig()
if err != nil {
return nil, err
}
provider, err := dex.NewProviderFromYAML(ctx, yamlConfig)
if err != nil {
return nil, fmt.Errorf("failed to create embedded IdP provider: %w", err)
}
log.WithContext(ctx).Infof("embedded Dex IDP initialized with issuer: %s", yamlConfig.Issuer)
return &EmbeddedIdPManager{
provider: provider,
appMetrics: appMetrics,
config: *config,
}, nil
}
// Handler returns the HTTP handler for serving OIDC requests.
func (m *EmbeddedIdPManager) Handler() http.Handler {
return m.provider.Handler()
}
// Stop gracefully shuts down the embedded IdP provider.
func (m *EmbeddedIdPManager) Stop(ctx context.Context) error {
return m.provider.Stop(ctx)
}
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
func (m *EmbeddedIdPManager) UpdateUserAppMetadata(ctx context.Context, userID string, appMetadata AppMetadata) error {
// TODO: implement
return nil
}
// GetUserDataByID requests user data from the embedded IdP via user ID.
func (m *EmbeddedIdPManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
user, err := m.provider.GetUserByID(ctx, userID)
if err != nil {
if m.appMetrics != nil {
m.appMetrics.IDPMetrics().CountRequestError()
}
return nil, fmt.Errorf("failed to get user by ID: %w", err)
}
return &UserData{
Email: user.Email,
Name: user.Username,
ID: user.UserID,
AppMetadata: appMetadata,
}, nil
}
// GetAccount returns all the users for a given account.
// Note: Embedded dex doesn't store account metadata, so this returns all users.
func (m *EmbeddedIdPManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
users, err := m.provider.ListUsers(ctx)
if err != nil {
if m.appMetrics != nil {
m.appMetrics.IDPMetrics().CountRequestError()
}
return nil, fmt.Errorf("failed to list users: %w", err)
}
result := make([]*UserData, 0, len(users))
for _, user := range users {
result = append(result, &UserData{
Email: user.Email,
Name: user.Username,
ID: user.UserID,
AppMetadata: AppMetadata{
WTAccountID: accountID,
},
})
}
return result, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// Note: Embedded dex doesn't store account metadata, so all users are indexed under UnsetAccountID.
func (m *EmbeddedIdPManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
if m.appMetrics != nil {
m.appMetrics.IDPMetrics().CountGetAllAccounts()
}
users, err := m.provider.ListUsers(ctx)
if err != nil {
if m.appMetrics != nil {
m.appMetrics.IDPMetrics().CountRequestError()
}
return nil, fmt.Errorf("failed to list users: %w", err)
}
indexedUsers := make(map[string][]*UserData)
for _, user := range users {
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], &UserData{
Email: user.Email,
Name: user.Username,
ID: user.UserID,
})
}
return indexedUsers, nil
}
// CreateUser creates a new user in the embedded IdP.
func (m *EmbeddedIdPManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
if m.appMetrics != nil {
m.appMetrics.IDPMetrics().CountCreateUser()
}
// Check if user already exists
_, err := m.provider.GetUser(ctx, email)
if err == nil {
return nil, fmt.Errorf("user with email %s already exists", email)
}
if !errors.Is(err, storage.ErrNotFound) {
if m.appMetrics != nil {
m.appMetrics.IDPMetrics().CountRequestError()
}
return nil, fmt.Errorf("failed to check existing user: %w", err)
}
// Generate a random password for the new user
password := GeneratePassword(16, 2, 2, 2)
// Create the user via provider (handles hashing and ID generation)
// The provider returns an encoded user ID in Dex's format (base64 protobuf with connector ID)
userID, err := m.provider.CreateUser(ctx, email, name, password)
if err != nil {
if m.appMetrics != nil {
m.appMetrics.IDPMetrics().CountRequestError()
}
return nil, fmt.Errorf("failed to create user in embedded IdP: %w", err)
}
log.WithContext(ctx).Debugf("created user %s in embedded IdP", email)
return &UserData{
Email: email,
Name: name,
ID: userID,
Password: password,
AppMetadata: AppMetadata{
WTAccountID: accountID,
WTInvitedBy: invitedByEmail,
},
}, nil
}
// GetUserByEmail searches users with a given email.
func (m *EmbeddedIdPManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
user, err := m.provider.GetUser(ctx, email)
if err != nil {
if errors.Is(err, storage.ErrNotFound) {
return nil, nil // Return empty slice for not found
}
if m.appMetrics != nil {
m.appMetrics.IDPMetrics().CountRequestError()
}
return nil, fmt.Errorf("failed to get user by email: %w", err)
}
return []*UserData{
{
Email: user.Email,
Name: user.Username,
ID: user.UserID,
},
}, nil
}
// CreateUserWithPassword creates a new user in the embedded IdP with a provided password.
// Unlike CreateUser which auto-generates a password, this method uses the provided password.
// This is useful for instance setup where the user provides their own password.
func (m *EmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email, password, name string) (*UserData, error) {
if m.appMetrics != nil {
m.appMetrics.IDPMetrics().CountCreateUser()
}
// Check if user already exists
_, err := m.provider.GetUser(ctx, email)
if err == nil {
return nil, fmt.Errorf("user with email %s already exists", email)
}
if !errors.Is(err, storage.ErrNotFound) {
if m.appMetrics != nil {
m.appMetrics.IDPMetrics().CountRequestError()
}
return nil, fmt.Errorf("failed to check existing user: %w", err)
}
// Create the user via provider with the provided password
userID, err := m.provider.CreateUser(ctx, email, name, password)
if err != nil {
if m.appMetrics != nil {
m.appMetrics.IDPMetrics().CountRequestError()
}
return nil, fmt.Errorf("failed to create user in embedded IdP: %w", err)
}
log.WithContext(ctx).Debugf("created user %s in embedded IdP with provided password", email)
return &UserData{
Email: email,
Name: name,
ID: userID,
}, nil
}
// InviteUserByID resends an invitation to a user.
func (m *EmbeddedIdPManager) InviteUserByID(ctx context.Context, userID string) error {
// TODO: implement
return fmt.Errorf("not implemented")
}
// DeleteUser deletes a user from the embedded IdP by user ID.
func (m *EmbeddedIdPManager) DeleteUser(ctx context.Context, userID string) error {
if m.appMetrics != nil {
m.appMetrics.IDPMetrics().CountDeleteUser()
}
// Get user by ID to retrieve email (provider.DeleteUser requires email)
user, err := m.provider.GetUserByID(ctx, userID)
if err != nil {
if m.appMetrics != nil {
m.appMetrics.IDPMetrics().CountRequestError()
}
return fmt.Errorf("failed to get user for deletion: %w", err)
}
err = m.provider.DeleteUser(ctx, user.Email)
if err != nil {
if m.appMetrics != nil {
m.appMetrics.IDPMetrics().CountRequestError()
}
return fmt.Errorf("failed to delete user from embedded IdP: %w", err)
}
log.WithContext(ctx).Debugf("deleted user %s from embedded IdP", user.Email)
return nil
}
// CreateConnector creates a new identity provider connector in Dex.
// Returns the created connector config with the redirect URL populated.
func (m *EmbeddedIdPManager) CreateConnector(ctx context.Context, cfg *dex.ConnectorConfig) (*dex.ConnectorConfig, error) {
return m.provider.CreateConnector(ctx, cfg)
}
// GetConnector retrieves an identity provider connector by ID.
func (m *EmbeddedIdPManager) GetConnector(ctx context.Context, id string) (*dex.ConnectorConfig, error) {
return m.provider.GetConnector(ctx, id)
}
// ListConnectors returns all identity provider connectors.
func (m *EmbeddedIdPManager) ListConnectors(ctx context.Context) ([]*dex.ConnectorConfig, error) {
return m.provider.ListConnectors(ctx)
}
// UpdateConnector updates an existing identity provider connector.
func (m *EmbeddedIdPManager) UpdateConnector(ctx context.Context, cfg *dex.ConnectorConfig) error {
// Preserve existing secret if not provided in update
if cfg.ClientSecret == "" {
existing, err := m.provider.GetConnector(ctx, cfg.ID)
if err != nil {
return fmt.Errorf("failed to get existing connector: %w", err)
}
cfg.ClientSecret = existing.ClientSecret
}
return m.provider.UpdateConnector(ctx, cfg)
}
// DeleteConnector removes an identity provider connector.
func (m *EmbeddedIdPManager) DeleteConnector(ctx context.Context, id string) error {
return m.provider.DeleteConnector(ctx, id)
}
// GetIssuer returns the OIDC issuer URL.
func (m *EmbeddedIdPManager) GetIssuer() string {
return m.provider.GetIssuer()
}
// GetTokenEndpoint returns the OAuth2 token endpoint URL.
func (m *EmbeddedIdPManager) GetTokenEndpoint() string {
return m.provider.GetTokenEndpoint()
}
// GetDeviceAuthEndpoint returns the OAuth2 device authorization endpoint URL.
func (m *EmbeddedIdPManager) GetDeviceAuthEndpoint() string {
return m.provider.GetDeviceAuthEndpoint()
}
// GetAuthorizationEndpoint returns the OAuth2 authorization endpoint URL.
func (m *EmbeddedIdPManager) GetAuthorizationEndpoint() string {
return m.provider.GetAuthorizationEndpoint()
}
// GetDefaultScopes returns the default OAuth2 scopes for authentication.
func (m *EmbeddedIdPManager) GetDefaultScopes() string {
return defaultScopes
}
// GetCLIClientID returns the client ID for CLI authentication.
func (m *EmbeddedIdPManager) GetCLIClientID() string {
return staticClientCLI
}
// GetCLIRedirectURLs returns the redirect URLs configured for the CLI client.
func (m *EmbeddedIdPManager) GetCLIRedirectURLs() []string {
if len(m.config.CLIRedirectURIs) == 0 {
return []string{defaultCLIRedirectURL1, defaultCLIRedirectURL2}
}
return m.config.CLIRedirectURIs
}
// GetKeysLocation returns the JWKS endpoint URL for token validation.
func (m *EmbeddedIdPManager) GetKeysLocation() string {
return m.provider.GetKeysLocation()
}
// GetClientIDs returns the OAuth2 client IDs configured for this provider.
func (m *EmbeddedIdPManager) GetClientIDs() []string {
return []string{staticClientDashboard, staticClientCLI}
}
// GetUserIDClaim returns the JWT claim name used for user identification.
func (m *EmbeddedIdPManager) GetUserIDClaim() string {
return defaultUserIDClaim
}

View File

@@ -0,0 +1,249 @@
package idp
import (
"context"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/idp/dex"
)
func TestEmbeddedIdPManager_CreateUser_EndToEnd(t *testing.T) {
ctx := context.Background()
// Create a temporary directory for the test
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
// Create the embedded IDP config
config := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: filepath.Join(tmpDir, "dex.db"),
},
},
}
// Create the embedded IDP manager
manager, err := NewEmbeddedIdPManager(ctx, config, nil)
require.NoError(t, err)
defer func() { _ = manager.Stop(ctx) }()
// Test data
email := "newuser@example.com"
name := "New User"
accountID := "test-account-id"
invitedByEmail := "admin@example.com"
// Create the user
userData, err := manager.CreateUser(ctx, email, name, accountID, invitedByEmail)
require.NoError(t, err)
require.NotNil(t, userData)
t.Logf("Created user: ID=%s, Email=%s, Name=%s, Password=%s",
userData.ID, userData.Email, userData.Name, userData.Password)
// Verify user data
assert.Equal(t, email, userData.Email)
assert.Equal(t, name, userData.Name)
assert.NotEmpty(t, userData.ID)
assert.NotEmpty(t, userData.Password)
assert.Equal(t, accountID, userData.AppMetadata.WTAccountID)
assert.Equal(t, invitedByEmail, userData.AppMetadata.WTInvitedBy)
// Verify the user ID is in Dex's encoded format (base64 protobuf)
rawUserID, connectorID, err := dex.DecodeDexUserID(userData.ID)
require.NoError(t, err)
assert.NotEmpty(t, rawUserID)
assert.Equal(t, "local", connectorID)
t.Logf("Decoded user ID: rawUserID=%s, connectorID=%s", rawUserID, connectorID)
// Verify we can look up the user by the encoded ID
lookedUpUser, err := manager.GetUserDataByID(ctx, userData.ID, AppMetadata{WTAccountID: accountID})
require.NoError(t, err)
assert.Equal(t, email, lookedUpUser.Email)
// Verify we can look up by email
users, err := manager.GetUserByEmail(ctx, email)
require.NoError(t, err)
require.Len(t, users, 1)
assert.Equal(t, email, users[0].Email)
// Verify creating duplicate user fails
_, err = manager.CreateUser(ctx, email, name, accountID, invitedByEmail)
assert.Error(t, err)
assert.Contains(t, err.Error(), "already exists")
}
func TestEmbeddedIdPManager_GetUserDataByID_WithEncodedID(t *testing.T) {
ctx := context.Background()
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
config := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: filepath.Join(tmpDir, "dex.db"),
},
},
}
manager, err := NewEmbeddedIdPManager(ctx, config, nil)
require.NoError(t, err)
defer func() { _ = manager.Stop(ctx) }()
// Create a user first
userData, err := manager.CreateUser(ctx, "test@example.com", "Test User", "account1", "admin@example.com")
require.NoError(t, err)
// The returned ID should be encoded
encodedID := userData.ID
// Lookup should work with the encoded ID
lookedUp, err := manager.GetUserDataByID(ctx, encodedID, AppMetadata{WTAccountID: "account1"})
require.NoError(t, err)
assert.Equal(t, "test@example.com", lookedUp.Email)
assert.Equal(t, "Test User", lookedUp.Name)
}
func TestEmbeddedIdPManager_DeleteUser(t *testing.T) {
ctx := context.Background()
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
config := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: filepath.Join(tmpDir, "dex.db"),
},
},
}
manager, err := NewEmbeddedIdPManager(ctx, config, nil)
require.NoError(t, err)
defer func() { _ = manager.Stop(ctx) }()
// Create a user
userData, err := manager.CreateUser(ctx, "delete-me@example.com", "Delete Me", "account1", "admin@example.com")
require.NoError(t, err)
// Delete the user using the encoded ID
err = manager.DeleteUser(ctx, userData.ID)
require.NoError(t, err)
// Verify user no longer exists
_, err = manager.GetUserDataByID(ctx, userData.ID, AppMetadata{})
assert.Error(t, err)
}
func TestEmbeddedIdPManager_GetAccount(t *testing.T) {
ctx := context.Background()
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
config := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: filepath.Join(tmpDir, "dex.db"),
},
},
}
manager, err := NewEmbeddedIdPManager(ctx, config, nil)
require.NoError(t, err)
defer func() { _ = manager.Stop(ctx) }()
// Create multiple users
_, err = manager.CreateUser(ctx, "user1@example.com", "User 1", "account1", "admin@example.com")
require.NoError(t, err)
_, err = manager.CreateUser(ctx, "user2@example.com", "User 2", "account1", "admin@example.com")
require.NoError(t, err)
// Get all users for the account
users, err := manager.GetAccount(ctx, "account1")
require.NoError(t, err)
assert.Len(t, users, 2)
emails := make([]string, len(users))
for i, u := range users {
emails[i] = u.Email
}
assert.Contains(t, emails, "user1@example.com")
assert.Contains(t, emails, "user2@example.com")
}
func TestEmbeddedIdPManager_UserIDFormat_MatchesJWT(t *testing.T) {
// This test verifies that the user ID returned by CreateUser
// matches the format that Dex uses in JWT tokens (the 'sub' claim)
ctx := context.Background()
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
config := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: filepath.Join(tmpDir, "dex.db"),
},
},
}
manager, err := NewEmbeddedIdPManager(ctx, config, nil)
require.NoError(t, err)
defer func() { _ = manager.Stop(ctx) }()
// Create a user
userData, err := manager.CreateUser(ctx, "jwt-test@example.com", "JWT Test", "account1", "admin@example.com")
require.NoError(t, err)
// The ID should be in the format: base64(protobuf{user_id, connector_id})
// Example: CiQ3YWFkOGMwNS0zMjg3LTQ3M2YtYjQyYS0zNjU1MDRiZjI1ZTcSBWxvY2Fs
// Verify it can be decoded
rawUserID, connectorID, err := dex.DecodeDexUserID(userData.ID)
require.NoError(t, err)
// Raw user ID should be a UUID
assert.Regexp(t, `^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`, rawUserID)
// Connector ID should be "local" for password-based auth
assert.Equal(t, "local", connectorID)
// Re-encoding should produce the same result
reEncoded := dex.EncodeDexUserID(rawUserID, connectorID)
assert.Equal(t, userData.ID, reEncoded)
t.Logf("User ID format verified:")
t.Logf(" Encoded ID: %s", userData.ID)
t.Logf(" Raw UUID: %s", rawUserID)
t.Logf(" Connector: %s", connectorID)
}

View File

@@ -72,6 +72,7 @@ type UserData struct {
Name string `json:"name"`
ID string `json:"user_id"`
AppMetadata AppMetadata `json:"app_metadata"`
Password string `json:"-"` // Plain password, only set on user creation, excluded from JSON
}
func (u *UserData) MarshalBinary() (data []byte, err error) {

View File

@@ -0,0 +1,136 @@
package instance
import (
"context"
"errors"
"fmt"
"net/mail"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
// Manager handles instance-level operations like initial setup.
type Manager interface {
// IsSetupRequired checks if instance setup is required.
// Returns true if embedded IDP is enabled and no accounts exist.
IsSetupRequired(ctx context.Context) (bool, error)
// CreateOwnerUser creates the initial owner user in the embedded IDP.
// This should only be called when IsSetupRequired returns true.
CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error)
}
// DefaultManager is the default implementation of Manager.
type DefaultManager struct {
store store.Store
embeddedIdpManager *idp.EmbeddedIdPManager
setupRequired bool
setupMu sync.RWMutex
}
// NewManager creates a new instance manager.
// If idpManager is not an EmbeddedIdPManager, setup-related operations will return appropriate defaults.
func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager) (Manager, error) {
embeddedIdp, _ := idpManager.(*idp.EmbeddedIdPManager)
m := &DefaultManager{
store: store,
embeddedIdpManager: embeddedIdp,
setupRequired: false,
}
if embeddedIdp != nil {
err := m.loadSetupRequired(ctx)
if err != nil {
return nil, err
}
}
return m, nil
}
func (m *DefaultManager) loadSetupRequired(ctx context.Context) error {
users, err := m.embeddedIdpManager.GetAllAccounts(ctx)
if err != nil {
return err
}
m.setupMu.Lock()
m.setupRequired = len(users) == 0
m.setupMu.Unlock()
return nil
}
// IsSetupRequired checks if instance setup is required.
// Setup is required when:
// 1. Embedded IDP is enabled
// 2. No accounts exist in the store
func (m *DefaultManager) IsSetupRequired(_ context.Context) (bool, error) {
if m.embeddedIdpManager == nil {
return false, nil
}
m.setupMu.RLock()
defer m.setupMu.RUnlock()
return m.setupRequired, nil
}
// CreateOwnerUser creates the initial owner user in the embedded IDP.
func (m *DefaultManager) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) {
if err := m.validateSetupInfo(email, password, name); err != nil {
return nil, err
}
if m.embeddedIdpManager == nil {
return nil, errors.New("embedded IDP is not enabled")
}
m.setupMu.RLock()
setupRequired := m.setupRequired
m.setupMu.RUnlock()
if !setupRequired {
return nil, status.Errorf(status.PreconditionFailed, "setup already completed")
}
userData, err := m.embeddedIdpManager.CreateUserWithPassword(ctx, email, password, name)
if err != nil {
return nil, fmt.Errorf("failed to create user in embedded IdP: %w", err)
}
m.setupMu.Lock()
m.setupRequired = false
m.setupMu.Unlock()
log.WithContext(ctx).Infof("created owner user %s in embedded IdP", email)
return userData, nil
}
func (m *DefaultManager) validateSetupInfo(email, password, name string) error {
if email == "" {
return status.Errorf(status.InvalidArgument, "email is required")
}
if _, err := mail.ParseAddress(email); err != nil {
return status.Errorf(status.InvalidArgument, "invalid email format")
}
if name == "" {
return status.Errorf(status.InvalidArgument, "name is required")
}
if password == "" {
return status.Errorf(status.InvalidArgument, "password is required")
}
if len(password) < 8 {
return status.Errorf(status.InvalidArgument, "password must be at least 8 characters")
}
return nil
}

View File

@@ -0,0 +1,268 @@
package instance
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/idp"
)
// mockStore implements a minimal store.Store for testing
type mockStore struct {
accountsCount int64
err error
}
func (m *mockStore) GetAccountsCounter(ctx context.Context) (int64, error) {
if m.err != nil {
return 0, m.err
}
return m.accountsCount, nil
}
// mockEmbeddedIdPManager wraps the real EmbeddedIdPManager for testing
type mockEmbeddedIdPManager struct {
createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error)
}
func (m *mockEmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error) {
if m.createUserFunc != nil {
return m.createUserFunc(ctx, email, password, name)
}
return &idp.UserData{
ID: "test-user-id",
Email: email,
Name: name,
}, nil
}
// testManager is a test implementation that accepts our mock types
type testManager struct {
store *mockStore
embeddedIdpManager *mockEmbeddedIdPManager
}
func (m *testManager) IsSetupRequired(ctx context.Context) (bool, error) {
if m.embeddedIdpManager == nil {
return false, nil
}
count, err := m.store.GetAccountsCounter(ctx)
if err != nil {
return false, err
}
return count == 0, nil
}
func (m *testManager) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) {
if m.embeddedIdpManager == nil {
return nil, errors.New("embedded IDP is not enabled")
}
return m.embeddedIdpManager.CreateUserWithPassword(ctx, email, password, name)
}
func TestIsSetupRequired_EmbeddedIdPDisabled(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 0},
embeddedIdpManager: nil, // No embedded IDP
}
required, err := manager.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.False(t, required, "setup should not be required when embedded IDP is disabled")
}
func TestIsSetupRequired_NoAccounts(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 0},
embeddedIdpManager: &mockEmbeddedIdPManager{},
}
required, err := manager.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.True(t, required, "setup should be required when no accounts exist")
}
func TestIsSetupRequired_AccountsExist(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 1},
embeddedIdpManager: &mockEmbeddedIdPManager{},
}
required, err := manager.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.False(t, required, "setup should not be required when accounts exist")
}
func TestIsSetupRequired_MultipleAccounts(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 5},
embeddedIdpManager: &mockEmbeddedIdPManager{},
}
required, err := manager.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.False(t, required, "setup should not be required when multiple accounts exist")
}
func TestIsSetupRequired_StoreError(t *testing.T) {
manager := &testManager{
store: &mockStore{err: errors.New("database error")},
embeddedIdpManager: &mockEmbeddedIdPManager{},
}
_, err := manager.IsSetupRequired(context.Background())
assert.Error(t, err, "should return error when store fails")
}
func TestCreateOwnerUser_Success(t *testing.T) {
expectedEmail := "admin@example.com"
expectedName := "Admin User"
expectedPassword := "securepassword123"
manager := &testManager{
store: &mockStore{accountsCount: 0},
embeddedIdpManager: &mockEmbeddedIdPManager{
createUserFunc: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
assert.Equal(t, expectedEmail, email)
assert.Equal(t, expectedPassword, password)
assert.Equal(t, expectedName, name)
return &idp.UserData{
ID: "created-user-id",
Email: email,
Name: name,
}, nil
},
},
}
userData, err := manager.CreateOwnerUser(context.Background(), expectedEmail, expectedPassword, expectedName)
require.NoError(t, err)
assert.Equal(t, "created-user-id", userData.ID)
assert.Equal(t, expectedEmail, userData.Email)
assert.Equal(t, expectedName, userData.Name)
}
func TestCreateOwnerUser_EmbeddedIdPDisabled(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 0},
embeddedIdpManager: nil,
}
_, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
assert.Error(t, err, "should return error when embedded IDP is disabled")
assert.Contains(t, err.Error(), "embedded IDP is not enabled")
}
func TestCreateOwnerUser_IdPError(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 0},
embeddedIdpManager: &mockEmbeddedIdPManager{
createUserFunc: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
return nil, errors.New("user already exists")
},
},
}
_, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
assert.Error(t, err, "should return error when IDP fails")
}
func TestDefaultManager_ValidateSetupRequest(t *testing.T) {
manager := &DefaultManager{
setupRequired: true,
}
tests := []struct {
name string
email string
password string
userName string
expectError bool
errorMsg string
}{
{
name: "valid request",
email: "admin@example.com",
password: "password123",
userName: "Admin User",
expectError: false,
},
{
name: "empty email",
email: "",
password: "password123",
userName: "Admin User",
expectError: true,
errorMsg: "email is required",
},
{
name: "invalid email format",
email: "not-an-email",
password: "password123",
userName: "Admin User",
expectError: true,
errorMsg: "invalid email format",
},
{
name: "empty name",
email: "admin@example.com",
password: "password123",
userName: "",
expectError: true,
errorMsg: "name is required",
},
{
name: "empty password",
email: "admin@example.com",
password: "",
userName: "Admin User",
expectError: true,
errorMsg: "password is required",
},
{
name: "password too short",
email: "admin@example.com",
password: "short",
userName: "Admin User",
expectError: true,
errorMsg: "password must be at least 8 characters",
},
{
name: "password exactly 8 characters",
email: "admin@example.com",
password: "12345678",
userName: "Admin User",
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := manager.validateSetupInfo(tt.email, tt.password, tt.userName)
if tt.expectError {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errorMsg)
} else {
require.NoError(t, err)
}
})
}
}
func TestDefaultManager_CreateOwnerUser_SetupAlreadyCompleted(t *testing.T) {
manager := &DefaultManager{
setupRequired: false,
embeddedIdpManager: &idp.EmbeddedIdPManager{},
}
_, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "setup already completed")
}

View File

@@ -381,7 +381,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
return nil, nil, "", cleanup, err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
return nil, nil, "", cleanup, err
}

View File

@@ -242,6 +242,7 @@ func startServer(
nil,
server.MockIntegratedValidator{},
networkMapController,
nil,
)
if err != nil {
t.Fatalf("failed creating management server: %v", err)

View File

@@ -27,13 +27,13 @@ import (
var _ account.Manager = (*MockAccountManager)(nil)
type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*types.Account, error)
GetOrCreateAccountByUserFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.Account, error)
GetAccountFunc func(ctx context.Context, accountID string) (*types.Account, error)
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType types.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error)
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error)
AccountExistsFunc func(ctx context.Context, accountID string) (bool, error)
GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error)
GetAccountIDByUserIdFunc func(ctx context.Context, userAuth auth.UserAuth) (string, error)
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
@@ -129,6 +129,12 @@ type MockAccountManager struct {
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error
GetIdentityProviderFunc func(ctx context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error)
GetIdentityProvidersFunc func(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error)
CreateIdentityProviderFunc func(ctx context.Context, accountID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error)
UpdateIdentityProviderFunc func(ctx context.Context, accountID, idpID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error)
DeleteIdentityProviderFunc func(ctx context.Context, accountID, idpID, userID string) error
}
func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
@@ -237,10 +243,10 @@ func (am *MockAccountManager) DeletePeer(ctx context.Context, accountID, peerID,
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface
func (am *MockAccountManager) GetOrCreateAccountByUser(
ctx context.Context, userId, domain string,
ctx context.Context, userAuth auth.UserAuth,
) (*types.Account, error) {
if am.GetOrCreateAccountByUserFunc != nil {
return am.GetOrCreateAccountByUserFunc(ctx, userId, domain)
return am.GetOrCreateAccountByUserFunc(ctx, userAuth)
}
return nil, status.Errorf(
codes.Unimplemented,
@@ -276,9 +282,9 @@ func (am *MockAccountManager) AccountExists(ctx context.Context, accountID strin
}
// GetAccountIDByUserID mock implementation of GetAccountIDByUserID from server.AccountManager interface
func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, domain string) (string, error) {
func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth auth.UserAuth) (string, error) {
if am.GetAccountIDByUserIdFunc != nil {
return am.GetAccountIDByUserIdFunc(ctx, userId, domain)
return am.GetAccountIDByUserIdFunc(ctx, userAuth)
}
return "", status.Errorf(
codes.Unimplemented,
@@ -993,3 +999,43 @@ func (am *MockAccountManager) RecalculateNetworkMapCache(ctx context.Context, ac
func (am *MockAccountManager) GetUserIDByPeerKey(ctx context.Context, peerKey string) (string, error) {
return "something", nil
}
// GetIdentityProvider mocks GetIdentityProvider of the AccountManager interface
func (am *MockAccountManager) GetIdentityProvider(ctx context.Context, accountID, idpID, userID string) (*types.IdentityProvider, error) {
if am.GetIdentityProviderFunc != nil {
return am.GetIdentityProviderFunc(ctx, accountID, idpID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetIdentityProvider is not implemented")
}
// GetIdentityProviders mocks GetIdentityProviders of the AccountManager interface
func (am *MockAccountManager) GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error) {
if am.GetIdentityProvidersFunc != nil {
return am.GetIdentityProvidersFunc(ctx, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetIdentityProviders is not implemented")
}
// CreateIdentityProvider mocks CreateIdentityProvider of the AccountManager interface
func (am *MockAccountManager) CreateIdentityProvider(ctx context.Context, accountID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error) {
if am.CreateIdentityProviderFunc != nil {
return am.CreateIdentityProviderFunc(ctx, accountID, userID, idp)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateIdentityProvider is not implemented")
}
// UpdateIdentityProvider mocks UpdateIdentityProvider of the AccountManager interface
func (am *MockAccountManager) UpdateIdentityProvider(ctx context.Context, accountID, idpID, userID string, idp *types.IdentityProvider) (*types.IdentityProvider, error) {
if am.UpdateIdentityProviderFunc != nil {
return am.UpdateIdentityProviderFunc(ctx, accountID, idpID, userID, idp)
}
return nil, status.Errorf(codes.Unimplemented, "method UpdateIdentityProvider is not implemented")
}
// DeleteIdentityProvider mocks DeleteIdentityProvider of the AccountManager interface
func (am *MockAccountManager) DeleteIdentityProvider(ctx context.Context, accountID, idpID, userID string) error {
if am.DeleteIdentityProviderFunc != nil {
return am.DeleteIdentityProviderFunc(ctx, accountID, idpID, userID)
}
return status.Errorf(codes.Unimplemented, "method DeleteIdentityProvider is not implemented")
}

View File

@@ -865,7 +865,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account,
userID := testUserID
domain := "example.com"
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
account := newAccountWithId(context.Background(), accountID, userID, domain, "", "", false)
account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup

View File

@@ -502,7 +502,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
accountID := "test_account"
adminUser := "account_creator"
someUser := "some_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
account := newAccountWithId(context.Background(), accountID, adminUser, "", "", "", false)
account.Users[someUser] = &types.User{
Id: someUser,
Role: types.UserRoleUser,
@@ -689,7 +689,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
accountID := "test_account"
adminUser := "account_creator"
someUser := "some_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
account := newAccountWithId(context.Background(), accountID, adminUser, "", "", "", false)
account.Users[someUser] = &types.User{
Id: someUser,
Role: testCase.role,
@@ -759,7 +759,7 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou
adminUser := "account_creator"
regularUser := "regular_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
account := newAccountWithId(context.Background(), accountID, adminUser, "", "", "", false)
account.Users[regularUser] = &types.User{
Id: regularUser,
Role: types.UserRoleUser,
@@ -2124,7 +2124,7 @@ func Test_DeletePeer(t *testing.T) {
// account with an admin and a regular user
accountID := "test_account"
adminUser := "account_creator"
account := newAccountWithId(context.Background(), accountID, adminUser, "", false)
account := newAccountWithId(context.Background(), accountID, adminUser, "", "", "", false)
account.Peers = map[string]*nbpeer.Peer{
"peer1": {
ID: "peer1",
@@ -2307,12 +2307,12 @@ func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) {
}
// Create account
account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
account := newAccountWithId(context.Background(), "test-account", "owner", "", "", "", false)
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
// Create user pending approval
pendingUser := types.NewRegularUser("pending-user")
pendingUser := types.NewRegularUser("pending-user", "", "")
pendingUser.AccountID = account.Id
pendingUser.Blocked = true
pendingUser.PendingApproval = true
@@ -2344,12 +2344,12 @@ func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) {
}
// Create account
account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
account := newAccountWithId(context.Background(), "test-account", "owner", "", "", "", false)
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
// Create regular user (not pending approval)
regularUser := types.NewRegularUser("regular-user")
regularUser := types.NewRegularUser("regular-user", "", "")
regularUser.AccountID = account.Id
err = manager.Store.SaveUser(context.Background(), regularUser)
require.NoError(t, err)
@@ -2378,12 +2378,12 @@ func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) {
}
// Create account
account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
account := newAccountWithId(context.Background(), "test-account", "owner", "", "", "", false)
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
// Create user pending approval
pendingUser := types.NewRegularUser("pending-user")
pendingUser := types.NewRegularUser("pending-user", "", "")
pendingUser.AccountID = account.Id
pendingUser.Blocked = true
pendingUser.PendingApproval = true
@@ -2443,12 +2443,12 @@ func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) {
}
// Create account
account := newAccountWithId(context.Background(), "test-account", "owner", "", false)
account := newAccountWithId(context.Background(), "test-account", "owner", "", "", "", false)
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
// Create regular user (not pending approval)
regularUser := types.NewRegularUser("regular-user")
regularUser := types.NewRegularUser("regular-user", "", "")
regularUser.AccountID = account.Id
err = manager.Store.SaveUser(context.Background(), regularUser)
require.NoError(t, err)

View File

@@ -3,33 +3,35 @@ package modules
type Module string
const (
Networks Module = "networks"
Peers Module = "peers"
Groups Module = "groups"
Settings Module = "settings"
Accounts Module = "accounts"
Dns Module = "dns"
Nameservers Module = "nameservers"
Events Module = "events"
Policies Module = "policies"
Routes Module = "routes"
Users Module = "users"
SetupKeys Module = "setup_keys"
Pats Module = "pats"
Networks Module = "networks"
Peers Module = "peers"
Groups Module = "groups"
Settings Module = "settings"
Accounts Module = "accounts"
Dns Module = "dns"
Nameservers Module = "nameservers"
Events Module = "events"
Policies Module = "policies"
Routes Module = "routes"
Users Module = "users"
SetupKeys Module = "setup_keys"
Pats Module = "pats"
IdentityProviders Module = "identity_providers"
)
var All = map[Module]struct{}{
Networks: {},
Peers: {},
Groups: {},
Settings: {},
Accounts: {},
Dns: {},
Nameservers: {},
Events: {},
Policies: {},
Routes: {},
Users: {},
SetupKeys: {},
Pats: {},
Networks: {},
Peers: {},
Groups: {},
Settings: {},
Accounts: {},
Dns: {},
Nameservers: {},
Events: {},
Policies: {},
Routes: {},
Users: {},
SetupKeys: {},
Pats: {},
IdentityProviders: {},
}

View File

@@ -93,5 +93,11 @@ var NetworkAdmin = RolePermissions{
operations.Update: false,
operations.Delete: false,
},
modules.IdentityProviders: {
operations.Read: true,
operations.Create: false,
operations.Update: false,
operations.Delete: false,
},
},
}

View File

@@ -109,7 +109,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er
ID: "peer1",
}
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, false)
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, "", "", false)
account.Users[admin.Id] = admin
account.Users[user.Id] = user
account.Peers["peer1"] = peer1

View File

@@ -1320,7 +1320,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou
accountID := "testingAcc"
domain := "example.com"
account := newAccountWithId(context.Background(), accountID, userID, domain, false)
account := newAccountWithId(context.Background(), accountID, userID, domain, "", "", false)
err := am.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, err

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
)
func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
@@ -24,7 +25,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
}
userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
if err != nil {
t.Fatal(err)
}
@@ -99,7 +100,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
}
userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
if err != nil {
t.Fatal(err)
}
@@ -204,7 +205,7 @@ func TestGetSetupKeys(t *testing.T) {
}
userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
if err != nil {
t.Fatal(err)
}
@@ -471,7 +472,7 @@ func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t
}
userID := "testingUser"
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: userID})
if err != nil {
t.Fatal(err)
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/management/server/types"
nbutil "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/crypt"
)
// storeFileName Store file name. Stored in the datadir
@@ -263,3 +264,8 @@ func (s *FileStore) Close(ctx context.Context) error {
func (s *FileStore) GetStoreEngine() types.Engine {
return types.FileStoreEngine
}
// SetFieldEncrypt is a no-op for FileStore as it doesn't support field encryption.
func (s *FileStore) SetFieldEncrypt(_ *crypt.FieldEncrypt) {
// no-op: FileStore stores data in plaintext JSON; encryption is not supported
}

View File

@@ -37,6 +37,7 @@ import (
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/util/crypt"
)
const (
@@ -57,13 +58,13 @@ const (
// SqlStore represents an account storage backed by a Sql DB persisted to disk
type SqlStore struct {
db *gorm.DB
globalAccountLock sync.Mutex
metrics telemetry.AppMetrics
installationPK int
storeEngine types.Engine
pool *pgxpool.Pool
db *gorm.DB
globalAccountLock sync.Mutex
metrics telemetry.AppMetrics
installationPK int
storeEngine types.Engine
pool *pgxpool.Pool
fieldEncrypt *crypt.FieldEncrypt
transactionTimeout time.Duration
}
@@ -175,6 +176,13 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
generateAccountSQLTypes(account)
// Encrypt sensitive user data before saving
for i := range account.UsersG {
if err := account.UsersG[i].EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt user: %w", err)
}
}
for _, group := range account.GroupsG {
group.StoreGroupPeers()
}
@@ -440,7 +448,18 @@ func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
return nil
}
result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&users)
usersCopy := make([]*types.User, len(users))
for i, user := range users {
userCopy := user.Copy()
userCopy.Email = user.Email
userCopy.Name = user.Name
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt user: %w", err)
}
usersCopy[i] = userCopy
}
result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&usersCopy)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save users to store: %s", result.Error)
return status.Errorf(status.Internal, "failed to save users to store")
@@ -450,7 +469,15 @@ func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
// SaveUser saves the given user to the database.
func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error {
result := s.db.Save(user)
userCopy := user.Copy()
userCopy.Email = user.Email
userCopy.Name = user.Name
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt user: %w", err)
}
result := s.db.Save(userCopy)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save user to store: %s", result.Error)
return status.Errorf(status.Internal, "failed to save user to store")
@@ -600,6 +627,10 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
return nil, status.NewGetUserFromStoreError()
}
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt user: %w", err)
}
return &user, nil
}
@@ -618,6 +649,10 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
return nil, status.NewGetUserFromStoreError()
}
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt user: %w", err)
}
return &user, nil
}
@@ -654,6 +689,12 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
return nil, status.Errorf(status.Internal, "issue getting users from store")
}
for _, user := range users {
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt user: %w", err)
}
}
return users, nil
}
@@ -672,6 +713,10 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
return nil, status.Errorf(status.Internal, "failed to get account owner from the store")
}
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt user: %w", err)
}
return &user, nil
}
@@ -866,6 +911,9 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
if user.AutoGroups == nil {
user.AutoGroups = []string{}
}
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt user: %w", err)
}
account.Users[user.Id] = &user
user.PATsG = nil
}
@@ -1141,6 +1189,9 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
account.Users = make(map[string]*types.User, len(account.UsersG))
for i := range account.UsersG {
user := &account.UsersG[i]
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt user: %w", err)
}
user.PATs = make(map[string]*types.PersonalAccessToken)
if userPats, ok := patsByUserID[user.Id]; ok {
for j := range userPats {
@@ -1545,7 +1596,7 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
}
func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) {
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1`
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
return nil, err
@@ -1555,7 +1606,7 @@ func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User
var autoGroups []byte
var lastLogin, createdAt sql.NullTime
var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType)
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name)
if err == nil {
if lastLogin.Valid {
u.LastLogin = &lastLogin.Time
@@ -3012,6 +3063,11 @@ func (s *SqlStore) GetDB() *gorm.DB {
return s.db
}
// SetFieldEncrypt sets the field encryptor for encrypting sensitive user data.
func (s *SqlStore) SetFieldEncrypt(enc *crypt.FieldEncrypt) {
s.fieldEncrypt = enc
}
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) {
tx := s.db
if lockStrength != LockingStrengthNone {

View File

@@ -32,6 +32,7 @@ import (
nbroute "github.com/netbirdio/netbird/route"
route2 "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/util/crypt"
)
func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) {
@@ -2090,7 +2091,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty
setupKeys := map[string]*types.SetupKey{}
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
owner := types.NewOwnerUser(userID)
owner := types.NewOwnerUser(userID, "", "")
owner.AccountID = accountID
users[userID] = owner
@@ -3114,6 +3115,138 @@ func TestSqlStore_SaveUsers(t *testing.T) {
require.Equal(t, users[1].AutoGroups, user.AutoGroups)
}
func TestSqlStore_SaveUserWithEncryption(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
// Enable encryption
key, err := crypt.GenerateKey()
require.NoError(t, err)
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
require.NoError(t, err)
store.SetFieldEncrypt(fieldEncrypt)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
// rawUser is used to read raw (potentially encrypted) data from the database
// without any gorm hooks or automatic decryption
type rawUser struct {
Id string
Email string
Name string
}
t.Run("save user with empty email and name", func(t *testing.T) {
user := &types.User{
Id: "user-empty-fields",
AccountID: accountID,
Role: types.UserRoleUser,
Email: "",
Name: "",
AutoGroups: []string{"groupA"},
}
err = store.SaveUser(context.Background(), user)
require.NoError(t, err)
// Verify using direct database query that empty strings remain empty (not encrypted)
var raw rawUser
err = store.(*SqlStore).db.Table("users").Select("id, email, name").Where("id = ?", user.Id).First(&raw).Error
require.NoError(t, err)
require.Equal(t, "", raw.Email, "empty email should remain empty in database")
require.Equal(t, "", raw.Name, "empty name should remain empty in database")
// Verify manual decryption returns empty strings
decryptedEmail, err := fieldEncrypt.Decrypt(raw.Email)
require.NoError(t, err)
require.Equal(t, "", decryptedEmail)
decryptedName, err := fieldEncrypt.Decrypt(raw.Name)
require.NoError(t, err)
require.Equal(t, "", decryptedName)
})
t.Run("save user with email and name", func(t *testing.T) {
user := &types.User{
Id: "user-with-fields",
AccountID: accountID,
Role: types.UserRoleAdmin,
Email: "test@example.com",
Name: "Test User",
AutoGroups: []string{"groupB"},
}
err = store.SaveUser(context.Background(), user)
require.NoError(t, err)
// Verify using direct database query that the data is encrypted (not plaintext)
var raw rawUser
err = store.(*SqlStore).db.Table("users").Select("id, email, name").Where("id = ?", user.Id).First(&raw).Error
require.NoError(t, err)
require.NotEqual(t, "test@example.com", raw.Email, "email should be encrypted in database")
require.NotEqual(t, "Test User", raw.Name, "name should be encrypted in database")
// Verify manual decryption returns correct values
decryptedEmail, err := fieldEncrypt.Decrypt(raw.Email)
require.NoError(t, err)
require.Equal(t, "test@example.com", decryptedEmail)
decryptedName, err := fieldEncrypt.Decrypt(raw.Name)
require.NoError(t, err)
require.Equal(t, "Test User", decryptedName)
})
t.Run("save multiple users with mixed fields", func(t *testing.T) {
users := []*types.User{
{
Id: "batch-user-1",
AccountID: accountID,
Email: "",
Name: "",
},
{
Id: "batch-user-2",
AccountID: accountID,
Email: "batch@example.com",
Name: "Batch User",
},
}
err = store.SaveUsers(context.Background(), users)
require.NoError(t, err)
// Verify first user (empty fields) using direct database query
var raw1 rawUser
err = store.(*SqlStore).db.Table("users").Select("id, email, name").Where("id = ?", "batch-user-1").First(&raw1).Error
require.NoError(t, err)
require.Equal(t, "", raw1.Email, "empty email should remain empty in database")
require.Equal(t, "", raw1.Name, "empty name should remain empty in database")
// Verify second user (with fields) using direct database query
var raw2 rawUser
err = store.(*SqlStore).db.Table("users").Select("id, email, name").Where("id = ?", "batch-user-2").First(&raw2).Error
require.NoError(t, err)
require.NotEqual(t, "batch@example.com", raw2.Email, "email should be encrypted in database")
require.NotEqual(t, "Batch User", raw2.Name, "name should be encrypted in database")
// Verify manual decryption returns empty strings for first user
decryptedEmail1, err := fieldEncrypt.Decrypt(raw1.Email)
require.NoError(t, err)
require.Equal(t, "", decryptedEmail1)
decryptedName1, err := fieldEncrypt.Decrypt(raw1.Name)
require.NoError(t, err)
require.Equal(t, "", decryptedName1)
// Verify manual decryption returns correct values for second user
decryptedEmail2, err := fieldEncrypt.Decrypt(raw2.Email)
require.NoError(t, err)
require.Equal(t, "batch@example.com", decryptedEmail2)
decryptedName2, err := fieldEncrypt.Decrypt(raw2.Name)
require.NoError(t, err)
require.Equal(t, "Batch User", decryptedName2)
})
}
func TestSqlStore_DeleteUser(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)

View File

@@ -27,6 +27,7 @@ import (
"github.com/netbirdio/netbird/management/server/testutil"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/crypt"
"github.com/netbirdio/netbird/management/server/migration"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -204,6 +205,9 @@ type Store interface {
MarkAccountPrimary(ctx context.Context, accountID string) error
UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error
GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) ([]*types.PolicyRule, error)
// SetFieldEncrypt sets the field encryptor for encrypting sensitive user data.
SetFieldEncrypt(enc *crypt.FieldEncrypt)
GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error)
}
@@ -340,6 +344,12 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc {
func(db *gorm.DB) error {
return migration.DropIndex[routerTypes.NetworkRouter](ctx, db, "idx_network_routers_id")
},
func(db *gorm.DB) error {
return migration.MigrateNewField[types.User](ctx, db, "name", "")
},
func(db *gorm.DB) error {
return migration.MigrateNewField[types.User](ctx, db, "email", "")
},
}
} // migratePostAuto migrates the SQLite database to the latest schema
func migratePostAuto(ctx context.Context, db *gorm.DB) error {

View File

@@ -0,0 +1,122 @@
package types
import (
"errors"
"net/url"
)
// Identity provider validation errors
var (
ErrIdentityProviderNameRequired = errors.New("identity provider name is required")
ErrIdentityProviderTypeRequired = errors.New("identity provider type is required")
ErrIdentityProviderTypeUnsupported = errors.New("unsupported identity provider type")
ErrIdentityProviderIssuerRequired = errors.New("identity provider issuer is required")
ErrIdentityProviderIssuerInvalid = errors.New("identity provider issuer must be a valid URL")
ErrIdentityProviderClientIDRequired = errors.New("identity provider client ID is required")
)
// IdentityProviderType is the type of identity provider
type IdentityProviderType string
const (
// IdentityProviderTypeOIDC is a generic OIDC identity provider
IdentityProviderTypeOIDC IdentityProviderType = "oidc"
// IdentityProviderTypeZitadel is the Zitadel identity provider
IdentityProviderTypeZitadel IdentityProviderType = "zitadel"
// IdentityProviderTypeEntra is the Microsoft Entra (Azure AD) identity provider
IdentityProviderTypeEntra IdentityProviderType = "entra"
// IdentityProviderTypeGoogle is the Google identity provider
IdentityProviderTypeGoogle IdentityProviderType = "google"
// IdentityProviderTypeOkta is the Okta identity provider
IdentityProviderTypeOkta IdentityProviderType = "okta"
// IdentityProviderTypePocketID is the PocketID identity provider
IdentityProviderTypePocketID IdentityProviderType = "pocketid"
// IdentityProviderTypeMicrosoft is the Microsoft identity provider
IdentityProviderTypeMicrosoft IdentityProviderType = "microsoft"
// IdentityProviderTypeAuthentik is the Authentik identity provider
IdentityProviderTypeAuthentik IdentityProviderType = "authentik"
// IdentityProviderTypeKeycloak is the Keycloak identity provider
IdentityProviderTypeKeycloak IdentityProviderType = "keycloak"
)
// IdentityProvider represents an identity provider configuration
type IdentityProvider struct {
// ID is the unique identifier of the identity provider
ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
// Type is the type of identity provider
Type IdentityProviderType
// Name is a human-readable name for the identity provider
Name string
// Issuer is the OIDC issuer URL
Issuer string
// ClientID is the OAuth2 client ID
ClientID string
// ClientSecret is the OAuth2 client secret
ClientSecret string
}
// Copy returns a copy of the IdentityProvider
func (idp *IdentityProvider) Copy() *IdentityProvider {
return &IdentityProvider{
ID: idp.ID,
AccountID: idp.AccountID,
Type: idp.Type,
Name: idp.Name,
Issuer: idp.Issuer,
ClientID: idp.ClientID,
ClientSecret: idp.ClientSecret,
}
}
// EventMeta returns a map of metadata for activity events
func (idp *IdentityProvider) EventMeta() map[string]any {
return map[string]any{
"name": idp.Name,
"type": string(idp.Type),
"issuer": idp.Issuer,
}
}
// Validate validates the identity provider configuration
func (idp *IdentityProvider) Validate() error {
if idp.Name == "" {
return ErrIdentityProviderNameRequired
}
if idp.Type == "" {
return ErrIdentityProviderTypeRequired
}
if !idp.Type.IsValid() {
return ErrIdentityProviderTypeUnsupported
}
if !idp.Type.HasBuiltInIssuer() && idp.Issuer == "" {
return ErrIdentityProviderIssuerRequired
}
if idp.Issuer != "" {
parsedURL, err := url.Parse(idp.Issuer)
if err != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
return ErrIdentityProviderIssuerInvalid
}
}
if idp.ClientID == "" {
return ErrIdentityProviderClientIDRequired
}
return nil
}
// IsValid checks if the given type is a supported identity provider type
func (t IdentityProviderType) IsValid() bool {
switch t {
case IdentityProviderTypeOIDC, IdentityProviderTypeZitadel, IdentityProviderTypeEntra,
IdentityProviderTypeGoogle, IdentityProviderTypeOkta, IdentityProviderTypePocketID,
IdentityProviderTypeMicrosoft, IdentityProviderTypeAuthentik, IdentityProviderTypeKeycloak:
return true
}
return false
}
// HasBuiltInIssuer returns true for types that don't require an issuer URL
func (t IdentityProviderType) HasBuiltInIssuer() bool {
return t == IdentityProviderTypeGoogle || t == IdentityProviderTypeMicrosoft
}

View File

@@ -0,0 +1,137 @@
package types
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestIdentityProvider_Validate(t *testing.T) {
tests := []struct {
name string
idp *IdentityProvider
expectedErr error
}{
{
name: "valid OIDC provider",
idp: &IdentityProvider{
Name: "Test Provider",
Type: IdentityProviderTypeOIDC,
Issuer: "https://example.com",
ClientID: "client-id",
},
expectedErr: nil,
},
{
name: "valid OIDC provider with path",
idp: &IdentityProvider{
Name: "Test Provider",
Type: IdentityProviderTypeOIDC,
Issuer: "https://example.com/oauth2/issuer",
ClientID: "client-id",
},
expectedErr: nil,
},
{
name: "missing name",
idp: &IdentityProvider{
Type: IdentityProviderTypeOIDC,
Issuer: "https://example.com",
ClientID: "client-id",
},
expectedErr: ErrIdentityProviderNameRequired,
},
{
name: "missing type",
idp: &IdentityProvider{
Name: "Test Provider",
Issuer: "https://example.com",
ClientID: "client-id",
},
expectedErr: ErrIdentityProviderTypeRequired,
},
{
name: "invalid type",
idp: &IdentityProvider{
Name: "Test Provider",
Type: "invalid",
Issuer: "https://example.com",
ClientID: "client-id",
},
expectedErr: ErrIdentityProviderTypeUnsupported,
},
{
name: "missing issuer for OIDC",
idp: &IdentityProvider{
Name: "Test Provider",
Type: IdentityProviderTypeOIDC,
ClientID: "client-id",
},
expectedErr: ErrIdentityProviderIssuerRequired,
},
{
name: "invalid issuer URL - no scheme",
idp: &IdentityProvider{
Name: "Test Provider",
Type: IdentityProviderTypeOIDC,
Issuer: "example.com",
ClientID: "client-id",
},
expectedErr: ErrIdentityProviderIssuerInvalid,
},
{
name: "invalid issuer URL - no host",
idp: &IdentityProvider{
Name: "Test Provider",
Type: IdentityProviderTypeOIDC,
Issuer: "https://",
ClientID: "client-id",
},
expectedErr: ErrIdentityProviderIssuerInvalid,
},
{
name: "invalid issuer URL - just path",
idp: &IdentityProvider{
Name: "Test Provider",
Type: IdentityProviderTypeOIDC,
Issuer: "/oauth2/issuer",
ClientID: "client-id",
},
expectedErr: ErrIdentityProviderIssuerInvalid,
},
{
name: "missing client ID",
idp: &IdentityProvider{
Name: "Test Provider",
Type: IdentityProviderTypeOIDC,
Issuer: "https://example.com",
},
expectedErr: ErrIdentityProviderClientIDRequired,
},
{
name: "Google provider without issuer is valid",
idp: &IdentityProvider{
Name: "Google SSO",
Type: IdentityProviderTypeGoogle,
ClientID: "client-id",
},
expectedErr: nil,
},
{
name: "Microsoft provider without issuer is valid",
idp: &IdentityProvider{
Name: "Microsoft SSO",
Type: IdentityProviderTypeMicrosoft,
ClientID: "client-id",
},
expectedErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.idp.Validate()
assert.Equal(t, tt.expectedErr, err)
})
}
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integration_reference"
"github.com/netbirdio/netbird/util/crypt"
)
const (
@@ -65,7 +66,11 @@ type UserInfo struct {
LastLogin time.Time `json:"last_login"`
Issued string `json:"issued"`
PendingApproval bool `json:"pending_approval"`
Password string `json:"password"`
IntegrationReference integration_reference.IntegrationReference `json:"-"`
// IdPID is the identity provider ID (connector ID) extracted from the Dex-encoded user ID.
// This field is only populated when the user ID can be decoded from Dex's format.
IdPID string `json:"idp_id,omitempty"`
}
// User represents a user of the system
@@ -96,6 +101,9 @@ type User struct {
Issued string `gorm:"default:api"`
IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
Name string `gorm:"default:''"`
Email string `gorm:"default:''"`
}
// IsBlocked returns true if the user is blocked, false otherwise
@@ -143,10 +151,16 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
}
if userData == nil {
name := u.Name
if u.IsServiceUser {
name = u.ServiceUserName
}
return &UserInfo{
ID: u.Id,
Email: "",
Name: u.ServiceUserName,
Email: u.Email,
Name: name,
Role: string(u.Role),
AutoGroups: u.AutoGroups,
Status: string(UserStatusActive),
@@ -178,6 +192,7 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
LastLogin: u.GetLastLogin(),
Issued: u.Issued,
PendingApproval: u.PendingApproval,
Password: userData.Password,
}, nil
}
@@ -204,11 +219,13 @@ func (u *User) Copy() *User {
CreatedAt: u.CreatedAt,
Issued: u.Issued,
IntegrationReference: u.IntegrationReference,
Email: u.Email,
Name: u.Name,
}
}
// NewUser creates a new user
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User {
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string, email string, name string) *User {
return &User{
Id: id,
Role: role,
@@ -218,20 +235,70 @@ func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, se
AutoGroups: autoGroups,
Issued: issued,
CreatedAt: time.Now().UTC(),
Name: name,
Email: email,
}
}
// NewRegularUser creates a new user with role UserRoleUser
func NewRegularUser(id string) *User {
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI)
func NewRegularUser(id, email, name string) *User {
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI, email, name)
}
// NewAdminUser creates a new user with role UserRoleAdmin
func NewAdminUser(id string) *User {
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI)
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI, "", "")
}
// NewOwnerUser creates a new user with role UserRoleOwner
func NewOwnerUser(id string) *User {
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI)
func NewOwnerUser(id string, email string, name string) *User {
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI, email, name)
}
// EncryptSensitiveData encrypts the user's sensitive fields (Email and Name) in place.
func (u *User) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
var err error
if u.Email != "" {
u.Email, err = enc.Encrypt(u.Email)
if err != nil {
return fmt.Errorf("encrypt email: %w", err)
}
}
if u.Name != "" {
u.Name, err = enc.Encrypt(u.Name)
if err != nil {
return fmt.Errorf("encrypt name: %w", err)
}
}
return nil
}
// DecryptSensitiveData decrypts the user's sensitive fields (Email and Name) in place.
func (u *User) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
var err error
if u.Email != "" {
u.Email, err = enc.Decrypt(u.Email)
if err != nil {
return fmt.Errorf("decrypt email: %w", err)
}
}
if u.Name != "" {
u.Name, err = enc.Decrypt(u.Name)
if err != nil {
return fmt.Errorf("decrypt name: %w", err)
}
}
return nil
}

View File

@@ -0,0 +1,298 @@
package types
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/util/crypt"
)
func TestUser_EncryptSensitiveData(t *testing.T) {
key, err := crypt.GenerateKey()
require.NoError(t, err)
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
require.NoError(t, err)
t.Run("encrypt email and name", func(t *testing.T) {
user := &User{
Id: "user-1",
Email: "test@example.com",
Name: "Test User",
}
err := user.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.NotEqual(t, "test@example.com", user.Email, "email should be encrypted")
assert.NotEqual(t, "Test User", user.Name, "name should be encrypted")
assert.NotEmpty(t, user.Email, "encrypted email should not be empty")
assert.NotEmpty(t, user.Name, "encrypted name should not be empty")
})
t.Run("encrypt empty email and name", func(t *testing.T) {
user := &User{
Id: "user-2",
Email: "",
Name: "",
}
err := user.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.Equal(t, "", user.Email, "empty email should remain empty")
assert.Equal(t, "", user.Name, "empty name should remain empty")
})
t.Run("encrypt only email", func(t *testing.T) {
user := &User{
Id: "user-3",
Email: "test@example.com",
Name: "",
}
err := user.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.NotEqual(t, "test@example.com", user.Email, "email should be encrypted")
assert.NotEmpty(t, user.Email, "encrypted email should not be empty")
assert.Equal(t, "", user.Name, "empty name should remain empty")
})
t.Run("encrypt only name", func(t *testing.T) {
user := &User{
Id: "user-4",
Email: "",
Name: "Test User",
}
err := user.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.Equal(t, "", user.Email, "empty email should remain empty")
assert.NotEqual(t, "Test User", user.Name, "name should be encrypted")
assert.NotEmpty(t, user.Name, "encrypted name should not be empty")
})
t.Run("nil encryptor returns no error", func(t *testing.T) {
user := &User{
Id: "user-5",
Email: "test@example.com",
Name: "Test User",
}
err := user.EncryptSensitiveData(nil)
require.NoError(t, err)
assert.Equal(t, "test@example.com", user.Email, "email should remain unchanged with nil encryptor")
assert.Equal(t, "Test User", user.Name, "name should remain unchanged with nil encryptor")
})
}
func TestUser_DecryptSensitiveData(t *testing.T) {
key, err := crypt.GenerateKey()
require.NoError(t, err)
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
require.NoError(t, err)
t.Run("decrypt email and name", func(t *testing.T) {
originalEmail := "test@example.com"
originalName := "Test User"
user := &User{
Id: "user-1",
Email: originalEmail,
Name: originalName,
}
err := user.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
err = user.DecryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.Equal(t, originalEmail, user.Email, "decrypted email should match original")
assert.Equal(t, originalName, user.Name, "decrypted name should match original")
})
t.Run("decrypt empty email and name", func(t *testing.T) {
user := &User{
Id: "user-2",
Email: "",
Name: "",
}
err := user.DecryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.Equal(t, "", user.Email, "empty email should remain empty")
assert.Equal(t, "", user.Name, "empty name should remain empty")
})
t.Run("decrypt only email", func(t *testing.T) {
originalEmail := "test@example.com"
user := &User{
Id: "user-3",
Email: originalEmail,
Name: "",
}
err := user.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
err = user.DecryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.Equal(t, originalEmail, user.Email, "decrypted email should match original")
assert.Equal(t, "", user.Name, "empty name should remain empty")
})
t.Run("decrypt only name", func(t *testing.T) {
originalName := "Test User"
user := &User{
Id: "user-4",
Email: "",
Name: originalName,
}
err := user.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
err = user.DecryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.Equal(t, "", user.Email, "empty email should remain empty")
assert.Equal(t, originalName, user.Name, "decrypted name should match original")
})
t.Run("nil encryptor returns no error", func(t *testing.T) {
user := &User{
Id: "user-5",
Email: "test@example.com",
Name: "Test User",
}
err := user.DecryptSensitiveData(nil)
require.NoError(t, err)
assert.Equal(t, "test@example.com", user.Email, "email should remain unchanged with nil encryptor")
assert.Equal(t, "Test User", user.Name, "name should remain unchanged with nil encryptor")
})
t.Run("decrypt with invalid ciphertext returns error", func(t *testing.T) {
user := &User{
Id: "user-6",
Email: "not-valid-base64-ciphertext!!!",
Name: "Test User",
}
err := user.DecryptSensitiveData(fieldEncrypt)
require.Error(t, err)
assert.Contains(t, err.Error(), "decrypt email")
})
t.Run("decrypt with wrong key returns error", func(t *testing.T) {
originalEmail := "test@example.com"
originalName := "Test User"
user := &User{
Id: "user-7",
Email: originalEmail,
Name: originalName,
}
err := user.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
differentKey, err := crypt.GenerateKey()
require.NoError(t, err)
differentEncrypt, err := crypt.NewFieldEncrypt(differentKey)
require.NoError(t, err)
err = user.DecryptSensitiveData(differentEncrypt)
require.Error(t, err)
assert.Contains(t, err.Error(), "decrypt email")
})
}
func TestUser_EncryptDecryptRoundTrip(t *testing.T) {
key, err := crypt.GenerateKey()
require.NoError(t, err)
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
require.NoError(t, err)
testCases := []struct {
name string
email string
uname string
}{
{
name: "standard email and name",
email: "user@example.com",
uname: "John Doe",
},
{
name: "email with special characters",
email: "user+tag@sub.example.com",
uname: "O'Brien, Mary-Jane",
},
{
name: "unicode characters",
email: "user@example.com",
uname: "Jean-Pierre Müller 日本語",
},
{
name: "long values",
email: "very.long.email.address.that.is.quite.extended@subdomain.example.organization.com",
uname: "A Very Long Name That Contains Many Words And Is Quite Extended For Testing Purposes",
},
{
name: "empty email only",
email: "",
uname: "Name Only",
},
{
name: "empty name only",
email: "email@only.com",
uname: "",
},
{
name: "both empty",
email: "",
uname: "",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
user := &User{
Id: "test-user",
Email: tc.email,
Name: tc.uname,
}
err := user.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
if tc.email != "" {
assert.NotEqual(t, tc.email, user.Email, "email should be encrypted")
}
if tc.uname != "" {
assert.NotEqual(t, tc.uname, user.Name, "name should be encrypted")
}
err = user.DecryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.Equal(t, tc.email, user.Email, "decrypted email should match original")
assert.Equal(t, tc.uname, user.Name, "decrypted name should match original")
})
}
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -40,7 +41,7 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI
}
newUserID := uuid.New().String()
newUser := types.NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, types.UserIssuedAPI)
newUser := types.NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, types.UserIssuedAPI, "", "")
newUser.AccountID = accountID
log.WithContext(ctx).Debugf("New User: %v", newUser)
@@ -104,7 +105,12 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
inviterID = createdBy
}
idpUser, err := am.createNewIdpUser(ctx, accountID, inviterID, invite)
var idpUser *idp.UserData
if IsEmbeddedIdp(am.idpManager) {
idpUser, err = am.createEmbeddedIdpUser(ctx, accountID, inviterID, invite)
} else {
idpUser, err = am.createNewIdpUser(ctx, accountID, inviterID, invite)
}
if err != nil {
return nil, err
}
@@ -117,18 +123,26 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
Issued: invite.Issued,
IntegrationReference: invite.IntegrationReference,
CreatedAt: time.Now().UTC(),
Email: invite.Email,
Name: invite.Name,
}
if err = am.Store.SaveUser(ctx, newUser); err != nil {
return nil, err
}
_, err = am.refreshCache(ctx, accountID)
if err != nil {
return nil, err
if !IsEmbeddedIdp(am.idpManager) {
_, err = am.refreshCache(ctx, accountID)
if err != nil {
return nil, err
}
}
am.StoreEvent(ctx, userID, newUser.Id, accountID, activity.UserInvited, nil)
eventType := activity.UserInvited
if IsEmbeddedIdp(am.idpManager) {
eventType = activity.UserCreated
}
am.StoreEvent(ctx, userID, newUser.Id, accountID, eventType, nil)
return newUser.ToUserInfo(idpUser)
}
@@ -172,6 +186,34 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID
return am.idpManager.CreateUser(ctx, invite.Email, invite.Name, accountID, inviterUser.Email)
}
// createEmbeddedIdpUser validates the invite and creates a new user in the embedded IdP.
// Unlike createNewIdpUser, this method fetches user data directly from the database
// since the embedded IdP usage ensures the username and email are stored locally in the User table.
func (am *DefaultAccountManager) createEmbeddedIdpUser(ctx context.Context, accountID string, inviterID string, invite *types.UserInfo) (*idp.UserData, error) {
inviter, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, inviterID)
if err != nil {
return nil, fmt.Errorf("failed to get inviter user: %w", err)
}
if inviter == nil {
return nil, status.Errorf(status.NotFound, "inviter user with ID %s doesn't exist", inviterID)
}
// check if the user is already registered with this email => reject
existingUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
for _, user := range existingUsers {
if strings.EqualFold(user.Email, invite.Email) {
return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account")
}
}
return am.idpManager.CreateUser(ctx, invite.Email, invite.Name, accountID, inviter.Email)
}
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
return am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
}
@@ -757,7 +799,7 @@ func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initi
// If the AccountManager has a non-nil idpManager and the User is not a service user,
// it will attempt to look up the UserData from the cache.
func (am *DefaultAccountManager) getUserInfo(ctx context.Context, user *types.User, accountID string) (*types.UserInfo, error) {
if !isNil(am.idpManager) && !user.IsServiceUser {
if !isNil(am.idpManager) && !user.IsServiceUser && !IsEmbeddedIdp(am.idpManager) {
userData, err := am.lookupUserInCache(ctx, user.Id, accountID)
if err != nil {
return nil, err
@@ -808,7 +850,10 @@ func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUse
}
// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist
func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userID, domain string) (*types.Account, error) {
func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userAuth auth.UserAuth) (*types.Account, error) {
userID := userAuth.UserId
domain := userAuth.Domain
start := time.Now()
unlock := am.Store.AcquireGlobalLock(ctx)
defer unlock()
@@ -819,7 +864,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u
account, err := am.Store.GetAccountByUser(ctx, userID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
account, err = am.newAccount(ctx, userID, lowerDomain)
account, err = am.newAccount(ctx, userID, lowerDomain, userAuth.Email, userAuth.Name)
if err != nil {
return nil, err
}
@@ -884,7 +929,8 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
var queriedUsers []*idp.UserData
var err error
if !isNil(am.idpManager) {
// embedded IdP ensures that we have user data (email and name) stored in the database.
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
users := make(map[string]userLoggedInOnce, len(accountUsers))
usersFromIntegration := make([]*idp.UserData, 0)
for _, user := range accountUsers {
@@ -921,6 +967,10 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
if err != nil {
return nil, err
}
// Try to decode Dex user ID to extract the IdP ID (connector ID)
if _, connectorID, decodeErr := dex.DecodeDexUserID(accountUser.Id); decodeErr == nil && connectorID != "" {
info.IdPID = connectorID
}
userInfosMap[accountUser.Id] = info
}
@@ -942,7 +992,7 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
info = &types.UserInfo{
ID: localUser.Id,
Email: "",
Email: localUser.Email,
Name: name,
Role: string(localUser.Role),
AutoGroups: localUser.AutoGroups,
@@ -951,6 +1001,10 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
NonDeletable: localUser.NonDeletable,
}
}
// Try to decode Dex user ID to extract the IdP ID (connector ID)
if _, connectorID, decodeErr := dex.DecodeDexUserID(localUser.Id); decodeErr == nil && connectorID != "" {
info.IdPID = connectorID
}
userInfosMap[info.ID] = info
}

View File

@@ -3,6 +3,7 @@ package server
import (
"context"
"fmt"
"os"
"reflect"
"testing"
"time"
@@ -29,6 +30,7 @@ import (
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integration_reference"
@@ -58,7 +60,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
err = s.SaveAccount(context.Background(), account)
if err != nil {
@@ -105,7 +107,7 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
account.Users[mockTargetUserId] = &types.User{
Id: mockTargetUserId,
IsServiceUser: false,
@@ -133,7 +135,7 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
account.Users[mockTargetUserId] = &types.User{
Id: mockTargetUserId,
IsServiceUser: true,
@@ -165,7 +167,7 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
err = store.SaveAccount(context.Background(), account)
if err != nil {
@@ -190,7 +192,7 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
err = store.SaveAccount(context.Background(), account)
if err != nil {
@@ -215,7 +217,7 @@ func TestUser_DeletePAT(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
account.Users[mockUserID] = &types.User{
Id: mockUserID,
PATs: map[string]*types.PersonalAccessToken{
@@ -258,7 +260,7 @@ func TestUser_GetPAT(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
account.Users[mockUserID] = &types.User{
Id: mockUserID,
AccountID: mockAccountID,
@@ -298,7 +300,7 @@ func TestUser_GetAllPATs(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
account.Users[mockUserID] = &types.User{
Id: mockUserID,
AccountID: mockAccountID,
@@ -362,6 +364,8 @@ func TestUser_Copy(t *testing.T) {
ID: 0,
IntegrationType: "test",
},
Email: "whatever@gmail.com",
Name: "John Doe",
}
err := validateStruct(user)
@@ -408,7 +412,7 @@ func TestUser_CreateServiceUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
err = store.SaveAccount(context.Background(), account)
if err != nil {
@@ -455,7 +459,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
err = store.SaveAccount(context.Background(), account)
if err != nil {
@@ -503,7 +507,7 @@ func TestUser_CreateUser_RegularUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
err = store.SaveAccount(context.Background(), account)
if err != nil {
@@ -534,7 +538,7 @@ func TestUser_InviteNewUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
err = store.SaveAccount(context.Background(), account)
if err != nil {
@@ -641,7 +645,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
account.Users[mockServiceUserID] = tt.serviceUser
err = store.SaveAccount(context.Background(), account)
@@ -680,7 +684,7 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
err = store.SaveAccount(context.Background(), account)
if err != nil {
@@ -707,7 +711,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
targetId := "user2"
account.Users[targetId] = &types.User{
@@ -801,7 +805,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
targetId := "user2"
account.Users[targetId] = &types.User{
@@ -969,7 +973,7 @@ func TestDefaultAccountManager_GetUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
err = store.SaveAccount(context.Background(), account)
if err != nil {
@@ -1005,9 +1009,9 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account.Users["normal_user1"] = types.NewRegularUser("normal_user1")
account.Users["normal_user2"] = types.NewRegularUser("normal_user2")
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
account.Users["normal_user1"] = types.NewRegularUser("normal_user1", "", "")
account.Users["normal_user2"] = types.NewRegularUser("normal_user2", "", "")
err = store.SaveAccount(context.Background(), account)
if err != nil {
@@ -1047,7 +1051,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
externalUser := &types.User{
Id: "externalUser",
Role: types.UserRoleUser,
@@ -1104,7 +1108,7 @@ func TestUser_IsAdmin(t *testing.T) {
user := types.NewAdminUser(mockUserID)
assert.True(t, user.HasAdminPower())
user = types.NewRegularUser(mockUserID)
user = types.NewRegularUser(mockUserID, "", "")
assert.False(t, user.HasAdminPower())
}
@@ -1115,7 +1119,7 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
account.Users[mockServiceUserID] = &types.User{
Id: mockServiceUserID,
Role: "user",
@@ -1149,7 +1153,7 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) {
}
t.Cleanup(cleanup)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", false)
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
account.Users[mockServiceUserID] = &types.User{
Id: mockServiceUserID,
Role: "user",
@@ -1320,13 +1324,13 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
// create an account and an admin user
account, err := manager.GetOrCreateAccountByUser(context.Background(), ownerUserID, "netbird.io")
account, err := manager.GetOrCreateAccountByUser(context.Background(), auth.UserAuth{UserId: ownerUserID, Domain: "netbird.io"})
if err != nil {
t.Fatal(err)
}
// create other users
account.Users[regularUserID] = types.NewRegularUser(regularUserID)
account.Users[regularUserID] = types.NewRegularUser(regularUserID, "", "")
account.Users[adminUserID] = types.NewAdminUser(adminUserID)
account.Users[serviceUserID] = &types.User{IsServiceUser: true, Id: serviceUserID, Role: types.UserRoleAdmin, ServiceUserName: "service"}
err = manager.Store.SaveAccount(context.Background(), account)
@@ -1516,7 +1520,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
}
t.Cleanup(cleanup)
account1 := newAccountWithId(context.Background(), "account1", "ownerAccount1", "", false)
account1 := newAccountWithId(context.Background(), "account1", "ownerAccount1", "", "", "", false)
targetId := "user2"
account1.Users[targetId] = &types.User{
Id: targetId,
@@ -1525,7 +1529,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
}
require.NoError(t, s.SaveAccount(context.Background(), account1))
account2 := newAccountWithId(context.Background(), "account2", "ownerAccount2", "", false)
account2 := newAccountWithId(context.Background(), "account2", "ownerAccount2", "", "", "", false)
require.NoError(t, s.SaveAccount(context.Background(), account2))
permissionsManager := permissions.NewManager(s)
@@ -1552,7 +1556,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
}
t.Cleanup(cleanup)
account1 := newAccountWithId(context.Background(), "account1", "account1Owner", "", false)
account1 := newAccountWithId(context.Background(), "account1", "account1Owner", "", "", "", false)
account1.Settings.RegularUsersViewBlocked = false
account1.Users["blocked-user"] = &types.User{
Id: "blocked-user",
@@ -1574,7 +1578,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
}
require.NoError(t, store.SaveAccount(context.Background(), account1))
account2 := newAccountWithId(context.Background(), "account2", "account2Owner", "", false)
account2 := newAccountWithId(context.Background(), "account2", "account2Owner", "", "", "", false)
account2.Users["settings-blocked-user"] = &types.User{
Id: "settings-blocked-user",
Role: types.UserRoleUser,
@@ -1771,7 +1775,7 @@ func TestApproveUser(t *testing.T) {
}
// Create account with admin and pending approval user
account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", false)
account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", "", "", false)
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
@@ -1782,7 +1786,7 @@ func TestApproveUser(t *testing.T) {
require.NoError(t, err)
// Create user pending approval
pendingUser := types.NewRegularUser("pending-user")
pendingUser := types.NewRegularUser("pending-user", "", "")
pendingUser.AccountID = account.Id
pendingUser.Blocked = true
pendingUser.PendingApproval = true
@@ -1807,12 +1811,12 @@ func TestApproveUser(t *testing.T) {
assert.Contains(t, err.Error(), "not pending approval")
// Test approval by non-admin should fail
regularUser := types.NewRegularUser("regular-user")
regularUser := types.NewRegularUser("regular-user", "", "")
regularUser.AccountID = account.Id
err = manager.Store.SaveUser(context.Background(), regularUser)
require.NoError(t, err)
pendingUser2 := types.NewRegularUser("pending-user-2")
pendingUser2 := types.NewRegularUser("pending-user-2", "", "")
pendingUser2.AccountID = account.Id
pendingUser2.Blocked = true
pendingUser2.PendingApproval = true
@@ -1830,7 +1834,7 @@ func TestRejectUser(t *testing.T) {
}
// Create account with admin and pending approval user
account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", false)
account := newAccountWithId(context.Background(), "account-1", "admin-user", "example.com", "", "", false)
err = manager.Store.SaveAccount(context.Background(), account)
require.NoError(t, err)
@@ -1841,7 +1845,7 @@ func TestRejectUser(t *testing.T) {
require.NoError(t, err)
// Create user pending approval
pendingUser := types.NewRegularUser("pending-user")
pendingUser := types.NewRegularUser("pending-user", "", "")
pendingUser.AccountID = account.Id
pendingUser.Blocked = true
pendingUser.PendingApproval = true
@@ -1857,7 +1861,7 @@ func TestRejectUser(t *testing.T) {
require.Error(t, err)
// Test rejection of non-pending user should fail
regularUser := types.NewRegularUser("regular-user")
regularUser := types.NewRegularUser("regular-user", "", "")
regularUser.AccountID = account.Id
err = manager.Store.SaveUser(context.Background(), regularUser)
require.NoError(t, err)
@@ -1867,7 +1871,7 @@ func TestRejectUser(t *testing.T) {
assert.Contains(t, err.Error(), "not pending approval")
// Test rejection by non-admin should fail
pendingUser2 := types.NewRegularUser("pending-user-2")
pendingUser2 := types.NewRegularUser("pending-user-2", "", "")
pendingUser2.AccountID = account.Id
pendingUser2.Blocked = true
pendingUser2.PendingApproval = true
@@ -1877,3 +1881,149 @@ func TestRejectUser(t *testing.T) {
err = manager.RejectUser(context.Background(), account.Id, regularUser.Id, pendingUser2.Id)
require.Error(t, err)
}
func TestUser_Operations_WithEmbeddedIDP(t *testing.T) {
ctx := context.Background()
// Create temporary directory for Dex
tmpDir := t.TempDir()
dexDataDir := tmpDir + "/dex"
require.NoError(t, os.MkdirAll(dexDataDir, 0700))
// Create embedded IDP config
embeddedIdPConfig := &idp.EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
Storage: idp.EmbeddedStorageConfig{
Type: "sqlite3",
Config: idp.EmbeddedStorageTypeConfig{
File: dexDataDir + "/dex.db",
},
},
}
// Create embedded IDP manager
embeddedIdp, err := idp.NewEmbeddedIdPManager(ctx, embeddedIdPConfig, nil)
require.NoError(t, err)
defer func() { _ = embeddedIdp.Stop(ctx) }()
// Create test store
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", tmpDir)
require.NoError(t, err)
defer cleanup()
// Create account with owner user
account := newAccountWithId(ctx, mockAccountID, mockUserID, "", "owner@test.com", "Owner User", false)
require.NoError(t, testStore.SaveAccount(ctx, account))
// Create mock network map controller
ctrl := gomock.NewController(t)
networkMapControllerMock := network_map.NewMockController(ctrl)
networkMapControllerMock.EXPECT().
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil).
AnyTimes()
// Create account manager with embedded IDP
permissionsManager := permissions.NewManager(testStore)
am := DefaultAccountManager{
Store: testStore,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: permissionsManager,
idpManager: embeddedIdp,
cacheLoading: map[string]chan struct{}{},
networkMapController: networkMapControllerMock,
}
// Initialize cache manager
cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn)
require.NoError(t, err)
am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cacheStore)
am.externalCacheManager = nbcache.NewUserDataCache(cacheStore)
t.Run("create regular user returns password", func(t *testing.T) {
userInfo, err := am.CreateUser(ctx, mockAccountID, mockUserID, &types.UserInfo{
Email: "newuser@test.com",
Name: "New User",
Role: "user",
AutoGroups: []string{},
IsServiceUser: false,
})
require.NoError(t, err)
require.NotNil(t, userInfo)
// Verify user data
assert.Equal(t, "newuser@test.com", userInfo.Email)
assert.Equal(t, "New User", userInfo.Name)
assert.Equal(t, "user", userInfo.Role)
assert.NotEmpty(t, userInfo.ID)
// IMPORTANT: Password should be returned for embedded IDP
assert.NotEmpty(t, userInfo.Password, "Password should be returned for embedded IDP user")
t.Logf("Created user: ID=%s, Email=%s, Password=%s", userInfo.ID, userInfo.Email, userInfo.Password)
// Verify user ID is in Dex encoded format
rawUserID, connectorID, err := dex.DecodeDexUserID(userInfo.ID)
require.NoError(t, err)
assert.NotEmpty(t, rawUserID)
assert.Equal(t, "local", connectorID)
t.Logf("Decoded user ID: rawUserID=%s, connectorID=%s", rawUserID, connectorID)
// Verify user exists in database with correct data
dbUser, err := testStore.GetUserByUserID(ctx, store.LockingStrengthNone, userInfo.ID)
require.NoError(t, err)
assert.Equal(t, "newuser@test.com", dbUser.Email)
assert.Equal(t, "New User", dbUser.Name)
// Store user ID for delete test
createdUserID := userInfo.ID
t.Run("delete user works", func(t *testing.T) {
err := am.DeleteUser(ctx, mockAccountID, mockUserID, createdUserID)
require.NoError(t, err)
// Verify user is deleted from database
_, err = testStore.GetUserByUserID(ctx, store.LockingStrengthNone, createdUserID)
assert.Error(t, err, "User should be deleted from database")
})
})
t.Run("create service user does not return password", func(t *testing.T) {
userInfo, err := am.CreateUser(ctx, mockAccountID, mockUserID, &types.UserInfo{
Name: "Service User",
Role: "user",
AutoGroups: []string{},
IsServiceUser: true,
})
require.NoError(t, err)
require.NotNil(t, userInfo)
assert.True(t, userInfo.IsServiceUser)
assert.Equal(t, "Service User", userInfo.Name)
// Service users don't have passwords
assert.Empty(t, userInfo.Password, "Service users should not have passwords")
})
t.Run("duplicate email fails", func(t *testing.T) {
// Create first user
_, err := am.CreateUser(ctx, mockAccountID, mockUserID, &types.UserInfo{
Email: "duplicate@test.com",
Name: "First User",
Role: "user",
AutoGroups: []string{},
IsServiceUser: false,
})
require.NoError(t, err)
// Try to create second user with same email
_, err = am.CreateUser(ctx, mockAccountID, mockUserID, &types.UserInfo{
Email: "duplicate@test.com",
Name: "Second User",
Role: "user",
AutoGroups: []string{},
IsServiceUser: false,
})
assert.Error(t, err, "Creating user with duplicate email should fail")
t.Logf("Duplicate email error: %v", err)
})
}