Merge branch 'main' into refactor/permissions-manager

This commit is contained in:
pascal
2026-04-07 17:35:39 +02:00
112 changed files with 10718 additions and 2106 deletions

View File

@@ -3139,7 +3139,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
if err != nil {
return nil, nil, err
}
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, proxyController, nil))
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, proxyController, proxyManager, nil))
return manager, updateManager, nil
}

View File

@@ -0,0 +1,61 @@
package store
// This file contains migration-only methods on Store.
// They satisfy the migration.MigrationEventStore interface via duck typing.
// Delete this file when migration tooling is no longer needed.
import (
"context"
"fmt"
"gorm.io/gorm"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp/migration"
)
// CheckSchema verifies that all tables and columns required by the migration exist in the event database.
func (store *Store) CheckSchema(checks []migration.SchemaCheck) []migration.SchemaError {
migrator := store.db.Migrator()
var errs []migration.SchemaError
for _, check := range checks {
if !migrator.HasTable(check.Table) {
errs = append(errs, migration.SchemaError{Table: check.Table})
continue
}
for _, col := range check.Columns {
if !migrator.HasColumn(check.Table, col) {
errs = append(errs, migration.SchemaError{Table: check.Table, Column: col})
}
}
}
return errs
}
// UpdateUserID updates all references to oldUserID in events and deleted_users tables.
func (store *Store) UpdateUserID(ctx context.Context, oldUserID, newUserID string) error {
return store.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Model(&activity.Event{}).
Where("initiator_id = ?", oldUserID).
Update("initiator_id", newUserID).Error; err != nil {
return fmt.Errorf("update events.initiator_id: %w", err)
}
if err := tx.Model(&activity.Event{}).
Where("target_id = ?", oldUserID).
Update("target_id", newUserID).Error; err != nil {
return fmt.Errorf("update events.target_id: %w", err)
}
// Raw exec: GORM can't update a PK via Model().Update()
if err := tx.Exec(
"UPDATE deleted_users SET id = ? WHERE id = ?", newUserID, oldUserID,
).Error; err != nil {
return fmt.Errorf("update deleted_users.id: %w", err)
}
return nil
})
}

View File

@@ -0,0 +1,161 @@
package store
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/util/crypt"
)
func TestUpdateUserID(t *testing.T) {
ctx := context.Background()
newStore := func(t *testing.T) *Store {
t.Helper()
key, _ := crypt.GenerateKey()
s, err := NewSqlStore(ctx, t.TempDir(), key)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { s.Close(ctx) }) //nolint
return s
}
t.Run("updates initiator_id in events", func(t *testing.T) {
store := newStore(t)
accountID := "account_1"
_, err := store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "old-user",
TargetID: "some-peer",
AccountID: accountID,
})
assert.NoError(t, err)
err = store.UpdateUserID(ctx, "old-user", "new-user")
assert.NoError(t, err)
result, err := store.Get(ctx, accountID, 0, 10, false)
assert.NoError(t, err)
assert.Len(t, result, 1)
assert.Equal(t, "new-user", result[0].InitiatorID)
})
t.Run("updates target_id in events", func(t *testing.T) {
store := newStore(t)
accountID := "account_1"
_, err := store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "some-admin",
TargetID: "old-user",
AccountID: accountID,
})
assert.NoError(t, err)
err = store.UpdateUserID(ctx, "old-user", "new-user")
assert.NoError(t, err)
result, err := store.Get(ctx, accountID, 0, 10, false)
assert.NoError(t, err)
assert.Len(t, result, 1)
assert.Equal(t, "new-user", result[0].TargetID)
})
t.Run("updates deleted_users id", func(t *testing.T) {
store := newStore(t)
accountID := "account_1"
// Save an event with email/name meta to create a deleted_users row for "old-user"
_, err := store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "admin",
TargetID: "old-user",
AccountID: accountID,
Meta: map[string]any{
"email": "user@example.com",
"name": "Test User",
},
})
assert.NoError(t, err)
err = store.UpdateUserID(ctx, "old-user", "new-user")
assert.NoError(t, err)
// Save another event referencing new-user with email/name meta.
// This should upsert (not conflict) because the PK was already migrated.
_, err = store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "admin",
TargetID: "new-user",
AccountID: accountID,
Meta: map[string]any{
"email": "user@example.com",
"name": "Test User",
},
})
assert.NoError(t, err)
// The deleted user info should be retrievable via Get (joined on target_id)
result, err := store.Get(ctx, accountID, 0, 10, false)
assert.NoError(t, err)
assert.Len(t, result, 2)
for _, ev := range result {
assert.Equal(t, "new-user", ev.TargetID)
}
})
t.Run("no-op when old user ID does not exist", func(t *testing.T) {
store := newStore(t)
err := store.UpdateUserID(ctx, "nonexistent-user", "new-user")
assert.NoError(t, err)
})
t.Run("only updates matching user leaves others unchanged", func(t *testing.T) {
store := newStore(t)
accountID := "account_1"
_, err := store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "user-a",
TargetID: "peer-1",
AccountID: accountID,
})
assert.NoError(t, err)
_, err = store.Save(ctx, &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "user-b",
TargetID: "peer-2",
AccountID: accountID,
})
assert.NoError(t, err)
err = store.UpdateUserID(ctx, "user-a", "user-a-new")
assert.NoError(t, err)
result, err := store.Get(ctx, accountID, 0, 10, false)
assert.NoError(t, err)
assert.Len(t, result, 2)
for _, ev := range result {
if ev.TargetID == "peer-1" {
assert.Equal(t, "user-a-new", ev.InitiatorID)
} else {
assert.Equal(t, "user-b", ev.InitiatorID)
}
}
})
}

View File

@@ -33,15 +33,20 @@ type manager struct {
extractor *nbjwt.ClaimsExtractor
}
func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool) Manager {
// @note if invalid/missing parameters are sent the validator will instantiate
// but it will fail when validating and parsing the token
jwtValidator := nbjwt.NewValidator(
issuer,
allAudiences,
keysLocation,
idpRefreshKeys,
)
func NewManager(store store.Store, issuer, audience, keysLocation, userIdClaim string, allAudiences []string, idpRefreshKeys bool, keyFetcher nbjwt.KeyFetcher) Manager {
var jwtValidator *nbjwt.Validator
if keyFetcher != nil {
jwtValidator = nbjwt.NewValidatorWithKeyFetcher(issuer, allAudiences, keyFetcher)
} else {
// @note if invalid/missing parameters are sent the validator will instantiate
// but it will fail when validating and parsing the token
jwtValidator = nbjwt.NewValidator(
issuer,
allAudiences,
keysLocation,
idpRefreshKeys,
)
}
claimsExtractor := nbjwt.NewClaimsExtractor(
nbjwt.WithAudience(audience),

View File

@@ -52,7 +52,7 @@ func TestAuthManager_GetAccountInfoFromPAT(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
user, pat, _, _, err := manager.GetPATInfo(context.Background(), token)
if err != nil {
@@ -92,7 +92,7 @@ func TestAuthManager_MarkPATUsed(t *testing.T) {
t.Fatalf("Error when saving account: %s", err)
}
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
err = manager.MarkPATUsed(context.Background(), "tokenId")
if err != nil {
@@ -142,7 +142,7 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) {
// these tests only assert groups are parsed from token as per account settings
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}})
manager := auth.NewManager(store, "", "", "", "", []string{}, false)
manager := auth.NewManager(store, "", "", "", "", []string{}, false, nil)
t.Run("JWT groups disabled", func(t *testing.T) {
userAuth, err := manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
@@ -225,7 +225,7 @@ func TestAuthManager_ValidateAndParseToken(t *testing.T) {
keyId := "test-key"
// note, we can use a nil store because ValidateAndParseToken does not use it in it's flow
manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false)
manager := auth.NewManager(nil, issuer, audience, server.URL, userIdClaim, []string{audience}, false, nil)
customClaim := func(name string) string {
return fmt.Sprintf("%s/%s", audience, name)

View File

@@ -130,6 +130,10 @@ func (gl *geolocationImpl) Lookup(ip net.IP) (*Record, error) {
gl.mux.RLock()
defer gl.mux.RUnlock()
if gl.db == nil {
return nil, fmt.Errorf("geolocation database is not available")
}
var record Record
err := gl.db.Lookup(ip, &record)
if err != nil {
@@ -173,8 +177,14 @@ func (gl *geolocationImpl) GetCitiesByCountry(countryISOCode string) ([]City, er
func (gl *geolocationImpl) Stop() error {
close(gl.stopCh)
if gl.db != nil {
if err := gl.db.Close(); err != nil {
gl.mux.Lock()
db := gl.db
gl.db = nil
gl.mux.Unlock()
if db != nil {
if err := db.Close(); err != nil {
return err
}
}

View File

@@ -19,7 +19,6 @@ import (
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
@@ -764,11 +763,10 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
// Saving a group linked to network router should update account peers and send peer update
t.Run("saving group linked to network router", func(t *testing.T) {
permissionsManager := permissions.NewManager(manager.Store)
groupsManager := groups.NewManager(manager.Store, manager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
resourcesManager := resources.NewManager(manager.Store, groupsManager, manager, manager.serviceManager)
routersManager := routers.NewManager(manager.Store, manager)
networksManager := networks.NewManager(manager.Store, resourcesManager, routersManager, manager)
network, err := networksManager.CreateNetwork(context.Background(), userID, &networkTypes.Network{
ID: "network_test",

View File

@@ -113,13 +113,12 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
if err != nil {
t.Fatalf("Failed to create proxy controller: %v", err)
}
domainManager.SetClusterCapabilities(serviceProxyController)
serviceManager := reverseproxymanager.NewManager(store, am, serviceProxyController, domainManager)
serviceManager := reverseproxymanager.NewManager(store, am, serviceProxyController, proxyMgr, domainManager)
proxyServiceServer.SetServiceManager(serviceManager)
am.SetServiceManager(serviceManager)
// @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false, nil)
authManagerMock := &serverauth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
@@ -243,13 +242,12 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
if err != nil {
t.Fatalf("Failed to create proxy controller: %v", err)
}
domainManager.SetClusterCapabilities(serviceProxyController)
serviceManager := reverseproxymanager.NewManager(store, am, serviceProxyController, domainManager)
serviceManager := reverseproxymanager.NewManager(store, am, serviceProxyController, proxyMgr, domainManager)
proxyServiceServer.SetServiceManager(serviceManager)
am.SetServiceManager(serviceManager)
// @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false, nil)
authManagerMock := &serverauth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/telemetry"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
const (
@@ -48,6 +49,8 @@ type EmbeddedIdPConfig struct {
// Existing local users are preserved and will be able to login again if re-enabled.
// Cannot be enabled if no external identity provider connectors are configured.
LocalAuthDisabled bool
// StaticConnectors are additional connectors to seed during initialization
StaticConnectors []dex.Connector
}
// EmbeddedStorageConfig holds storage configuration for the embedded IdP.
@@ -157,6 +160,7 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
RedirectURIs: cliRedirectURIs,
},
},
StaticConnectors: c.StaticConnectors,
}
// Add owner user if provided
@@ -193,6 +197,9 @@ type OAuthConfigProvider interface {
// Management server has embedded Dex and can validate tokens via localhost,
// avoiding external network calls and DNS resolution issues during startup.
GetLocalKeysLocation() string
// GetKeyFetcher returns a KeyFetcher that reads keys directly from the IDP storage,
// or nil if direct key fetching is not supported (falls back to HTTP).
GetKeyFetcher() nbjwt.KeyFetcher
GetClientIDs() []string
GetUserIDClaim() string
GetTokenEndpoint() string
@@ -593,6 +600,11 @@ func (m *EmbeddedIdPManager) GetCLIRedirectURLs() []string {
return m.config.CLIRedirectURIs
}
// GetKeyFetcher returns a KeyFetcher that reads keys directly from Dex storage.
func (m *EmbeddedIdPManager) GetKeyFetcher() nbjwt.KeyFetcher {
return m.provider.GetJWKS
}
// GetKeysLocation returns the JWKS endpoint URL for token validation.
func (m *EmbeddedIdPManager) GetKeysLocation() string {
return m.provider.GetKeysLocation()

View File

@@ -0,0 +1,235 @@
// Package migration provides utility functions for migrating from the external IdP solution in pre v0.62.0
// to the new embedded IdP manager (Dex based), which is the default in v0.62.0 and later.
// It includes functions to seed connectors and migrate existing users to use these connectors.
package migration
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"os"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/types"
)
// Server is the dependency interface that migration functions use to access
// the main data store and the activity event store.
type Server interface {
Store() Store
EventStore() EventStore // may return nil
}
const idpSeedInfoKey = "IDP_SEED_INFO"
const dryRunEnvKey = "NB_IDP_MIGRATION_DRY_RUN"
func isDryRun() bool {
return os.Getenv(dryRunEnvKey) == "true"
}
var ErrNoSeedInfo = errors.New("no seed info found in environment")
// SeedConnectorFromEnv reads the IDP_SEED_INFO env var, base64-decodes it,
// and JSON-unmarshals it into a dex.Connector. Returns nil if not set.
func SeedConnectorFromEnv() (*dex.Connector, error) {
val, ok := os.LookupEnv(idpSeedInfoKey)
if !ok || val == "" {
return nil, ErrNoSeedInfo
}
decoded, err := base64.StdEncoding.DecodeString(val)
if err != nil {
return nil, fmt.Errorf("base64 decode: %w", err)
}
var conn dex.Connector
if err := json.Unmarshal(decoded, &conn); err != nil {
return nil, fmt.Errorf("json unmarshal: %w", err)
}
return &conn, nil
}
// MigrateUsersToStaticConnectors re-keys every user ID in the main store (and
// the activity store, if present) so that it encodes the given connector ID,
// skipping users that have already been migrated. Set NB_IDP_MIGRATION_DRY_RUN=true
// to log what would happen without writing any changes.
func MigrateUsersToStaticConnectors(s Server, conn *dex.Connector) error {
ctx := context.Background()
if isDryRun() {
log.Info("[DRY RUN] migration dry-run mode enabled, no changes will be written")
}
users, err := s.Store().ListUsers(ctx)
if err != nil {
return fmt.Errorf("failed to list users: %w", err)
}
// Reconciliation pass: fix activity store for users already migrated in main DB
// but whose activity references may still use old IDs (from a previous partial failure).
if s.EventStore() != nil && !isDryRun() {
if err := reconcileActivityStore(ctx, s.EventStore(), users); err != nil {
return err
}
}
var migratedCount, skippedCount int
for _, user := range users {
_, _, decErr := dex.DecodeDexUserID(user.Id)
if decErr == nil {
skippedCount++
continue
}
newUserID := dex.EncodeDexUserID(user.Id, conn.ID)
if isDryRun() {
log.Infof("[DRY RUN] would migrate user %s -> %s (account: %s)", user.Id, newUserID, user.AccountID)
migratedCount++
continue
}
if err := migrateUser(ctx, s, user.Id, user.AccountID, newUserID); err != nil {
return err
}
migratedCount++
}
if isDryRun() {
log.Infof("[DRY RUN] migration summary: %d users would be migrated, %d already migrated", migratedCount, skippedCount)
} else {
log.Infof("migration complete: %d users migrated, %d already migrated", migratedCount, skippedCount)
}
return nil
}
// reconcileActivityStore updates activity store references for users already migrated
// in the main DB whose activity entries may still use old IDs from a previous partial failure.
func reconcileActivityStore(ctx context.Context, eventStore EventStore, users []*types.User) error {
for _, user := range users {
originalID, _, err := dex.DecodeDexUserID(user.Id)
if err != nil {
// skip users that aren't migrated, they will be handled in the main migration loop
continue
}
if err := eventStore.UpdateUserID(ctx, originalID, user.Id); err != nil {
return fmt.Errorf("reconcile activity store for user %s: %w", user.Id, err)
}
}
return nil
}
// migrateUser updates a single user's ID in both the main store and the activity store.
func migrateUser(ctx context.Context, s Server, oldID, accountID, newID string) error {
if err := s.Store().UpdateUserID(ctx, accountID, oldID, newID); err != nil {
return fmt.Errorf("failed to update user ID for user %s: %w", oldID, err)
}
if s.EventStore() == nil {
return nil
}
if err := s.EventStore().UpdateUserID(ctx, oldID, newID); err != nil {
return fmt.Errorf("failed to update activity store user ID for user %s: %w", oldID, err)
}
return nil
}
// PopulateUserInfo fetches user email and name from the external IDP and updates
// the store for users that are missing this information.
func PopulateUserInfo(s Server, idpManager idp.Manager, dryRun bool) error {
ctx := context.Background()
users, err := s.Store().ListUsers(ctx)
if err != nil {
return fmt.Errorf("failed to list users: %w", err)
}
// Build a map of IDP user ID -> UserData from the external IDP
allAccounts, err := idpManager.GetAllAccounts(ctx)
if err != nil {
return fmt.Errorf("failed to fetch accounts from IDP: %w", err)
}
idpUsers := make(map[string]*idp.UserData)
for _, accountUsers := range allAccounts {
for _, userData := range accountUsers {
idpUsers[userData.ID] = userData
}
}
log.Infof("fetched %d users from IDP", len(idpUsers))
var updatedCount, skippedCount, notFoundCount int
for _, user := range users {
if user.IsServiceUser {
skippedCount++
continue
}
if user.Email != "" && user.Name != "" {
skippedCount++
continue
}
// The user ID in the store may be the original IDP ID or a Dex-encoded ID.
// Try to decode the Dex format first to get the original IDP ID.
lookupID := user.Id
if originalID, _, decErr := dex.DecodeDexUserID(user.Id); decErr == nil {
lookupID = originalID
}
idpUser, found := idpUsers[lookupID]
if !found {
notFoundCount++
log.Debugf("user %s (lookup: %s) not found in IDP, skipping", user.Id, lookupID)
continue
}
email := user.Email
name := user.Name
if email == "" && idpUser.Email != "" {
email = idpUser.Email
}
if name == "" && idpUser.Name != "" {
name = idpUser.Name
}
if email == user.Email && name == user.Name {
skippedCount++
continue
}
if dryRun {
log.Infof("[DRY RUN] would update user %s: email=%q, name=%q", user.Id, email, name)
updatedCount++
continue
}
if err := s.Store().UpdateUserInfo(ctx, user.Id, email, name); err != nil {
return fmt.Errorf("failed to update user info for %s: %w", user.Id, err)
}
log.Infof("updated user %s: email=%q, name=%q", user.Id, email, name)
updatedCount++
}
if dryRun {
log.Infof("[DRY RUN] user info summary: %d would be updated, %d skipped, %d not found in IDP", updatedCount, skippedCount, notFoundCount)
} else {
log.Infof("user info population complete: %d updated, %d skipped, %d not found in IDP", updatedCount, skippedCount, notFoundCount)
}
return nil
}

View File

@@ -0,0 +1,828 @@
package migration
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/types"
)
// testStore is a hand-written mock for MigrationStore.
type testStore struct {
listUsersFunc func(ctx context.Context) ([]*types.User, error)
updateUserIDFunc func(ctx context.Context, accountID, oldUserID, newUserID string) error
updateUserInfoFunc func(ctx context.Context, userID, email, name string) error
checkSchemaFunc func(checks []SchemaCheck) []SchemaError
updateCalls []updateUserIDCall
updateInfoCalls []updateUserInfoCall
}
type updateUserIDCall struct {
AccountID string
OldUserID string
NewUserID string
}
type updateUserInfoCall struct {
UserID string
Email string
Name string
}
func (s *testStore) ListUsers(ctx context.Context) ([]*types.User, error) {
return s.listUsersFunc(ctx)
}
func (s *testStore) UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error {
s.updateCalls = append(s.updateCalls, updateUserIDCall{accountID, oldUserID, newUserID})
return s.updateUserIDFunc(ctx, accountID, oldUserID, newUserID)
}
func (s *testStore) UpdateUserInfo(ctx context.Context, userID, email, name string) error {
s.updateInfoCalls = append(s.updateInfoCalls, updateUserInfoCall{userID, email, name})
if s.updateUserInfoFunc != nil {
return s.updateUserInfoFunc(ctx, userID, email, name)
}
return nil
}
func (s *testStore) CheckSchema(checks []SchemaCheck) []SchemaError {
if s.checkSchemaFunc != nil {
return s.checkSchemaFunc(checks)
}
return nil
}
type testServer struct {
store Store
eventStore EventStore
}
func (s *testServer) Store() Store { return s.store }
func (s *testServer) EventStore() EventStore { return s.eventStore }
func TestSeedConnectorFromEnv(t *testing.T) {
t.Run("returns ErrNoSeedInfo when env var is not set", func(t *testing.T) {
os.Unsetenv(idpSeedInfoKey)
conn, err := SeedConnectorFromEnv()
assert.ErrorIs(t, err, ErrNoSeedInfo)
assert.Nil(t, conn)
})
t.Run("returns ErrNoSeedInfo when env var is empty", func(t *testing.T) {
t.Setenv(idpSeedInfoKey, "")
conn, err := SeedConnectorFromEnv()
assert.ErrorIs(t, err, ErrNoSeedInfo)
assert.Nil(t, conn)
})
t.Run("returns error on invalid base64", func(t *testing.T) {
t.Setenv(idpSeedInfoKey, "not-valid-base64!!!")
conn, err := SeedConnectorFromEnv()
assert.NotErrorIs(t, err, ErrNoSeedInfo)
assert.Error(t, err)
assert.Nil(t, conn)
assert.Contains(t, err.Error(), "base64 decode")
})
t.Run("returns error on invalid JSON", func(t *testing.T) {
encoded := base64.StdEncoding.EncodeToString([]byte("not json"))
t.Setenv(idpSeedInfoKey, encoded)
conn, err := SeedConnectorFromEnv()
assert.NotErrorIs(t, err, ErrNoSeedInfo)
assert.Error(t, err)
assert.Nil(t, conn)
assert.Contains(t, err.Error(), "json unmarshal")
})
t.Run("successfully decodes valid connector", func(t *testing.T) {
expected := dex.Connector{
Type: "oidc",
Name: "Test Provider",
ID: "test-provider",
Config: map[string]any{
"issuer": "https://example.com",
"clientID": "my-client-id",
"clientSecret": "my-secret",
},
}
data, err := json.Marshal(expected)
require.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
t.Setenv(idpSeedInfoKey, encoded)
conn, err := SeedConnectorFromEnv()
assert.NoError(t, err)
require.NotNil(t, conn)
assert.Equal(t, expected.Type, conn.Type)
assert.Equal(t, expected.Name, conn.Name)
assert.Equal(t, expected.ID, conn.ID)
assert.Equal(t, expected.Config["issuer"], conn.Config["issuer"])
})
}
func TestMigrateUsersToStaticConnectors(t *testing.T) {
connector := &dex.Connector{
Type: "oidc",
Name: "Test Provider",
ID: "test-connector",
}
t.Run("succeeds with no users", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) { return nil, nil },
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { return nil },
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
})
t.Run("returns error when ListUsers fails", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return nil, fmt.Errorf("db error")
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error { return nil },
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to list users")
})
t.Run("migrates single user with correct encoded ID", func(t *testing.T) {
user := &types.User{Id: "user-1", AccountID: "account-1"}
expectedNewID := dex.EncodeDexUserID("user-1", "test-connector")
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{user}, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
require.Len(t, ms.updateCalls, 1)
assert.Equal(t, "account-1", ms.updateCalls[0].AccountID)
assert.Equal(t, "user-1", ms.updateCalls[0].OldUserID)
assert.Equal(t, expectedNewID, ms.updateCalls[0].NewUserID)
})
t.Run("migrates multiple users", func(t *testing.T) {
users := []*types.User{
{Id: "user-1", AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
{Id: "user-3", AccountID: "account-2"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
assert.Len(t, ms.updateCalls, 3)
})
t.Run("returns error when UpdateUserID fails", func(t *testing.T) {
users := []*types.User{
{Id: "user-1", AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
}
callCount := 0
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
callCount++
if callCount == 2 {
return fmt.Errorf("update failed")
}
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to update user ID for user user-2")
})
t.Run("stops on first UpdateUserID error", func(t *testing.T) {
users := []*types.User{
{Id: "user-1", AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return fmt.Errorf("update failed")
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.Error(t, err)
assert.Len(t, ms.updateCalls, 1) // stopped after first error
})
t.Run("skips already migrated users", func(t *testing.T) {
alreadyMigratedID := dex.EncodeDexUserID("user-1", "test-connector")
users := []*types.User{
{Id: alreadyMigratedID, AccountID: "account-1"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
assert.Len(t, ms.updateCalls, 0)
})
t.Run("migrates only non-migrated users in mixed state", func(t *testing.T) {
alreadyMigratedID := dex.EncodeDexUserID("user-1", "test-connector")
users := []*types.User{
{Id: alreadyMigratedID, AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
{Id: "user-3", AccountID: "account-2"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
// Only user-2 and user-3 should be migrated
assert.Len(t, ms.updateCalls, 2)
assert.Equal(t, "user-2", ms.updateCalls[0].OldUserID)
assert.Equal(t, "user-3", ms.updateCalls[1].OldUserID)
})
t.Run("dry run does not call UpdateUserID", func(t *testing.T) {
t.Setenv(dryRunEnvKey, "true")
users := []*types.User{
{Id: "user-1", AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
t.Fatal("UpdateUserID should not be called in dry-run mode")
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
assert.Len(t, ms.updateCalls, 0)
})
t.Run("dry run skips already migrated users", func(t *testing.T) {
t.Setenv(dryRunEnvKey, "true")
alreadyMigratedID := dex.EncodeDexUserID("user-1", "test-connector")
users := []*types.User{
{Id: alreadyMigratedID, AccountID: "account-1"},
{Id: "user-2", AccountID: "account-1"},
}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return users, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
t.Fatal("UpdateUserID should not be called in dry-run mode")
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
})
t.Run("dry run disabled by default", func(t *testing.T) {
user := &types.User{Id: "user-1", AccountID: "account-1"}
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{user}, nil
},
updateUserIDFunc: func(ctx context.Context, accountID, oldUserID, newUserID string) error {
return nil
},
}
srv := &testServer{store: ms}
err := MigrateUsersToStaticConnectors(srv, connector)
assert.NoError(t, err)
assert.Len(t, ms.updateCalls, 1) // proves it's not in dry-run
})
}
func TestPopulateUserInfo(t *testing.T) {
noopUpdateID := func(ctx context.Context, accountID, oldUserID, newUserID string) error { return nil }
t.Run("succeeds with no users", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) { return nil, nil },
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
t.Run("returns error when ListUsers fails", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return nil, fmt.Errorf("db error")
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to list users")
})
t.Run("returns error when GetAllAccounts fails", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{{Id: "user-1", AccountID: "acc-1"}}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return nil, fmt.Errorf("idp error")
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to fetch accounts from IDP")
})
t.Run("updates user with missing email and name", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {
{ID: "user-1", Email: "user1@example.com", Name: "User One"},
},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
require.Len(t, ms.updateInfoCalls, 1)
assert.Equal(t, "user-1", ms.updateInfoCalls[0].UserID)
assert.Equal(t, "user1@example.com", ms.updateInfoCalls[0].Email)
assert.Equal(t, "User One", ms.updateInfoCalls[0].Name)
})
t.Run("updates only missing email when name exists", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: "Existing Name"},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "user-1", Email: "user1@example.com", Name: "IDP Name"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
require.Len(t, ms.updateInfoCalls, 1)
assert.Equal(t, "user1@example.com", ms.updateInfoCalls[0].Email)
assert.Equal(t, "Existing Name", ms.updateInfoCalls[0].Name)
})
t.Run("updates only missing name when email exists", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "existing@example.com", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "user-1", Email: "idp@example.com", Name: "IDP Name"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
require.Len(t, ms.updateInfoCalls, 1)
assert.Equal(t, "existing@example.com", ms.updateInfoCalls[0].Email)
assert.Equal(t, "IDP Name", ms.updateInfoCalls[0].Name)
})
t.Run("skips users that already have both email and name", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "user1@example.com", Name: "User One"},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "user-1", Email: "different@example.com", Name: "Different Name"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
t.Run("skips service users", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "svc-1", AccountID: "acc-1", Email: "", Name: "", IsServiceUser: true},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "svc-1", Email: "svc@example.com", Name: "Service"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
t.Run("skips users not found in IDP", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "different-user", Email: "other@example.com", Name: "Other"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
t.Run("looks up dex-encoded user IDs by original ID", func(t *testing.T) {
dexEncodedID := dex.EncodeDexUserID("original-idp-id", "my-connector")
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: dexEncodedID, AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "original-idp-id", Email: "user@example.com", Name: "User"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
require.Len(t, ms.updateInfoCalls, 1)
assert.Equal(t, dexEncodedID, ms.updateInfoCalls[0].UserID)
assert.Equal(t, "user@example.com", ms.updateInfoCalls[0].Email)
assert.Equal(t, "User", ms.updateInfoCalls[0].Name)
})
t.Run("handles multiple users across multiple accounts", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
{Id: "user-2", AccountID: "acc-1", Email: "already@set.com", Name: "Already Set"},
{Id: "user-3", AccountID: "acc-2", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {
{ID: "user-1", Email: "u1@example.com", Name: "User 1"},
{ID: "user-2", Email: "u2@example.com", Name: "User 2"},
},
"acc-2": {
{ID: "user-3", Email: "u3@example.com", Name: "User 3"},
},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
require.Len(t, ms.updateInfoCalls, 2)
assert.Equal(t, "user-1", ms.updateInfoCalls[0].UserID)
assert.Equal(t, "u1@example.com", ms.updateInfoCalls[0].Email)
assert.Equal(t, "user-3", ms.updateInfoCalls[1].UserID)
assert.Equal(t, "u3@example.com", ms.updateInfoCalls[1].Email)
})
t.Run("returns error when UpdateUserInfo fails", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
updateUserInfoFunc: func(ctx context.Context, userID, email, name string) error {
return fmt.Errorf("db write error")
},
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "user-1", Email: "u1@example.com", Name: "User 1"}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to update user info for user-1")
})
t.Run("stops on first UpdateUserInfo error", func(t *testing.T) {
callCount := 0
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
{Id: "user-2", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
updateUserInfoFunc: func(ctx context.Context, userID, email, name string) error {
callCount++
return fmt.Errorf("db write error")
},
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {
{ID: "user-1", Email: "u1@example.com", Name: "U1"},
{ID: "user-2", Email: "u2@example.com", Name: "U2"},
},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.Error(t, err)
assert.Equal(t, 1, callCount)
})
t.Run("dry run does not call UpdateUserInfo", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
{Id: "user-2", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
updateUserInfoFunc: func(ctx context.Context, userID, email, name string) error {
t.Fatal("UpdateUserInfo should not be called in dry-run mode")
return nil
},
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {
{ID: "user-1", Email: "u1@example.com", Name: "U1"},
{ID: "user-2", Email: "u2@example.com", Name: "U2"},
},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, true)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
t.Run("skips user when IDP has empty email and name too", func(t *testing.T) {
ms := &testStore{
listUsersFunc: func(ctx context.Context) ([]*types.User, error) {
return []*types.User{
{Id: "user-1", AccountID: "acc-1", Email: "", Name: ""},
}, nil
},
updateUserIDFunc: noopUpdateID,
}
mockIDP := &idp.MockIDP{
GetAllAccountsFunc: func(ctx context.Context) (map[string][]*idp.UserData, error) {
return map[string][]*idp.UserData{
"acc-1": {{ID: "user-1", Email: "", Name: ""}},
}, nil
},
}
srv := &testServer{store: ms}
err := PopulateUserInfo(srv, mockIDP, false)
assert.NoError(t, err)
assert.Empty(t, ms.updateInfoCalls)
})
}
func TestSchemaError_String(t *testing.T) {
t.Run("missing table", func(t *testing.T) {
e := SchemaError{Table: "jobs"}
assert.Equal(t, `table "jobs" is missing`, e.String())
})
t.Run("missing column", func(t *testing.T) {
e := SchemaError{Table: "users", Column: "email"}
assert.Equal(t, `column "email" on table "users" is missing`, e.String())
})
}
func TestRequiredSchema(t *testing.T) {
// Verify RequiredSchema covers all the tables touched by UpdateUserID and UpdateUserInfo.
expectedTables := []string{
"users",
"personal_access_tokens",
"peers",
"accounts",
"user_invites",
"proxy_access_tokens",
"jobs",
}
schemaTableNames := make([]string, len(RequiredSchema))
for i, s := range RequiredSchema {
schemaTableNames[i] = s.Table
}
for _, expected := range expectedTables {
assert.Contains(t, schemaTableNames, expected, "RequiredSchema should include table %q", expected)
}
}
func TestCheckSchema_MockStore(t *testing.T) {
t.Run("returns nil when all schema exists", func(t *testing.T) {
ms := &testStore{
checkSchemaFunc: func(checks []SchemaCheck) []SchemaError {
return nil
},
}
errs := ms.CheckSchema(RequiredSchema)
assert.Empty(t, errs)
})
t.Run("returns errors for missing tables", func(t *testing.T) {
ms := &testStore{
checkSchemaFunc: func(checks []SchemaCheck) []SchemaError {
return []SchemaError{
{Table: "jobs"},
{Table: "proxy_access_tokens"},
}
},
}
errs := ms.CheckSchema(RequiredSchema)
require.Len(t, errs, 2)
assert.Equal(t, "jobs", errs[0].Table)
assert.Equal(t, "", errs[0].Column)
assert.Equal(t, "proxy_access_tokens", errs[1].Table)
})
t.Run("returns errors for missing columns", func(t *testing.T) {
ms := &testStore{
checkSchemaFunc: func(checks []SchemaCheck) []SchemaError {
return []SchemaError{
{Table: "users", Column: "email"},
{Table: "users", Column: "name"},
}
},
}
errs := ms.CheckSchema(RequiredSchema)
require.Len(t, errs, 2)
assert.Equal(t, "users", errs[0].Table)
assert.Equal(t, "email", errs[0].Column)
})
}

View File

@@ -0,0 +1,82 @@
package migration
import (
"context"
"fmt"
"github.com/netbirdio/netbird/management/server/types"
)
// SchemaCheck represents a table and the columns required on it.
type SchemaCheck struct {
Table string
Columns []string
}
// RequiredSchema lists all tables and columns that the migration tool needs.
// If any are missing, the user must upgrade their management server first so
// that the automatic GORM migrations create them.
var RequiredSchema = []SchemaCheck{
{Table: "users", Columns: []string{"id", "email", "name", "account_id"}},
{Table: "personal_access_tokens", Columns: []string{"user_id", "created_by"}},
{Table: "peers", Columns: []string{"user_id"}},
{Table: "accounts", Columns: []string{"created_by"}},
{Table: "user_invites", Columns: []string{"created_by"}},
{Table: "proxy_access_tokens", Columns: []string{"created_by"}},
{Table: "jobs", Columns: []string{"triggered_by"}},
}
// SchemaError describes a single missing table or column.
type SchemaError struct {
Table string
Column string // empty when the whole table is missing
}
func (e SchemaError) String() string {
if e.Column == "" {
return fmt.Sprintf("table %q is missing", e.Table)
}
return fmt.Sprintf("column %q on table %q is missing", e.Column, e.Table)
}
// Store defines the data store operations required for IdP user migration.
// This interface is separate from the main store.Store interface because these methods
// are only used during one-time migration and should be removed once migration tooling
// is no longer needed.
//
// The SQL store implementations (SqlStore) already have these methods on their concrete
// types, so they satisfy this interface via Go's structural typing with zero code changes.
type Store interface {
// ListUsers returns all users across all accounts.
ListUsers(ctx context.Context) ([]*types.User, error)
// UpdateUserID atomically updates a user's ID and all foreign key references
// across the database (peers, groups, policies, PATs, etc.).
UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error
// UpdateUserInfo updates a user's email and name in the store.
UpdateUserInfo(ctx context.Context, userID, email, name string) error
// CheckSchema verifies that all tables and columns required by the migration
// exist in the database. Returns a list of problems; an empty slice means OK.
CheckSchema(checks []SchemaCheck) []SchemaError
}
// RequiredEventSchema lists all tables and columns that the migration tool needs
// in the activity/event store.
var RequiredEventSchema = []SchemaCheck{
{Table: "events", Columns: []string{"initiator_id", "target_id"}},
{Table: "deleted_users", Columns: []string{"id"}},
}
// EventStore defines the activity event store operations required for migration.
// Like Store, this is a temporary interface for migration tooling only.
type EventStore interface {
// CheckSchema verifies that all tables and columns required by the migration
// exist in the event database. Returns a list of problems; an empty slice means OK.
CheckSchema(checks []SchemaCheck) []SchemaError
// UpdateUserID updates all event references (initiator_id, target_id) and
// deleted_users records to use the new user ID format.
UpdateUserID(ctx context.Context, oldUserID, newUserID string) error
}

View File

@@ -64,10 +64,19 @@ type Manager interface {
GetVersionInfo(ctx context.Context) (*VersionInfo, error)
}
type instanceStore interface {
GetAccountsCounter(ctx context.Context) (int64, error)
}
type embeddedIdP interface {
CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error)
GetAllAccounts(ctx context.Context) (map[string][]*idp.UserData, error)
}
// DefaultManager is the default implementation of Manager.
type DefaultManager struct {
store store.Store
embeddedIdpManager *idp.EmbeddedIdPManager
store instanceStore
embeddedIdpManager embeddedIdP
setupRequired bool
setupMu sync.RWMutex
@@ -82,18 +91,18 @@ type DefaultManager struct {
// NewManager creates a new instance manager.
// If idpManager is not an EmbeddedIdPManager, setup-related operations will return appropriate defaults.
func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager) (Manager, error) {
embeddedIdp, _ := idpManager.(*idp.EmbeddedIdPManager)
embeddedIdp, ok := idpManager.(*idp.EmbeddedIdPManager)
m := &DefaultManager{
store: store,
embeddedIdpManager: embeddedIdp,
setupRequired: false,
store: store,
setupRequired: false,
httpClient: &http.Client{
Timeout: httpTimeout,
},
}
if embeddedIdp != nil {
if ok && embeddedIdp != nil {
m.embeddedIdpManager = embeddedIdp
err := m.loadSetupRequired(ctx)
if err != nil {
return nil, err
@@ -143,36 +152,61 @@ func (m *DefaultManager) IsSetupRequired(_ context.Context) (bool, error) {
// CreateOwnerUser creates the initial owner user in the embedded IDP.
func (m *DefaultManager) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) {
if err := m.validateSetupInfo(email, password, name); err != nil {
return nil, err
}
if m.embeddedIdpManager == nil {
return nil, errors.New("embedded IDP is not enabled")
}
m.setupMu.RLock()
setupRequired := m.setupRequired
m.setupMu.RUnlock()
if err := m.validateSetupInfo(email, password, name); err != nil {
return nil, err
}
if !setupRequired {
m.setupMu.Lock()
defer m.setupMu.Unlock()
if !m.setupRequired {
return nil, status.Errorf(status.PreconditionFailed, "setup already completed")
}
if err := m.checkSetupRequiredFromDB(ctx); err != nil {
var sErr *status.Error
if errors.As(err, &sErr) && sErr.Type() == status.PreconditionFailed {
m.setupRequired = false
}
return nil, err
}
userData, err := m.embeddedIdpManager.CreateUserWithPassword(ctx, email, password, name)
if err != nil {
return nil, fmt.Errorf("failed to create user in embedded IdP: %w", err)
}
m.setupMu.Lock()
m.setupRequired = false
m.setupMu.Unlock()
log.WithContext(ctx).Infof("created owner user %s in embedded IdP", email)
return userData, nil
}
func (m *DefaultManager) checkSetupRequiredFromDB(ctx context.Context) error {
numAccounts, err := m.store.GetAccountsCounter(ctx)
if err != nil {
return fmt.Errorf("failed to check accounts: %w", err)
}
if numAccounts > 0 {
return status.Errorf(status.PreconditionFailed, "setup already completed")
}
users, err := m.embeddedIdpManager.GetAllAccounts(ctx)
if err != nil {
return fmt.Errorf("failed to check IdP users: %w", err)
}
if len(users) > 0 {
return status.Errorf(status.PreconditionFailed, "setup already completed")
}
return nil
}
func (m *DefaultManager) validateSetupInfo(email, password, name string) error {
if email == "" {
return status.Errorf(status.InvalidArgument, "email is required")
@@ -189,6 +223,9 @@ func (m *DefaultManager) validateSetupInfo(email, password, name string) error {
if len(password) < 8 {
return status.Errorf(status.InvalidArgument, "password must be at least 8 characters")
}
if len(password) > 72 {
return status.Errorf(status.InvalidArgument, "password must be at most 72 characters")
}
return nil
}

View File

@@ -3,7 +3,12 @@ package instance
import (
"context"
"errors"
"fmt"
"net/http"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -11,173 +16,215 @@ import (
"github.com/netbirdio/netbird/management/server/idp"
)
// mockStore implements a minimal store.Store for testing
type mockIdP struct {
mu sync.Mutex
createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error)
users map[string][]*idp.UserData
getAllAccountsErr error
}
func (m *mockIdP) CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error) {
if m.createUserFunc != nil {
return m.createUserFunc(ctx, email, password, name)
}
return &idp.UserData{ID: "test-user-id", Email: email, Name: name}, nil
}
func (m *mockIdP) GetAllAccounts(_ context.Context) (map[string][]*idp.UserData, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.getAllAccountsErr != nil {
return nil, m.getAllAccountsErr
}
return m.users, nil
}
type mockStore struct {
accountsCount int64
err error
}
func (m *mockStore) GetAccountsCounter(ctx context.Context) (int64, error) {
func (m *mockStore) GetAccountsCounter(_ context.Context) (int64, error) {
if m.err != nil {
return 0, m.err
}
return m.accountsCount, nil
}
// mockEmbeddedIdPManager wraps the real EmbeddedIdPManager for testing
type mockEmbeddedIdPManager struct {
createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error)
}
func (m *mockEmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error) {
if m.createUserFunc != nil {
return m.createUserFunc(ctx, email, password, name)
func newTestManager(idpMock *mockIdP, storeMock *mockStore) *DefaultManager {
return &DefaultManager{
store: storeMock,
embeddedIdpManager: idpMock,
setupRequired: true,
httpClient: &http.Client{Timeout: httpTimeout},
}
return &idp.UserData{
ID: "test-user-id",
Email: email,
Name: name,
}, nil
}
// testManager is a test implementation that accepts our mock types
type testManager struct {
store *mockStore
embeddedIdpManager *mockEmbeddedIdPManager
}
func (m *testManager) IsSetupRequired(ctx context.Context) (bool, error) {
if m.embeddedIdpManager == nil {
return false, nil
}
count, err := m.store.GetAccountsCounter(ctx)
if err != nil {
return false, err
}
return count == 0, nil
}
func (m *testManager) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) {
if m.embeddedIdpManager == nil {
return nil, errors.New("embedded IDP is not enabled")
}
return m.embeddedIdpManager.CreateUserWithPassword(ctx, email, password, name)
}
func TestIsSetupRequired_EmbeddedIdPDisabled(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 0},
embeddedIdpManager: nil, // No embedded IDP
}
required, err := manager.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.False(t, required, "setup should not be required when embedded IDP is disabled")
}
func TestIsSetupRequired_NoAccounts(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 0},
embeddedIdpManager: &mockEmbeddedIdPManager{},
}
required, err := manager.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.True(t, required, "setup should be required when no accounts exist")
}
func TestIsSetupRequired_AccountsExist(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 1},
embeddedIdpManager: &mockEmbeddedIdPManager{},
}
required, err := manager.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.False(t, required, "setup should not be required when accounts exist")
}
func TestIsSetupRequired_MultipleAccounts(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 5},
embeddedIdpManager: &mockEmbeddedIdPManager{},
}
required, err := manager.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.False(t, required, "setup should not be required when multiple accounts exist")
}
func TestIsSetupRequired_StoreError(t *testing.T) {
manager := &testManager{
store: &mockStore{err: errors.New("database error")},
embeddedIdpManager: &mockEmbeddedIdPManager{},
}
_, err := manager.IsSetupRequired(context.Background())
assert.Error(t, err, "should return error when store fails")
}
func TestCreateOwnerUser_Success(t *testing.T) {
expectedEmail := "admin@example.com"
expectedName := "Admin User"
expectedPassword := "securepassword123"
idpMock := &mockIdP{}
mgr := newTestManager(idpMock, &mockStore{})
manager := &testManager{
store: &mockStore{accountsCount: 0},
embeddedIdpManager: &mockEmbeddedIdPManager{
createUserFunc: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
assert.Equal(t, expectedEmail, email)
assert.Equal(t, expectedPassword, password)
assert.Equal(t, expectedName, name)
return &idp.UserData{
ID: "created-user-id",
Email: email,
Name: name,
}, nil
},
},
}
userData, err := manager.CreateOwnerUser(context.Background(), expectedEmail, expectedPassword, expectedName)
userData, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.NoError(t, err)
assert.Equal(t, "created-user-id", userData.ID)
assert.Equal(t, expectedEmail, userData.Email)
assert.Equal(t, expectedName, userData.Name)
assert.Equal(t, "admin@example.com", userData.Email)
_, err = mgr.CreateOwnerUser(context.Background(), "admin2@example.com", "password123", "Admin2")
require.Error(t, err)
assert.Contains(t, err.Error(), "setup already completed")
}
func TestCreateOwnerUser_SetupAlreadyCompleted(t *testing.T) {
mgr := newTestManager(&mockIdP{}, &mockStore{})
mgr.setupRequired = false
_, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "setup already completed")
}
func TestCreateOwnerUser_EmbeddedIdPDisabled(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 0},
embeddedIdpManager: nil,
}
mgr := &DefaultManager{setupRequired: true}
_, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
assert.Error(t, err, "should return error when embedded IDP is disabled")
_, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "embedded IDP is not enabled")
}
func TestCreateOwnerUser_IdPError(t *testing.T) {
manager := &testManager{
store: &mockStore{accountsCount: 0},
embeddedIdpManager: &mockEmbeddedIdPManager{
createUserFunc: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
return nil, errors.New("user already exists")
},
idpMock := &mockIdP{
createUserFunc: func(_ context.Context, _, _, _ string) (*idp.UserData, error) {
return nil, errors.New("provider error")
},
}
mgr := newTestManager(idpMock, &mockStore{})
_, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
assert.Error(t, err, "should return error when IDP fails")
_, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "provider error")
required, _ := mgr.IsSetupRequired(context.Background())
assert.True(t, required, "setup should still be required after IdP error")
}
func TestCreateOwnerUser_TransientDBError_DoesNotBlockSetup(t *testing.T) {
mgr := newTestManager(&mockIdP{}, &mockStore{err: errors.New("connection refused")})
_, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "connection refused")
required, _ := mgr.IsSetupRequired(context.Background())
assert.True(t, required, "setup should still be required after transient DB error")
mgr.store = &mockStore{}
userData, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.NoError(t, err)
assert.Equal(t, "admin@example.com", userData.Email)
}
func TestCreateOwnerUser_TransientIdPError_DoesNotBlockSetup(t *testing.T) {
idpMock := &mockIdP{getAllAccountsErr: errors.New("connection reset")}
mgr := newTestManager(idpMock, &mockStore{})
_, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "connection reset")
required, _ := mgr.IsSetupRequired(context.Background())
assert.True(t, required, "setup should still be required after transient IdP error")
idpMock.getAllAccountsErr = nil
userData, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.NoError(t, err)
assert.Equal(t, "admin@example.com", userData.Email)
}
func TestCreateOwnerUser_DBCheckBlocksConcurrent(t *testing.T) {
idpMock := &mockIdP{
users: map[string][]*idp.UserData{
"acc1": {{ID: "existing-user"}},
},
}
mgr := newTestManager(idpMock, &mockStore{})
_, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "setup already completed")
}
func TestCreateOwnerUser_DBCheckBlocksWhenAccountsExist(t *testing.T) {
mgr := newTestManager(&mockIdP{}, &mockStore{accountsCount: 1})
_, err := mgr.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "setup already completed")
}
func TestCreateOwnerUser_ConcurrentRequests(t *testing.T) {
var idpCallCount atomic.Int32
var successCount atomic.Int32
var failCount atomic.Int32
idpMock := &mockIdP{
createUserFunc: func(_ context.Context, email, _, _ string) (*idp.UserData, error) {
idpCallCount.Add(1)
time.Sleep(50 * time.Millisecond)
return &idp.UserData{ID: "user-1", Email: email, Name: "Owner"}, nil
},
}
mgr := newTestManager(idpMock, &mockStore{})
var wg sync.WaitGroup
for i := range 10 {
wg.Add(1)
go func(idx int) {
defer wg.Done()
_, err := mgr.CreateOwnerUser(
context.Background(),
fmt.Sprintf("owner%d@example.com", idx),
"password1234",
fmt.Sprintf("Owner%d", idx),
)
if err != nil {
failCount.Add(1)
} else {
successCount.Add(1)
}
}(i)
}
wg.Wait()
assert.Equal(t, int32(1), successCount.Load(), "exactly one concurrent setup request should succeed")
assert.Equal(t, int32(9), failCount.Load(), "remaining concurrent requests should fail")
assert.Equal(t, int32(1), idpCallCount.Load(), "IdP CreateUser should be called exactly once")
}
func TestIsSetupRequired_EmbeddedIdPDisabled(t *testing.T) {
mgr := &DefaultManager{}
required, err := mgr.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.False(t, required)
}
func TestIsSetupRequired_ReturnsFlag(t *testing.T) {
mgr := newTestManager(&mockIdP{}, &mockStore{})
required, err := mgr.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.True(t, required)
mgr.setupMu.Lock()
mgr.setupRequired = false
mgr.setupMu.Unlock()
required, err = mgr.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.False(t, required)
}
func TestDefaultManager_ValidateSetupRequest(t *testing.T) {
manager := &DefaultManager{
setupRequired: true,
}
manager := &DefaultManager{setupRequired: true}
tests := []struct {
name string
@@ -188,11 +235,10 @@ func TestDefaultManager_ValidateSetupRequest(t *testing.T) {
errorMsg string
}{
{
name: "valid request",
email: "admin@example.com",
password: "password123",
userName: "Admin User",
expectError: false,
name: "valid request",
email: "admin@example.com",
password: "password123",
userName: "Admin User",
},
{
name: "empty email",
@@ -235,11 +281,24 @@ func TestDefaultManager_ValidateSetupRequest(t *testing.T) {
errorMsg: "password must be at least 8 characters",
},
{
name: "password exactly 8 characters",
name: "password exactly 8 characters",
email: "admin@example.com",
password: "12345678",
userName: "Admin User",
},
{
name: "password exactly 72 characters",
email: "admin@example.com",
password: "aaaaaaaabbbbbbbbccccccccddddddddeeeeeeeeffffffffgggggggghhhhhhhhiiiiiiii",
userName: "Admin User",
},
{
name: "password too long",
email: "admin@example.com",
password: "12345678",
password: "aaaaaaaabbbbbbbbccccccccddddddddeeeeeeeeffffffffgggggggghhhhhhhhiiiiiiiij",
userName: "Admin User",
expectError: false,
expectError: true,
errorMsg: "password must be at most 72 characters",
},
}
@@ -255,14 +314,3 @@ func TestDefaultManager_ValidateSetupRequest(t *testing.T) {
})
}
}
func TestDefaultManager_CreateOwnerUser_SetupAlreadyCompleted(t *testing.T) {
manager := &DefaultManager{
setupRequired: false,
embeddedIdpManager: &idp.EmbeddedIdPManager{},
}
_, err := manager.CreateOwnerUser(context.Background(), "admin@example.com", "password123", "Admin")
require.Error(t, err)
assert.Contains(t, err.Error(), "setup already completed")
}

View File

@@ -46,7 +46,7 @@ type MockAccountManager struct {
AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error)
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error)
GetGroupByNameFunc func(ctx context.Context, groupName, accountID string) (*types.Group, error)
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
@@ -406,9 +406,9 @@ func (am *MockAccountManager) AddPeer(
}
// GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface
func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) {
func (am *MockAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
if am.GetGroupByNameFunc != nil {
return am.GetGroupByNameFunc(ctx, accountID, groupName)
return am.GetGroupByNameFunc(ctx, groupName, accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented")
}

View File

@@ -6,7 +6,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/networks/resources"
@@ -26,11 +25,10 @@ func Test_GetAllNetworksReturnsNetworks(t *testing.T) {
}
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
resourcesManager := resources.NewManager(s, groupsManager, &am, nil)
manager := NewManager(s, resourcesManager, routerManager, &am)
networks, err := manager.GetAllNetworks(ctx, accountID, userID)
require.NoError(t, err)
@@ -38,28 +36,6 @@ func Test_GetAllNetworksReturnsNetworks(t *testing.T) {
require.Equal(t, "testNetworkId", networks[0].ID)
}
func Test_GetAllNetworksReturnsPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
userID := "testUserId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
networks, err := manager.GetAllNetworks(ctx, accountID, userID)
require.Error(t, err)
require.Nil(t, networks)
}
func Test_GetNetworkReturnsNetwork(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
@@ -72,40 +48,16 @@ func Test_GetNetworkReturnsNetwork(t *testing.T) {
}
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
resourcesManager := resources.NewManager(s, groupsManager, &am, nil)
manager := NewManager(s, resourcesManager, routerManager, &am)
networks, err := manager.GetNetwork(ctx, accountID, userID, networkID)
require.NoError(t, err)
require.Equal(t, "testNetworkId", networks.ID)
}
func Test_GetNetworkReturnsPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
userID := "testUserId"
networkID := "testNetworkId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
network, err := manager.GetNetwork(ctx, accountID, userID, networkID)
require.Error(t, err)
require.Nil(t, network)
}
func Test_CreateNetworkSuccessfully(t *testing.T) {
ctx := context.Background()
userID := "testAdminId"
@@ -120,42 +72,16 @@ func Test_CreateNetworkSuccessfully(t *testing.T) {
}
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
resourcesManager := resources.NewManager(s, groupsManager, &am, nil)
manager := NewManager(s, resourcesManager, routerManager, &am)
createdNetwork, err := manager.CreateNetwork(ctx, userID, network)
require.NoError(t, err)
require.Equal(t, network.Name, createdNetwork.Name)
}
func Test_CreateNetworkFailsWithPermissionDenied(t *testing.T) {
ctx := context.Background()
userID := "testUserId"
network := &types.Network{
AccountID: "testAccountId",
Name: "new-network",
}
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
createdNetwork, err := manager.CreateNetwork(ctx, userID, network)
require.Error(t, err)
require.Nil(t, createdNetwork)
}
func Test_DeleteNetworkSuccessfully(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
@@ -168,38 +94,15 @@ func Test_DeleteNetworkSuccessfully(t *testing.T) {
}
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
resourcesManager := resources.NewManager(s, groupsManager, &am, nil)
manager := NewManager(s, resourcesManager, routerManager, &am)
err = manager.DeleteNetwork(ctx, accountID, userID, networkID)
require.NoError(t, err)
}
func Test_DeleteNetworkFailsWithPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
userID := "testUserId"
networkID := "testNetworkId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
err = manager.DeleteNetwork(ctx, accountID, userID, networkID)
require.Error(t, err)
}
func Test_UpdateNetworkSuccessfully(t *testing.T) {
ctx := context.Background()
userID := "testAdminId"
@@ -215,40 +118,12 @@ func Test_UpdateNetworkSuccessfully(t *testing.T) {
}
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
resourcesManager := resources.NewManager(s, groupsManager, &am, nil)
manager := NewManager(s, resourcesManager, routerManager, &am)
updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network)
require.NoError(t, err)
require.Equal(t, network.Name, updatedNetwork.Name)
}
func Test_UpdateNetworkFailsWithPermissionDenied(t *testing.T) {
ctx := context.Background()
userID := "testUserId"
network := &types.Network{
AccountID: "testAccountId",
ID: "testNetworkId",
Name: "new-network",
}
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
am := mock_server.MockAccountManager{}
permissionsManager := permissions.NewManager(s)
groupsManager := groups.NewManagerMock()
routerManager := routers.NewManagerMock()
resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am, nil)
manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am)
updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network)
require.Error(t, err)
require.Nil(t, updatedNetwork)
}

View File

@@ -7,13 +7,11 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
reverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/networks/resources/types"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) {
@@ -27,41 +25,17 @@ func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
manager := NewManager(store, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
require.NoError(t, err)
require.Len(t, resources, 2)
}
func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
userID := "testUserId"
networkID := "testNetworkId"
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
require.Error(t, err)
require.Equal(t, status.NewPermissionDeniedError(), err)
require.Nil(t, resources)
}
func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
@@ -72,41 +46,17 @@ func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
manager := NewManager(store, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
require.NoError(t, err)
require.Len(t, resources, 2)
}
func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
userID := "testUserId"
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
require.Error(t, err)
require.Equal(t, status.NewPermissionDeniedError(), err)
require.Nil(t, resources)
}
func Test_GetResourceInNetworkReturnsResources(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
@@ -119,43 +69,17 @@ func Test_GetResourceInNetworkReturnsResources(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
manager := NewManager(store, groupsManager, &am, serviceManager)
resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
require.NoError(t, err)
require.Equal(t, resourceID, resource.ID)
}
func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
userID := "testUserId"
networkID := "testNetworkId"
resourceID := "testResourceId"
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
require.Error(t, err)
require.Equal(t, status.NewPermissionDeniedError(), err)
require.Nil(t, resources)
}
func Test_CreateResourceSuccessfully(t *testing.T) {
ctx := context.Background()
userID := "testAdminId"
@@ -172,48 +96,18 @@ func Test_CreateResourceSuccessfully(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
serviceManager := reverseproxy.NewMockManager(ctrl)
serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
manager := NewManager(store, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource)
require.NoError(t, err)
require.Equal(t, resource.Name, createdResource.Name)
}
func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) {
ctx := context.Background()
userID := "testUserId"
resource := &types.NetworkResource{
AccountID: "testAccountId",
NetworkID: "testNetworkId",
Name: "testResourceId",
Description: "description",
Address: "192.168.1.1",
}
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err)
require.Equal(t, status.NewPermissionDeniedError(), err)
require.Nil(t, createdResource)
}
func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) {
ctx := context.Background()
userID := "testAdminId"
@@ -230,12 +124,11 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
manager := NewManager(store, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err)
@@ -258,12 +151,11 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
manager := NewManager(store, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err)
@@ -290,13 +182,12 @@ func Test_UpdateResourceSuccessfully(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
serviceManager := reverseproxy.NewMockManager(ctrl)
serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
manager := NewManager(store, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.NoError(t, err)
@@ -325,12 +216,11 @@ func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
manager := NewManager(store, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err)
@@ -357,43 +247,11 @@ func Test_UpdateResourceFailsWithNameInUse(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err)
require.Nil(t, updatedResource)
}
func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
userID := "testUserId"
networkID := "testNetworkId"
resourceID := "testResourceId"
resource := &types.NetworkResource{
AccountID: accountID,
NetworkID: networkID,
Name: resourceID,
Description: "new-description",
Address: "1.2.3.0/24",
}
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
manager := NewManager(store, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err)
@@ -412,37 +270,13 @@ func Test_DeleteResourceSuccessfully(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
serviceManager := reverseproxy.NewMockManager(ctrl)
serviceManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
manager := NewManager(store, groupsManager, &am, serviceManager)
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
require.NoError(t, err)
}
func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
userID := "testUserId"
networkID := "testNetworkId"
resourceID := "testResourceId"
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(store)
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
require.Error(t, err)
}

View File

@@ -6,11 +6,9 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) {
@@ -24,9 +22,8 @@ func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
manager := NewManager(s, &am)
routers, err := manager.GetAllRoutersInNetwork(ctx, accountID, userID, networkID)
require.NoError(t, err)
@@ -34,27 +31,6 @@ func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) {
require.Equal(t, "testRouterId", routers[0].ID)
}
func Test_GetAllRoutersInNetworkReturnsPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
userID := "testUserId"
networkID := "testNetworkId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
routers, err := manager.GetAllRoutersInNetwork(ctx, accountID, userID, networkID)
require.Error(t, err)
require.Equal(t, status.NewPermissionDeniedError(), err)
require.Nil(t, routers)
}
func Test_GetRouterReturnsRouter(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
@@ -67,37 +43,14 @@ func Test_GetRouterReturnsRouter(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
manager := NewManager(s, &am)
router, err := manager.GetRouter(ctx, accountID, userID, networkID, resourceID)
require.NoError(t, err)
require.Equal(t, "testRouterId", router.ID)
}
func Test_GetRouterReturnsPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
userID := "testUserId"
networkID := "testNetworkId"
resourceID := "testRouterId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
router, err := manager.GetRouter(ctx, accountID, userID, networkID, resourceID)
require.Error(t, err)
require.Equal(t, status.NewPermissionDeniedError(), err)
require.Nil(t, router)
}
func Test_CreateRouterSuccessfully(t *testing.T) {
ctx := context.Background()
userID := "testAdminId"
@@ -111,9 +64,8 @@ func Test_CreateRouterSuccessfully(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
manager := NewManager(s, &am)
createdRouter, err := manager.CreateRouter(ctx, userID, router)
require.NoError(t, err)
@@ -124,29 +76,6 @@ func Test_CreateRouterSuccessfully(t *testing.T) {
require.Equal(t, router.Masquerade, createdRouter.Masquerade)
}
func Test_CreateRouterFailsWithPermissionDenied(t *testing.T) {
ctx := context.Background()
userID := "testUserId"
router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 9999, true)
if err != nil {
require.NoError(t, err)
}
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
createdRouter, err := manager.CreateRouter(ctx, userID, router)
require.Error(t, err)
require.Equal(t, status.NewPermissionDeniedError(), err)
require.Nil(t, createdRouter)
}
func Test_DeleteRouterSuccessfully(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
@@ -159,35 +88,13 @@ func Test_DeleteRouterSuccessfully(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
manager := NewManager(s, &am)
err = manager.DeleteRouter(ctx, accountID, userID, networkID, routerID)
require.NoError(t, err)
}
func Test_DeleteRouterFailsWithPermissionDenied(t *testing.T) {
ctx := context.Background()
accountID := "testAccountId"
userID := "testUserId"
networkID := "testNetworkId"
routerID := "testRouterId"
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
err = manager.DeleteRouter(ctx, accountID, userID, networkID, routerID)
require.Error(t, err)
require.Equal(t, status.NewPermissionDeniedError(), err)
}
func Test_UpdateRouterSuccessfully(t *testing.T) {
ctx := context.Background()
userID := "testAdminId"
@@ -201,34 +108,10 @@ func Test_UpdateRouterSuccessfully(t *testing.T) {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
manager := NewManager(s, &am)
updatedRouter, err := manager.UpdateRouter(ctx, userID, router)
require.NoError(t, err)
require.Equal(t, router.Metric, updatedRouter.Metric)
}
func Test_UpdateRouterFailsWithPermissionDenied(t *testing.T) {
ctx := context.Background()
userID := "testUserId"
router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 1, true)
if err != nil {
require.NoError(t, err)
}
s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
permissionsManager := permissions.NewManager(s)
am := mock_server.MockAccountManager{}
manager := NewManager(s, permissionsManager, &am)
updatedRouter, err := manager.UpdateRouter(ctx, userID, router)
require.Error(t, err)
require.Equal(t, status.NewPermissionDeniedError(), err)
require.Nil(t, updatedRouter)
}

View File

@@ -799,7 +799,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
opEvent.Meta["setup_key_name"] = peerAddConfig.SetupKeyName
}
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
if !temporary {
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
}
if err := am.networkMapController.OnPeersAdded(ctx, accountID, []string{newPeer.ID}); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
@@ -1412,9 +1414,11 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil {
return nil, err
}
peerDeletedEvents = append(peerDeletedEvents, func() {
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
})
if !(peer.ProxyMeta.Embedded || peer.Meta.KernelVersion == "wasm") {
peerDeletedEvents = append(peerDeletedEvents, func() {
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
})
}
}
return peerDeletedEvents, nil

View File

@@ -2080,7 +2080,8 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) {
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
pass_host_header, rewrite_redirects, session_private_key, session_public_key
pass_host_header, rewrite_redirects, session_private_key, session_public_key,
mode, listen_port, port_auto_assigned, source, source_peer, terminated
FROM services WHERE account_id = $1`
const targetsQuery = `SELECT id, account_id, service_id, path, host, port, protocol,
@@ -2097,6 +2098,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
var auth []byte
var createdAt, certIssuedAt sql.NullTime
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
var mode, source, sourcePeer sql.NullString
err := row.Scan(
&s.ID,
&s.AccountID,
@@ -2112,6 +2114,12 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
&s.RewriteRedirects,
&sessionPrivateKey,
&sessionPublicKey,
&mode,
&s.ListenPort,
&s.PortAutoAssigned,
&source,
&sourcePeer,
&s.Terminated,
)
if err != nil {
return nil, err
@@ -2143,6 +2151,15 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
if sessionPublicKey.Valid {
s.SessionPublicKey = sessionPublicKey.String
}
if mode.Valid {
s.Mode = mode.String
}
if source.Valid {
s.Source = source.String
}
if sourcePeer.Valid {
s.SourcePeer = sourcePeer.String
}
s.Targets = []*rpservice.Target{}
return &s, nil
@@ -5445,7 +5462,7 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)).
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
Distinct("cluster_address").
Pluck("cluster_address", &addresses)
@@ -5463,7 +5480,7 @@ func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster,
result := s.db.Model(&proxy.Proxy{}).
Select("cluster_address as address, COUNT(*) as connected_proxies").
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)).
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
Group("cluster_address").
Scan(&clusters)
@@ -5475,6 +5492,63 @@ func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster,
return clusters, nil
}
// proxyActiveThreshold is the maximum age of a heartbeat for a proxy to be
// considered active. Must be at least 2x the heartbeat interval (1 min).
const proxyActiveThreshold = 2 * time.Minute
var validCapabilityColumns = map[string]struct{}{
"supports_custom_ports": {},
"require_subdomain": {},
}
// GetClusterSupportsCustomPorts returns whether any active proxy in the cluster
// supports custom ports. Returns nil when no proxy reported the capability.
func (s *SqlStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
return s.getClusterCapability(ctx, clusterAddr, "supports_custom_ports")
}
// GetClusterRequireSubdomain returns whether any active proxy in the cluster
// requires a subdomain. Returns nil when no proxy reported the capability.
func (s *SqlStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
return s.getClusterCapability(ctx, clusterAddr, "require_subdomain")
}
// getClusterCapability returns an aggregated boolean capability for the given
// cluster. It checks active (connected, recently seen) proxies and returns:
// - *true if any proxy in the cluster has the capability set to true,
// - *false if at least one proxy reported but none set it to true,
// - nil if no proxy reported the capability at all.
func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column string) *bool {
if _, ok := validCapabilityColumns[column]; !ok {
log.WithContext(ctx).Errorf("invalid capability column: %s", column)
return nil
}
var result struct {
HasCapability bool
AnyTrue bool
}
err := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Select("COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) > 0 AS has_capability, "+
"COALESCE(MAX(CASE WHEN "+column+" = true THEN 1 ELSE 0 END), 0) = 1 AS any_true").
Where("cluster_address = ? AND status = ? AND last_seen > ?",
clusterAddr, "connected", time.Now().Add(-proxyActiveThreshold)).
Scan(&result).Error
if err != nil {
log.WithContext(ctx).Errorf("query cluster capability %s for %s: %v", column, clusterAddr, err)
return nil
}
if !result.HasCapability {
return nil
}
return &result.AnyTrue
}
// CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration
func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
cutoffTime := time.Now().Add(-inactivityDuration)

View File

@@ -0,0 +1,177 @@
package store
// This file contains migration-only methods on SqlStore.
// They satisfy the migration.Store interface via duck typing.
// Delete this file when migration tooling is no longer needed.
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
"github.com/netbirdio/netbird/management/server/idp/migration"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
func (s *SqlStore) CheckSchema(checks []migration.SchemaCheck) []migration.SchemaError {
migrator := s.db.Migrator()
var errs []migration.SchemaError
for _, check := range checks {
if !migrator.HasTable(check.Table) {
errs = append(errs, migration.SchemaError{Table: check.Table})
continue
}
for _, col := range check.Columns {
if !migrator.HasColumn(check.Table, col) {
errs = append(errs, migration.SchemaError{Table: check.Table, Column: col})
}
}
}
return errs
}
func (s *SqlStore) ListUsers(ctx context.Context) ([]*types.User, error) {
tx := s.db
var users []*types.User
result := tx.Find(&users)
if result.Error != nil {
log.WithContext(ctx).Errorf("error when listing users from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue listing users from store")
}
for _, user := range users {
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt user: %w", err)
}
}
return users, nil
}
// txDeferFKConstraints defers foreign key constraint checks for the duration of the transaction.
// MySQL is already handled by s.transaction (SET FOREIGN_KEY_CHECKS = 0).
func (s *SqlStore) txDeferFKConstraints(tx *gorm.DB) error {
if s.storeEngine == types.SqliteStoreEngine {
return tx.Exec("PRAGMA defer_foreign_keys = ON").Error
}
if s.storeEngine != types.PostgresStoreEngine {
return nil
}
// GORM creates FK constraints as NOT DEFERRABLE by default, so
// SET CONSTRAINTS ALL DEFERRED is a no-op unless we ALTER them first.
err := tx.Exec(`
DO $$ DECLARE r RECORD;
BEGIN
FOR r IN SELECT conname, conrelid::regclass AS tbl
FROM pg_constraint WHERE contype = 'f' AND NOT condeferrable
LOOP
EXECUTE format('ALTER TABLE %s ALTER CONSTRAINT %I DEFERRABLE INITIALLY IMMEDIATE', r.tbl, r.conname);
END LOOP;
END $$
`).Error
if err != nil {
return fmt.Errorf("make FK constraints deferrable: %w", err)
}
return tx.Exec("SET CONSTRAINTS ALL DEFERRED").Error
}
// txRestoreFKConstraints reverts FK constraints back to NOT DEFERRABLE after the
// deferred updates are done but before the transaction commits.
func (s *SqlStore) txRestoreFKConstraints(tx *gorm.DB) error {
if s.storeEngine != types.PostgresStoreEngine {
return nil
}
return tx.Exec(`
DO $$ DECLARE r RECORD;
BEGIN
FOR r IN SELECT conname, conrelid::regclass AS tbl
FROM pg_constraint WHERE contype = 'f' AND condeferrable
LOOP
EXECUTE format('ALTER TABLE %s ALTER CONSTRAINT %I NOT DEFERRABLE', r.tbl, r.conname);
END LOOP;
END $$
`).Error
}
func (s *SqlStore) UpdateUserInfo(ctx context.Context, userID, email, name string) error {
user := &types.User{Email: email, Name: name}
if err := user.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt user info: %w", err)
}
result := s.db.Model(&types.User{}).Where("id = ?", userID).Updates(map[string]any{
"email": user.Email,
"name": user.Name,
})
if result.Error != nil {
log.WithContext(ctx).Errorf("error updating user info for %s: %s", userID, result.Error)
return status.Errorf(status.Internal, "failed to update user info")
}
return nil
}
func (s *SqlStore) UpdateUserID(ctx context.Context, accountID, oldUserID, newUserID string) error {
type fkUpdate struct {
model any
column string
where string
}
updates := []fkUpdate{
{&types.PersonalAccessToken{}, "user_id", "user_id = ?"},
{&types.PersonalAccessToken{}, "created_by", "created_by = ?"},
{&nbpeer.Peer{}, "user_id", "user_id = ?"},
{&types.UserInviteRecord{}, "created_by", "created_by = ?"},
{&types.Account{}, "created_by", "created_by = ?"},
{&types.ProxyAccessToken{}, "created_by", "created_by = ?"},
{&types.Job{}, "triggered_by", "triggered_by = ?"},
}
log.Info("Updating user ID in the store")
err := s.transaction(func(tx *gorm.DB) error {
if err := s.txDeferFKConstraints(tx); err != nil {
return err
}
for _, u := range updates {
if err := tx.Model(u.model).Where(u.where, oldUserID).Update(u.column, newUserID).Error; err != nil {
return fmt.Errorf("update %s: %w", u.column, err)
}
}
if err := tx.Model(&types.User{}).Where(accountAndIDQueryCondition, accountID, oldUserID).Update("id", newUserID).Error; err != nil {
return fmt.Errorf("update users: %w", err)
}
return nil
})
if err != nil {
log.WithContext(ctx).Errorf("failed to update user ID in the store: %s", err)
return status.Errorf(status.Internal, "failed to update user ID in store")
}
log.Info("Restoring FK constraints")
err = s.transaction(func(tx *gorm.DB) error {
if err := s.txRestoreFKConstraints(tx); err != nil {
return fmt.Errorf("restore FK constraints: %w", err)
}
return nil
})
if err != nil {
log.WithContext(ctx).Errorf("failed to restore FK constraints after user ID update: %s", err)
return status.Errorf(status.Internal, "failed to restore FK constraints after user ID update")
}
return nil
}

View File

@@ -121,7 +121,7 @@ type Store interface {
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error)
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error)
CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error
UpdateGroups(ctx context.Context, accountID string, groups []*types.Group) error
@@ -287,6 +287,8 @@ type Store interface {
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)

View File

@@ -1361,6 +1361,34 @@ func (mr *MockStoreMockRecorder) GetAnyAccountID(ctx interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnyAccountID", reflect.TypeOf((*MockStore)(nil).GetAnyAccountID), ctx)
}
// GetClusterRequireSubdomain mocks base method.
func (m *MockStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusterRequireSubdomain", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// GetClusterRequireSubdomain indicates an expected call of GetClusterRequireSubdomain.
func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr)
}
// GetClusterSupportsCustomPorts mocks base method.
func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusterSupportsCustomPorts", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// GetClusterSupportsCustomPorts indicates an expected call of GetClusterSupportsCustomPorts.
func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr)
}
// GetCustomDomain mocks base method.
func (m *MockStore) GetCustomDomain(ctx context.Context, accountID, domainID string) (*domain.Domain, error) {
m.ctrl.T.Helper()
@@ -1438,18 +1466,18 @@ func (mr *MockStoreMockRecorder) GetGroupByID(ctx, lockStrength, accountID, grou
}
// GetGroupByName mocks base method.
func (m *MockStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types2.Group, error) {
func (m *MockStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types2.Group, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupByName", ctx, lockStrength, groupName, accountID)
ret := m.ctrl.Call(m, "GetGroupByName", ctx, lockStrength, accountID, groupName)
ret0, _ := ret[0].(*types2.Group)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupByName indicates an expected call of GetGroupByName.
func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, groupName, accountID interface{}) *gomock.Call {
func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, accountID, groupName interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, groupName, accountID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, accountID, groupName)
}
// GetGroupsByIDs mocks base method.
@@ -1946,6 +1974,21 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRouteByID", reflect.TypeOf((*MockStore)(nil).GetRouteByID), ctx, lockStrength, accountID, routeID)
}
// GetRoutingPeerNetworks mocks base method.
func (m *MockStore) GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetRoutingPeerNetworks", ctx, accountID, peerID)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetRoutingPeerNetworks indicates an expected call of GetRoutingPeerNetworks.
func (mr *MockStoreMockRecorder) GetRoutingPeerNetworks(ctx, accountID, peerID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoutingPeerNetworks", reflect.TypeOf((*MockStore)(nil).GetRoutingPeerNetworks), ctx, accountID, peerID)
}
// GetServiceByDomain mocks base method.
func (m *MockStore) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
m.ctrl.T.Helper()
@@ -2333,21 +2376,6 @@ func (mr *MockStoreMockRecorder) IncrementSetupKeyUsage(ctx, setupKeyID interfac
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementSetupKeyUsage", reflect.TypeOf((*MockStore)(nil).IncrementSetupKeyUsage), ctx, setupKeyID)
}
// GetRoutingPeerNetworks mocks base method.
func (m *MockStore) GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetRoutingPeerNetworks", ctx, accountID, peerID)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetRoutingPeerNetworks indicates an expected call of GetRoutingPeerNetworks.
func (mr *MockStoreMockRecorder) GetRoutingPeerNetworks(ctx, accountID, peerID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoutingPeerNetworks", reflect.TypeOf((*MockStore)(nil).GetRoutingPeerNetworks), ctx, accountID, peerID)
}
// IsPrimaryAccount mocks base method.
func (m *MockStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
m.ctrl.T.Helper()

View File

@@ -84,6 +84,12 @@ func setupTestAccount() *Account {
},
},
Groups: map[string]*Group{
"groupAll": {
ID: "groupAll",
Name: "All",
Peers: []string{"peer1", "peer2", "peer3", "peer11", "peer12", "peer21", "peer31", "peer32", "peer41", "peer51", "peer61"},
Issued: GroupIssuedAPI,
},
"group1": {
ID: "group1",
Peers: []string{"peer11", "peer12"},

View File

@@ -375,6 +375,10 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
return nil, err
}
if targetUser.AccountID != accountID {
return nil, status.NewPermissionDeniedError()
}
// @note this is essential to prevent non admin users with Pats create permission frpm creating one for a service user
if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) {
return nil, status.NewAdminPermissionError()
@@ -407,6 +411,10 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
return err
}
if targetUser.AccountID != accountID {
return status.NewPermissionDeniedError()
}
if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) {
return status.NewAdminPermissionError()
}
@@ -438,6 +446,10 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
return nil, err
}
if targetUser.AccountID != accountID {
return nil, status.NewPermissionDeniedError()
}
if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) {
return nil, status.NewAdminPermissionError()
}
@@ -457,6 +469,10 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
return nil, err
}
if targetUser.AccountID != accountID {
return nil, status.NewPermissionDeniedError()
}
if initiatorUserID != targetUserID && !(initiatorUser.HasAdminPower() && targetUser.IsServiceUser) {
return nil, status.NewAdminPermissionError()
}
@@ -691,9 +707,15 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
updatedUser.Role = update.Role
updatedUser.Blocked = update.Blocked
updatedUser.AutoGroups = update.AutoGroups
// these two fields can't be set via API, only via direct call to the method
// these fields can't be set via API, only via direct call to the method
updatedUser.Issued = update.Issued
updatedUser.IntegrationReference = update.IntegrationReference
if update.Name != "" {
updatedUser.Name = update.Name
}
if update.Email != "" {
updatedUser.Email = update.Email
}
var transferredOwnerRole bool
result, err := handleOwnerRoleTransfer(ctx, transaction, initiatorUser, update)

View File

@@ -336,6 +336,104 @@ func TestUser_GetAllPATs(t *testing.T) {
assert.Equal(t, 2, len(pats))
}
func TestUser_PAT_CrossAccountProtection(t *testing.T) {
const (
accountAID = "accountA"
accountBID = "accountB"
userAID = "userA"
adminBID = "adminB"
serviceUserBID = "serviceUserB"
regularUserBID = "regularUserB"
tokenBID = "tokenB1"
hashedTokenB = "SoMeHaShEdToKeNB"
)
setupStore := func(t *testing.T) (*DefaultAccountManager, func()) {
t.Helper()
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
require.NoError(t, err, "creating store")
accountA := newAccountWithId(context.Background(), accountAID, userAID, "", "", "", false)
require.NoError(t, s.SaveAccount(context.Background(), accountA))
accountB := newAccountWithId(context.Background(), accountBID, adminBID, "", "", "", false)
accountB.Users[serviceUserBID] = &types.User{
Id: serviceUserBID,
AccountID: accountBID,
IsServiceUser: true,
ServiceUserName: "svcB",
Role: types.UserRoleAdmin,
PATs: map[string]*types.PersonalAccessToken{
tokenBID: {
ID: tokenBID,
HashedToken: hashedTokenB,
},
},
}
accountB.Users[regularUserBID] = &types.User{
Id: regularUserBID,
AccountID: accountBID,
Role: types.UserRoleUser,
}
require.NoError(t, s.SaveAccount(context.Background(), accountB))
pm := permissions.NewManager(s)
am := &DefaultAccountManager{
Store: s,
eventStore: &activity.InMemoryEventStore{},
permissionsManager: pm,
}
return am, cleanup
}
t.Run("CreatePAT for user in different account is denied", func(t *testing.T) {
am, cleanup := setupStore(t)
t.Cleanup(cleanup)
_, err := am.CreatePAT(context.Background(), accountAID, userAID, serviceUserBID, "xss-token", 7)
require.Error(t, err, "cross-account CreatePAT must fail")
_, err = am.CreatePAT(context.Background(), accountAID, userAID, regularUserBID, "xss-token", 7)
require.Error(t, err, "cross-account CreatePAT for regular user must fail")
_, err = am.CreatePAT(context.Background(), accountBID, adminBID, serviceUserBID, "legit-token", 7)
require.NoError(t, err, "same-account CreatePAT should succeed")
})
t.Run("DeletePAT for user in different account is denied", func(t *testing.T) {
am, cleanup := setupStore(t)
t.Cleanup(cleanup)
err := am.DeletePAT(context.Background(), accountAID, userAID, serviceUserBID, tokenBID)
require.Error(t, err, "cross-account DeletePAT must fail")
})
t.Run("GetPAT for user in different account is denied", func(t *testing.T) {
am, cleanup := setupStore(t)
t.Cleanup(cleanup)
_, err := am.GetPAT(context.Background(), accountAID, userAID, serviceUserBID, tokenBID)
require.Error(t, err, "cross-account GetPAT must fail")
})
t.Run("GetAllPATs for user in different account is denied", func(t *testing.T) {
am, cleanup := setupStore(t)
t.Cleanup(cleanup)
_, err := am.GetAllPATs(context.Background(), accountAID, userAID, serviceUserBID)
require.Error(t, err, "cross-account GetAllPATs must fail")
})
t.Run("CreatePAT with forged accountID targeting foreign user is denied", func(t *testing.T) {
am, cleanup := setupStore(t)
t.Cleanup(cleanup)
_, err := am.CreatePAT(context.Background(), accountAID, userAID, adminBID, "forged", 7)
require.Error(t, err, "forged accountID CreatePAT must fail")
})
}
func TestUser_Copy(t *testing.T) {
// this is an imaginary case which will never be in DB this way
user := types.User{