[management] Legacy to embedded IdP migration tool (#5586)

This commit is contained in:
shuuri-labs
2026-04-01 12:53:19 +01:00
committed by GitHub
parent 4d3e2f8ad3
commit 940f530ac2
24 changed files with 4023 additions and 31 deletions

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/telemetry"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
const (
@@ -48,6 +49,8 @@ type EmbeddedIdPConfig struct {
// Existing local users are preserved and will be able to login again if re-enabled.
// Cannot be enabled if no external identity provider connectors are configured.
LocalAuthDisabled bool
// StaticConnectors are additional connectors to seed during initialization
StaticConnectors []dex.Connector
}
// EmbeddedStorageConfig holds storage configuration for the embedded IdP.
@@ -157,6 +160,7 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
RedirectURIs: cliRedirectURIs,
},
},
StaticConnectors: c.StaticConnectors,
}
// Add owner user if provided
@@ -193,6 +197,9 @@ type OAuthConfigProvider interface {
// Management server has embedded Dex and can validate tokens via localhost,
// avoiding external network calls and DNS resolution issues during startup.
GetLocalKeysLocation() string
// GetKeyFetcher returns a KeyFetcher that reads keys directly from the IDP storage,
// or nil if direct key fetching is not supported (falls back to HTTP).
GetKeyFetcher() nbjwt.KeyFetcher
GetClientIDs() []string
GetUserIDClaim() string
GetTokenEndpoint() string
@@ -593,6 +600,11 @@ func (m *EmbeddedIdPManager) GetCLIRedirectURLs() []string {
return m.config.CLIRedirectURIs
}
// GetKeyFetcher returns a KeyFetcher that reads keys directly from Dex storage.
func (m *EmbeddedIdPManager) GetKeyFetcher() nbjwt.KeyFetcher {
return m.provider.GetJWKS
}
// GetKeysLocation returns the JWKS endpoint URL for token validation.
func (m *EmbeddedIdPManager) GetKeysLocation() string {
return m.provider.GetKeysLocation()

View File

@@ -0,0 +1,235 @@
// Package migration provides utility functions for migrating from the external IdP solution in pre v0.62.0
// to the new embedded IdP manager (Dex based), which is the default in v0.62.0 and later.
// It includes functions to seed connectors and migrate existing users to use these connectors.
package migration
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"os"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/types"
)
// Server is the dependency interface that migration functions use to access
// the main data store and the activity event store.
type Server interface {
Store() Store
EventStore() EventStore // may return nil
}
const idpSeedInfoKey = "IDP_SEED_INFO"
const dryRunEnvKey = "NB_IDP_MIGRATION_DRY_RUN"
func isDryRun() bool {
return os.Getenv(dryRunEnvKey) == "true"
}
var ErrNoSeedInfo = errors.New("no seed info found in environment")
// SeedConnectorFromEnv reads the IDP_SEED_INFO env var, base64-decodes it,
// and JSON-unmarshals it into a dex.Connector. Returns nil if not set.
func SeedConnectorFromEnv() (*dex.Connector, error) {
val, ok := os.LookupEnv(idpSeedInfoKey)
if !ok || val == "" {
return nil, ErrNoSeedInfo
}
decoded, err := base64.StdEncoding.DecodeString(val)
if err != nil {
return nil, fmt.Errorf("base64 decode: %w", err)
}
var conn dex.Connector
if err := json.Unmarshal(decoded, &conn); err != nil {
return nil, fmt.Errorf("json unmarshal: %w", err)
}
return &conn, nil
}
// MigrateUsersToStaticConnectors re-keys every user ID in the main store (and
// the activity store, if present) so that it encodes the given connector ID,
// skipping users that have already been migrated. Set NB_IDP_MIGRATION_DRY_RUN=true
// to log what would happen without writing any changes.
func MigrateUsersToStaticConnectors(s Server, conn *dex.Connector) error {
ctx := context.Background()
if isDryRun() {
log.Info("[DRY RUN] migration dry-run mode enabled, no changes will be written")
}
users, err := s.Store().ListUsers(ctx)
if err != nil {
return fmt.Errorf("failed to list users: %w", err)
}
// Reconciliation pass: fix activity store for users already migrated in main DB
// but whose activity references may still use old IDs (from a previous partial failure).
if s.EventStore() != nil && !isDryRun() {
if err := reconcileActivityStore(ctx, s.EventStore(), users); err != nil {
return err
}
}
var migratedCount, skippedCount int
for _, user := range users {
_, _, decErr := dex.DecodeDexUserID(user.Id)
if decErr == nil {
skippedCount++
continue
}
newUserID := dex.EncodeDexUserID(user.Id, conn.ID)
if isDryRun() {
log.Infof("[DRY RUN] would migrate user %s -> %s (account: %s)", user.Id, newUserID, user.AccountID)
migratedCount++
continue
}
if err := migrateUser(ctx, s, user.Id, user.AccountID, newUserID); err != nil {
return err
}
migratedCount++
}
if isDryRun() {
log.Infof("[DRY RUN] migration summary: %d users would be migrated, %d already migrated", migratedCount, skippedCount)
} else {
log.Infof("migration complete: %d users migrated, %d already migrated", migratedCount, skippedCount)
}
return nil
}
// reconcileActivityStore updates activity store references for users already migrated
// in the main DB whose activity entries may still use old IDs from a previous partial failure.
func reconcileActivityStore(ctx context.Context, eventStore EventStore, users []*types.User) error {
for _, user := range users {
originalID, _, err := dex.DecodeDexUserID(user.Id)
if err != nil {
// skip users that aren't migrated, they will be handled in the main migration loop
continue
}
if err := eventStore.UpdateUserID(ctx, originalID, user.Id); err != nil {
return fmt.Errorf("reconcile activity store for user %s: %w", user.Id, err)
}
}
return nil
}
// migrateUser updates a single user's ID in both the main store and the activity store.
func migrateUser(ctx context.Context, s Server, oldID, accountID, newID string) error {
if err := s.Store().UpdateUserID(ctx, accountID, oldID, newID); err != nil {
return fmt.Errorf("failed to update user ID for user %s: %w", oldID, err)
}
if s.EventStore() == nil {
return nil
}
if err := s.EventStore().UpdateUserID(ctx, oldID, newID); err != nil {
return fmt.Errorf("failed to update activity store user ID for user %s: %w", oldID, err)
}
return nil
}
// PopulateUserInfo fetches user email and name from the external IDP and updates
// the store for users that are missing this information.
func PopulateUserInfo(s Server, idpManager idp.Manager, dryRun bool) error {
ctx := context.Background()
users, err := s.Store().ListUsers(ctx)
if err != nil {
return fmt.Errorf("failed to list users: %w", err)
}
// Build a map of IDP user ID -> UserData from the external IDP
allAccounts, err := idpManager.GetAllAccounts(ctx)
if err != nil {
return fmt.Errorf("failed to fetch accounts from IDP: %w", err)
}
idpUsers := make(map[string]*idp.UserData)
for _, accountUsers := range allAccounts {
for _, userData := range accountUsers {
idpUsers[userData.ID] = userData
}
}
log.Infof("fetched %d users from IDP", len(idpUsers))
var updatedCount, skippedCount, notFoundCount int
for _, user := range users {
if user.IsServiceUser {
skippedCount++
continue
}
if user.Email != "" && user.Name != "" {
skippedCount++
continue
}
// The user ID in the store may be the original IDP ID or a Dex-encoded ID.
// Try to decode the Dex format first to get the original IDP ID.
lookupID := user.Id
if originalID, _, decErr := dex.DecodeDexUserID(user.Id); decErr == nil {
lookupID = originalID
}
idpUser, found := idpUsers[lookupID]
if !found {
notFoundCount++
log.Debugf("user %s (lookup: %s) not found in IDP, skipping", user.Id, lookupID)
continue
}
email := user.Email
name := user.Name
if email == "" && idpUser.Email != "" {
email = idpUser.Email
}
if name == "" && idpUser.Name != "" {
name = idpUser.Name
}
if email == user.Email && name == user.Name {
skippedCount++
continue
}
if dryRun {
log.Infof("[DRY RUN] would update user %s: email=%q, name=%q", user.Id, email, name)
updatedCount++
continue
}
if err := s.Store().UpdateUserInfo(ctx, user.Id, email, name); err != nil {
return fmt.Errorf("failed to update user info for %s: %w", user.Id, err)
}
log.Infof("updated user %s: email=%q, name=%q", user.Id, email, name)
updatedCount++
}
if dryRun {
log.Infof("[DRY RUN] user info summary: %d would be updated, %d skipped, %d not found in IDP", updatedCount, skippedCount, notFoundCount)
} else {
log.Infof("user info population complete: %d updated, %d skipped, %d not found in IDP", updatedCount, skippedCount, notFoundCount)
}
return nil
}

View File

@@ -0,0 +1,828 @@
package migration
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/types"
)
// testStore is a hand-written mock for MigrationStore.
type testStore struct {
listUsersFunc func(ctx context.Context) ([]*types.User, error)
updateUserIDFunc func(ctx context.Context, accountID, oldUserID, newUserID string) error
updateUserInfoFunc func(ctx context.Context, userID, email, name string) error
checkSchemaFunc func(checks []SchemaCheck) []SchemaError
updateCalls []updateUserIDCall
updateInfoCalls []updateUserInfoCall
}
type updateUserIDCall struct {
AccountID string
OldUserID string
NewUserID string
}
type updateUserInfoCall struct {
UserID string
Email string
Name string
}
func (s *testStore) ListUsers(ctx context.Context) ([]*types.User, error) {
return s.listUsersFunc(ctx)
}
func (s *testStore) UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error {
s.updateCalls = append(s.updateCalls, updateUserIDCall{accountID, oldUserID, newUserID})
return s.updateUserIDFunc(ctx, accountID, oldUserID, newUserID)
}
func (s *testStore) UpdateUserInfo(ctx context.Context, userID, email, name string) error {
s.updateInfoCalls = append(s.updateInfoCalls, updateUserInfoCall{userID, email, name})
if s.updateUserInfoFunc != nil {
return s.updateUserInfoFunc(ctx, userID, email, name)
}
return nil
}
func (s *testStore) CheckSchema(checks []SchemaCheck) []SchemaError {
if s.checkSchemaFunc != nil {
return s.checkSchemaFunc(checks)
}
return nil
}
type testServer struct {
store Store
eventStore EventStore
}
func (s *testServer) Store() Store { return s.store }
func (s *testServer) EventStore() EventStore { return s.eventStore }
func TestSeedConnectorFromEnv(t *testing.T) {
t.Run("returns ErrNoSeedInfo when env var is not set", func(t *testing.T) {
os.Unsetenv(idpSeedInfoKey)
conn, err := SeedConnectorFromEnv()
assert.ErrorIs(t, err, ErrNoSeedInfo)
assert.Nil(t, conn)
})
t.Run("returns ErrNoSeedInfo when env var is empty", func(t *testing.T) {
t.Setenv(idpSeedInfoKey, "")
conn, err := SeedConnectorFromEnv()
assert.ErrorIs(t, err, ErrNoSeedInfo)
assert.Nil(t, conn)
})
t.Run("returns error on invalid base64", func(t *testing.T) {
t.Setenv(idpSeedInfoKey, "not-valid-base64!!!")
conn, err := SeedConnectorFromEnv()
assert.NotErrorIs(t, err, ErrNoSeedInfo)
assert.Error(t, err)
assert.Nil(t, conn)
assert.Contains(t, err.Error(), "base64 decode")
})
t.Run("returns error on invalid JSON", func(t *testing.T) {
encoded := base64.StdEncoding.EncodeToString([]byte("not json"))
t.Setenv(idpSeedInfoKey, encoded)
conn, err := SeedConnectorFromEnv()
assert.NotErrorIs(t, err, ErrNoSeedInfo)
assert.Error(t, err)
assert.Nil(t, conn)
assert.Contains(t, err.Error(), "json unmarshal")
})
t.Run("successfully decodes valid connector", func(t *testing.T) {
expected := dex.Connector{
Type: "oidc",
Name: "Test Provider",
ID: "test-provider",
Config: map[string]any{
"issuer": "https://example.com",
"clientID": "my-client-id",
"clientSecret": "my-secret",
},
}
data, err := json.Marshal(expected)
require.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
t.Setenv(idpSeedInfoKey, encoded)
conn, err := SeedConnectorFromEnv()
assert.NoError(t, err)
require.NotNil(t, conn)
assert.Equal(t, expected.Type, conn.Type)
assert.Equal(t, expected.Name, conn.Name)
assert.Equal(t, expected.ID, conn.ID)
assert.Equal(t, expected.Config["issuer"], conn.Config["issuer"])
})
}
func TestMigrateUsersToStaticConnectors(t *testing.T) {
connector := &dex.Connector{
Type: "oidc",
Name: "Test Provider",
ID: "test-connector",
}
t.Run("succeeds with no users", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) { return nil, nil },
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { return nil },
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
})
t.Run("returns error when ListUsers fails", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return nil, fmt.Errorf("db error")
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { return nil },
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to list users")
})
t.Run("migrates single user with correct encoded ID", func(t *testing.T) {
user := &types.User{Id: "user-1", AccountID: "account-1"}
expectedNewID := dex.EncodeDexUserID("user-1", "test-connector")
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{user}, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
require.Len(t, ms.updateCalls, 1)
assert.Equal(t, "account-1", ms.updateCalls[0].AccountID)
assert.Equal(t, "user-1", ms.updateCalls[0].OldUserID)
assert.Equal(t, expectedNewID, ms.updateCalls[0].NewUserID)
})
t.Run("migrates multiple users", func(t *testing.T) {
users := []*types.User{
{Id: "user-1", AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
{Id: "user-3", AccountID: "account-2"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
assert.Len(t, ms.updateCalls, 3)
})
t.Run("returns error when UpdateUserID fails", func(t *testing.T) {
users := []*types.User{
{Id: "user-1", AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
}
callCount := 0
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
callCount++
if callCount == 2 {
return fmt.Errorf("update failed")
}
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to update user ID for user user-2")
})
t.Run("stops on first UpdateUserID error", func(t *testing.T) {
users := []*types.User{
{Id: "user-1", AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return fmt.Errorf("update failed")
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.Error(t, err)
assert.Len(t, ms.updateCalls, 1) // stopped after first error
})
t.Run("skips already migrated users", func(t *testing.T) {
alreadyMigratedID := dex.EncodeDexUserID("user-1", "test-connector")
users := []*types.User{
{Id: alreadyMigratedID, AccountID: "account-1"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
assert.Len(t, ms.updateCalls, 0)
})
t.Run("migrates only non-migrated users in mixed state", func(t *testing.T) {
alreadyMigratedID := dex.EncodeDexUserID("user-1", "test-connector")
users := []*types.User{
{Id: alreadyMigratedID, AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
{Id: "user-3", AccountID: "account-2"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
// Only user-2 and user-3 should be migrated
assert.Len(t, ms.updateCalls, 2)
assert.Equal(t, "user-2", ms.updateCalls[0].OldUserID)
assert.Equal(t, "user-3", ms.updateCalls[1].OldUserID)
})
t.Run("dry run does not call UpdateUserID", func(t *testing.T) {
t.Setenv(dryRunEnvKey, "true")
users := []*types.User{
{Id: "user-1", AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
t.Fatal("UpdateUserID should not be called in dry-run mode")
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
assert.Len(t, ms.updateCalls, 0)
})
t.Run("dry run skips already migrated users", func(t *testing.T) {
t.Setenv(dryRunEnvKey, "true")
alreadyMigratedID := dex.EncodeDexUserID("user-1", "test-connector")
users := []*types.User{
{Id: alreadyMigratedID, AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
t.Fatal("UpdateUserID should not be called in dry-run mode")
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
})
t.Run("dry run disabled by default", func(t *testing.T) {
user := &types.User{Id: "user-1", AccountID: "account-1"}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{user}, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
assert.Len(t, ms.updateCalls, 1) // proves it's not in dry-run
})
}
func TestPopulateUserInfo(t *testing.T) {
noopUpdateID := func(ctx context.Context, accountID, oldUserID, newUserID string) error { return nil }
t.Run("succeeds with no users", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) { return nil, nil },
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
t.Run("returns error when ListUsers fails", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return nil, fmt.Errorf("db error")
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to list users")
})
t.Run("returns error when GetAllAccounts fails", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{{Id: "user-1", AccountID: "acc-1"}}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return nil, fmt.Errorf("idp error")
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to fetch accounts from IDP")
})
t.Run("updates user with missing email and name", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {
{ID: "user-1", Email: "user1@example.com", Name: "User One"},
},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
require.Len(t, ms.updateInfoCalls, 1)
assert.Equal(t, "user-1", ms.updateInfoCalls[0].UserID)
assert.Equal(t, "user1@example.com", ms.updateInfoCalls[0].Email)
assert.Equal(t, "User One", ms.updateInfoCalls[0].Name)
})
t.Run("updates only missing email when name exists", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: "Existing Name"},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "user-1", Email: "user1@example.com", Name: "IDP Name"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
require.Len(t, ms.updateInfoCalls, 1)
assert.Equal(t, "user1@example.com", ms.updateInfoCalls[0].Email)
assert.Equal(t, "Existing Name", ms.updateInfoCalls[0].Name)
})
t.Run("updates only missing name when email exists", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "existing@example.com", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "user-1", Email: "idp@example.com", Name: "IDP Name"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
require.Len(t, ms.updateInfoCalls, 1)
assert.Equal(t, "existing@example.com", ms.updateInfoCalls[0].Email)
assert.Equal(t, "IDP Name", ms.updateInfoCalls[0].Name)
})
t.Run("skips users that already have both email and name", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "user1@example.com", Name: "User One"},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "user-1", Email: "different@example.com", Name: "Different Name"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
t.Run("skips service users", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "svc-1", AccountID: "acc-1", Email: "", Name: "", IsServiceUser: true},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "svc-1", Email: "svc@example.com", Name: "Service"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
t.Run("skips users not found in IDP", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "different-user", Email: "other@example.com", Name: "Other"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
t.Run("looks up dex-encoded user IDs by original ID", func(t *testing.T) {
dexEncodedID := dex.EncodeDexUserID("original-idp-id", "my-connector")
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: dexEncodedID, AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "original-idp-id", Email: "user@example.com", Name: "User"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
require.Len(t, ms.updateInfoCalls, 1)
assert.Equal(t, dexEncodedID, ms.updateInfoCalls[0].UserID)
assert.Equal(t, "user@example.com", ms.updateInfoCalls[0].Email)
assert.Equal(t, "User", ms.updateInfoCalls[0].Name)
})
t.Run("handles multiple users across multiple accounts", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
{Id: "user-2", AccountID: "acc-1", Email: "already@set.com", Name: "Already Set"},
{Id: "user-3", AccountID: "acc-2", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {
{ID: "user-1", Email: "u1@example.com", Name: "User 1"},
{ID: "user-2", Email: "u2@example.com", Name: "User 2"},
},
"acc-2": {
{ID: "user-3", Email: "u3@example.com", Name: "User 3"},
},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
require.Len(t, ms.updateInfoCalls, 2)
assert.Equal(t, "user-1", ms.updateInfoCalls[0].UserID)
assert.Equal(t, "u1@example.com", ms.updateInfoCalls[0].Email)
assert.Equal(t, "user-3", ms.updateInfoCalls[1].UserID)
assert.Equal(t, "u3@example.com", ms.updateInfoCalls[1].Email)
})
t.Run("returns error when UpdateUserInfo fails", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
updateUserInfoFunc: func(ctx context.Context, userID, email, name string) error {
return fmt.Errorf("db write error")
},
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "user-1", Email: "u1@example.com", Name: "User 1"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to update user info for user-1")
})
t.Run("stops on first UpdateUserInfo error", func(t *testing.T) {
callCount := 0
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
{Id: "user-2", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
updateUserInfoFunc: func(ctx context.Context, userID, email, name string) error {
callCount++
return fmt.Errorf("db write error")
},
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {
{ID: "user-1", Email: "u1@example.com", Name: "U1"},
{ID: "user-2", Email: "u2@example.com", Name: "U2"},
},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.Error(t, err)
assert.Equal(t, 1, callCount)
})
t.Run("dry run does not call UpdateUserInfo", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
{Id: "user-2", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
updateUserInfoFunc: func(ctx context.Context, userID, email, name string) error {
t.Fatal("UpdateUserInfo should not be called in dry-run mode")
return nil
},
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {
{ID: "user-1", Email: "u1@example.com", Name: "U1"},
{ID: "user-2", Email: "u2@example.com", Name: "U2"},
},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, true)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
t.Run("skips user when IDP has empty email and name too", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "user-1", Email: "", Name: ""}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
}
func TestSchemaError_String(t *testing.T) {
t.Run("missing table", func(t *testing.T) {
e := SchemaError{Table: "jobs"}
assert.Equal(t, `table "jobs" is missing`, e.String())
})
t.Run("missing column", func(t *testing.T) {
e := SchemaError{Table: "users", Column: "email"}
assert.Equal(t, `column "email" on table "users" is missing`, e.String())
})
}
func TestRequiredSchema(t *testing.T) {
// Verify RequiredSchema covers all the tables touched by UpdateUserID and UpdateUserInfo.
expectedTables := []string{
"users",
"personal_access_tokens",
"peers",
"accounts",
"user_invites",
"proxy_access_tokens",
"jobs",
}
schemaTableNames := make([]string, len(RequiredSchema))
for i, s := range RequiredSchema {
schemaTableNames[i] = s.Table
}
for _, expected := range expectedTables {
assert.Contains(t, schemaTableNames, expected, "RequiredSchema should include table %q", expected)
}
}
func TestCheckSchema_MockStore(t *testing.T) {
t.Run("returns nil when all schema exists", func(t *testing.T) {
ms := &testStore{
checkSchemaFunc: func(checks []SchemaCheck) []SchemaError {
return nil
},
}
errs := ms.CheckSchema(RequiredSchema)
assert.Empty(t, errs)
})
t.Run("returns errors for missing tables", func(t *testing.T) {
ms := &testStore{
checkSchemaFunc: func(checks []SchemaCheck) []SchemaError {
return []SchemaError{
{Table: "jobs"},
{Table: "proxy_access_tokens"},
}
},
}
errs := ms.CheckSchema(RequiredSchema)
require.Len(t, errs, 2)
assert.Equal(t, "jobs", errs[0].Table)
assert.Equal(t, "", errs[0].Column)
assert.Equal(t, "proxy_access_tokens", errs[1].Table)
})
t.Run("returns errors for missing columns", func(t *testing.T) {
ms := &testStore{
checkSchemaFunc: func(checks []SchemaCheck) []SchemaError {
return []SchemaError{
{Table: "users", Column: "email"},
{Table: "users", Column: "name"},
}
},
}
errs := ms.CheckSchema(RequiredSchema)
require.Len(t, errs, 2)
assert.Equal(t, "users", errs[0].Table)
assert.Equal(t, "email", errs[0].Column)
})
}

View File

@@ -0,0 +1,82 @@
package migration
import (
"context"
"fmt"
"github.com/netbirdio/netbird/management/server/types"
)
// SchemaCheck represents a table and the columns required on it.
type SchemaCheck struct {
Table string
Columns []string
}
// RequiredSchema lists all tables and columns that the migration tool needs.
// If any are missing, the user must upgrade their management server first so
// that the automatic GORM migrations create them.
var RequiredSchema = []SchemaCheck{
{Table: "users", Columns: []string{"id", "email", "name", "account_id"}},
{Table: "personal_access_tokens", Columns: []string{"user_id", "created_by"}},
{Table: "peers", Columns: []string{"user_id"}},
{Table: "accounts", Columns: []string{"created_by"}},
{Table: "user_invites", Columns: []string{"created_by"}},
{Table: "proxy_access_tokens", Columns: []string{"created_by"}},
{Table: "jobs", Columns: []string{"triggered_by"}},
}
// SchemaError describes a single missing table or column.
type SchemaError struct {
Table string
Column string // empty when the whole table is missing
}
func (e SchemaError) String() string {
if e.Column == "" {
return fmt.Sprintf("table %q is missing", e.Table)
}
return fmt.Sprintf("column %q on table %q is missing", e.Column, e.Table)
}
// Store defines the data store operations required for IdP user migration.
// This interface is separate from the main store.Store interface because these methods
// are only used during one-time migration and should be removed once migration tooling
// is no longer needed.
//
// The SQL store implementations (SqlStore) already have these methods on their concrete
// types, so they satisfy this interface via Go's structural typing with zero code changes.
type Store interface {
// ListUsers returns all users across all accounts.
ListUsers(ctx context.Context) ([]*types.User, error)
// UpdateUserID atomically updates a user's ID and all foreign key references
// across the database (peers, groups, policies, PATs, etc.).
UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error
// UpdateUserInfo updates a user's email and name in the store.
UpdateUserInfo(ctx context.Context, userID, email, name string) error
// CheckSchema verifies that all tables and columns required by the migration
// exist in the database. Returns a list of problems; an empty slice means OK.
CheckSchema(checks []SchemaCheck) []SchemaError
}
// RequiredEventSchema lists all tables and columns that the migration tool needs
// in the activity/event store.
var RequiredEventSchema = []SchemaCheck{
{Table: "events", Columns: []string{"initiator_id", "target_id"}},
{Table: "deleted_users", Columns: []string{"id"}},
}
// EventStore defines the activity event store operations required for migration.
// Like Store, this is a temporary interface for migration tooling only.
type EventStore interface {
// CheckSchema verifies that all tables and columns required by the migration
// exist in the event database. Returns a list of problems; an empty slice means OK.
CheckSchema(checks []SchemaCheck) []SchemaError
// UpdateUserID updates all event references (initiator_id, target_id) and
// deleted_users records to use the new user ID format.
UpdateUserID(ctx context.Context, oldUserID, newUserID string) error
}