Merge branch 'main' into prototype/reverse-proxy

This commit is contained in:
Viktor Liu
2026-02-08 17:44:49 +08:00
41 changed files with 1638 additions and 268 deletions

View File

@@ -795,6 +795,19 @@ func IsEmbeddedIdp(i idp.Manager) bool {
return ok
}
// IsLocalAuthDisabled checks if local (email/password) authentication is disabled.
// Returns true only when using embedded IDP with local auth disabled in config.
func IsLocalAuthDisabled(ctx context.Context, i idp.Manager) bool {
if isNil(i) {
return false
}
embeddedIdp, ok := i.(*idp.EmbeddedIdPManager)
if !ok {
return false
}
return embeddedIdp.IsLocalAuthDisabled()
}
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
@@ -1657,13 +1670,13 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
}
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
if err != nil {
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
}
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID)
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime)
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
@@ -1671,8 +1684,20 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
return peer, netMap, postureChecks, dnsfwdPort, nil
}
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID)
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err)
return nil
}
if peer.Status.LastSeen.After(streamStartTime) {
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect",
peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339))
return nil
}
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC())
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
}

View File

@@ -58,7 +58,7 @@ type Manager interface {
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
@@ -114,8 +114,8 @@ type Manager interface {
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)

View File

@@ -1881,7 +1881,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
@@ -1952,7 +1952,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second)
@@ -1961,6 +1961,82 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
}
}
func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
}, false)
require.NoError(t, err, "unable to add peer")
t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err, "unable to get peer")
require.True(t, peer.Status.Connected, "peer should be connected")
streamStartTime := time.Now().UTC()
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
require.NoError(t, err)
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.False(t, peer.Status.Connected, "peer should be disconnected")
})
t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should be connected")
streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour)
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
require.NoError(t, err)
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected,
"peer should remain connected because LastSeen > streamStartTime (zombie stream protection)")
})
t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) {
node2SyncTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime)
require.NoError(t, err, "node 2 should connect peer")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should be connected")
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime")
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime)
require.NoError(t, err, "stale connect should not return error")
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should still be connected")
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(),
"LastSeen should NOT be overwritten by stale syncTime from blocked goroutine")
})
}
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
@@ -1983,7 +2059,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{}
@@ -3176,7 +3252,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}, time.Now().UTC())
assert.NoError(b, err)
}

View File

@@ -140,15 +140,15 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}
// Check if embedded IdP is enabled
// Check if embedded IdP is enabled for instance manager
embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager)
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP)
if err != nil {
return nil, fmt.Errorf("failed to create instance manager: %w", err)
}
accounts.AddEndpoints(accountManager, settingsManager, embeddedIdpEnabled, router)
peers.AddEndpoints(accountManager, router, networkMapController)
accounts.AddEndpoints(accountManager, settingsManager, router)
peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager)
users.AddEndpoints(accountManager, router)
users.AddInvitesEndpoints(accountManager, router)
users.AddPublicInvitesEndpoints(accountManager, router)

View File

@@ -36,24 +36,22 @@ const (
// handler is a handler that handles the server.Account HTTP endpoints
type handler struct {
accountManager account.Manager
settingsManager settings.Manager
embeddedIdpEnabled bool
accountManager account.Manager
settingsManager settings.Manager
}
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool, router *mux.Router) {
accountsHandler := newHandler(accountManager, settingsManager, embeddedIdpEnabled)
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) {
accountsHandler := newHandler(accountManager, settingsManager)
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
}
// newHandler creates a new handler HTTP handler
func newHandler(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool) *handler {
func newHandler(accountManager account.Manager, settingsManager settings.Manager) *handler {
return &handler{
accountManager: accountManager,
settingsManager: settingsManager,
embeddedIdpEnabled: embeddedIdpEnabled,
accountManager: accountManager,
settingsManager: settingsManager,
}
}
@@ -165,7 +163,7 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
return
}
resp := toAccountResponse(accountID, settings, meta, onboarding, h.embeddedIdpEnabled)
resp := toAccountResponse(accountID, settings, meta, onboarding)
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
}
@@ -292,7 +290,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
return
}
resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding, h.embeddedIdpEnabled)
resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding)
util.WriteJSONObject(r.Context(), w, &resp)
}
@@ -321,7 +319,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding, embeddedIdpEnabled bool) *api.Account {
func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding) *api.Account {
jwtAllowGroups := settings.JWTAllowGroups
if jwtAllowGroups == nil {
jwtAllowGroups = []string{}
@@ -341,7 +339,8 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
DnsDomain: &settings.DNSDomain,
AutoUpdateVersion: &settings.AutoUpdateVersion,
EmbeddedIdpEnabled: &embeddedIdpEnabled,
EmbeddedIdpEnabled: &settings.EmbeddedIdpEnabled,
LocalAuthDisabled: &settings.LocalAuthDisabled,
}
if settings.NetworkRange.IsValid() {

View File

@@ -33,7 +33,6 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
AnyTimes()
return &handler{
embeddedIdpEnabled: false,
accountManager: &mock_server.MockAccountManager{
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
return account.Settings, nil
@@ -124,6 +123,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
LocalAuthDisabled: br(false),
},
expectedArray: true,
expectedID: accountID,
@@ -148,6 +148,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
LocalAuthDisabled: br(false),
},
expectedArray: false,
expectedID: accountID,
@@ -172,6 +173,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
DnsDomain: sr(""),
AutoUpdateVersion: sr("latest"),
EmbeddedIdpEnabled: br(false),
LocalAuthDisabled: br(false),
},
expectedArray: false,
expectedID: accountID,
@@ -196,6 +198,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
LocalAuthDisabled: br(false),
},
expectedArray: false,
expectedID: accountID,
@@ -220,6 +223,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
LocalAuthDisabled: br(false),
},
expectedArray: false,
expectedID: accountID,
@@ -244,6 +248,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
EmbeddedIdpEnabled: br(false),
LocalAuthDisabled: br(false),
},
expectedArray: false,
expectedID: accountID,

View File

@@ -46,7 +46,7 @@ func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) {
util.WriteErrorResponse("failed to check instance status", http.StatusInternalServerError, w)
return
}
log.WithContext(r.Context()).Infof("instance setup status: %v", setupRequired)
util.WriteJSONObject(r.Context(), w, api.InstanceStatus{
SetupRequired: setupRequired,
})

View File

@@ -17,6 +17,7 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
@@ -26,11 +27,12 @@ import (
// Handler is a handler that returns peers of the account
type Handler struct {
accountManager account.Manager
permissionsManager permissions.Manager
networkMapController network_map.Controller
}
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller) {
peersHandler := NewHandler(accountManager, networkMapController)
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller, permissionsManager permissions.Manager) {
peersHandler := NewHandler(accountManager, networkMapController, permissionsManager)
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
@@ -42,10 +44,11 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMap
}
// NewHandler creates a new peers Handler
func NewHandler(accountManager account.Manager, networkMapController network_map.Controller) *Handler {
func NewHandler(accountManager account.Manager, networkMapController network_map.Controller, permissionsManager permissions.Manager) *Handler {
return &Handler{
accountManager: accountManager,
networkMapController: networkMapController,
permissionsManager: permissionsManager,
}
}
@@ -359,13 +362,19 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
return
}
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
user, err := h.accountManager.GetUserByID(r.Context(), userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := h.accountManager.GetUserByID(r.Context(), userID)
err = h.permissionsManager.ValidateAccountAccess(r.Context(), accountID, user, false)
if err != nil {
util.WriteError(r.Context(), status.NewPermissionDeniedError(), w)
return
}
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -13,13 +13,15 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"go.uber.org/mock/gomock"
ugomock "go.uber.org/mock/gomock"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbcontext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
@@ -102,7 +104,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
},
}
ctrl := gomock.NewController(t)
ctrl := ugomock.NewController(t)
networkMapController := network_map.NewMockController(ctrl)
networkMapController.EXPECT().
@@ -110,6 +112,10 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
Return("domain").
AnyTimes()
ctrl2 := gomock.NewController(t)
permissionsManager := permissions.NewMockManager(ctrl2)
permissionsManager.EXPECT().ValidateAccountAccess(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
return &Handler{
accountManager: &mock_server.MockAccountManager{
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
@@ -199,6 +205,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
},
},
networkMapController: networkMapController,
permissionsManager: permissionsManager,
}
}

View File

@@ -205,6 +205,14 @@ func TestCreateInvite(t *testing.T) {
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
},
},
{
name: "local auth disabled",
requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`,
expectedStatus: http.StatusPreconditionFailed,
mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) {
return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
},
},
{
name: "invalid JSON",
requestBody: `{invalid json}`,
@@ -376,6 +384,15 @@ func TestAcceptInvite(t *testing.T) {
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
},
},
{
name: "local auth disabled",
token: testInviteToken,
requestBody: `{"password":"SecurePass123!"}`,
expectedStatus: http.StatusPreconditionFailed,
mockFunc: func(ctx context.Context, token, password string) error {
return status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
},
},
{
name: "missing token",
token: "",

View File

@@ -73,7 +73,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
proxyController := integrations.NewController(store)
userManager := users.NewManager(store)
permissionsManager := permissions.NewManager(store)
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager, settings.IdpConfig{})
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)

View File

@@ -43,6 +43,11 @@ type EmbeddedIdPConfig struct {
Owner *OwnerConfig
// SignKeyRefreshEnabled enables automatic key rotation for signing keys
SignKeyRefreshEnabled bool
// LocalAuthDisabled disables the local (email/password) authentication connector.
// When true, users cannot authenticate via email/password, only via external identity providers.
// Existing local users are preserved and will be able to login again if re-enabled.
// Cannot be enabled if no external identity provider connectors are configured.
LocalAuthDisabled bool
}
// EmbeddedStorageConfig holds storage configuration for the embedded IdP.
@@ -110,6 +115,8 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
Issuer: "NetBird",
Theme: "light",
},
// Always enable password DB initially - we disable the local connector after startup if needed.
// This ensures Dex has at least one connector during initialization.
EnablePasswordDB: true,
StaticClients: []storage.Client{
{
@@ -197,11 +204,32 @@ func NewEmbeddedIdPManager(ctx context.Context, config *EmbeddedIdPConfig, appMe
return nil, err
}
log.WithContext(ctx).Debugf("initializing embedded Dex IDP with config: %+v", config)
provider, err := dex.NewProviderFromYAML(ctx, yamlConfig)
if err != nil {
return nil, fmt.Errorf("failed to create embedded IdP provider: %w", err)
}
// If local auth is disabled, validate that other connectors exist
if config.LocalAuthDisabled {
hasOthers, err := provider.HasNonLocalConnectors(ctx)
if err != nil {
_ = provider.Stop(ctx)
return nil, fmt.Errorf("failed to check connectors: %w", err)
}
if !hasOthers {
_ = provider.Stop(ctx)
return nil, fmt.Errorf("cannot disable local authentication: no other identity providers configured")
}
// Ensure local connector is removed (it might exist from a previous run)
if err := provider.DisableLocalAuth(ctx); err != nil {
_ = provider.Stop(ctx)
return nil, fmt.Errorf("failed to disable local auth: %w", err)
}
log.WithContext(ctx).Info("local authentication disabled - only external identity providers can be used")
}
log.WithContext(ctx).Infof("embedded Dex IDP initialized with issuer: %s", yamlConfig.Issuer)
return &EmbeddedIdPManager{
@@ -286,6 +314,8 @@ func (m *EmbeddedIdPManager) GetAllAccounts(ctx context.Context) (map[string][]*
return nil, fmt.Errorf("failed to list users: %w", err)
}
log.WithContext(ctx).Debugf("retrieved %d users from embedded IdP", len(users))
indexedUsers := make(map[string][]*UserData)
for _, user := range users {
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], &UserData{
@@ -295,11 +325,17 @@ func (m *EmbeddedIdPManager) GetAllAccounts(ctx context.Context) (map[string][]*
})
}
log.WithContext(ctx).Debugf("retrieved %d users from embedded IdP", len(indexedUsers[UnsetAccountID]))
return indexedUsers, nil
}
// CreateUser creates a new user in the embedded IdP.
func (m *EmbeddedIdPManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
if m.config.LocalAuthDisabled {
return nil, fmt.Errorf("local user creation is disabled")
}
if m.appMetrics != nil {
m.appMetrics.IDPMetrics().CountCreateUser()
}
@@ -369,6 +405,10 @@ func (m *EmbeddedIdPManager) GetUserByEmail(ctx context.Context, email string) (
// Unlike CreateUser which auto-generates a password, this method uses the provided password.
// This is useful for instance setup where the user provides their own password.
func (m *EmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email, password, name string) (*UserData, error) {
if m.config.LocalAuthDisabled {
return nil, fmt.Errorf("local user creation is disabled")
}
if m.appMetrics != nil {
m.appMetrics.IDPMetrics().CountCreateUser()
}
@@ -558,3 +598,13 @@ func (m *EmbeddedIdPManager) GetClientIDs() []string {
func (m *EmbeddedIdPManager) GetUserIDClaim() string {
return defaultUserIDClaim
}
// IsLocalAuthDisabled returns whether local authentication is disabled based on configuration.
func (m *EmbeddedIdPManager) IsLocalAuthDisabled() bool {
return m.config.LocalAuthDisabled
}
// HasNonLocalConnectors checks if there are any identity provider connectors other than local.
func (m *EmbeddedIdPManager) HasNonLocalConnectors(ctx context.Context) (bool, error) {
return m.provider.HasNonLocalConnectors(ctx)
}

View File

@@ -370,3 +370,234 @@ func TestEmbeddedIdPManager_GetLocalKeysLocation(t *testing.T) {
})
}
}
func TestEmbeddedIdPManager_LocalAuthDisabled(t *testing.T) {
ctx := context.Background()
t.Run("cannot start with local auth disabled without other connectors", func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
config := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
LocalAuthDisabled: true,
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: filepath.Join(tmpDir, "dex.db"),
},
},
}
_, err = NewEmbeddedIdPManager(ctx, config, nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "no other identity providers configured")
})
t.Run("local auth enabled by default", func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
config := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: filepath.Join(tmpDir, "dex.db"),
},
},
}
manager, err := NewEmbeddedIdPManager(ctx, config, nil)
require.NoError(t, err)
defer func() { _ = manager.Stop(ctx) }()
// Verify local auth is enabled by default
assert.False(t, manager.IsLocalAuthDisabled())
})
t.Run("start with local auth disabled when connector exists", func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
dbFile := filepath.Join(tmpDir, "dex.db")
// First, create a manager with local auth enabled and add a connector
config1 := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: dbFile,
},
},
}
manager1, err := NewEmbeddedIdPManager(ctx, config1, nil)
require.NoError(t, err)
// Create a user
userData, err := manager1.CreateUser(ctx, "preserved@example.com", "Preserved User", "account1", "admin@example.com")
require.NoError(t, err)
userID := userData.ID
// Add an external connector (Google doesn't require OIDC discovery)
_, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{
ID: "google-test",
Name: "Google Test",
Type: "google",
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
})
require.NoError(t, err)
// Stop the first manager
err = manager1.Stop(ctx)
require.NoError(t, err)
// Now create a new manager with local auth disabled
config2 := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
LocalAuthDisabled: true,
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: dbFile,
},
},
}
manager2, err := NewEmbeddedIdPManager(ctx, config2, nil)
require.NoError(t, err)
defer func() { _ = manager2.Stop(ctx) }()
// Verify local auth is disabled via config
assert.True(t, manager2.IsLocalAuthDisabled())
// Verify the user still exists in storage (just can't login via local)
lookedUp, err := manager2.GetUserDataByID(ctx, userID, AppMetadata{})
require.NoError(t, err)
assert.Equal(t, "preserved@example.com", lookedUp.Email)
})
t.Run("CreateUser fails when local auth is disabled", func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
dbFile := filepath.Join(tmpDir, "dex.db")
// First, create a manager and add an external connector
config1 := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: dbFile,
},
},
}
manager1, err := NewEmbeddedIdPManager(ctx, config1, nil)
require.NoError(t, err)
_, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{
ID: "google-test",
Name: "Google Test",
Type: "google",
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
})
require.NoError(t, err)
err = manager1.Stop(ctx)
require.NoError(t, err)
// Create manager with local auth disabled
config2 := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
LocalAuthDisabled: true,
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: dbFile,
},
},
}
manager2, err := NewEmbeddedIdPManager(ctx, config2, nil)
require.NoError(t, err)
defer func() { _ = manager2.Stop(ctx) }()
// Try to create a user - should fail
_, err = manager2.CreateUser(ctx, "newuser@example.com", "New User", "account1", "admin@example.com")
require.Error(t, err)
assert.Contains(t, err.Error(), "local user creation is disabled")
})
t.Run("CreateUserWithPassword fails when local auth is disabled", func(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
dbFile := filepath.Join(tmpDir, "dex.db")
// First, create a manager and add an external connector
config1 := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: dbFile,
},
},
}
manager1, err := NewEmbeddedIdPManager(ctx, config1, nil)
require.NoError(t, err)
_, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{
ID: "google-test",
Name: "Google Test",
Type: "google",
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
})
require.NoError(t, err)
err = manager1.Stop(ctx)
require.NoError(t, err)
// Create manager with local auth disabled
config2 := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "http://localhost:5556/dex",
LocalAuthDisabled: true,
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: dbFile,
},
},
}
manager2, err := NewEmbeddedIdPManager(ctx, config2, nil)
require.NoError(t, err)
defer func() { _ = manager2.Stop(ctx) }()
// Try to create a user with password - should fail
_, err = manager2.CreateUserWithPassword(ctx, "newuser@example.com", "SecurePass123!", "New User")
require.Error(t, err)
assert.Contains(t, err.Error(), "local user creation is disabled")
})
}

View File

@@ -104,13 +104,22 @@ func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager)
}
func (m *DefaultManager) loadSetupRequired(ctx context.Context) error {
// Check if there are any accounts in the NetBird store
numAccounts, err := m.store.GetAccountsCounter(ctx)
if err != nil {
return err
}
hasAccounts := numAccounts > 0
// Check if there are any users in the embedded IdP (Dex)
users, err := m.embeddedIdpManager.GetAllAccounts(ctx)
if err != nil {
return err
}
hasLocalUsers := len(users) > 0
m.setupMu.Lock()
m.setupRequired = len(users) == 0
m.setupRequired = !(hasAccounts || hasLocalUsers)
m.setupMu.Unlock()
return nil

View File

@@ -610,6 +610,7 @@ func TestSync10PeersGetUpdates(t *testing.T) {
initialPeers := 10
additionalPeers := 10
expectedPeerCount := initialPeers + additionalPeers - 1 // -1 because peer doesn't see itself
var peers []wgtypes.Key
for i := 0; i < initialPeers; i++ {
@@ -618,8 +619,19 @@ func TestSync10PeersGetUpdates(t *testing.T) {
peers = append(peers, key)
}
// Track the maximum peer count each peer has seen
type peerState struct {
mu sync.Mutex
maxPeerCount int
done bool
}
peerStates := make(map[string]*peerState)
for _, pk := range peers {
peerStates[pk.PublicKey().String()] = &peerState{}
}
var wg sync.WaitGroup
wg.Add(initialPeers + initialPeers*additionalPeers)
wg.Add(initialPeers) // One completion per initial peer
var syncClients []mgmtProto.ManagementService_SyncClient
for _, pk := range peers {
@@ -643,6 +655,9 @@ func TestSync10PeersGetUpdates(t *testing.T) {
syncClients = append(syncClients, s)
go func(pk wgtypes.Key, syncStream mgmtProto.ManagementService_SyncClient) {
pubKey := pk.PublicKey().String()
state := peerStates[pubKey]
for {
encMsg := &mgmtProto.EncryptedMessage{}
err := syncStream.RecvMsg(encMsg)
@@ -651,19 +666,28 @@ func TestSync10PeersGetUpdates(t *testing.T) {
}
decryptedBytes, decErr := encryption.Decrypt(encMsg.Body, ts.serverPubKey, pk)
if decErr != nil {
t.Errorf("failed to decrypt SyncResponse for peer %s: %v", pk.PublicKey().String(), decErr)
t.Errorf("failed to decrypt SyncResponse for peer %s: %v", pubKey, decErr)
return
}
resp := &mgmtProto.SyncResponse{}
umErr := pb.Unmarshal(decryptedBytes, resp)
if umErr != nil {
t.Errorf("failed to unmarshal SyncResponse for peer %s: %v", pk.PublicKey().String(), umErr)
t.Errorf("failed to unmarshal SyncResponse for peer %s: %v", pubKey, umErr)
return
}
// We only count if there's a new peer update
if len(resp.GetRemotePeers()) > 0 {
// Track the maximum peer count seen (due to debouncing, updates are coalesced)
peerCount := len(resp.GetRemotePeers())
state.mu.Lock()
if peerCount > state.maxPeerCount {
state.maxPeerCount = peerCount
}
// Signal completion when this peer has seen all expected peers
if !state.done && state.maxPeerCount >= expectedPeerCount {
state.done = true
wg.Done()
}
state.mu.Unlock()
}
}(pk, s)
}
@@ -677,7 +701,30 @@ func TestSync10PeersGetUpdates(t *testing.T) {
time.Sleep(time.Duration(n) * time.Millisecond)
}
wg.Wait()
// Wait for debouncer to flush final updates (debounce interval is 1000ms)
time.Sleep(1500 * time.Millisecond)
// Wait with timeout
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
// Success - all peers received expected peer count
case <-time.After(5 * time.Second):
// Timeout - report which peers didn't receive all updates
t.Error("Timeout waiting for all peers to receive updates")
for pubKey, state := range peerStates {
state.mu.Lock()
if state.maxPeerCount < expectedPeerCount {
t.Errorf("Peer %s only saw %d peers, expected %d", pubKey, state.maxPeerCount, expectedPeerCount)
}
state.mu.Unlock()
}
}
for _, sc := range syncClients {
err := sc.CloseSend()

View File

@@ -37,8 +37,8 @@ type MockAccountManager struct {
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP, syncTime time.Time) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
@@ -214,16 +214,15 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use
return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented")
}
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if am.SyncAndMarkPeerFunc != nil {
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP, syncTime)
}
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error {
// TODO implement me
panic("implement me")
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
return nil
}
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {
@@ -323,9 +322,9 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
}
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error {
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
if am.MarkPeerConnectedFunc != nil {
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP)
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP, syncTime)
}
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}

View File

@@ -103,11 +103,13 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
}
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error {
// syncTime is used as the LastSeen timestamp and for stale request detection
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
var peer *nbpeer.Peer
var settings *types.Settings
var expired bool
var err error
var skipped bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey)
@@ -115,9 +117,19 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
return err
}
expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID)
if connected && !syncTime.After(peer.Status.LastSeen) {
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s >= syncTime=%s), skipping connect",
peer.ID, peer.Status.LastSeen.Format(time.RFC3339), syncTime.Format(time.RFC3339))
skipped = true
return nil
}
expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID, syncTime)
return err
})
if skipped {
return nil
}
if err != nil {
return err
}
@@ -147,10 +159,10 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
return nil
}
func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) {
func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string, syncTime time.Time) (bool, error) {
oldStatus := peer.Status.Copy()
newStatus := oldStatus
newStatus.LastSeen = time.Now().UTC()
newStatus.LastSeen = syncTime
newStatus.Connected = connected
// whenever peer got connected that means that it logged in successfully
if newStatus.Connected {

View File

@@ -24,19 +24,28 @@ type Manager interface {
UpdateExtraSettings(ctx context.Context, accountID, userID string, extraSettings *types.ExtraSettings) (bool, error)
}
// IdpConfig holds IdP-related configuration that is set at runtime
// and not stored in the database.
type IdpConfig struct {
EmbeddedIdpEnabled bool
LocalAuthDisabled bool
}
type managerImpl struct {
store store.Store
extraSettingsManager extra_settings.Manager
userManager users.Manager
permissionsManager permissions.Manager
idpConfig IdpConfig
}
func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager, permissionsManager permissions.Manager) Manager {
func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager, permissionsManager permissions.Manager, idpConfig IdpConfig) Manager {
return &managerImpl{
store: store,
extraSettingsManager: extraSettingsManager,
userManager: userManager,
permissionsManager: permissionsManager,
idpConfig: idpConfig,
}
}
@@ -74,6 +83,10 @@ func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string)
settings.Extra.FlowDnsCollectionEnabled = extraSettings.FlowDnsCollectionEnabled
}
// Fill in IdP-related runtime settings
settings.EmbeddedIdpEnabled = m.idpConfig.EmbeddedIdpEnabled
settings.LocalAuthDisabled = m.idpConfig.LocalAuthDisabled
return settings, nil
}

View File

@@ -55,6 +55,14 @@ type Settings struct {
// AutoUpdateVersion client auto-update version
AutoUpdateVersion string `gorm:"default:'disabled'"`
// EmbeddedIdpEnabled indicates if the embedded identity provider is enabled.
// This is a runtime-only field, not stored in the database.
EmbeddedIdpEnabled bool `gorm:"-"`
// LocalAuthDisabled indicates if local (email/password) authentication is disabled.
// This is a runtime-only field, not stored in the database.
LocalAuthDisabled bool `gorm:"-"`
}
// Copy copies the Settings struct
@@ -76,6 +84,8 @@ func (s *Settings) Copy() *Settings {
DNSDomain: s.DNSDomain,
NetworkRange: s.NetworkRange,
AutoUpdateVersion: s.AutoUpdateVersion,
EmbeddedIdpEnabled: s.EmbeddedIdpEnabled,
LocalAuthDisabled: s.LocalAuthDisabled,
}
if s.Extra != nil {
settings.Extra = s.Extra.Copy()

View File

@@ -191,6 +191,10 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID
// Unlike createNewIdpUser, this method fetches user data directly from the database
// since the embedded IdP usage ensures the username and email are stored locally in the User table.
func (am *DefaultAccountManager) createEmbeddedIdpUser(ctx context.Context, accountID string, inviterID string, invite *types.UserInfo) (*idp.UserData, error) {
if IsLocalAuthDisabled(ctx, am.idpManager) {
return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
}
inviter, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, inviterID)
if err != nil {
return nil, fmt.Errorf("failed to get inviter user: %w", err)
@@ -1462,6 +1466,10 @@ func (am *DefaultAccountManager) CreateUserInvite(ctx context.Context, accountID
return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
}
if IsLocalAuthDisabled(ctx, am.idpManager) {
return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
}
if err := validateUserInvite(invite); err != nil {
return nil, err
}
@@ -1621,6 +1629,10 @@ func (am *DefaultAccountManager) AcceptUserInvite(ctx context.Context, token, pa
return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider")
}
if IsLocalAuthDisabled(ctx, am.idpManager) {
return status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider")
}
if password == "" {
return status.Errorf(status.InvalidArgument, "password is required")
}