Merge branch 'main' into crowdsec-integration

# Conflicts:
#	shared/management/proto/proxy_service.pb.go
This commit is contained in:
Viktor Liu
2026-04-01 19:59:14 +02:00
70 changed files with 6354 additions and 565 deletions

View File

@@ -0,0 +1,61 @@
package store
// This file contains migration-only methods on Store.
// They satisfy the migration.MigrationEventStore interface via duck typing.
// Delete this file when migration tooling is no longer needed.
import (
"context"
"fmt"
"gorm.io/gorm"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp/migration"
)
// CheckSchema verifies that all tables and columns required by the migration exist in the event database.
func (store *Store) CheckSchema(checks []migration.SchemaCheck) []migration.SchemaError {
migrator := store.db.Migrator()
var errs []migration.SchemaError
for _, check := range checks {
if !migrator.HasTable(check.Table) {
errs = append(errs, migration.SchemaError{Table: check.Table})
continue
}
for _, col := range check.Columns {
if !migrator.HasColumn(check.Table, col) {
errs = append(errs, migration.SchemaError{Table: check.Table, Column: col})
}
}
}
return errs
}
// UpdateUserID updates all references to oldUserID in events and deleted_users tables.
func (store *Store) UpdateUserID(ctx context.Context, oldUserID, newUserID string) error {
return store.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Model(&activity.Event{}).
Where("initiator_id = ?", oldUserID).
Update("initiator_id", newUserID).Error; err != nil {
return fmt.Errorf("update events.initiator_id: %w", err)
}
if err := tx.Model(&activity.Event{}).
Where("target_id = ?", oldUserID).
Update("target_id", newUserID).Error; err != nil {
return fmt.Errorf("update events.target_id: %w", err)
}
// Raw exec: GORM can't update a PK via Model().Update()
if err := tx.Exec(
"UPDATE deleted_users SET id = ? WHERE id = ?", newUserID, oldUserID,
).Error; err != nil {
return fmt.Errorf("update deleted_users.id: %w", err)
}
return nil
})
}

View File

@@ -0,0 +1,161 @@
package store
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/util/crypt"
)
func TestUpdateUserID(t *testing.T) {
ctx := context.Background()
newStore := func(t *testing.T) *Store {
t.Helper()
key, _ := crypt.GenerateKey()
s, err := NewSqlStore(ctx, t.TempDir(), key)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { s.Close(ctx) }) //nolint
return s
}
t.Run("updates initiator_id in events", func(t *testing.T) {
store := newStore(t)
accountID := "account_1"
_, err := store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "old-user",
TargetID: "some-peer",
AccountID: accountID,
})
assert.NoError(t, err)
err = store.UpdateUserID(ctx, "old-user", "new-user")
assert.NoError(t, err)
result, err := store.Get(ctx, accountID, 0, 10, false)
assert.NoError(t, err)
assert.Len(t, result, 1)
assert.Equal(t, "new-user", result[0].InitiatorID)
})
t.Run("updates target_id in events", func(t *testing.T) {
store := newStore(t)
accountID := "account_1"
_, err := store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "some-admin",
TargetID: "old-user",
AccountID: accountID,
})
assert.NoError(t, err)
err = store.UpdateUserID(ctx, "old-user", "new-user")
assert.NoError(t, err)
result, err := store.Get(ctx, accountID, 0, 10, false)
assert.NoError(t, err)
assert.Len(t, result, 1)
assert.Equal(t, "new-user", result[0].TargetID)
})
t.Run("updates deleted_users id", func(t *testing.T) {
store := newStore(t)
accountID := "account_1"
// Save an event with email/name meta to create a deleted_users row for "old-user"
_, err := store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "admin",
TargetID: "old-user",
AccountID: accountID,
Meta: map[string]any{
"email": "user@example.com",
"name": "Test User",
},
})
assert.NoError(t, err)
err = store.UpdateUserID(ctx, "old-user", "new-user")
assert.NoError(t, err)
// Save another event referencing new-user with email/name meta.
// This should upsert (not conflict) because the PK was already migrated.
_, err = store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "admin",
TargetID: "new-user",
AccountID: accountID,
Meta: map[string]any{
"email": "user@example.com",
"name": "Test User",
},
})
assert.NoError(t, err)
// The deleted user info should be retrievable via Get (joined on target_id)
result, err := store.Get(ctx, accountID, 0, 10, false)
assert.NoError(t, err)
assert.Len(t, result, 2)
for _, ev := range result {
assert.Equal(t, "new-user", ev.TargetID)
}
})
t.Run("no-op when old user ID does not exist", func(t *testing.T) {
store := newStore(t)
err := store.UpdateUserID(ctx, "nonexistent-user", "new-user")
assert.NoError(t, err)
})
t.Run("only updates matching user leaves others unchanged", func(t *testing.T) {
store := newStore(t)
accountID := "account_1"
_, err := store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "user-a",
TargetID: "peer-1",
AccountID: accountID,
})
assert.NoError(t, err)
_, err = store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "user-b",
TargetID: "peer-2",
AccountID: accountID,
})
assert.NoError(t, err)
err = store.UpdateUserID(ctx, "user-a", "user-a-new")
assert.NoError(t, err)
result, err := store.Get(ctx, accountID, 0, 10, false)
assert.NoError(t, err)
assert.Len(t, result, 2)
for _, ev := range result {
if ev.TargetID == "peer-1" {
assert.Equal(t, "user-a-new", ev.InitiatorID)
} else {
assert.Equal(t, "user-b", ev.InitiatorID)
}
}
})
}

View File

@@ -33,15 +33,20 @@ type manager struct {
extractor *nbjwt.ClaimsExtractor
}
func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool) Manager {
// @note if invalid/missing parameters are sent the validator will instantiate
// but it will fail when validating and parsing the token
jwtValidator := nbjwt.NewValidator(
issuer,
allAudiences,
keysLocation,
idpRefreshKeys,
)
func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool, keyFetcher nbjwt.KeyFetcher) Manager {
var jwtValidator *nbjwt.Validator
if keyFetcher != nil {
jwtValidator = nbjwt.NewValidatorWithKeyFetcher(issuer, allAudiences, keyFetcher)
} else {
// @note if invalid/missing parameters are sent the validator will instantiate
// but it will fail when validating and parsing the token
jwtValidator = nbjwt.NewValidator(
issuer,
allAudiences,
keysLocation,
idpRefreshKeys,
)
}
claimsExtractor := nbjwt.NewClaimsExtractor(
nbjwt.WithAudience(audience),

View File

@@ -52,7 +52,7 @@ func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
user, pat, _, _, err := manager.GetPATInfo(context.Background(), token)
if err != nil {
@@ -92,7 +92,7 @@ func TestAuthManager_MarkPATUsed(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
err = manager.MarkPATUsed(context.Background(), "tokenId")
if err != nil {
@@ -142,7 +142,7 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) {
// these tests only assert groups are parsed from token as per account settings
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}})
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
t.Run("JWT groups disabled", func(t *testing.T) {
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
@@ -225,7 +225,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) {
keyId := "test-key"
// note, we can use a nil store because ValidateAndParseToken does not use it in it's flow
manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false)
manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false, nil)
customClaim := func(name string) string {
return fmt.Sprintf("%s/%s", audience, name)

View File

@@ -130,6 +130,10 @@ func (gl *geolocationImpl) Lookup(ip net.IP) (*Record, error) {
gl.mux.RLock()
defer gl.mux.RUnlock()
if gl.db == nil {
return nil, fmt.Errorf("geolocation database is not available")
}
var record Record
err := gl.db.Lookup(ip, &record)
if err != nil {
@@ -173,8 +177,14 @@ func (gl *geolocationImpl) GetCitiesByCountry(countryISOCode string) ([]City, er
func (gl *geolocationImpl) Stop() error {
close(gl.stopCh)
if gl.db != nil {
if err := gl.db.Close(); err != nil {
gl.mux.Lock()
db := gl.db
gl.db = nil
gl.mux.Unlock()
if db != nil {
if err := db.Close(); err != nil {
return err
}
}

View File

@@ -119,7 +119,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
am.SetServiceManager(serviceManager)
// @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false, nil)
authManagerMock := &serverauth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
@@ -248,7 +248,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
am.SetServiceManager(serviceManager)
// @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false, nil)
authManagerMock := &serverauth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/telemetry"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
const (
@@ -48,6 +49,8 @@ type EmbeddedIdPConfig struct {
// Existing local users are preserved and will be able to login again if re-enabled.
// Cannot be enabled if no external identity provider connectors are configured.
LocalAuthDisabled bool
// StaticConnectors are additional connectors to seed during initialization
StaticConnectors []dex.Connector
}
// EmbeddedStorageConfig holds storage configuration for the embedded IdP.
@@ -157,6 +160,7 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
RedirectURIs: cliRedirectURIs,
},
},
StaticConnectors: c.StaticConnectors,
}
// Add owner user if provided
@@ -193,6 +197,9 @@ type OAuthConfigProvider interface {
// Management server has embedded Dex and can validate tokens via localhost,
// avoiding external network calls and DNS resolution issues during startup.
GetLocalKeysLocation() string
// GetKeyFetcher returns a KeyFetcher that reads keys directly from the IDP storage,
// or nil if direct key fetching is not supported (falls back to HTTP).
GetKeyFetcher() nbjwt.KeyFetcher
GetClientIDs() []string
GetUserIDClaim() string
GetTokenEndpoint() string
@@ -593,6 +600,11 @@ func (m *EmbeddedIdPManager) GetCLIRedirectURLs() []string {
return m.config.CLIRedirectURIs
}
// GetKeyFetcher returns a KeyFetcher that reads keys directly from Dex storage.
func (m *EmbeddedIdPManager) GetKeyFetcher() nbjwt.KeyFetcher {
return m.provider.GetJWKS
}
// GetKeysLocation returns the JWKS endpoint URL for token validation.
func (m *EmbeddedIdPManager) GetKeysLocation() string {
return m.provider.GetKeysLocation()

View File

@@ -0,0 +1,235 @@
// Package migration provides utility functions for migrating from the external IdP solution in pre v0.62.0
// to the new embedded IdP manager (Dex based), which is the default in v0.62.0 and later.
// It includes functions to seed connectors and migrate existing users to use these connectors.
package migration
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"os"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/types"
)
// Server is the dependency interface that migration functions use to access
// the main data store and the activity event store.
type Server interface {
Store() Store
EventStore() EventStore // may return nil
}
const idpSeedInfoKey = "IDP_SEED_INFO"
const dryRunEnvKey = "NB_IDP_MIGRATION_DRY_RUN"
func isDryRun() bool {
return os.Getenv(dryRunEnvKey) == "true"
}
var ErrNoSeedInfo = errors.New("no seed info found in environment")
// SeedConnectorFromEnv reads the IDP_SEED_INFO env var, base64-decodes it,
// and JSON-unmarshals it into a dex.Connector. Returns nil if not set.
func SeedConnectorFromEnv() (*dex.Connector, error) {
val, ok := os.LookupEnv(idpSeedInfoKey)
if !ok || val == "" {
return nil, ErrNoSeedInfo
}
decoded, err := base64.StdEncoding.DecodeString(val)
if err != nil {
return nil, fmt.Errorf("base64 decode: %w", err)
}
var conn dex.Connector
if err := json.Unmarshal(decoded, &conn); err != nil {
return nil, fmt.Errorf("json unmarshal: %w", err)
}
return &conn, nil
}
// MigrateUsersToStaticConnectors re-keys every user ID in the main store (and
// the activity store, if present) so that it encodes the given connector ID,
// skipping users that have already been migrated. Set NB_IDP_MIGRATION_DRY_RUN=true
// to log what would happen without writing any changes.
func MigrateUsersToStaticConnectors(s Server, conn *dex.Connector) error {
ctx := context.Background()
if isDryRun() {
log.Info("[DRY RUN] migration dry-run mode enabled, no changes will be written")
}
users, err := s.Store().ListUsers(ctx)
if err != nil {
return fmt.Errorf("failed to list users: %w", err)
}
// Reconciliation pass: fix activity store for users already migrated in main DB
// but whose activity references may still use old IDs (from a previous partial failure).
if s.EventStore() != nil && !isDryRun() {
if err := reconcileActivityStore(ctx, s.EventStore(), users); err != nil {
return err
}
}
var migratedCount, skippedCount int
for _, user := range users {
_, _, decErr := dex.DecodeDexUserID(user.Id)
if decErr == nil {
skippedCount++
continue
}
newUserID := dex.EncodeDexUserID(user.Id, conn.ID)
if isDryRun() {
log.Infof("[DRY RUN] would migrate user %s -> %s (account: %s)", user.Id, newUserID, user.AccountID)
migratedCount++
continue
}
if err := migrateUser(ctx, s, user.Id, user.AccountID, newUserID); err != nil {
return err
}
migratedCount++
}
if isDryRun() {
log.Infof("[DRY RUN] migration summary: %d users would be migrated, %d already migrated", migratedCount, skippedCount)
} else {
log.Infof("migration complete: %d users migrated, %d already migrated", migratedCount, skippedCount)
}
return nil
}
// reconcileActivityStore updates activity store references for users already migrated
// in the main DB whose activity entries may still use old IDs from a previous partial failure.
func reconcileActivityStore(ctx context.Context, eventStore EventStore, users []*types.User) error {
for _, user := range users {
originalID, _, err := dex.DecodeDexUserID(user.Id)
if err != nil {
// skip users that aren't migrated, they will be handled in the main migration loop
continue
}
if err := eventStore.UpdateUserID(ctx, originalID, user.Id); err != nil {
return fmt.Errorf("reconcile activity store for user %s: %w", user.Id, err)
}
}
return nil
}
// migrateUser updates a single user's ID in both the main store and the activity store.
func migrateUser(ctx context.Context, s Server, oldID, accountID, newID string) error {
if err := s.Store().UpdateUserID(ctx, accountID, oldID, newID); err != nil {
return fmt.Errorf("failed to update user ID for user %s: %w", oldID, err)
}
if s.EventStore() == nil {
return nil
}
if err := s.EventStore().UpdateUserID(ctx, oldID, newID); err != nil {
return fmt.Errorf("failed to update activity store user ID for user %s: %w", oldID, err)
}
return nil
}
// PopulateUserInfo fetches user email and name from the external IDP and updates
// the store for users that are missing this information.
func PopulateUserInfo(s Server, idpManager idp.Manager, dryRun bool) error {
ctx := context.Background()
users, err := s.Store().ListUsers(ctx)
if err != nil {
return fmt.Errorf("failed to list users: %w", err)
}
// Build a map of IDP user ID -> UserData from the external IDP
allAccounts, err := idpManager.GetAllAccounts(ctx)
if err != nil {
return fmt.Errorf("failed to fetch accounts from IDP: %w", err)
}
idpUsers := make(map[string]*idp.UserData)
for _, accountUsers := range allAccounts {
for _, userData := range accountUsers {
idpUsers[userData.ID] = userData
}
}
log.Infof("fetched %d users from IDP", len(idpUsers))
var updatedCount, skippedCount, notFoundCount int
for _, user := range users {
if user.IsServiceUser {
skippedCount++
continue
}
if user.Email != "" && user.Name != "" {
skippedCount++
continue
}
// The user ID in the store may be the original IDP ID or a Dex-encoded ID.
// Try to decode the Dex format first to get the original IDP ID.
lookupID := user.Id
if originalID, _, decErr := dex.DecodeDexUserID(user.Id); decErr == nil {
lookupID = originalID
}
idpUser, found := idpUsers[lookupID]
if !found {
notFoundCount++
log.Debugf("user %s (lookup: %s) not found in IDP, skipping", user.Id, lookupID)
continue
}
email := user.Email
name := user.Name
if email == "" && idpUser.Email != "" {
email = idpUser.Email
}
if name == "" && idpUser.Name != "" {
name = idpUser.Name
}
if email == user.Email && name == user.Name {
skippedCount++
continue
}
if dryRun {
log.Infof("[DRY RUN] would update user %s: email=%q, name=%q", user.Id, email, name)
updatedCount++
continue
}
if err := s.Store().UpdateUserInfo(ctx, user.Id, email, name); err != nil {
return fmt.Errorf("failed to update user info for %s: %w", user.Id, err)
}
log.Infof("updated user %s: email=%q, name=%q", user.Id, email, name)
updatedCount++
}
if dryRun {
log.Infof("[DRY RUN] user info summary: %d would be updated, %d skipped, %d not found in IDP", updatedCount, skippedCount, notFoundCount)
} else {
log.Infof("user info population complete: %d updated, %d skipped, %d not found in IDP", updatedCount, skippedCount, notFoundCount)
}
return nil
}

View File

@@ -0,0 +1,828 @@
package migration
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/types"
)
// testStore is a hand-written mock for MigrationStore.
type testStore struct {
listUsersFunc func(ctx context.Context) ([]*types.User, error)
updateUserIDFunc func(ctx context.Context, accountID, oldUserID, newUserID string) error
updateUserInfoFunc func(ctx context.Context, userID, email, name string) error
checkSchemaFunc func(checks []SchemaCheck) []SchemaError
updateCalls []updateUserIDCall
updateInfoCalls []updateUserInfoCall
}
type updateUserIDCall struct {
AccountID string
OldUserID string
NewUserID string
}
type updateUserInfoCall struct {
UserID string
Email string
Name string
}
func (s *testStore) ListUsers(ctx context.Context) ([]*types.User, error) {
return s.listUsersFunc(ctx)
}
func (s *testStore) UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error {
s.updateCalls = append(s.updateCalls, updateUserIDCall{accountID, oldUserID, newUserID})
return s.updateUserIDFunc(ctx, accountID, oldUserID, newUserID)
}
func (s *testStore) UpdateUserInfo(ctx context.Context, userID, email, name string) error {
s.updateInfoCalls = append(s.updateInfoCalls, updateUserInfoCall{userID, email, name})
if s.updateUserInfoFunc != nil {
return s.updateUserInfoFunc(ctx, userID, email, name)
}
return nil
}
func (s *testStore) CheckSchema(checks []SchemaCheck) []SchemaError {
if s.checkSchemaFunc != nil {
return s.checkSchemaFunc(checks)
}
return nil
}
type testServer struct {
store Store
eventStore EventStore
}
func (s *testServer) Store() Store { return s.store }
func (s *testServer) EventStore() EventStore { return s.eventStore }
func TestSeedConnectorFromEnv(t *testing.T) {
t.Run("returns ErrNoSeedInfo when env var is not set", func(t *testing.T) {
os.Unsetenv(idpSeedInfoKey)
conn, err := SeedConnectorFromEnv()
assert.ErrorIs(t, err, ErrNoSeedInfo)
assert.Nil(t, conn)
})
t.Run("returns ErrNoSeedInfo when env var is empty", func(t *testing.T) {
t.Setenv(idpSeedInfoKey, "")
conn, err := SeedConnectorFromEnv()
assert.ErrorIs(t, err, ErrNoSeedInfo)
assert.Nil(t, conn)
})
t.Run("returns error on invalid base64", func(t *testing.T) {
t.Setenv(idpSeedInfoKey, "not-valid-base64!!!")
conn, err := SeedConnectorFromEnv()
assert.NotErrorIs(t, err, ErrNoSeedInfo)
assert.Error(t, err)
assert.Nil(t, conn)
assert.Contains(t, err.Error(), "base64 decode")
})
t.Run("returns error on invalid JSON", func(t *testing.T) {
encoded := base64.StdEncoding.EncodeToString([]byte("not json"))
t.Setenv(idpSeedInfoKey, encoded)
conn, err := SeedConnectorFromEnv()
assert.NotErrorIs(t, err, ErrNoSeedInfo)
assert.Error(t, err)
assert.Nil(t, conn)
assert.Contains(t, err.Error(), "json unmarshal")
})
t.Run("successfully decodes valid connector", func(t *testing.T) {
expected := dex.Connector{
Type: "oidc",
Name: "Test Provider",
ID: "test-provider",
Config: map[string]any{
"issuer": "https://example.com",
"clientID": "my-client-id",
"clientSecret": "my-secret",
},
}
data, err := json.Marshal(expected)
require.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
t.Setenv(idpSeedInfoKey, encoded)
conn, err := SeedConnectorFromEnv()
assert.NoError(t, err)
require.NotNil(t, conn)
assert.Equal(t, expected.Type, conn.Type)
assert.Equal(t, expected.Name, conn.Name)
assert.Equal(t, expected.ID, conn.ID)
assert.Equal(t, expected.Config["issuer"], conn.Config["issuer"])
})
}
func TestMigrateUsersToStaticConnectors(t *testing.T) {
connector := &dex.Connector{
Type: "oidc",
Name: "Test Provider",
ID: "test-connector",
}
t.Run("succeeds with no users", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) { return nil, nil },
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { return nil },
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
})
t.Run("returns error when ListUsers fails", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return nil, fmt.Errorf("db error")
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { return nil },
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to list users")
})
t.Run("migrates single user with correct encoded ID", func(t *testing.T) {
user := &types.User{Id: "user-1", AccountID: "account-1"}
expectedNewID := dex.EncodeDexUserID("user-1", "test-connector")
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{user}, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
require.Len(t, ms.updateCalls, 1)
assert.Equal(t, "account-1", ms.updateCalls[0].AccountID)
assert.Equal(t, "user-1", ms.updateCalls[0].OldUserID)
assert.Equal(t, expectedNewID, ms.updateCalls[0].NewUserID)
})
t.Run("migrates multiple users", func(t *testing.T) {
users := []*types.User{
{Id: "user-1", AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
{Id: "user-3", AccountID: "account-2"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
assert.Len(t, ms.updateCalls, 3)
})
t.Run("returns error when UpdateUserID fails", func(t *testing.T) {
users := []*types.User{
{Id: "user-1", AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
}
callCount := 0
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
callCount++
if callCount == 2 {
return fmt.Errorf("update failed")
}
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to update user ID for user user-2")
})
t.Run("stops on first UpdateUserID error", func(t *testing.T) {
users := []*types.User{
{Id: "user-1", AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return fmt.Errorf("update failed")
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.Error(t, err)
assert.Len(t, ms.updateCalls, 1) // stopped after first error
})
t.Run("skips already migrated users", func(t *testing.T) {
alreadyMigratedID := dex.EncodeDexUserID("user-1", "test-connector")
users := []*types.User{
{Id: alreadyMigratedID, AccountID: "account-1"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
assert.Len(t, ms.updateCalls, 0)
})
t.Run("migrates only non-migrated users in mixed state", func(t *testing.T) {
alreadyMigratedID := dex.EncodeDexUserID("user-1", "test-connector")
users := []*types.User{
{Id: alreadyMigratedID, AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
{Id: "user-3", AccountID: "account-2"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
// Only user-2 and user-3 should be migrated
assert.Len(t, ms.updateCalls, 2)
assert.Equal(t, "user-2", ms.updateCalls[0].OldUserID)
assert.Equal(t, "user-3", ms.updateCalls[1].OldUserID)
})
t.Run("dry run does not call UpdateUserID", func(t *testing.T) {
t.Setenv(dryRunEnvKey, "true")
users := []*types.User{
{Id: "user-1", AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
t.Fatal("UpdateUserID should not be called in dry-run mode")
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
assert.Len(t, ms.updateCalls, 0)
})
t.Run("dry run skips already migrated users", func(t *testing.T) {
t.Setenv(dryRunEnvKey, "true")
alreadyMigratedID := dex.EncodeDexUserID("user-1", "test-connector")
users := []*types.User{
{Id: alreadyMigratedID, AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
t.Fatal("UpdateUserID should not be called in dry-run mode")
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
})
t.Run("dry run disabled by default", func(t *testing.T) {
user := &types.User{Id: "user-1", AccountID: "account-1"}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{user}, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
assert.Len(t, ms.updateCalls, 1) // proves it's not in dry-run
})
}
func TestPopulateUserInfo(t *testing.T) {
noopUpdateID := func(ctx context.Context, accountID, oldUserID, newUserID string) error { return nil }
t.Run("succeeds with no users", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) { return nil, nil },
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
t.Run("returns error when ListUsers fails", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return nil, fmt.Errorf("db error")
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to list users")
})
t.Run("returns error when GetAllAccounts fails", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{{Id: "user-1", AccountID: "acc-1"}}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return nil, fmt.Errorf("idp error")
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to fetch accounts from IDP")
})
t.Run("updates user with missing email and name", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {
{ID: "user-1", Email: "user1@example.com", Name: "User One"},
},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
require.Len(t, ms.updateInfoCalls, 1)
assert.Equal(t, "user-1", ms.updateInfoCalls[0].UserID)
assert.Equal(t, "user1@example.com", ms.updateInfoCalls[0].Email)
assert.Equal(t, "User One", ms.updateInfoCalls[0].Name)
})
t.Run("updates only missing email when name exists", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: "Existing Name"},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "user-1", Email: "user1@example.com", Name: "IDP Name"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
require.Len(t, ms.updateInfoCalls, 1)
assert.Equal(t, "user1@example.com", ms.updateInfoCalls[0].Email)
assert.Equal(t, "Existing Name", ms.updateInfoCalls[0].Name)
})
t.Run("updates only missing name when email exists", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "existing@example.com", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "user-1", Email: "idp@example.com", Name: "IDP Name"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
require.Len(t, ms.updateInfoCalls, 1)
assert.Equal(t, "existing@example.com", ms.updateInfoCalls[0].Email)
assert.Equal(t, "IDP Name", ms.updateInfoCalls[0].Name)
})
t.Run("skips users that already have both email and name", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "user1@example.com", Name: "User One"},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "user-1", Email: "different@example.com", Name: "Different Name"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
t.Run("skips service users", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "svc-1", AccountID: "acc-1", Email: "", Name: "", IsServiceUser: true},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "svc-1", Email: "svc@example.com", Name: "Service"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
t.Run("skips users not found in IDP", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "different-user", Email: "other@example.com", Name: "Other"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
t.Run("looks up dex-encoded user IDs by original ID", func(t *testing.T) {
dexEncodedID := dex.EncodeDexUserID("original-idp-id", "my-connector")
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: dexEncodedID, AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "original-idp-id", Email: "user@example.com", Name: "User"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
require.Len(t, ms.updateInfoCalls, 1)
assert.Equal(t, dexEncodedID, ms.updateInfoCalls[0].UserID)
assert.Equal(t, "user@example.com", ms.updateInfoCalls[0].Email)
assert.Equal(t, "User", ms.updateInfoCalls[0].Name)
})
t.Run("handles multiple users across multiple accounts", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
{Id: "user-2", AccountID: "acc-1", Email: "already@set.com", Name: "Already Set"},
{Id: "user-3", AccountID: "acc-2", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {
{ID: "user-1", Email: "u1@example.com", Name: "User 1"},
{ID: "user-2", Email: "u2@example.com", Name: "User 2"},
},
"acc-2": {
{ID: "user-3", Email: "u3@example.com", Name: "User 3"},
},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
require.Len(t, ms.updateInfoCalls, 2)
assert.Equal(t, "user-1", ms.updateInfoCalls[0].UserID)
assert.Equal(t, "u1@example.com", ms.updateInfoCalls[0].Email)
assert.Equal(t, "user-3", ms.updateInfoCalls[1].UserID)
assert.Equal(t, "u3@example.com", ms.updateInfoCalls[1].Email)
})
t.Run("returns error when UpdateUserInfo fails", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
updateUserInfoFunc: func(ctx context.Context, userID, email, name string) error {
return fmt.Errorf("db write error")
},
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "user-1", Email: "u1@example.com", Name: "User 1"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to update user info for user-1")
})
t.Run("stops on first UpdateUserInfo error", func(t *testing.T) {
callCount := 0
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
{Id: "user-2", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
updateUserInfoFunc: func(ctx context.Context, userID, email, name string) error {
callCount++
return fmt.Errorf("db write error")
},
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {
{ID: "user-1", Email: "u1@example.com", Name: "U1"},
{ID: "user-2", Email: "u2@example.com", Name: "U2"},
},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.Error(t, err)
assert.Equal(t, 1, callCount)
})
t.Run("dry run does not call UpdateUserInfo", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
{Id: "user-2", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
updateUserInfoFunc: func(ctx context.Context, userID, email, name string) error {
t.Fatal("UpdateUserInfo should not be called in dry-run mode")
return nil
},
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {
{ID: "user-1", Email: "u1@example.com", Name: "U1"},
{ID: "user-2", Email: "u2@example.com", Name: "U2"},
},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, true)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
t.Run("skips user when IDP has empty email and name too", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "user-1", Email: "", Name: ""}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
}
func TestSchemaError_String(t *testing.T) {
t.Run("missing table", func(t *testing.T) {
e := SchemaError{Table: "jobs"}
assert.Equal(t, `table "jobs" is missing`, e.String())
})
t.Run("missing column", func(t *testing.T) {
e := SchemaError{Table: "users", Column: "email"}
assert.Equal(t, `column "email" on table "users" is missing`, e.String())
})
}
func TestRequiredSchema(t *testing.T) {
// Verify RequiredSchema covers all the tables touched by UpdateUserID and UpdateUserInfo.
expectedTables := []string{
"users",
"personal_access_tokens",
"peers",
"accounts",
"user_invites",
"proxy_access_tokens",
"jobs",
}
schemaTableNames := make([]string, len(RequiredSchema))
for i, s := range RequiredSchema {
schemaTableNames[i] = s.Table
}
for _, expected := range expectedTables {
assert.Contains(t, schemaTableNames, expected, "RequiredSchema should include table %q", expected)
}
}
func TestCheckSchema_MockStore(t *testing.T) {
t.Run("returns nil when all schema exists", func(t *testing.T) {
ms := &testStore{
checkSchemaFunc: func(checks []SchemaCheck) []SchemaError {
return nil
},
}
errs := ms.CheckSchema(RequiredSchema)
assert.Empty(t, errs)
})
t.Run("returns errors for missing tables", func(t *testing.T) {
ms := &testStore{
checkSchemaFunc: func(checks []SchemaCheck) []SchemaError {
return []SchemaError{
{Table: "jobs"},
{Table: "proxy_access_tokens"},
}
},
}
errs := ms.CheckSchema(RequiredSchema)
require.Len(t, errs, 2)
assert.Equal(t, "jobs", errs[0].Table)
assert.Equal(t, "", errs[0].Column)
assert.Equal(t, "proxy_access_tokens", errs[1].Table)
})
t.Run("returns errors for missing columns", func(t *testing.T) {
ms := &testStore{
checkSchemaFunc: func(checks []SchemaCheck) []SchemaError {
return []SchemaError{
{Table: "users", Column: "email"},
{Table: "users", Column: "name"},
}
},
}
errs := ms.CheckSchema(RequiredSchema)
require.Len(t, errs, 2)
assert.Equal(t, "users", errs[0].Table)
assert.Equal(t, "email", errs[0].Column)
})
}

View File

@@ -0,0 +1,82 @@
package migration
import (
"context"
"fmt"
"github.com/netbirdio/netbird/management/server/types"
)
// SchemaCheck represents a table and the columns required on it.
type SchemaCheck struct {
Table string
Columns []string
}
// RequiredSchema lists all tables and columns that the migration tool needs.
// If any are missing, the user must upgrade their management server first so
// that the automatic GORM migrations create them.
var RequiredSchema = []SchemaCheck{
{Table: "users", Columns: []string{"id", "email", "name", "account_id"}},
{Table: "personal_access_tokens", Columns: []string{"user_id", "created_by"}},
{Table: "peers", Columns: []string{"user_id"}},
{Table: "accounts", Columns: []string{"created_by"}},
{Table: "user_invites", Columns: []string{"created_by"}},
{Table: "proxy_access_tokens", Columns: []string{"created_by"}},
{Table: "jobs", Columns: []string{"triggered_by"}},
}
// SchemaError describes a single missing table or column.
type SchemaError struct {
Table string
Column string // empty when the whole table is missing
}
func (e SchemaError) String() string {
if e.Column == "" {
return fmt.Sprintf("table %q is missing", e.Table)
}
return fmt.Sprintf("column %q on table %q is missing", e.Column, e.Table)
}
// Store defines the data store operations required for IdP user migration.
// This interface is separate from the main store.Store interface because these methods
// are only used during one-time migration and should be removed once migration tooling
// is no longer needed.
//
// The SQL store implementations (SqlStore) already have these methods on their concrete
// types, so they satisfy this interface via Go's structural typing with zero code changes.
type Store interface {
// ListUsers returns all users across all accounts.
ListUsers(ctx context.Context) ([]*types.User, error)
// UpdateUserID atomically updates a user's ID and all foreign key references
// across the database (peers, groups, policies, PATs, etc.).
UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error
// UpdateUserInfo updates a user's email and name in the store.
UpdateUserInfo(ctx context.Context, userID, email, name string) error
// CheckSchema verifies that all tables and columns required by the migration
// exist in the database. Returns a list of problems; an empty slice means OK.
CheckSchema(checks []SchemaCheck) []SchemaError
}
// RequiredEventSchema lists all tables and columns that the migration tool needs
// in the activity/event store.
var RequiredEventSchema = []SchemaCheck{
{Table: "events", Columns: []string{"initiator_id", "target_id"}},
{Table: "deleted_users", Columns: []string{"id"}},
}
// EventStore defines the activity event store operations required for migration.
// Like Store, this is a temporary interface for migration tooling only.
type EventStore interface {
// CheckSchema verifies that all tables and columns required by the migration
// exist in the event database. Returns a list of problems; an empty slice means OK.
CheckSchema(checks []SchemaCheck) []SchemaError
// UpdateUserID updates all event references (initiator_id, target_id) and
// deleted_users records to use the new user ID format.
UpdateUserID(ctx context.Context, oldUserID, newUserID string) error
}

View File

@@ -64,10 +64,19 @@ type Manager interface {
GetVersionInfo(ctx context.Context) (*VersionInfo, error)
}
type instanceStore interface {
GetAccountsCounter(ctx context.Context) (int64, error)
}
type embeddedIdP interface {
CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error)
GetAllAccounts(ctx context.Context) (map[string][]*idp.UserData, error)
}
// DefaultManager is the default implementation of Manager.
type DefaultManager struct {
store store.Store
embeddedIdpManager *idp.EmbeddedIdPManager
store instanceStore
embeddedIdpManager embeddedIdP
setupRequired bool
setupMu sync.RWMutex
@@ -82,18 +91,18 @@ type DefaultManager struct {
// NewManager creates a new instance manager.
// If idpManager is not an EmbeddedIdPManager, setup-related operations will return appropriate defaults.
func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager) (Manager, error) {
embeddedIdp, _ := idpManager.(*idp.EmbeddedIdPManager)
embeddedIdp, ok := idpManager.(*idp.EmbeddedIdPManager)
m := &DefaultManager{
store: store,
embeddedIdpManager: embeddedIdp,
setupRequired: false,
store: store,
setupRequired: false,
httpClient: &http.Client{
Timeout: httpTimeout,
},
}
if embeddedIdp != nil {
if ok && embeddedIdp != nil {
m.embeddedIdpManager = embeddedIdp
err := m.loadSetupRequired(ctx)
if err != nil {
return nil, err
@@ -143,36 +152,61 @@ func (m *DefaultManager) IsSetupRequired(_ context.Context) (bool, error) {
// CreateOwnerUser creates the initial owner user in the embedded IDP.
func (m *DefaultManager) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) {
if err := m.validateSetupInfo(email, password, name); err != nil {
return nil, err
}
if m.embeddedIdpManager == nil {
return nil, errors.New("embedded IDP is not enabled")
}
m.setupMu.RLock()
setupRequired := m.setupRequired
m.setupMu.RUnlock()
if err := m.validateSetupInfo(email, password, name); err != nil {
return nil, err
}
if !setupRequired {
m.setupMu.Lock()
defer m.setupMu.Unlock()
if !m.setupRequired {
return nil, status.Errorf(status.PreconditionFailed, "setup already completed")
}
if err := m.checkSetupRequiredFromDB(ctx); err != nil {
var sErr *status.Error
if errors.As(err, &sErr) && sErr.Type() == status.PreconditionFailed {
m.setupRequired = false
}
return nil, err
}
userData, err := m.embeddedIdpManager.CreateUserWithPassword(ctx, email, password, name)
if err != nil {
return nil, fmt.Errorf("failed to create user in embedded IdP: %w", err)
}
m.setupMu.Lock()
m.setupRequired = false
m.setupMu.Unlock()
log.WithContext(ctx).Infof("created owner user %s in embedded IdP", email)
return userData, nil
}
func (m *DefaultManager) checkSetupRequiredFromDB(ctx context.Context) error {
numAccounts, err := m.store.GetAccountsCounter(ctx)
if err != nil {
return fmt.Errorf("failed to check accounts: %w", err)
}
if numAccounts > 0 {
return status.Errorf(status.PreconditionFailed, "setup already completed")
}
users, err := m.embeddedIdpManager.GetAllAccounts(ctx)
if err != nil {
return fmt.Errorf("failed to check IdP users: %w", err)
}
if len(users) > 0 {
return status.Errorf(status.PreconditionFailed, "setup already completed")
}
return nil
}
func (m *DefaultManager) validateSetupInfo(email, password, name string) error {
if email == "" {
return status.Errorf(status.InvalidArgument, "email is required")
@@ -189,6 +223,9 @@ func (m *DefaultManager) validateSetupInfo(email, password, name string) error {
if len(password) < 8 {
return status.Errorf(status.InvalidArgument, "password must be at least 8 characters")
}
if len(password) > 72 {
return status.Errorf(status.InvalidArgument, "password must be at most 72 characters")
}
return nil
}

View File

@@ -3,7 +3,12 @@ package instance
import (
"context"
"errors"
"fmt"
"net/http"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -11,173 +16,215 @@ import (
"github.com/netbirdio/netbird/management/server/idp"
)
// mockStore implements a minimal store.Store for testing
type mockIdP struct {
mu sync.Mutex
createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error)
users map[string][]*idp.UserData
getAllAccountsErr error
}
func (m *mockIdP) CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error) {
if m.createUserFunc != nil {
return m.createUserFunc(ctx, email, password, name)
}
return &idp.UserData{ID: "test-user-id", Email: email, Name: name}, nil
}
func (m *mockIdP) GetAllAccounts(_ context.Context) (map[string][]*idp.UserData, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.getAllAccountsErr != nil {
return nil, m.getAllAccountsErr
}
return m.users, nil
}
type mockStore struct {
accountsCount int64
err error
}
func (m *mockStore) GetAccountsCounter(ctx context.Context) (int64, error) {
func (m *mockStore) GetAccountsCounter(_ context.Context) (int64, error) {
if m.err != nil {
return 0, m.err
}
return m.accountsCount, nil
}
// mockEmbeddedIdPManager wraps the real EmbeddedIdPManager for testing
type mockEmbeddedIdPManager struct {
createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error)
}
func (m *mockEmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error) {
if m.createUserFunc != nil {
return m.createUserFunc(ctx, email, password, name)
func newTestManager(idpMock *mockIdP, storeMock *mockStore) *DefaultManager {
return &DefaultManager{
store: storeMock,
embeddedIdpManager: idpMock,
setupRequired: true,
httpClient: &http.Client{Timeout: httpTimeout},
}
return &idp.UserData{
ID: "test-user-id",
Email: email,
Name: name,
}, nil
}
// testManager is a test implementation that accepts our mock types
type testManager struct {
store *mockStore
embeddedIdpManager *mockEmbeddedIdPManager
}
func (m *testManager) IsSetupRequired(ctx context.Context) (bool, error) {
if m.embeddedIdpManager == nil {
return false, nil
}
count, err := m.store.GetAccountsCounter(ctx)
if err != nil {
return false, err
}
return count == 0, nil
}
func (m *testManager) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) {
if m.embeddedIdpManager == nil {
return nil, errors.New("embedded IDP is not enabled")
}
return m.embeddedIdpManager.CreateUserWithPassword(ctx, email, password, name)
}
func TestIsSetupRequired_EmbeddedIdPDisabled(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 0},
embeddedIdpManager: nil, // No embedded IDP
}
required, err := manager.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.False(t, required, "setup should not be required when embedded IDP is disabled")
}
func TestIsSetupRequired_NoAccounts(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 0},
embeddedIdpManager: &mockEmbeddedIdPManager{},
}
required, err := manager.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.True(t, required, "setup should be required when no accounts exist")
}
func TestIsSetupRequired_AccountsExist(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 1},
embeddedIdpManager: &mockEmbeddedIdPManager{},
}
required, err := manager.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.False(t, required, "setup should not be required when accounts exist")
}
func TestIsSetupRequired_MultipleAccounts(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 5},
embeddedIdpManager: &mockEmbeddedIdPManager{},
}
required, err := manager.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.False(t, required, "setup should not be required when multiple accounts exist")
}
func TestIsSetupRequired_StoreError(t *testing.T) {
manager := &testManager{
store: &mockStore{err: errors.New("database error")},
embeddedIdpManager: &mockEmbeddedIdPManager{},
}
_, err := manager.IsSetupRequired(context.Background())
assert.Error(t, err, "should return error when store fails")
}
func TestCreateOwnerUser_Success(t *testing.T) {
expectedEmail := "admin@example.com"
expectedName := "Admin User"
expectedPassword := "securepassword123"
idpMock := &mockIdP{}
mgr := newTestManager(idpMock, &mockStore{})
manager := &testManager{
store: &mockStore{accountsCount: 0},
embeddedIdpManager: &mockEmbeddedIdPManager{
createUserFunc: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
assert.Equal(t, expectedEmail, email)
assert.Equal(t, expectedPassword, password)
assert.Equal(t, expectedName, name)
return &idp.UserData{
ID: "created-user-id",
Email: email,
Name: name,
}, nil
},
},
}
userData, err := manager.CreateOwnerUser(context.Background(), expectedEmail, expectedPassword, expectedName)
userData, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.NoError(t, err)
assert.Equal(t, "created-user-id", userData.ID)
assert.Equal(t, expectedEmail, userData.Email)
assert.Equal(t, expectedName, userData.Name)
assert.Equal(t, "admin@example.com", userData.Email)
_, err = mgr.CreateOwnerUser(context.Background(), "admin2@example.com", "password123", "Admin2")
require.Error(t, err)
assert.Contains(t, err.Error(), "setup already completed")
}
func TestCreateOwnerUser_SetupAlreadyCompleted(t *testing.T) {
mgr := newTestManager(&mockIdP{}, &mockStore{})
mgr.setupRequired = false
_, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "setup already completed")
}
func TestCreateOwnerUser_EmbeddedIdPDisabled(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 0},
embeddedIdpManager: nil,
}
mgr := &DefaultManager{setupRequired: true}
_, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
assert.Error(t, err, "should return error when embedded IDP is disabled")
_, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "embedded IDP is not enabled")
}
func TestCreateOwnerUser_IdPError(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 0},
embeddedIdpManager: &mockEmbeddedIdPManager{
createUserFunc: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
return nil, errors.New("user already exists")
},
idpMock := &mockIdP{
createUserFunc: func(_ context.Context, _, _, _ string) (*idp.UserData, error) {
return nil, errors.New("provider error")
},
}
mgr := newTestManager(idpMock, &mockStore{})
_, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
assert.Error(t, err, "should return error when IDP fails")
_, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "provider error")
required, _ := mgr.IsSetupRequired(context.Background())
assert.True(t, required, "setup should still be required after IdP error")
}
func TestCreateOwnerUser_TransientDBError_DoesNotBlockSetup(t *testing.T) {
mgr := newTestManager(&mockIdP{}, &mockStore{err: errors.New("connection refused")})
_, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "connection refused")
required, _ := mgr.IsSetupRequired(context.Background())
assert.True(t, required, "setup should still be required after transient DB error")
mgr.store = &mockStore{}
userData, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.NoError(t, err)
assert.Equal(t, "admin@example.com", userData.Email)
}
func TestCreateOwnerUser_TransientIdPError_DoesNotBlockSetup(t *testing.T) {
idpMock := &mockIdP{getAllAccountsErr: errors.New("connection reset")}
mgr := newTestManager(idpMock, &mockStore{})
_, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "connection reset")
required, _ := mgr.IsSetupRequired(context.Background())
assert.True(t, required, "setup should still be required after transient IdP error")
idpMock.getAllAccountsErr = nil
userData, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.NoError(t, err)
assert.Equal(t, "admin@example.com", userData.Email)
}
func TestCreateOwnerUser_DBCheckBlocksConcurrent(t *testing.T) {
idpMock := &mockIdP{
users: map[string][]*idp.UserData{
"acc1": {{ID: "existing-user"}},
},
}
mgr := newTestManager(idpMock, &mockStore{})
_, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "setup already completed")
}
func TestCreateOwnerUser_DBCheckBlocksWhenAccountsExist(t *testing.T) {
mgr := newTestManager(&mockIdP{}, &mockStore{accountsCount: 1})
_, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "setup already completed")
}
func TestCreateOwnerUser_ConcurrentRequests(t *testing.T) {
var idpCallCount atomic.Int32
var successCount atomic.Int32
var failCount atomic.Int32
idpMock := &mockIdP{
createUserFunc: func(_ context.Context, email, _, _ string) (*idp.UserData, error) {
idpCallCount.Add(1)
time.Sleep(50 * time.Millisecond)
return &idp.UserData{ID: "user-1", Email: email, Name: "Owner"}, nil
},
}
mgr := newTestManager(idpMock, &mockStore{})
var wg sync.WaitGroup
for i := range 10 {
wg.Add(1)
go func(idx int) {
defer wg.Done()
_, err := mgr.CreateOwnerUser(
context.Background(),
fmt.Sprintf("owner%d@example.com", idx),
"password1234",
fmt.Sprintf("Owner%d", idx),
)
if err != nil {
failCount.Add(1)
} else {
successCount.Add(1)
}
}(i)
}
wg.Wait()
assert.Equal(t, int32(1), successCount.Load(), "exactly one concurrent setup request should succeed")
assert.Equal(t, int32(9), failCount.Load(), "remaining concurrent requests should fail")
assert.Equal(t, int32(1), idpCallCount.Load(), "IdP CreateUser should be called exactly once")
}
func TestIsSetupRequired_EmbeddedIdPDisabled(t *testing.T) {
mgr := &DefaultManager{}
required, err := mgr.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.False(t, required)
}
func TestIsSetupRequired_ReturnsFlag(t *testing.T) {
mgr := newTestManager(&mockIdP{}, &mockStore{})
required, err := mgr.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.True(t, required)
mgr.setupMu.Lock()
mgr.setupRequired = false
mgr.setupMu.Unlock()
required, err = mgr.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.False(t, required)
}
func TestDefaultManager_ValidateSetupRequest(t *testing.T) {
manager := &DefaultManager{
setupRequired: true,
}
manager := &DefaultManager{setupRequired: true}
tests := []struct {
name string
@@ -188,11 +235,10 @@ func TestDefaultManager_ValidateSetupRequest(t *testing.T) {
errorMsg string
}{
{
name: "valid request",
email: "admin@example.com",
password: "password123",
userName: "Admin User",
expectError: false,
name: "valid request",
email: "admin@example.com",
password: "password123",
userName: "Admin User",
},
{
name: "empty email",
@@ -235,11 +281,24 @@ func TestDefaultManager_ValidateSetupRequest(t *testing.T) {
errorMsg: "password must be at least 8 characters",
},
{
name: "password exactly 8 characters",
name: "password exactly 8 characters",
email: "admin@example.com",
password: "12345678",
userName: "Admin User",
},
{
name: "password exactly 72 characters",
email: "admin@example.com",
password: "aaaaaaaabbbbbbbbccccccccddddddddeeeeeeeeffffffffgggggggghhhhhhhhiiiiiiii",
userName: "Admin User",
},
{
name: "password too long",
email: "admin@example.com",
password: "12345678",
password: "aaaaaaaabbbbbbbbccccccccddddddddeeeeeeeeffffffffgggggggghhhhhhhhiiiiiiiij",
userName: "Admin User",
expectError: false,
expectError: true,
errorMsg: "password must be at most 72 characters",
},
}
@@ -255,14 +314,3 @@ func TestDefaultManager_ValidateSetupRequest(t *testing.T) {
})
}
}
func TestDefaultManager_CreateOwnerUser_SetupAlreadyCompleted(t *testing.T) {
manager := &DefaultManager{
setupRequired: false,
embeddedIdpManager: &idp.EmbeddedIdPManager{},
}
_, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "setup already completed")
}

View File

@@ -84,7 +84,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
// DeletePostureChecks deletes a posture check by ID.
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}

View File

@@ -0,0 +1,177 @@
package store
// This file contains migration-only methods on SqlStore.
// They satisfy the migration.Store interface via duck typing.
// Delete this file when migration tooling is no longer needed.
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
"github.com/netbirdio/netbird/management/server/idp/migration"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
func (s *SqlStore) CheckSchema(checks []migration.SchemaCheck) []migration.SchemaError {
migrator := s.db.Migrator()
var errs []migration.SchemaError
for _, check := range checks {
if !migrator.HasTable(check.Table) {
errs = append(errs, migration.SchemaError{Table: check.Table})
continue
}
for _, col := range check.Columns {
if !migrator.HasColumn(check.Table, col) {
errs = append(errs, migration.SchemaError{Table: check.Table, Column: col})
}
}
}
return errs
}
func (s *SqlStore) ListUsers(ctx context.Context) ([]*types.User, error) {
tx := s.db
var users []*types.User
result := tx.Find(&users)
if result.Error != nil {
log.WithContext(ctx).Errorf("error when listing users from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue listing users from store")
}
for _, user := range users {
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt user: %w", err)
}
}
return users, nil
}
// txDeferFKConstraints defers foreign key constraint checks for the duration of the transaction.
// MySQL is already handled by s.transaction (SET FOREIGN_KEY_CHECKS = 0).
func (s *SqlStore) txDeferFKConstraints(tx *gorm.DB) error {
if s.storeEngine == types.SqliteStoreEngine {
return tx.Exec("PRAGMA defer_foreign_keys = ON").Error
}
if s.storeEngine != types.PostgresStoreEngine {
return nil
}
// GORM creates FK constraints as NOT DEFERRABLE by default, so
// SET CONSTRAINTS ALL DEFERRED is a no-op unless we ALTER them first.
err := tx.Exec(`
DO $$ DECLARE r RECORD;
BEGIN
FOR r IN SELECT conname, conrelid::regclass AS tbl
FROM pg_constraint WHERE contype = 'f' AND NOT condeferrable
LOOP
EXECUTE format('ALTER TABLE %s ALTER CONSTRAINT %I DEFERRABLE INITIALLY IMMEDIATE', r.tbl, r.conname);
END LOOP;
END $$
`).Error
if err != nil {
return fmt.Errorf("make FK constraints deferrable: %w", err)
}
return tx.Exec("SET CONSTRAINTS ALL DEFERRED").Error
}
// txRestoreFKConstraints reverts FK constraints back to NOT DEFERRABLE after the
// deferred updates are done but before the transaction commits.
func (s *SqlStore) txRestoreFKConstraints(tx *gorm.DB) error {
if s.storeEngine != types.PostgresStoreEngine {
return nil
}
return tx.Exec(`
DO $$ DECLARE r RECORD;
BEGIN
FOR r IN SELECT conname, conrelid::regclass AS tbl
FROM pg_constraint WHERE contype = 'f' AND condeferrable
LOOP
EXECUTE format('ALTER TABLE %s ALTER CONSTRAINT %I NOT DEFERRABLE', r.tbl, r.conname);
END LOOP;
END $$
`).Error
}
func (s *SqlStore) UpdateUserInfo(ctx context.Context, userID, email, name string) error {
user := &types.User{Email: email, Name: name}
if err := user.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt user info: %w", err)
}
result := s.db.Model(&types.User{}).Where("id = ?", userID).Updates(map[string]any{
"email": user.Email,
"name": user.Name,
})
if result.Error != nil {
log.WithContext(ctx).Errorf("error updating user info for %s: %s", userID, result.Error)
return status.Errorf(status.Internal, "failed to update user info")
}
return nil
}
func (s *SqlStore) UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error {
type fkUpdate struct {
model any
column string
where string
}
updates := []fkUpdate{
{&types.PersonalAccessToken{}, "user_id", "user_id = ?"},
{&types.PersonalAccessToken{}, "created_by", "created_by = ?"},
{&nbpeer.Peer{}, "user_id", "user_id = ?"},
{&types.UserInviteRecord{}, "created_by", "created_by = ?"},
{&types.Account{}, "created_by", "created_by = ?"},
{&types.ProxyAccessToken{}, "created_by", "created_by = ?"},
{&types.Job{}, "triggered_by", "triggered_by = ?"},
}
log.Info("Updating user ID in the store")
err := s.transaction(func(tx *gorm.DB) error {
if err := s.txDeferFKConstraints(tx); err != nil {
return err
}
for _, u := range updates {
if err := tx.Model(u.model).Where(u.where, oldUserID).Update(u.column, newUserID).Error; err != nil {
return fmt.Errorf("update %s: %w", u.column, err)
}
}
if err := tx.Model(&types.User{}).Where(accountAndIDQueryCondition, accountID, oldUserID).Update("id", newUserID).Error; err != nil {
return fmt.Errorf("update users: %w", err)
}
return nil
})
if err != nil {
log.WithContext(ctx).Errorf("failed to update user ID in the store: %s", err)
return status.Errorf(status.Internal, "failed to update user ID in store")
}
log.Info("Restoring FK constraints")
err = s.transaction(func(tx *gorm.DB) error {
if err := s.txRestoreFKConstraints(tx); err != nil {
return fmt.Errorf("restore FK constraints: %w", err)
}
return nil
})
if err != nil {
log.WithContext(ctx).Errorf("failed to restore FK constraints after user ID update: %s", err)
return status.Errorf(status.Internal, "failed to restore FK constraints after user ID update")
}
return nil
}

View File

@@ -84,6 +84,12 @@ func setupTestAccount() *Account {
},
},
Groups: map[string]*Group{
"groupAll": {
ID: "groupAll",
Name: "All",
Peers: []string{"peer1", "peer2", "peer3", "peer11", "peer12", "peer21", "peer31", "peer32", "peer41", "peer51", "peer61"},
Issued: GroupIssuedAPI,
},
"group1": {
ID: "group1",
Peers: []string{"peer11", "peer12"},

View File

@@ -417,6 +417,10 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
return nil, err
}
if targetUser.AccountID != accountID {
return nil, status.NewPermissionDeniedError()
}
// @note this is essential to prevent non admin users with Pats create permission frpm creating one for a service user
if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) {
return nil, status.NewAdminPermissionError()
@@ -457,6 +461,10 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
return err
}
if targetUser.AccountID != accountID {
return status.NewPermissionDeniedError()
}
if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) {
return status.NewAdminPermissionError()
}
@@ -496,6 +504,10 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
return nil, err
}
if targetUser.AccountID != accountID {
return nil, status.NewPermissionDeniedError()
}
if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) {
return nil, status.NewAdminPermissionError()
}
@@ -523,6 +535,10 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
return nil, err
}
if targetUser.AccountID != accountID {
return nil, status.NewPermissionDeniedError()
}
if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) {
return nil, status.NewAdminPermissionError()
}

View File

@@ -336,6 +336,104 @@ func TestUser_GetAllPATs(t *testing.T) {
assert.Equal(t, 2, len(pats))
}
func TestUser_PAT_CrossAccountProtection(t *testing.T) {
const (
accountAID = "accountA"
accountBID = "accountB"
userAID = "userA"
adminBID = "adminB"
serviceUserBID = "serviceUserB"
regularUserBID = "regularUserB"
tokenBID = "tokenB1"
hashedTokenB = "SoMeHaShEdToKeNB"
)
setupStore := func(t *testing.T) (*DefaultAccountManager, func()) {
t.Helper()
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
require.NoError(t, err, "creating store")
accountA := newAccountWithId(context.Background(), accountAID, userAID, "", "", "", false)
require.NoError(t, s.SaveAccount(context.Background(), accountA))
accountB := newAccountWithId(context.Background(), accountBID, adminBID, "", "", "", false)
accountB.Users[serviceUserBID] = &types.User{
Id: serviceUserBID,
AccountID: accountBID,
IsServiceUser: true,
ServiceUserName: "svcB",
Role: types.UserRoleAdmin,
PATs: map[string]*types.PersonalAccessToken{
tokenBID: {
ID: tokenBID,
HashedToken: hashedTokenB,
},
},
}
accountB.Users[regularUserBID] = &types.User{
Id: regularUserBID,
AccountID: accountBID,
Role: types.UserRoleUser,
}
require.NoError(t, s.SaveAccount(context.Background(), accountB))
pm := permissions.NewManager(s)
am := &DefaultAccountManager{
Store: s,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: pm,
}
return am, cleanup
}
t.Run("CreatePAT for user in different account is denied", func(t *testing.T) {
am, cleanup := setupStore(t)
t.Cleanup(cleanup)
_, err := am.CreatePAT(context.Background(), accountAID, userAID, serviceUserBID, "xss-token", 7)
require.Error(t, err, "cross-account CreatePAT must fail")
_, err = am.CreatePAT(context.Background(), accountAID, userAID, regularUserBID, "xss-token", 7)
require.Error(t, err, "cross-account CreatePAT for regular user must fail")
_, err = am.CreatePAT(context.Background(), accountBID, adminBID, serviceUserBID, "legit-token", 7)
require.NoError(t, err, "same-account CreatePAT should succeed")
})
t.Run("DeletePAT for user in different account is denied", func(t *testing.T) {
am, cleanup := setupStore(t)
t.Cleanup(cleanup)
err := am.DeletePAT(context.Background(), accountAID, userAID, serviceUserBID, tokenBID)
require.Error(t, err, "cross-account DeletePAT must fail")
})
t.Run("GetPAT for user in different account is denied", func(t *testing.T) {
am, cleanup := setupStore(t)
t.Cleanup(cleanup)
_, err := am.GetPAT(context.Background(), accountAID, userAID, serviceUserBID, tokenBID)
require.Error(t, err, "cross-account GetPAT must fail")
})
t.Run("GetAllPATs for user in different account is denied", func(t *testing.T) {
am, cleanup := setupStore(t)
t.Cleanup(cleanup)
_, err := am.GetAllPATs(context.Background(), accountAID, userAID, serviceUserBID)
require.Error(t, err, "cross-account GetAllPATs must fail")
})
t.Run("CreatePAT with forged accountID targeting foreign user is denied", func(t *testing.T) {
am, cleanup := setupStore(t)
t.Cleanup(cleanup)
_, err := am.CreatePAT(context.Background(), accountAID, userAID, adminBID, "forged", 7)
require.Error(t, err, "forged accountID CreatePAT must fail")
})
}
func TestUser_Copy(t *testing.T) {
// this is an imaginary case which will never be in DB this way
user := types.User{