mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[management] Legacy to embedded IdP migration tool (#5586)
This commit is contained in:
61
management/server/activity/store/sql_store_idp_migration.go
Normal file
61
management/server/activity/store/sql_store_idp_migration.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package store
|
||||
|
||||
// This file contains migration-only methods on Store.
|
||||
// They satisfy the migration.MigrationEventStore interface via duck typing.
|
||||
// Delete this file when migration tooling is no longer needed.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp/migration"
|
||||
)
|
||||
|
||||
// CheckSchema verifies that all tables and columns required by the migration exist in the event database.
|
||||
func (store *Store) CheckSchema(checks []migration.SchemaCheck) []migration.SchemaError {
|
||||
migrator := store.db.Migrator()
|
||||
var errs []migration.SchemaError
|
||||
|
||||
for _, check := range checks {
|
||||
if !migrator.HasTable(check.Table) {
|
||||
errs = append(errs, migration.SchemaError{Table: check.Table})
|
||||
continue
|
||||
}
|
||||
for _, col := range check.Columns {
|
||||
if !migrator.HasColumn(check.Table, col) {
|
||||
errs = append(errs, migration.SchemaError{Table: check.Table, Column: col})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
// UpdateUserID updates all references to oldUserID in events and deleted_users tables.
|
||||
func (store *Store) UpdateUserID(ctx context.Context, oldUserID, newUserID string) error {
|
||||
return store.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Model(&activity.Event{}).
|
||||
Where("initiator_id = ?", oldUserID).
|
||||
Update("initiator_id", newUserID).Error; err != nil {
|
||||
return fmt.Errorf("update events.initiator_id: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Model(&activity.Event{}).
|
||||
Where("target_id = ?", oldUserID).
|
||||
Update("target_id", newUserID).Error; err != nil {
|
||||
return fmt.Errorf("update events.target_id: %w", err)
|
||||
}
|
||||
|
||||
// Raw exec: GORM can't update a PK via Model().Update()
|
||||
if err := tx.Exec(
|
||||
"UPDATE deleted_users SET id = ? WHERE id = ?", newUserID, oldUserID,
|
||||
).Error; err != nil {
|
||||
return fmt.Errorf("update deleted_users.id: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
161
management/server/activity/store/sql_store_idp_migration_test.go
Normal file
161
management/server/activity/store/sql_store_idp_migration_test.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
func TestUpdateUserID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
newStore := func(t *testing.T) *Store {
|
||||
t.Helper()
|
||||
key, _ := crypt.GenerateKey()
|
||||
s, err := NewSqlStore(ctx, t.TempDir(), key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { s.Close(ctx) }) //nolint
|
||||
return s
|
||||
}
|
||||
|
||||
t.Run("updates initiator_id in events", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
accountID := "account_1"
|
||||
|
||||
_, err := store.Save(ctx, &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.PeerAddedByUser,
|
||||
InitiatorID: "old-user",
|
||||
TargetID: "some-peer",
|
||||
AccountID: accountID,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = store.UpdateUserID(ctx, "old-user", "new-user")
|
||||
assert.NoError(t, err)
|
||||
|
||||
result, err := store.Get(ctx, accountID, 0, 10, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, "new-user", result[0].InitiatorID)
|
||||
})
|
||||
|
||||
t.Run("updates target_id in events", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
accountID := "account_1"
|
||||
|
||||
_, err := store.Save(ctx, &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.PeerAddedByUser,
|
||||
InitiatorID: "some-admin",
|
||||
TargetID: "old-user",
|
||||
AccountID: accountID,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = store.UpdateUserID(ctx, "old-user", "new-user")
|
||||
assert.NoError(t, err)
|
||||
|
||||
result, err := store.Get(ctx, accountID, 0, 10, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, "new-user", result[0].TargetID)
|
||||
})
|
||||
|
||||
t.Run("updates deleted_users id", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
accountID := "account_1"
|
||||
|
||||
// Save an event with email/name meta to create a deleted_users row for "old-user"
|
||||
_, err := store.Save(ctx, &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.PeerAddedByUser,
|
||||
InitiatorID: "admin",
|
||||
TargetID: "old-user",
|
||||
AccountID: accountID,
|
||||
Meta: map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = store.UpdateUserID(ctx, "old-user", "new-user")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Save another event referencing new-user with email/name meta.
|
||||
// This should upsert (not conflict) because the PK was already migrated.
|
||||
_, err = store.Save(ctx, &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.PeerAddedByUser,
|
||||
InitiatorID: "admin",
|
||||
TargetID: "new-user",
|
||||
AccountID: accountID,
|
||||
Meta: map[string]any{
|
||||
"email": "user@example.com",
|
||||
"name": "Test User",
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// The deleted user info should be retrievable via Get (joined on target_id)
|
||||
result, err := store.Get(ctx, accountID, 0, 10, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 2)
|
||||
for _, ev := range result {
|
||||
assert.Equal(t, "new-user", ev.TargetID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no-op when old user ID does not exist", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
|
||||
err := store.UpdateUserID(ctx, "nonexistent-user", "new-user")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("only updates matching user leaves others unchanged", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
accountID := "account_1"
|
||||
|
||||
_, err := store.Save(ctx, &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.PeerAddedByUser,
|
||||
InitiatorID: "user-a",
|
||||
TargetID: "peer-1",
|
||||
AccountID: accountID,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = store.Save(ctx, &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.PeerAddedByUser,
|
||||
InitiatorID: "user-b",
|
||||
TargetID: "peer-2",
|
||||
AccountID: accountID,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = store.UpdateUserID(ctx, "user-a", "user-a-new")
|
||||
assert.NoError(t, err)
|
||||
|
||||
result, err := store.Get(ctx, accountID, 0, 10, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, 2)
|
||||
|
||||
for _, ev := range result {
|
||||
if ev.TargetID == "peer-1" {
|
||||
assert.Equal(t, "user-a-new", ev.InitiatorID)
|
||||
} else {
|
||||
assert.Equal(t, "user-b", ev.InitiatorID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -33,15 +33,20 @@ type manager struct {
|
||||
extractor *nbjwt.ClaimsExtractor
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool) Manager {
|
||||
// @note if invalid/missing parameters are sent the validator will instantiate
|
||||
// but it will fail when validating and parsing the token
|
||||
jwtValidator := nbjwt.NewValidator(
|
||||
issuer,
|
||||
allAudiences,
|
||||
keysLocation,
|
||||
idpRefreshKeys,
|
||||
)
|
||||
func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool, keyFetcher nbjwt.KeyFetcher) Manager {
|
||||
var jwtValidator *nbjwt.Validator
|
||||
if keyFetcher != nil {
|
||||
jwtValidator = nbjwt.NewValidatorWithKeyFetcher(issuer, allAudiences, keyFetcher)
|
||||
} else {
|
||||
// @note if invalid/missing parameters are sent the validator will instantiate
|
||||
// but it will fail when validating and parsing the token
|
||||
jwtValidator = nbjwt.NewValidator(
|
||||
issuer,
|
||||
allAudiences,
|
||||
keysLocation,
|
||||
idpRefreshKeys,
|
||||
)
|
||||
}
|
||||
|
||||
claimsExtractor := nbjwt.NewClaimsExtractor(
|
||||
nbjwt.WithAudience(audience),
|
||||
|
||||
@@ -52,7 +52,7 @@ func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
|
||||
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
|
||||
|
||||
user, pat, _, _, err := manager.GetPATInfo(context.Background(), token)
|
||||
if err != nil {
|
||||
@@ -92,7 +92,7 @@ func TestAuthManager_MarkPATUsed(t *testing.T) {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
|
||||
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
|
||||
|
||||
err = manager.MarkPATUsed(context.Background(), "tokenId")
|
||||
if err != nil {
|
||||
@@ -142,7 +142,7 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) {
|
||||
// these tests only assert groups are parsed from token as per account settings
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}})
|
||||
|
||||
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
|
||||
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
|
||||
|
||||
t.Run("JWT groups disabled", func(t *testing.T) {
|
||||
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
|
||||
@@ -225,7 +225,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) {
|
||||
keyId := "test-key"
|
||||
|
||||
// note, we can use a nil store because ValidateAndParseToken does not use it in it's flow
|
||||
manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false)
|
||||
manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false, nil)
|
||||
|
||||
customClaim := func(name string) string {
|
||||
return fmt.Sprintf("%s/%s", audience, name)
|
||||
|
||||
@@ -119,7 +119,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
am.SetServiceManager(serviceManager)
|
||||
|
||||
// @note this is required so that PAT's validate from store, but JWT's are mocked
|
||||
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
|
||||
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false, nil)
|
||||
authManagerMock := &serverauth.MockManager{
|
||||
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
||||
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
|
||||
@@ -248,7 +248,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
||||
am.SetServiceManager(serviceManager)
|
||||
|
||||
// @note this is required so that PAT's validate from store, but JWT's are mocked
|
||||
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
|
||||
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false, nil)
|
||||
authManagerMock := &serverauth.MockManager{
|
||||
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
||||
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -48,6 +49,8 @@ type EmbeddedIdPConfig struct {
|
||||
// Existing local users are preserved and will be able to login again if re-enabled.
|
||||
// Cannot be enabled if no external identity provider connectors are configured.
|
||||
LocalAuthDisabled bool
|
||||
// StaticConnectors are additional connectors to seed during initialization
|
||||
StaticConnectors []dex.Connector
|
||||
}
|
||||
|
||||
// EmbeddedStorageConfig holds storage configuration for the embedded IdP.
|
||||
@@ -157,6 +160,7 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
RedirectURIs: cliRedirectURIs,
|
||||
},
|
||||
},
|
||||
StaticConnectors: c.StaticConnectors,
|
||||
}
|
||||
|
||||
// Add owner user if provided
|
||||
@@ -193,6 +197,9 @@ type OAuthConfigProvider interface {
|
||||
// Management server has embedded Dex and can validate tokens via localhost,
|
||||
// avoiding external network calls and DNS resolution issues during startup.
|
||||
GetLocalKeysLocation() string
|
||||
// GetKeyFetcher returns a KeyFetcher that reads keys directly from the IDP storage,
|
||||
// or nil if direct key fetching is not supported (falls back to HTTP).
|
||||
GetKeyFetcher() nbjwt.KeyFetcher
|
||||
GetClientIDs() []string
|
||||
GetUserIDClaim() string
|
||||
GetTokenEndpoint() string
|
||||
@@ -593,6 +600,11 @@ func (m *EmbeddedIdPManager) GetCLIRedirectURLs() []string {
|
||||
return m.config.CLIRedirectURIs
|
||||
}
|
||||
|
||||
// GetKeyFetcher returns a KeyFetcher that reads keys directly from Dex storage.
|
||||
func (m *EmbeddedIdPManager) GetKeyFetcher() nbjwt.KeyFetcher {
|
||||
return m.provider.GetJWKS
|
||||
}
|
||||
|
||||
// GetKeysLocation returns the JWKS endpoint URL for token validation.
|
||||
func (m *EmbeddedIdPManager) GetKeysLocation() string {
|
||||
return m.provider.GetKeysLocation()
|
||||
|
||||
235
management/server/idp/migration/migration.go
Normal file
235
management/server/idp/migration/migration.go
Normal file
@@ -0,0 +1,235 @@
|
||||
// Package migration provides utility functions for migrating from the external IdP solution in pre v0.62.0
|
||||
// to the new embedded IdP manager (Dex based), which is the default in v0.62.0 and later.
|
||||
// It includes functions to seed connectors and migrate existing users to use these connectors.
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// Server is the dependency interface that migration functions use to access
|
||||
// the main data store and the activity event store.
|
||||
type Server interface {
|
||||
Store() Store
|
||||
EventStore() EventStore // may return nil
|
||||
}
|
||||
|
||||
const idpSeedInfoKey = "IDP_SEED_INFO"
|
||||
const dryRunEnvKey = "NB_IDP_MIGRATION_DRY_RUN"
|
||||
|
||||
func isDryRun() bool {
|
||||
return os.Getenv(dryRunEnvKey) == "true"
|
||||
}
|
||||
|
||||
var ErrNoSeedInfo = errors.New("no seed info found in environment")
|
||||
|
||||
// SeedConnectorFromEnv reads the IDP_SEED_INFO env var, base64-decodes it,
|
||||
// and JSON-unmarshals it into a dex.Connector. Returns nil if not set.
|
||||
func SeedConnectorFromEnv() (*dex.Connector, error) {
|
||||
val, ok := os.LookupEnv(idpSeedInfoKey)
|
||||
if !ok || val == "" {
|
||||
return nil, ErrNoSeedInfo
|
||||
}
|
||||
|
||||
decoded, err := base64.StdEncoding.DecodeString(val)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("base64 decode: %w", err)
|
||||
}
|
||||
|
||||
var conn dex.Connector
|
||||
if err := json.Unmarshal(decoded, &conn); err != nil {
|
||||
return nil, fmt.Errorf("json unmarshal: %w", err)
|
||||
}
|
||||
|
||||
return &conn, nil
|
||||
}
|
||||
|
||||
// MigrateUsersToStaticConnectors re-keys every user ID in the main store (and
|
||||
// the activity store, if present) so that it encodes the given connector ID,
|
||||
// skipping users that have already been migrated. Set NB_IDP_MIGRATION_DRY_RUN=true
|
||||
// to log what would happen without writing any changes.
|
||||
func MigrateUsersToStaticConnectors(s Server, conn *dex.Connector) error {
|
||||
ctx := context.Background()
|
||||
|
||||
if isDryRun() {
|
||||
log.Info("[DRY RUN] migration dry-run mode enabled, no changes will be written")
|
||||
}
|
||||
|
||||
users, err := s.Store().ListUsers(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list users: %w", err)
|
||||
}
|
||||
|
||||
// Reconciliation pass: fix activity store for users already migrated in main DB
|
||||
// but whose activity references may still use old IDs (from a previous partial failure).
|
||||
if s.EventStore() != nil && !isDryRun() {
|
||||
if err := reconcileActivityStore(ctx, s.EventStore(), users); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var migratedCount, skippedCount int
|
||||
|
||||
for _, user := range users {
|
||||
_, _, decErr := dex.DecodeDexUserID(user.Id)
|
||||
if decErr == nil {
|
||||
skippedCount++
|
||||
continue
|
||||
}
|
||||
|
||||
newUserID := dex.EncodeDexUserID(user.Id, conn.ID)
|
||||
|
||||
if isDryRun() {
|
||||
log.Infof("[DRY RUN] would migrate user %s -> %s (account: %s)", user.Id, newUserID, user.AccountID)
|
||||
migratedCount++
|
||||
continue
|
||||
}
|
||||
|
||||
if err := migrateUser(ctx, s, user.Id, user.AccountID, newUserID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
migratedCount++
|
||||
}
|
||||
|
||||
if isDryRun() {
|
||||
log.Infof("[DRY RUN] migration summary: %d users would be migrated, %d already migrated", migratedCount, skippedCount)
|
||||
} else {
|
||||
log.Infof("migration complete: %d users migrated, %d already migrated", migratedCount, skippedCount)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reconcileActivityStore updates activity store references for users already migrated
|
||||
// in the main DB whose activity entries may still use old IDs from a previous partial failure.
|
||||
func reconcileActivityStore(ctx context.Context, eventStore EventStore, users []*types.User) error {
|
||||
for _, user := range users {
|
||||
originalID, _, err := dex.DecodeDexUserID(user.Id)
|
||||
if err != nil {
|
||||
// skip users that aren't migrated, they will be handled in the main migration loop
|
||||
continue
|
||||
}
|
||||
if err := eventStore.UpdateUserID(ctx, originalID, user.Id); err != nil {
|
||||
return fmt.Errorf("reconcile activity store for user %s: %w", user.Id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateUser updates a single user's ID in both the main store and the activity store.
|
||||
func migrateUser(ctx context.Context, s Server, oldID, accountID, newID string) error {
|
||||
if err := s.Store().UpdateUserID(ctx, accountID, oldID, newID); err != nil {
|
||||
return fmt.Errorf("failed to update user ID for user %s: %w", oldID, err)
|
||||
}
|
||||
|
||||
if s.EventStore() == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.EventStore().UpdateUserID(ctx, oldID, newID); err != nil {
|
||||
return fmt.Errorf("failed to update activity store user ID for user %s: %w", oldID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PopulateUserInfo fetches user email and name from the external IDP and updates
|
||||
// the store for users that are missing this information.
|
||||
func PopulateUserInfo(s Server, idpManager idp.Manager, dryRun bool) error {
|
||||
ctx := context.Background()
|
||||
|
||||
users, err := s.Store().ListUsers(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list users: %w", err)
|
||||
}
|
||||
|
||||
// Build a map of IDP user ID -> UserData from the external IDP
|
||||
allAccounts, err := idpManager.GetAllAccounts(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch accounts from IDP: %w", err)
|
||||
}
|
||||
|
||||
idpUsers := make(map[string]*idp.UserData)
|
||||
for _, accountUsers := range allAccounts {
|
||||
for _, userData := range accountUsers {
|
||||
idpUsers[userData.ID] = userData
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("fetched %d users from IDP", len(idpUsers))
|
||||
|
||||
var updatedCount, skippedCount, notFoundCount int
|
||||
|
||||
for _, user := range users {
|
||||
if user.IsServiceUser {
|
||||
skippedCount++
|
||||
continue
|
||||
}
|
||||
|
||||
if user.Email != "" && user.Name != "" {
|
||||
skippedCount++
|
||||
continue
|
||||
}
|
||||
|
||||
// The user ID in the store may be the original IDP ID or a Dex-encoded ID.
|
||||
// Try to decode the Dex format first to get the original IDP ID.
|
||||
lookupID := user.Id
|
||||
if originalID, _, decErr := dex.DecodeDexUserID(user.Id); decErr == nil {
|
||||
lookupID = originalID
|
||||
}
|
||||
|
||||
idpUser, found := idpUsers[lookupID]
|
||||
if !found {
|
||||
notFoundCount++
|
||||
log.Debugf("user %s (lookup: %s) not found in IDP, skipping", user.Id, lookupID)
|
||||
continue
|
||||
}
|
||||
|
||||
email := user.Email
|
||||
name := user.Name
|
||||
if email == "" && idpUser.Email != "" {
|
||||
email = idpUser.Email
|
||||
}
|
||||
if name == "" && idpUser.Name != "" {
|
||||
name = idpUser.Name
|
||||
}
|
||||
|
||||
if email == user.Email && name == user.Name {
|
||||
skippedCount++
|
||||
continue
|
||||
}
|
||||
|
||||
if dryRun {
|
||||
log.Infof("[DRY RUN] would update user %s: email=%q, name=%q", user.Id, email, name)
|
||||
updatedCount++
|
||||
continue
|
||||
}
|
||||
|
||||
if err := s.Store().UpdateUserInfo(ctx, user.Id, email, name); err != nil {
|
||||
return fmt.Errorf("failed to update user info for %s: %w", user.Id, err)
|
||||
}
|
||||
|
||||
log.Infof("updated user %s: email=%q, name=%q", user.Id, email, name)
|
||||
updatedCount++
|
||||
}
|
||||
|
||||
if dryRun {
|
||||
log.Infof("[DRY RUN] user info summary: %d would be updated, %d skipped, %d not found in IDP", updatedCount, skippedCount, notFoundCount)
|
||||
} else {
|
||||
log.Infof("user info population complete: %d updated, %d skipped, %d not found in IDP", updatedCount, skippedCount, notFoundCount)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
828
management/server/idp/migration/migration_test.go
Normal file
828
management/server/idp/migration/migration_test.go
Normal file
@@ -0,0 +1,828 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// testStore is a hand-written mock for MigrationStore.
|
||||
type testStore struct {
|
||||
listUsersFunc func(ctx context.Context) ([]*types.User, error)
|
||||
updateUserIDFunc func(ctx context.Context, accountID, oldUserID, newUserID string) error
|
||||
updateUserInfoFunc func(ctx context.Context, userID, email, name string) error
|
||||
checkSchemaFunc func(checks []SchemaCheck) []SchemaError
|
||||
updateCalls []updateUserIDCall
|
||||
updateInfoCalls []updateUserInfoCall
|
||||
}
|
||||
|
||||
type updateUserIDCall struct {
|
||||
AccountID string
|
||||
OldUserID string
|
||||
NewUserID string
|
||||
}
|
||||
|
||||
type updateUserInfoCall struct {
|
||||
UserID string
|
||||
Email string
|
||||
Name string
|
||||
}
|
||||
|
||||
func (s *testStore) ListUsers(ctx context.Context) ([]*types.User, error) {
|
||||
return s.listUsersFunc(ctx)
|
||||
}
|
||||
|
||||
func (s *testStore) UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error {
|
||||
s.updateCalls = append(s.updateCalls, updateUserIDCall{accountID, oldUserID, newUserID})
|
||||
return s.updateUserIDFunc(ctx, accountID, oldUserID, newUserID)
|
||||
}
|
||||
|
||||
func (s *testStore) UpdateUserInfo(ctx context.Context, userID, email, name string) error {
|
||||
s.updateInfoCalls = append(s.updateInfoCalls, updateUserInfoCall{userID, email, name})
|
||||
if s.updateUserInfoFunc != nil {
|
||||
return s.updateUserInfoFunc(ctx, userID, email, name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *testStore) CheckSchema(checks []SchemaCheck) []SchemaError {
|
||||
if s.checkSchemaFunc != nil {
|
||||
return s.checkSchemaFunc(checks)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type testServer struct {
|
||||
store Store
|
||||
eventStore EventStore
|
||||
}
|
||||
|
||||
func (s *testServer) Store() Store { return s.store }
|
||||
func (s *testServer) EventStore() EventStore { return s.eventStore }
|
||||
|
||||
func TestSeedConnectorFromEnv(t *testing.T) {
|
||||
t.Run("returns ErrNoSeedInfo when env var is not set", func(t *testing.T) {
|
||||
os.Unsetenv(idpSeedInfoKey)
|
||||
|
||||
conn, err := SeedConnectorFromEnv()
|
||||
assert.ErrorIs(t, err, ErrNoSeedInfo)
|
||||
assert.Nil(t, conn)
|
||||
})
|
||||
|
||||
t.Run("returns ErrNoSeedInfo when env var is empty", func(t *testing.T) {
|
||||
t.Setenv(idpSeedInfoKey, "")
|
||||
|
||||
conn, err := SeedConnectorFromEnv()
|
||||
assert.ErrorIs(t, err, ErrNoSeedInfo)
|
||||
assert.Nil(t, conn)
|
||||
})
|
||||
|
||||
t.Run("returns error on invalid base64", func(t *testing.T) {
|
||||
t.Setenv(idpSeedInfoKey, "not-valid-base64!!!")
|
||||
|
||||
conn, err := SeedConnectorFromEnv()
|
||||
assert.NotErrorIs(t, err, ErrNoSeedInfo)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, conn)
|
||||
assert.Contains(t, err.Error(), "base64 decode")
|
||||
})
|
||||
|
||||
t.Run("returns error on invalid JSON", func(t *testing.T) {
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte("not json"))
|
||||
t.Setenv(idpSeedInfoKey, encoded)
|
||||
|
||||
conn, err := SeedConnectorFromEnv()
|
||||
assert.NotErrorIs(t, err, ErrNoSeedInfo)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, conn)
|
||||
assert.Contains(t, err.Error(), "json unmarshal")
|
||||
})
|
||||
|
||||
t.Run("successfully decodes valid connector", func(t *testing.T) {
|
||||
expected := dex.Connector{
|
||||
Type: "oidc",
|
||||
Name: "Test Provider",
|
||||
ID: "test-provider",
|
||||
Config: map[string]any{
|
||||
"issuer": "https://example.com",
|
||||
"clientID": "my-client-id",
|
||||
"clientSecret": "my-secret",
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(expected)
|
||||
require.NoError(t, err)
|
||||
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
t.Setenv(idpSeedInfoKey, encoded)
|
||||
|
||||
conn, err := SeedConnectorFromEnv()
|
||||
assert.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
assert.Equal(t, expected.Type, conn.Type)
|
||||
assert.Equal(t, expected.Name, conn.Name)
|
||||
assert.Equal(t, expected.ID, conn.ID)
|
||||
assert.Equal(t, expected.Config["issuer"], conn.Config["issuer"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrateUsersToStaticConnectors(t *testing.T) {
|
||||
connector := &dex.Connector{
|
||||
Type: "oidc",
|
||||
Name: "Test Provider",
|
||||
ID: "test-connector",
|
||||
}
|
||||
|
||||
t.Run("succeeds with no users", func(t *testing.T) {
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) { return nil, nil },
|
||||
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { return nil },
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := MigrateUsersToStaticConnectors(srv, connector)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("returns error when ListUsers fails", func(t *testing.T) {
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return nil, fmt.Errorf("db error")
|
||||
},
|
||||
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { return nil },
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := MigrateUsersToStaticConnectors(srv, connector)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to list users")
|
||||
})
|
||||
|
||||
t.Run("migrates single user with correct encoded ID", func(t *testing.T) {
|
||||
user := &types.User{Id: "user-1", AccountID: "account-1"}
|
||||
expectedNewID := dex.EncodeDexUserID("user-1", "test-connector")
|
||||
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return []*types.User{user}, nil
|
||||
},
|
||||
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := MigrateUsersToStaticConnectors(srv, connector)
|
||||
assert.NoError(t, err)
|
||||
require.Len(t, ms.updateCalls, 1)
|
||||
assert.Equal(t, "account-1", ms.updateCalls[0].AccountID)
|
||||
assert.Equal(t, "user-1", ms.updateCalls[0].OldUserID)
|
||||
assert.Equal(t, expectedNewID, ms.updateCalls[0].NewUserID)
|
||||
})
|
||||
|
||||
t.Run("migrates multiple users", func(t *testing.T) {
|
||||
users := []*types.User{
|
||||
{Id: "user-1", AccountID: "account-1"},
|
||||
{Id: "user-2", AccountID: "account-1"},
|
||||
{Id: "user-3", AccountID: "account-2"},
|
||||
}
|
||||
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return users, nil
|
||||
},
|
||||
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := MigrateUsersToStaticConnectors(srv, connector)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, ms.updateCalls, 3)
|
||||
})
|
||||
|
||||
t.Run("returns error when UpdateUserID fails", func(t *testing.T) {
|
||||
users := []*types.User{
|
||||
{Id: "user-1", AccountID: "account-1"},
|
||||
{Id: "user-2", AccountID: "account-1"},
|
||||
}
|
||||
|
||||
callCount := 0
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return users, nil
|
||||
},
|
||||
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
|
||||
callCount++
|
||||
if callCount == 2 {
|
||||
return fmt.Errorf("update failed")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := MigrateUsersToStaticConnectors(srv, connector)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to update user ID for user user-2")
|
||||
})
|
||||
|
||||
t.Run("stops on first UpdateUserID error", func(t *testing.T) {
|
||||
users := []*types.User{
|
||||
{Id: "user-1", AccountID: "account-1"},
|
||||
{Id: "user-2", AccountID: "account-1"},
|
||||
}
|
||||
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return users, nil
|
||||
},
|
||||
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
|
||||
return fmt.Errorf("update failed")
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := MigrateUsersToStaticConnectors(srv, connector)
|
||||
assert.Error(t, err)
|
||||
assert.Len(t, ms.updateCalls, 1) // stopped after first error
|
||||
})
|
||||
|
||||
t.Run("skips already migrated users", func(t *testing.T) {
|
||||
alreadyMigratedID := dex.EncodeDexUserID("user-1", "test-connector")
|
||||
users := []*types.User{
|
||||
{Id: alreadyMigratedID, AccountID: "account-1"},
|
||||
}
|
||||
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return users, nil
|
||||
},
|
||||
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := MigrateUsersToStaticConnectors(srv, connector)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, ms.updateCalls, 0)
|
||||
})
|
||||
|
||||
t.Run("migrates only non-migrated users in mixed state", func(t *testing.T) {
|
||||
alreadyMigratedID := dex.EncodeDexUserID("user-1", "test-connector")
|
||||
users := []*types.User{
|
||||
{Id: alreadyMigratedID, AccountID: "account-1"},
|
||||
{Id: "user-2", AccountID: "account-1"},
|
||||
{Id: "user-3", AccountID: "account-2"},
|
||||
}
|
||||
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return users, nil
|
||||
},
|
||||
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := MigrateUsersToStaticConnectors(srv, connector)
|
||||
assert.NoError(t, err)
|
||||
// Only user-2 and user-3 should be migrated
|
||||
assert.Len(t, ms.updateCalls, 2)
|
||||
assert.Equal(t, "user-2", ms.updateCalls[0].OldUserID)
|
||||
assert.Equal(t, "user-3", ms.updateCalls[1].OldUserID)
|
||||
})
|
||||
|
||||
t.Run("dry run does not call UpdateUserID", func(t *testing.T) {
|
||||
t.Setenv(dryRunEnvKey, "true")
|
||||
|
||||
users := []*types.User{
|
||||
{Id: "user-1", AccountID: "account-1"},
|
||||
{Id: "user-2", AccountID: "account-1"},
|
||||
}
|
||||
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return users, nil
|
||||
},
|
||||
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
|
||||
t.Fatal("UpdateUserID should not be called in dry-run mode")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := MigrateUsersToStaticConnectors(srv, connector)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, ms.updateCalls, 0)
|
||||
})
|
||||
|
||||
t.Run("dry run skips already migrated users", func(t *testing.T) {
|
||||
t.Setenv(dryRunEnvKey, "true")
|
||||
|
||||
alreadyMigratedID := dex.EncodeDexUserID("user-1", "test-connector")
|
||||
users := []*types.User{
|
||||
{Id: alreadyMigratedID, AccountID: "account-1"},
|
||||
{Id: "user-2", AccountID: "account-1"},
|
||||
}
|
||||
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return users, nil
|
||||
},
|
||||
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
|
||||
t.Fatal("UpdateUserID should not be called in dry-run mode")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := MigrateUsersToStaticConnectors(srv, connector)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("dry run disabled by default", func(t *testing.T) {
|
||||
user := &types.User{Id: "user-1", AccountID: "account-1"}
|
||||
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return []*types.User{user}, nil
|
||||
},
|
||||
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := MigrateUsersToStaticConnectors(srv, connector)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, ms.updateCalls, 1) // proves it's not in dry-run
|
||||
})
|
||||
}
|
||||
|
||||
func TestPopulateUserInfo(t *testing.T) {
|
||||
noopUpdateID := func(ctx context.Context, accountID, oldUserID, newUserID string) error { return nil }
|
||||
|
||||
t.Run("succeeds with no users", func(t *testing.T) {
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) { return nil, nil },
|
||||
updateUserIDFunc: noopUpdateID,
|
||||
}
|
||||
mockIDP := &idp.MockIDP{
|
||||
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
|
||||
return map[string][]*idp.UserData{}, nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := PopulateUserInfo(srv, mockIDP, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, ms.updateInfoCalls)
|
||||
})
|
||||
|
||||
t.Run("returns error when ListUsers fails", func(t *testing.T) {
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return nil, fmt.Errorf("db error")
|
||||
},
|
||||
updateUserIDFunc: noopUpdateID,
|
||||
}
|
||||
mockIDP := &idp.MockIDP{}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := PopulateUserInfo(srv, mockIDP, false)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to list users")
|
||||
})
|
||||
|
||||
t.Run("returns error when GetAllAccounts fails", func(t *testing.T) {
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return []*types.User{{Id: "user-1", AccountID: "acc-1"}}, nil
|
||||
},
|
||||
updateUserIDFunc: noopUpdateID,
|
||||
}
|
||||
mockIDP := &idp.MockIDP{
|
||||
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
|
||||
return nil, fmt.Errorf("idp error")
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := PopulateUserInfo(srv, mockIDP, false)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to fetch accounts from IDP")
|
||||
})
|
||||
|
||||
t.Run("updates user with missing email and name", func(t *testing.T) {
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return []*types.User{
|
||||
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
|
||||
}, nil
|
||||
},
|
||||
updateUserIDFunc: noopUpdateID,
|
||||
}
|
||||
mockIDP := &idp.MockIDP{
|
||||
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
|
||||
return map[string][]*idp.UserData{
|
||||
"acc-1": {
|
||||
{ID: "user-1", Email: "user1@example.com", Name: "User One"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := PopulateUserInfo(srv, mockIDP, false)
|
||||
assert.NoError(t, err)
|
||||
require.Len(t, ms.updateInfoCalls, 1)
|
||||
assert.Equal(t, "user-1", ms.updateInfoCalls[0].UserID)
|
||||
assert.Equal(t, "user1@example.com", ms.updateInfoCalls[0].Email)
|
||||
assert.Equal(t, "User One", ms.updateInfoCalls[0].Name)
|
||||
})
|
||||
|
||||
t.Run("updates only missing email when name exists", func(t *testing.T) {
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return []*types.User{
|
||||
{Id: "user-1", AccountID: "acc-1", Email: "", Name: "Existing Name"},
|
||||
}, nil
|
||||
},
|
||||
updateUserIDFunc: noopUpdateID,
|
||||
}
|
||||
mockIDP := &idp.MockIDP{
|
||||
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
|
||||
return map[string][]*idp.UserData{
|
||||
"acc-1": {{ID: "user-1", Email: "user1@example.com", Name: "IDP Name"}},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := PopulateUserInfo(srv, mockIDP, false)
|
||||
assert.NoError(t, err)
|
||||
require.Len(t, ms.updateInfoCalls, 1)
|
||||
assert.Equal(t, "user1@example.com", ms.updateInfoCalls[0].Email)
|
||||
assert.Equal(t, "Existing Name", ms.updateInfoCalls[0].Name)
|
||||
})
|
||||
|
||||
t.Run("updates only missing name when email exists", func(t *testing.T) {
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return []*types.User{
|
||||
{Id: "user-1", AccountID: "acc-1", Email: "existing@example.com", Name: ""},
|
||||
}, nil
|
||||
},
|
||||
updateUserIDFunc: noopUpdateID,
|
||||
}
|
||||
mockIDP := &idp.MockIDP{
|
||||
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
|
||||
return map[string][]*idp.UserData{
|
||||
"acc-1": {{ID: "user-1", Email: "idp@example.com", Name: "IDP Name"}},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := PopulateUserInfo(srv, mockIDP, false)
|
||||
assert.NoError(t, err)
|
||||
require.Len(t, ms.updateInfoCalls, 1)
|
||||
assert.Equal(t, "existing@example.com", ms.updateInfoCalls[0].Email)
|
||||
assert.Equal(t, "IDP Name", ms.updateInfoCalls[0].Name)
|
||||
})
|
||||
|
||||
t.Run("skips users that already have both email and name", func(t *testing.T) {
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return []*types.User{
|
||||
{Id: "user-1", AccountID: "acc-1", Email: "user1@example.com", Name: "User One"},
|
||||
}, nil
|
||||
},
|
||||
updateUserIDFunc: noopUpdateID,
|
||||
}
|
||||
mockIDP := &idp.MockIDP{
|
||||
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
|
||||
return map[string][]*idp.UserData{
|
||||
"acc-1": {{ID: "user-1", Email: "different@example.com", Name: "Different Name"}},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := PopulateUserInfo(srv, mockIDP, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, ms.updateInfoCalls)
|
||||
})
|
||||
|
||||
t.Run("skips service users", func(t *testing.T) {
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return []*types.User{
|
||||
{Id: "svc-1", AccountID: "acc-1", Email: "", Name: "", IsServiceUser: true},
|
||||
}, nil
|
||||
},
|
||||
updateUserIDFunc: noopUpdateID,
|
||||
}
|
||||
mockIDP := &idp.MockIDP{
|
||||
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
|
||||
return map[string][]*idp.UserData{
|
||||
"acc-1": {{ID: "svc-1", Email: "svc@example.com", Name: "Service"}},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := PopulateUserInfo(srv, mockIDP, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, ms.updateInfoCalls)
|
||||
})
|
||||
|
||||
t.Run("skips users not found in IDP", func(t *testing.T) {
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return []*types.User{
|
||||
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
|
||||
}, nil
|
||||
},
|
||||
updateUserIDFunc: noopUpdateID,
|
||||
}
|
||||
mockIDP := &idp.MockIDP{
|
||||
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
|
||||
return map[string][]*idp.UserData{
|
||||
"acc-1": {{ID: "different-user", Email: "other@example.com", Name: "Other"}},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := PopulateUserInfo(srv, mockIDP, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, ms.updateInfoCalls)
|
||||
})
|
||||
|
||||
t.Run("looks up dex-encoded user IDs by original ID", func(t *testing.T) {
|
||||
dexEncodedID := dex.EncodeDexUserID("original-idp-id", "my-connector")
|
||||
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return []*types.User{
|
||||
{Id: dexEncodedID, AccountID: "acc-1", Email: "", Name: ""},
|
||||
}, nil
|
||||
},
|
||||
updateUserIDFunc: noopUpdateID,
|
||||
}
|
||||
mockIDP := &idp.MockIDP{
|
||||
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
|
||||
return map[string][]*idp.UserData{
|
||||
"acc-1": {{ID: "original-idp-id", Email: "user@example.com", Name: "User"}},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := PopulateUserInfo(srv, mockIDP, false)
|
||||
assert.NoError(t, err)
|
||||
require.Len(t, ms.updateInfoCalls, 1)
|
||||
assert.Equal(t, dexEncodedID, ms.updateInfoCalls[0].UserID)
|
||||
assert.Equal(t, "user@example.com", ms.updateInfoCalls[0].Email)
|
||||
assert.Equal(t, "User", ms.updateInfoCalls[0].Name)
|
||||
})
|
||||
|
||||
t.Run("handles multiple users across multiple accounts", func(t *testing.T) {
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return []*types.User{
|
||||
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
|
||||
{Id: "user-2", AccountID: "acc-1", Email: "already@set.com", Name: "Already Set"},
|
||||
{Id: "user-3", AccountID: "acc-2", Email: "", Name: ""},
|
||||
}, nil
|
||||
},
|
||||
updateUserIDFunc: noopUpdateID,
|
||||
}
|
||||
mockIDP := &idp.MockIDP{
|
||||
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
|
||||
return map[string][]*idp.UserData{
|
||||
"acc-1": {
|
||||
{ID: "user-1", Email: "u1@example.com", Name: "User 1"},
|
||||
{ID: "user-2", Email: "u2@example.com", Name: "User 2"},
|
||||
},
|
||||
"acc-2": {
|
||||
{ID: "user-3", Email: "u3@example.com", Name: "User 3"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := PopulateUserInfo(srv, mockIDP, false)
|
||||
assert.NoError(t, err)
|
||||
require.Len(t, ms.updateInfoCalls, 2)
|
||||
assert.Equal(t, "user-1", ms.updateInfoCalls[0].UserID)
|
||||
assert.Equal(t, "u1@example.com", ms.updateInfoCalls[0].Email)
|
||||
assert.Equal(t, "user-3", ms.updateInfoCalls[1].UserID)
|
||||
assert.Equal(t, "u3@example.com", ms.updateInfoCalls[1].Email)
|
||||
})
|
||||
|
||||
t.Run("returns error when UpdateUserInfo fails", func(t *testing.T) {
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return []*types.User{
|
||||
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
|
||||
}, nil
|
||||
},
|
||||
updateUserIDFunc: noopUpdateID,
|
||||
updateUserInfoFunc: func(ctx context.Context, userID, email, name string) error {
|
||||
return fmt.Errorf("db write error")
|
||||
},
|
||||
}
|
||||
mockIDP := &idp.MockIDP{
|
||||
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
|
||||
return map[string][]*idp.UserData{
|
||||
"acc-1": {{ID: "user-1", Email: "u1@example.com", Name: "User 1"}},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := PopulateUserInfo(srv, mockIDP, false)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "failed to update user info for user-1")
|
||||
})
|
||||
|
||||
t.Run("stops on first UpdateUserInfo error", func(t *testing.T) {
|
||||
callCount := 0
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return []*types.User{
|
||||
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
|
||||
{Id: "user-2", AccountID: "acc-1", Email: "", Name: ""},
|
||||
}, nil
|
||||
},
|
||||
updateUserIDFunc: noopUpdateID,
|
||||
updateUserInfoFunc: func(ctx context.Context, userID, email, name string) error {
|
||||
callCount++
|
||||
return fmt.Errorf("db write error")
|
||||
},
|
||||
}
|
||||
mockIDP := &idp.MockIDP{
|
||||
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
|
||||
return map[string][]*idp.UserData{
|
||||
"acc-1": {
|
||||
{ID: "user-1", Email: "u1@example.com", Name: "U1"},
|
||||
{ID: "user-2", Email: "u2@example.com", Name: "U2"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := PopulateUserInfo(srv, mockIDP, false)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, 1, callCount)
|
||||
})
|
||||
|
||||
t.Run("dry run does not call UpdateUserInfo", func(t *testing.T) {
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return []*types.User{
|
||||
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
|
||||
{Id: "user-2", AccountID: "acc-1", Email: "", Name: ""},
|
||||
}, nil
|
||||
},
|
||||
updateUserIDFunc: noopUpdateID,
|
||||
updateUserInfoFunc: func(ctx context.Context, userID, email, name string) error {
|
||||
t.Fatal("UpdateUserInfo should not be called in dry-run mode")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
mockIDP := &idp.MockIDP{
|
||||
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
|
||||
return map[string][]*idp.UserData{
|
||||
"acc-1": {
|
||||
{ID: "user-1", Email: "u1@example.com", Name: "U1"},
|
||||
{ID: "user-2", Email: "u2@example.com", Name: "U2"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := PopulateUserInfo(srv, mockIDP, true)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, ms.updateInfoCalls)
|
||||
})
|
||||
|
||||
t.Run("skips user when IDP has empty email and name too", func(t *testing.T) {
|
||||
ms := &testStore{
|
||||
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
|
||||
return []*types.User{
|
||||
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
|
||||
}, nil
|
||||
},
|
||||
updateUserIDFunc: noopUpdateID,
|
||||
}
|
||||
mockIDP := &idp.MockIDP{
|
||||
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
|
||||
return map[string][]*idp.UserData{
|
||||
"acc-1": {{ID: "user-1", Email: "", Name: ""}},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := &testServer{store: ms}
|
||||
err := PopulateUserInfo(srv, mockIDP, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, ms.updateInfoCalls)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSchemaError_String(t *testing.T) {
|
||||
t.Run("missing table", func(t *testing.T) {
|
||||
e := SchemaError{Table: "jobs"}
|
||||
assert.Equal(t, `table "jobs" is missing`, e.String())
|
||||
})
|
||||
|
||||
t.Run("missing column", func(t *testing.T) {
|
||||
e := SchemaError{Table: "users", Column: "email"}
|
||||
assert.Equal(t, `column "email" on table "users" is missing`, e.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequiredSchema(t *testing.T) {
|
||||
// Verify RequiredSchema covers all the tables touched by UpdateUserID and UpdateUserInfo.
|
||||
expectedTables := []string{
|
||||
"users",
|
||||
"personal_access_tokens",
|
||||
"peers",
|
||||
"accounts",
|
||||
"user_invites",
|
||||
"proxy_access_tokens",
|
||||
"jobs",
|
||||
}
|
||||
|
||||
schemaTableNames := make([]string, len(RequiredSchema))
|
||||
for i, s := range RequiredSchema {
|
||||
schemaTableNames[i] = s.Table
|
||||
}
|
||||
|
||||
for _, expected := range expectedTables {
|
||||
assert.Contains(t, schemaTableNames, expected, "RequiredSchema should include table %q", expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckSchema_MockStore(t *testing.T) {
|
||||
t.Run("returns nil when all schema exists", func(t *testing.T) {
|
||||
ms := &testStore{
|
||||
checkSchemaFunc: func(checks []SchemaCheck) []SchemaError {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
errs := ms.CheckSchema(RequiredSchema)
|
||||
assert.Empty(t, errs)
|
||||
})
|
||||
|
||||
t.Run("returns errors for missing tables", func(t *testing.T) {
|
||||
ms := &testStore{
|
||||
checkSchemaFunc: func(checks []SchemaCheck) []SchemaError {
|
||||
return []SchemaError{
|
||||
{Table: "jobs"},
|
||||
{Table: "proxy_access_tokens"},
|
||||
}
|
||||
},
|
||||
}
|
||||
errs := ms.CheckSchema(RequiredSchema)
|
||||
require.Len(t, errs, 2)
|
||||
assert.Equal(t, "jobs", errs[0].Table)
|
||||
assert.Equal(t, "", errs[0].Column)
|
||||
assert.Equal(t, "proxy_access_tokens", errs[1].Table)
|
||||
})
|
||||
|
||||
t.Run("returns errors for missing columns", func(t *testing.T) {
|
||||
ms := &testStore{
|
||||
checkSchemaFunc: func(checks []SchemaCheck) []SchemaError {
|
||||
return []SchemaError{
|
||||
{Table: "users", Column: "email"},
|
||||
{Table: "users", Column: "name"},
|
||||
}
|
||||
},
|
||||
}
|
||||
errs := ms.CheckSchema(RequiredSchema)
|
||||
require.Len(t, errs, 2)
|
||||
assert.Equal(t, "users", errs[0].Table)
|
||||
assert.Equal(t, "email", errs[0].Column)
|
||||
})
|
||||
}
|
||||
82
management/server/idp/migration/store.go
Normal file
82
management/server/idp/migration/store.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// SchemaCheck represents a table and the columns required on it.
|
||||
type SchemaCheck struct {
|
||||
Table string
|
||||
Columns []string
|
||||
}
|
||||
|
||||
// RequiredSchema lists all tables and columns that the migration tool needs.
|
||||
// If any are missing, the user must upgrade their management server first so
|
||||
// that the automatic GORM migrations create them.
|
||||
var RequiredSchema = []SchemaCheck{
|
||||
{Table: "users", Columns: []string{"id", "email", "name", "account_id"}},
|
||||
{Table: "personal_access_tokens", Columns: []string{"user_id", "created_by"}},
|
||||
{Table: "peers", Columns: []string{"user_id"}},
|
||||
{Table: "accounts", Columns: []string{"created_by"}},
|
||||
{Table: "user_invites", Columns: []string{"created_by"}},
|
||||
{Table: "proxy_access_tokens", Columns: []string{"created_by"}},
|
||||
{Table: "jobs", Columns: []string{"triggered_by"}},
|
||||
}
|
||||
|
||||
// SchemaError describes a single missing table or column.
|
||||
type SchemaError struct {
|
||||
Table string
|
||||
Column string // empty when the whole table is missing
|
||||
}
|
||||
|
||||
func (e SchemaError) String() string {
|
||||
if e.Column == "" {
|
||||
return fmt.Sprintf("table %q is missing", e.Table)
|
||||
}
|
||||
return fmt.Sprintf("column %q on table %q is missing", e.Column, e.Table)
|
||||
}
|
||||
|
||||
// Store defines the data store operations required for IdP user migration.
|
||||
// This interface is separate from the main store.Store interface because these methods
|
||||
// are only used during one-time migration and should be removed once migration tooling
|
||||
// is no longer needed.
|
||||
//
|
||||
// The SQL store implementations (SqlStore) already have these methods on their concrete
|
||||
// types, so they satisfy this interface via Go's structural typing with zero code changes.
|
||||
type Store interface {
|
||||
// ListUsers returns all users across all accounts.
|
||||
ListUsers(ctx context.Context) ([]*types.User, error)
|
||||
|
||||
// UpdateUserID atomically updates a user's ID and all foreign key references
|
||||
// across the database (peers, groups, policies, PATs, etc.).
|
||||
UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error
|
||||
|
||||
// UpdateUserInfo updates a user's email and name in the store.
|
||||
UpdateUserInfo(ctx context.Context, userID, email, name string) error
|
||||
|
||||
// CheckSchema verifies that all tables and columns required by the migration
|
||||
// exist in the database. Returns a list of problems; an empty slice means OK.
|
||||
CheckSchema(checks []SchemaCheck) []SchemaError
|
||||
}
|
||||
|
||||
// RequiredEventSchema lists all tables and columns that the migration tool needs
|
||||
// in the activity/event store.
|
||||
var RequiredEventSchema = []SchemaCheck{
|
||||
{Table: "events", Columns: []string{"initiator_id", "target_id"}},
|
||||
{Table: "deleted_users", Columns: []string{"id"}},
|
||||
}
|
||||
|
||||
// EventStore defines the activity event store operations required for migration.
|
||||
// Like Store, this is a temporary interface for migration tooling only.
|
||||
type EventStore interface {
|
||||
// CheckSchema verifies that all tables and columns required by the migration
|
||||
// exist in the event database. Returns a list of problems; an empty slice means OK.
|
||||
CheckSchema(checks []SchemaCheck) []SchemaError
|
||||
|
||||
// UpdateUserID updates all event references (initiator_id, target_id) and
|
||||
// deleted_users records to use the new user ID format.
|
||||
UpdateUserID(ctx context.Context, oldUserID, newUserID string) error
|
||||
}
|
||||
177
management/server/store/sql_store_idp_migration.go
Normal file
177
management/server/store/sql_store_idp_migration.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package store
|
||||
|
||||
// This file contains migration-only methods on SqlStore.
|
||||
// They satisfy the migration.Store interface via duck typing.
|
||||
// Delete this file when migration tooling is no longer needed.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp/migration"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
func (s *SqlStore) CheckSchema(checks []migration.SchemaCheck) []migration.SchemaError {
|
||||
migrator := s.db.Migrator()
|
||||
var errs []migration.SchemaError
|
||||
|
||||
for _, check := range checks {
|
||||
if !migrator.HasTable(check.Table) {
|
||||
errs = append(errs, migration.SchemaError{Table: check.Table})
|
||||
continue
|
||||
}
|
||||
for _, col := range check.Columns {
|
||||
if !migrator.HasColumn(check.Table, col) {
|
||||
errs = append(errs, migration.SchemaError{Table: check.Table, Column: col})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
func (s *SqlStore) ListUsers(ctx context.Context) ([]*types.User, error) {
|
||||
tx := s.db
|
||||
var users []*types.User
|
||||
result := tx.Find(&users)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("error when listing users from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue listing users from store")
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// txDeferFKConstraints defers foreign key constraint checks for the duration of the transaction.
|
||||
// MySQL is already handled by s.transaction (SET FOREIGN_KEY_CHECKS = 0).
|
||||
func (s *SqlStore) txDeferFKConstraints(tx *gorm.DB) error {
|
||||
if s.storeEngine == types.SqliteStoreEngine {
|
||||
return tx.Exec("PRAGMA defer_foreign_keys = ON").Error
|
||||
}
|
||||
|
||||
if s.storeEngine != types.PostgresStoreEngine {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GORM creates FK constraints as NOT DEFERRABLE by default, so
|
||||
// SET CONSTRAINTS ALL DEFERRED is a no-op unless we ALTER them first.
|
||||
err := tx.Exec(`
|
||||
DO $$ DECLARE r RECORD;
|
||||
BEGIN
|
||||
FOR r IN SELECT conname, conrelid::regclass AS tbl
|
||||
FROM pg_constraint WHERE contype = 'f' AND NOT condeferrable
|
||||
LOOP
|
||||
EXECUTE format('ALTER TABLE %s ALTER CONSTRAINT %I DEFERRABLE INITIALLY IMMEDIATE', r.tbl, r.conname);
|
||||
END LOOP;
|
||||
END $$
|
||||
`).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("make FK constraints deferrable: %w", err)
|
||||
}
|
||||
return tx.Exec("SET CONSTRAINTS ALL DEFERRED").Error
|
||||
}
|
||||
|
||||
// txRestoreFKConstraints reverts FK constraints back to NOT DEFERRABLE after the
|
||||
// deferred updates are done but before the transaction commits.
|
||||
func (s *SqlStore) txRestoreFKConstraints(tx *gorm.DB) error {
|
||||
if s.storeEngine != types.PostgresStoreEngine {
|
||||
return nil
|
||||
}
|
||||
|
||||
return tx.Exec(`
|
||||
DO $$ DECLARE r RECORD;
|
||||
BEGIN
|
||||
FOR r IN SELECT conname, conrelid::regclass AS tbl
|
||||
FROM pg_constraint WHERE contype = 'f' AND condeferrable
|
||||
LOOP
|
||||
EXECUTE format('ALTER TABLE %s ALTER CONSTRAINT %I NOT DEFERRABLE', r.tbl, r.conname);
|
||||
END LOOP;
|
||||
END $$
|
||||
`).Error
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateUserInfo(ctx context.Context, userID, email, name string) error {
|
||||
user := &types.User{Email: email, Name: name}
|
||||
if err := user.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return fmt.Errorf("encrypt user info: %w", err)
|
||||
}
|
||||
|
||||
result := s.db.Model(&types.User{}).Where("id = ?", userID).Updates(map[string]any{
|
||||
"email": user.Email,
|
||||
"name": user.Name,
|
||||
})
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("error updating user info for %s: %s", userID, result.Error)
|
||||
return status.Errorf(status.Internal, "failed to update user info")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error {
|
||||
type fkUpdate struct {
|
||||
model any
|
||||
column string
|
||||
where string
|
||||
}
|
||||
|
||||
updates := []fkUpdate{
|
||||
{&types.PersonalAccessToken{}, "user_id", "user_id = ?"},
|
||||
{&types.PersonalAccessToken{}, "created_by", "created_by = ?"},
|
||||
{&nbpeer.Peer{}, "user_id", "user_id = ?"},
|
||||
{&types.UserInviteRecord{}, "created_by", "created_by = ?"},
|
||||
{&types.Account{}, "created_by", "created_by = ?"},
|
||||
{&types.ProxyAccessToken{}, "created_by", "created_by = ?"},
|
||||
{&types.Job{}, "triggered_by", "triggered_by = ?"},
|
||||
}
|
||||
|
||||
log.Info("Updating user ID in the store")
|
||||
err := s.transaction(func(tx *gorm.DB) error {
|
||||
if err := s.txDeferFKConstraints(tx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, u := range updates {
|
||||
if err := tx.Model(u.model).Where(u.where, oldUserID).Update(u.column, newUserID).Error; err != nil {
|
||||
return fmt.Errorf("update %s: %w", u.column, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Model(&types.User{}).Where(accountAndIDQueryCondition, accountID, oldUserID).Update("id", newUserID).Error; err != nil {
|
||||
return fmt.Errorf("update users: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update user ID in the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to update user ID in store")
|
||||
}
|
||||
|
||||
log.Info("Restoring FK constraints")
|
||||
err = s.transaction(func(tx *gorm.DB) error {
|
||||
if err := s.txRestoreFKConstraints(tx); err != nil {
|
||||
return fmt.Errorf("restore FK constraints: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to restore FK constraints after user ID update: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to restore FK constraints after user ID update")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user