mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 01:36:46 +00:00
Compare commits
19 Commits
feature/up
...
feature/mi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a88dd8a692 | ||
|
|
2b00a429d7 | ||
|
|
29a31001bd | ||
|
|
023d85f42a | ||
|
|
3434760526 | ||
|
|
e33e5673c5 | ||
|
|
71d98940dc | ||
|
|
5f7a6b839b | ||
|
|
1481dbcdd7 | ||
|
|
7956f676a4 | ||
|
|
ddcf9f820b | ||
|
|
475ce092c8 | ||
|
|
80c49c268f | ||
|
|
cdfe0f3d41 | ||
|
|
794976263e | ||
|
|
77ea4b7444 | ||
|
|
9fd34718a6 | ||
|
|
ea37d4b768 | ||
|
|
f7ee019f26 |
@@ -1402,9 +1402,6 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
|||||||
return fmt.Errorf("error saving groups: %w", err)
|
return fmt.Errorf("error saving groups: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
addNewGroups = util.Difference(updatedAutoGroups, user.AutoGroups)
|
|
||||||
removeOldGroups = util.Difference(user.AutoGroups, updatedAutoGroups)
|
|
||||||
|
|
||||||
user.AutoGroups = updatedAutoGroups
|
user.AutoGroups = updatedAutoGroups
|
||||||
if err = transaction.SaveUser(ctx, user); err != nil {
|
if err = transaction.SaveUser(ctx, user); err != nil {
|
||||||
return fmt.Errorf("error saving user: %w", err)
|
return fmt.Errorf("error saving user: %w", err)
|
||||||
|
|||||||
@@ -918,6 +918,7 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
|
func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
|
||||||
|
b.Setenv("NETBIRD_STORE_ENGINE", "postgres")
|
||||||
claims := auth.UserAuth{
|
claims := auth.UserAuth{
|
||||||
Domain: "example.com",
|
Domain: "example.com",
|
||||||
UserId: "pvt-domain-user",
|
UserId: "pvt-domain-user",
|
||||||
@@ -945,6 +946,18 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
|
|||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
a, err := am.Store.GetAccount(context.Background(), id)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
a.Groups = genGroups()
|
||||||
|
|
||||||
|
err = am.Store.SaveAccount(context.Background(), a)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
users := genUsers("priv", 100)
|
users := genUsers("priv", 100)
|
||||||
|
|
||||||
acc, err := am.Store.GetAccount(context.Background(), id)
|
acc, err := am.Store.GetAccount(context.Background(), id)
|
||||||
@@ -1005,6 +1018,41 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func genGroups() map[string]*types.Group {
|
||||||
|
return map[string]*types.Group{
|
||||||
|
"one": {
|
||||||
|
Name: "one",
|
||||||
|
},
|
||||||
|
"two": {
|
||||||
|
Name: "two",
|
||||||
|
},
|
||||||
|
"three": {
|
||||||
|
Name: "three",
|
||||||
|
},
|
||||||
|
"four": {
|
||||||
|
Name: "four",
|
||||||
|
},
|
||||||
|
"five": {
|
||||||
|
Name: "five",
|
||||||
|
},
|
||||||
|
"six": {
|
||||||
|
Name: "six",
|
||||||
|
},
|
||||||
|
"seven": {
|
||||||
|
Name: "seven",
|
||||||
|
},
|
||||||
|
"eight": {
|
||||||
|
Name: "eight",
|
||||||
|
},
|
||||||
|
"nine": {
|
||||||
|
Name: "nine",
|
||||||
|
},
|
||||||
|
"ten": {
|
||||||
|
Name: "ten",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func genUsers(p string, n int) map[string]*types.User {
|
func genUsers(p string, n int) map[string]*types.User {
|
||||||
users := map[string]*types.User{}
|
users := map[string]*types.User{}
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
@@ -1723,6 +1771,13 @@ func TestAccount_Copy(t *testing.T) {
|
|||||||
Id: "user1",
|
Id: "user1",
|
||||||
Role: types.UserRoleAdmin,
|
Role: types.UserRoleAdmin,
|
||||||
AutoGroups: []string{"group1"},
|
AutoGroups: []string{"group1"},
|
||||||
|
Groups: []*types.GroupUser{
|
||||||
|
{
|
||||||
|
AccountID: "account1",
|
||||||
|
UserID: "user1",
|
||||||
|
GroupID: "group1",
|
||||||
|
},
|
||||||
|
},
|
||||||
PATs: map[string]*types.PersonalAccessToken{
|
PATs: map[string]*types.PersonalAccessToken{
|
||||||
"pat1": {
|
"pat1": {
|
||||||
ID: "pat1",
|
ID: "pat1",
|
||||||
@@ -1742,6 +1797,13 @@ func TestAccount_Copy(t *testing.T) {
|
|||||||
Peers: []string{"peer1"},
|
Peers: []string{"peer1"},
|
||||||
Resources: []types.Resource{},
|
Resources: []types.Resource{},
|
||||||
GroupPeers: []types.GroupPeer{},
|
GroupPeers: []types.GroupPeer{},
|
||||||
|
GroupUsers: []types.GroupUser{
|
||||||
|
{
|
||||||
|
AccountID: "account1",
|
||||||
|
UserID: "user1",
|
||||||
|
GroupID: "group1",
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Policies: []*types.Policy{
|
Policies: []*types.Policy{
|
||||||
|
|||||||
@@ -380,13 +380,6 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
|
|||||||
AutoGroups: []string{groupForUsers.ID},
|
AutoGroups: []string{groupForUsers.ID},
|
||||||
}
|
}
|
||||||
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, "", "", false)
|
account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain, "", "", false)
|
||||||
account.Routes[routeResource.ID] = routeResource
|
|
||||||
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
|
|
||||||
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
|
|
||||||
account.Policies = append(account.Policies, policy)
|
|
||||||
account.SetupKeys[setupKey.Id] = setupKey
|
|
||||||
account.Users[user.Id] = user
|
|
||||||
|
|
||||||
err := am.Store.SaveAccount(context.Background(), account)
|
err := am.Store.SaveAccount(context.Background(), account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
@@ -400,6 +393,23 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t
|
|||||||
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
|
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
|
||||||
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
|
_ = am.CreateGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
|
||||||
|
|
||||||
|
account, err = am.Store.GetAccount(context.Background(), accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
account.Routes[routeResource.ID] = routeResource
|
||||||
|
account.Routes[routePeerGroupResource.ID] = routePeerGroupResource
|
||||||
|
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
|
||||||
|
account.Policies = append(account.Policies, policy)
|
||||||
|
account.SetupKeys[setupKey.Id] = setupKey
|
||||||
|
account.Users[user.Id] = user
|
||||||
|
|
||||||
|
err = am.Store.SaveAccount(context.Background(), account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
acc, err := am.Store.GetAccount(context.Background(), account.Id)
|
acc, err := am.Store.GetAccount(context.Background(), account.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
|||||||
@@ -539,3 +539,103 @@ func RemoveDuplicatePeerKeys(ctx context.Context, db *gorm.DB) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CleanupOrphanedIDs removes non-existent IDs from the JSON array column.
|
||||||
|
// T is the type of the model that contains the list.
|
||||||
|
// This migration cleans up the lists field by removing IDs that no longer exist in the target table.
|
||||||
|
func CleanupOrphanedIDs[T, S any](ctx context.Context, db *gorm.DB, columnName string) error {
|
||||||
|
var sourceModel T
|
||||||
|
var fkModel S
|
||||||
|
|
||||||
|
if !db.Migrator().HasTable(&sourceModel) {
|
||||||
|
log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", sourceModel)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !db.Migrator().HasTable(&fkModel) {
|
||||||
|
log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", fkModel)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt := &gorm.Statement{DB: db}
|
||||||
|
err := stmt.Parse(&sourceModel)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse model: %w", err)
|
||||||
|
}
|
||||||
|
tableName := stmt.Schema.Table
|
||||||
|
|
||||||
|
if !db.Migrator().HasColumn(&sourceModel, columnName) {
|
||||||
|
log.WithContext(ctx).Debugf("Column %s does not exist in table %s, no migration needed", columnName, tableName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
var rows []map[string]any
|
||||||
|
if err := tx.Table(tableName).Select("id", columnName).Find(&rows).Error; err != nil {
|
||||||
|
return fmt.Errorf("find rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all valid IDs from the fk table
|
||||||
|
var validIDs []string
|
||||||
|
if err := tx.Model(fkModel).Select("id").Pluck("id", &validIDs).Error; err != nil {
|
||||||
|
return fmt.Errorf("fetch valid group IDs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
validIDMap := make(map[string]bool, len(validIDs))
|
||||||
|
for _, id := range validIDs {
|
||||||
|
validIDMap[id] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedCount := 0
|
||||||
|
for _, row := range rows {
|
||||||
|
jsonValue, ok := row[columnName].(string)
|
||||||
|
if !ok || jsonValue == "" || jsonValue == "null" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var list []string
|
||||||
|
if err := json.Unmarshal([]byte(jsonValue), &list); err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("Failed to unmarshal %s for id %v: %v", columnName, row["id"], err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(list) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter out non-existent IDs
|
||||||
|
cleanedList := make([]string, 0, len(list))
|
||||||
|
for _, groupID := range list {
|
||||||
|
if validIDMap[groupID] {
|
||||||
|
cleanedList = append(cleanedList, groupID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only update if there were orphaned ids removed
|
||||||
|
if len(cleanedList) != len(list) {
|
||||||
|
cleanedJSON, err := json.Marshal(cleanedList)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal cleaned %s: %w", columnName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(columnName, cleanedJSON).Error; err != nil {
|
||||||
|
return fmt.Errorf("update row with id %v: %w", row["id"], err)
|
||||||
|
}
|
||||||
|
updatedCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if updatedCount > 0 {
|
||||||
|
log.WithContext(ctx).Infof("Cleaned up orphaned %s in %d rows from table %s", columnName, updatedCount, tableName)
|
||||||
|
} else {
|
||||||
|
log.WithContext(ctx).Debugf("No orphaned %s found in table %s", columnName, tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).Infof("Cleanup of orphaned %s from table %s completed", columnName, tableName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
|||||||
return nil, fmt.Errorf("migratePreAuto: %w", err)
|
return nil, fmt.Errorf("migratePreAuto: %w", err)
|
||||||
}
|
}
|
||||||
err = db.AutoMigrate(
|
err = db.AutoMigrate(
|
||||||
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
|
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, &types.GroupUser{},
|
||||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||||
@@ -177,7 +177,8 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
|
|||||||
generateAccountSQLTypes(account)
|
generateAccountSQLTypes(account)
|
||||||
|
|
||||||
// Encrypt sensitive user data before saving
|
// Encrypt sensitive user data before saving
|
||||||
for i := range account.UsersG {
|
for i, user := range account.UsersG {
|
||||||
|
user.StoreAutoGroups()
|
||||||
if err := account.UsersG[i].EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
if err := account.UsersG[i].EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||||
return fmt.Errorf("encrypt user: %w", err)
|
return fmt.Errorf("encrypt user: %w", err)
|
||||||
}
|
}
|
||||||
@@ -203,15 +204,35 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
|
|||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Save account without UsersG.Groups to avoid FK constraint violations
|
||||||
|
// (groups must exist before group_users can reference them)
|
||||||
result = tx.
|
result = tx.
|
||||||
Session(&gorm.Session{FullSaveAssociations: true}).
|
Session(&gorm.Session{FullSaveAssociations: true}).
|
||||||
|
Omit("UsersG.Groups").
|
||||||
Clauses(clause.OnConflict{UpdateAll: true}).
|
Clauses(clause.OnConflict{UpdateAll: true}).
|
||||||
Create(account)
|
Create(account)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return result.Error
|
return result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Now save the user-group associations after both users and groups exist
|
||||||
|
for _, user := range account.UsersG {
|
||||||
|
if len(user.Groups) > 0 {
|
||||||
|
result = tx.Clauses(clause.OnConflict{
|
||||||
|
Columns: []clause.Column{{Name: "group_id"}, {Name: "user_id"}},
|
||||||
|
UpdateAll: true,
|
||||||
|
}).Create(&user.Groups)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
took := time.Since(start)
|
took := time.Since(start)
|
||||||
if s.metrics != nil {
|
if s.metrics != nil {
|
||||||
@@ -219,7 +240,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
|
|||||||
}
|
}
|
||||||
log.WithContext(ctx).Debugf("took %d ms to persist an account to the store", took.Milliseconds())
|
log.WithContext(ctx).Debugf("took %d ms to persist an account to the store", took.Milliseconds())
|
||||||
|
|
||||||
return err
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateAccountSQLTypes generates the GORM compatible types for the account
|
// generateAccountSQLTypes generates the GORM compatible types for the account
|
||||||
@@ -243,7 +264,7 @@ func generateAccountSQLTypes(account *types.Account) {
|
|||||||
pat.ID = id
|
pat.ID = id
|
||||||
user.PATsG = append(user.PATsG, *pat)
|
user.PATsG = append(user.PATsG, *pat)
|
||||||
}
|
}
|
||||||
account.UsersG = append(account.UsersG, *user)
|
account.UsersG = append(account.UsersG, user)
|
||||||
}
|
}
|
||||||
|
|
||||||
for id, group := range account.Groups {
|
for id, group := range account.Groups {
|
||||||
@@ -453,6 +474,7 @@ func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
|
|||||||
userCopy := user.Copy()
|
userCopy := user.Copy()
|
||||||
userCopy.Email = user.Email
|
userCopy.Email = user.Email
|
||||||
userCopy.Name = user.Name
|
userCopy.Name = user.Name
|
||||||
|
userCopy.StoreAutoGroups()
|
||||||
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||||
return fmt.Errorf("encrypt user: %w", err)
|
return fmt.Errorf("encrypt user: %w", err)
|
||||||
}
|
}
|
||||||
@@ -472,16 +494,37 @@ func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error {
|
|||||||
userCopy := user.Copy()
|
userCopy := user.Copy()
|
||||||
userCopy.Email = user.Email
|
userCopy.Email = user.Email
|
||||||
userCopy.Name = user.Name
|
userCopy.Name = user.Name
|
||||||
|
userCopy.StoreAutoGroups()
|
||||||
|
|
||||||
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||||
return fmt.Errorf("encrypt user: %w", err)
|
return fmt.Errorf("encrypt user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
result := s.db.Save(userCopy)
|
err := s.transaction(func(tx *gorm.DB) error {
|
||||||
if result.Error != nil {
|
result := tx.Omit("Groups").Save(userCopy)
|
||||||
log.WithContext(ctx).Errorf("failed to save user to store: %s", result.Error)
|
if result.Error != nil {
|
||||||
return status.Errorf(status.Internal, "failed to save user to store")
|
return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = tx.Delete(&types.GroupUser{}, "user_id = ?", user.Id)
|
||||||
|
if result.Error != nil {
|
||||||
|
return status.Errorf(status.Internal, "failed to delete user groups from store: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(userCopy.Groups) != 0 {
|
||||||
|
result = tx.Save(userCopy.Groups)
|
||||||
|
if result.Error != nil {
|
||||||
|
return status.Errorf(status.Internal, "failed to save user groups to store: %v", result.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to save user to store: %s", err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -617,6 +660,7 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
|
|||||||
|
|
||||||
var user types.User
|
var user types.User
|
||||||
result := tx.
|
result := tx.
|
||||||
|
Preload("Groups").
|
||||||
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
|
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
|
||||||
Where("personal_access_tokens.id = ?", patID).Take(&user)
|
Where("personal_access_tokens.id = ?", patID).Take(&user)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
@@ -631,6 +675,8 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
|
|||||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user.LoadAutoGroups()
|
||||||
|
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -641,7 +687,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
|||||||
}
|
}
|
||||||
|
|
||||||
var user types.User
|
var user types.User
|
||||||
result := tx.Take(&user, idQueryCondition, userID)
|
result := tx.Preload("Groups").Take(&user, idQueryCondition, userID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.NewUserNotFoundError(userID)
|
return nil, status.NewUserNotFoundError(userID)
|
||||||
@@ -653,6 +699,8 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
|||||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user.LoadAutoGroups()
|
||||||
|
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -680,7 +728,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
|
|||||||
}
|
}
|
||||||
|
|
||||||
var users []*types.User
|
var users []*types.User
|
||||||
result := tx.Find(&users, accountIDCondition, accountID)
|
result := tx.Preload("Groups").Find(&users, accountIDCondition, accountID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||||
@@ -693,6 +741,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
|
|||||||
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||||
}
|
}
|
||||||
|
user.LoadAutoGroups()
|
||||||
}
|
}
|
||||||
|
|
||||||
return users, nil
|
return users, nil
|
||||||
@@ -705,7 +754,7 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
|
|||||||
}
|
}
|
||||||
|
|
||||||
var user types.User
|
var user types.User
|
||||||
result := tx.Take(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
|
result := tx.Preload("Groups").Take(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "account owner not found: index lookup failed")
|
return nil, status.Errorf(status.NotFound, "account owner not found: index lookup failed")
|
||||||
@@ -717,6 +766,8 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
|
|||||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user.LoadAutoGroups()
|
||||||
|
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -867,7 +918,10 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
|
|||||||
Preload("SetupKeysG").
|
Preload("SetupKeysG").
|
||||||
Preload("PeersG").
|
Preload("PeersG").
|
||||||
Preload("UsersG").
|
Preload("UsersG").
|
||||||
|
Preload("UsersG.Groups").
|
||||||
|
Preload("GroupsG").
|
||||||
Preload("GroupsG.GroupPeers").
|
Preload("GroupsG.GroupPeers").
|
||||||
|
Preload("GroupsG.GroupUsers").
|
||||||
Preload("RoutesG").
|
Preload("RoutesG").
|
||||||
Preload("NameServerGroupsG").
|
Preload("NameServerGroupsG").
|
||||||
Preload("PostureChecks").
|
Preload("PostureChecks").
|
||||||
@@ -908,13 +962,14 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
|
|||||||
pat.UserID = ""
|
pat.UserID = ""
|
||||||
user.PATs[pat.ID] = &pat
|
user.PATs[pat.ID] = &pat
|
||||||
}
|
}
|
||||||
if user.AutoGroups == nil {
|
if user.Groups == nil {
|
||||||
user.AutoGroups = []string{}
|
user.Groups = []*types.GroupUser{}
|
||||||
}
|
}
|
||||||
|
user.LoadAutoGroups()
|
||||||
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||||
}
|
}
|
||||||
account.Users[user.Id] = &user
|
account.Users[user.Id] = user
|
||||||
user.PATsG = nil
|
user.PATsG = nil
|
||||||
}
|
}
|
||||||
account.UsersG = nil
|
account.UsersG = nil
|
||||||
@@ -1116,8 +1171,8 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
|||||||
groupIDs = append(groupIDs, g.ID)
|
groupIDs = append(groupIDs, g.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Add(3)
|
wg.Add(4)
|
||||||
errChan = make(chan error, 3)
|
errChan = make(chan error, 4)
|
||||||
|
|
||||||
var pats []types.PersonalAccessToken
|
var pats []types.PersonalAccessToken
|
||||||
go func() {
|
go func() {
|
||||||
@@ -1149,6 +1204,16 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
var groupUsers []types.GroupUser
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
var err error
|
||||||
|
groupUsers, err = s.getGroupUsers(ctx, userIDs)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
close(errChan)
|
close(errChan)
|
||||||
for e := range errChan {
|
for e := range errChan {
|
||||||
@@ -1174,6 +1239,12 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
|||||||
peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID)
|
peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
groupsByUserID := make(map[string][]*types.GroupUser)
|
||||||
|
for i := range groupUsers {
|
||||||
|
gu := &groupUsers[i]
|
||||||
|
groupsByUserID[gu.UserID] = append(groupsByUserID[gu.UserID], gu)
|
||||||
|
}
|
||||||
|
|
||||||
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
|
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
|
||||||
for i := range account.SetupKeysG {
|
for i := range account.SetupKeysG {
|
||||||
key := &account.SetupKeysG[i]
|
key := &account.SetupKeysG[i]
|
||||||
@@ -1188,7 +1259,7 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
|||||||
|
|
||||||
account.Users = make(map[string]*types.User, len(account.UsersG))
|
account.Users = make(map[string]*types.User, len(account.UsersG))
|
||||||
for i := range account.UsersG {
|
for i := range account.UsersG {
|
||||||
user := &account.UsersG[i]
|
user := account.UsersG[i]
|
||||||
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1199,6 +1270,8 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
|||||||
user.PATs[pat.ID] = pat
|
user.PATs[pat.ID] = pat
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
user.Groups = groupsByUserID[user.Id]
|
||||||
|
user.LoadAutoGroups()
|
||||||
account.Users[user.Id] = user
|
account.Users[user.Id] = user
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1595,44 +1668,41 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
|
|||||||
return peers, nil
|
return peers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) {
|
func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]*types.User, error) {
|
||||||
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1`
|
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1`
|
||||||
rows, err := s.pool.Query(ctx, query, accountID)
|
rows, err := s.pool.Query(ctx, query, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) {
|
users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.User, error) {
|
||||||
var u types.User
|
var u types.User
|
||||||
var autoGroups []byte
|
|
||||||
var lastLogin, createdAt sql.NullTime
|
var lastLogin, createdAt sql.NullTime
|
||||||
var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool
|
var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool
|
||||||
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name)
|
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
if lastLogin.Valid {
|
return &u, err
|
||||||
u.LastLogin = &lastLogin.Time
|
|
||||||
}
|
|
||||||
if createdAt.Valid {
|
|
||||||
u.CreatedAt = createdAt.Time
|
|
||||||
}
|
|
||||||
if isServiceUser.Valid {
|
|
||||||
u.IsServiceUser = isServiceUser.Bool
|
|
||||||
}
|
|
||||||
if nonDeletable.Valid {
|
|
||||||
u.NonDeletable = nonDeletable.Bool
|
|
||||||
}
|
|
||||||
if blocked.Valid {
|
|
||||||
u.Blocked = blocked.Bool
|
|
||||||
}
|
|
||||||
if pendingApproval.Valid {
|
|
||||||
u.PendingApproval = pendingApproval.Bool
|
|
||||||
}
|
|
||||||
if autoGroups != nil {
|
|
||||||
_ = json.Unmarshal(autoGroups, &u.AutoGroups)
|
|
||||||
} else {
|
|
||||||
u.AutoGroups = []string{}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return u, err
|
|
||||||
|
if lastLogin.Valid {
|
||||||
|
u.LastLogin = &lastLogin.Time
|
||||||
|
}
|
||||||
|
if createdAt.Valid {
|
||||||
|
u.CreatedAt = createdAt.Time
|
||||||
|
}
|
||||||
|
if isServiceUser.Valid {
|
||||||
|
u.IsServiceUser = isServiceUser.Bool
|
||||||
|
}
|
||||||
|
if nonDeletable.Valid {
|
||||||
|
u.NonDeletable = nonDeletable.Bool
|
||||||
|
}
|
||||||
|
if blocked.Valid {
|
||||||
|
u.Blocked = blocked.Bool
|
||||||
|
}
|
||||||
|
if pendingApproval.Valid {
|
||||||
|
u.PendingApproval = pendingApproval.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
return &u, nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -2038,6 +2108,22 @@ func (s *SqlStore) getGroupPeers(ctx context.Context, groupIDs []string) ([]type
|
|||||||
return groupPeers, nil
|
return groupPeers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) getGroupUsers(ctx context.Context, userIDs []string) ([]types.GroupUser, error) {
|
||||||
|
if len(userIDs) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
const query = `SELECT account_id, group_id, user_id FROM group_users WHERE user_id = ANY($1)`
|
||||||
|
rows, err := s.pool.Query(ctx, query, userIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
groupUsers, err := pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupUser])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return groupUsers, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) {
|
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) {
|
||||||
var user types.User
|
var user types.User
|
||||||
result := s.db.Select("account_id").Take(&user, idQueryCondition, userID)
|
result := s.db.Select("account_id").Take(&user, idQueryCondition, userID)
|
||||||
@@ -2659,6 +2745,41 @@ func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, group
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) AddUserToGroup(ctx context.Context, accountID, userID, groupID string) error {
|
||||||
|
user := &types.GroupUser{
|
||||||
|
AccountID: accountID,
|
||||||
|
GroupID: groupID,
|
||||||
|
UserID: userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := s.db.Clauses(clause.OnConflict{
|
||||||
|
Columns: []clause.Column{{Name: "group_id"}, {Name: "user_id"}},
|
||||||
|
DoNothing: true,
|
||||||
|
}).Create(user).Error
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to add user %s to group %s for account %s: %v", userID, groupID, accountID, err)
|
||||||
|
return status.Errorf(status.Internal, "failed to add user to group")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) RemoveUserFromGroup(ctx context.Context, userID, groupID string) error {
|
||||||
|
result := s.db.Delete(&types.GroupUser{}, "group_id = ? AND user_id = ?", groupID, userID)
|
||||||
|
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to remove user %s from group %s: %v", userID, groupID, result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to remove user from group")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
log.WithContext(ctx).Warnf("user %s was not in group %s", userID, groupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// RemovePeerFromAllGroups removes a peer from all groups
|
// RemovePeerFromAllGroups removes a peer from all groups
|
||||||
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
|
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
|
||||||
err := s.db.
|
err := s.db.
|
||||||
|
|||||||
@@ -1372,6 +1372,7 @@ func TestSqlStore_CreateGroup(t *testing.T) {
|
|||||||
Peers: []string{},
|
Peers: []string{},
|
||||||
Resources: []types.Resource{},
|
Resources: []types.Resource{},
|
||||||
GroupPeers: []types.GroupPeer{},
|
GroupPeers: []types.GroupPeer{},
|
||||||
|
GroupUsers: []types.GroupUser{},
|
||||||
}
|
}
|
||||||
err = store.CreateGroup(context.Background(), group)
|
err = store.CreateGroup(context.Background(), group)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1396,6 +1397,7 @@ func TestSqlStore_CreateUpdateGroups(t *testing.T) {
|
|||||||
Peers: []string{},
|
Peers: []string{},
|
||||||
Resources: []types.Resource{},
|
Resources: []types.Resource{},
|
||||||
GroupPeers: []types.GroupPeer{},
|
GroupPeers: []types.GroupPeer{},
|
||||||
|
GroupUsers: []types.GroupUser{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "group-2",
|
ID: "group-2",
|
||||||
@@ -1404,6 +1406,7 @@ func TestSqlStore_CreateUpdateGroups(t *testing.T) {
|
|||||||
Peers: []string{},
|
Peers: []string{},
|
||||||
Resources: []types.Resource{},
|
Resources: []types.Resource{},
|
||||||
GroupPeers: []types.GroupPeer{},
|
GroupPeers: []types.GroupPeer{},
|
||||||
|
GroupUsers: []types.GroupUser{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err = store.CreateGroups(context.Background(), accountID, groups)
|
err = store.CreateGroups(context.Background(), accountID, groups)
|
||||||
@@ -3059,7 +3062,7 @@ func TestSqlStore_SaveUser(t *testing.T) {
|
|||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
Role: types.UserRoleAdmin,
|
Role: types.UserRoleAdmin,
|
||||||
IsServiceUser: false,
|
IsServiceUser: false,
|
||||||
AutoGroups: []string{"groupA", "groupB"},
|
AutoGroups: []string{"cfefqs706sqkneg59g2g", "cfefqs706sqkneg59g3g"},
|
||||||
Blocked: false,
|
Blocked: false,
|
||||||
LastLogin: util.ToPtr(time.Now().UTC()),
|
LastLogin: util.ToPtr(time.Now().UTC()),
|
||||||
CreatedAt: time.Now().UTC().Add(-time.Hour),
|
CreatedAt: time.Now().UTC().Add(-time.Hour),
|
||||||
@@ -3097,13 +3100,13 @@ func TestSqlStore_SaveUsers(t *testing.T) {
|
|||||||
Id: "user-1",
|
Id: "user-1",
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
Issued: "api",
|
Issued: "api",
|
||||||
AutoGroups: []string{"groupA", "groupB"},
|
AutoGroups: []string{"cfefqs706sqkneg59g2g", "cfefqs706sqkneg59g3g"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: "user-2",
|
Id: "user-2",
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
Issued: "integration",
|
Issued: "integration",
|
||||||
AutoGroups: []string{"groupA"},
|
AutoGroups: []string{"cfefqs706sqkneg59g2g"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err = store.SaveUsers(context.Background(), users)
|
err = store.SaveUsers(context.Background(), users)
|
||||||
@@ -3113,7 +3116,7 @@ func TestSqlStore_SaveUsers(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, accountUsers, 4)
|
require.Len(t, accountUsers, 4)
|
||||||
|
|
||||||
users[1].AutoGroups = []string{"groupA", "groupC"}
|
users[1].AutoGroups = []string{"cfefqs706sqkneg59g2g", "cfefqs706sqkneg59g4g"}
|
||||||
err = store.SaveUsers(context.Background(), users)
|
err = store.SaveUsers(context.Background(), users)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -3151,7 +3154,7 @@ func TestSqlStore_SaveUserWithEncryption(t *testing.T) {
|
|||||||
Role: types.UserRoleUser,
|
Role: types.UserRoleUser,
|
||||||
Email: "",
|
Email: "",
|
||||||
Name: "",
|
Name: "",
|
||||||
AutoGroups: []string{"groupA"},
|
AutoGroups: []string{"cfefqs706sqkneg59g2g"},
|
||||||
}
|
}
|
||||||
err = store.SaveUser(context.Background(), user)
|
err = store.SaveUser(context.Background(), user)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -3180,7 +3183,7 @@ func TestSqlStore_SaveUserWithEncryption(t *testing.T) {
|
|||||||
Role: types.UserRoleAdmin,
|
Role: types.UserRoleAdmin,
|
||||||
Email: "test@example.com",
|
Email: "test@example.com",
|
||||||
Name: "Test User",
|
Name: "Test User",
|
||||||
AutoGroups: []string{"groupB"},
|
AutoGroups: []string{"cfefqs706sqkneg59g3g"},
|
||||||
}
|
}
|
||||||
err = store.SaveUser(context.Background(), user)
|
err = store.SaveUser(context.Background(), user)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ func (s *SqlStore) GetAccountSlow(ctx context.Context, accountID string) (*types
|
|||||||
for _, pat := range user.PATsG {
|
for _, pat := range user.PATsG {
|
||||||
user.PATs[pat.ID] = pat.Copy()
|
user.PATs[pat.ID] = pat.Copy()
|
||||||
}
|
}
|
||||||
|
user.LoadAutoGroups()
|
||||||
account.Users[user.Id] = user.Copy()
|
account.Users[user.Id] = user.Copy()
|
||||||
}
|
}
|
||||||
account.UsersG = nil
|
account.UsersG = nil
|
||||||
@@ -89,6 +90,9 @@ func (s *SqlStore) GetAccountSlow(ctx context.Context, accountID string) (*types
|
|||||||
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
|
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
|
||||||
for _, group := range account.GroupsG {
|
for _, group := range account.GroupsG {
|
||||||
account.Groups[group.ID] = group.Copy()
|
account.Groups[group.ID] = group.Copy()
|
||||||
|
if len(group.GroupUsers) == 0 {
|
||||||
|
account.Groups[group.ID] = nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
account.GroupsG = nil
|
account.GroupsG = nil
|
||||||
|
|
||||||
@@ -175,10 +179,12 @@ func (s *SqlStore) GetAccountGormOpt(ctx context.Context, accountID string) (*ty
|
|||||||
pat.UserID = ""
|
pat.UserID = ""
|
||||||
user.PATs[pat.ID] = &pat
|
user.PATs[pat.ID] = &pat
|
||||||
}
|
}
|
||||||
if user.AutoGroups == nil {
|
user.LoadAutoGroups()
|
||||||
|
if len(user.AutoGroups) == 0 {
|
||||||
user.AutoGroups = []string{}
|
user.AutoGroups = []string{}
|
||||||
|
user.Groups = []*types.GroupUser{}
|
||||||
}
|
}
|
||||||
account.Users[user.Id] = &user
|
account.Users[user.Id] = user
|
||||||
user.PATsG = nil
|
user.PATsG = nil
|
||||||
}
|
}
|
||||||
account.UsersG = nil
|
account.UsersG = nil
|
||||||
@@ -191,6 +197,9 @@ func (s *SqlStore) GetAccountGormOpt(ctx context.Context, accountID string) (*ty
|
|||||||
if group.Resources == nil {
|
if group.Resources == nil {
|
||||||
group.Resources = []types.Resource{}
|
group.Resources = []types.Resource{}
|
||||||
}
|
}
|
||||||
|
if group.GroupUsers == nil {
|
||||||
|
group.GroupUsers = []types.GroupUser{}
|
||||||
|
}
|
||||||
account.Groups[group.ID] = group
|
account.Groups[group.ID] = group
|
||||||
}
|
}
|
||||||
account.GroupsG = nil
|
account.GroupsG = nil
|
||||||
@@ -259,7 +268,7 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) {
|
|||||||
|
|
||||||
models := []interface{}{
|
models := []interface{}{
|
||||||
&types.Account{}, &types.SetupKey{}, &nbpeer.Peer{}, &types.User{},
|
&types.Account{}, &types.SetupKey{}, &nbpeer.Peer{}, &types.User{},
|
||||||
&types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
|
&types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, &types.GroupUser{},
|
||||||
&types.Policy{}, &types.PolicyRule{}, &route.Route{},
|
&types.Policy{}, &types.PolicyRule{}, &route.Route{},
|
||||||
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
|
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
|
||||||
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
|
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
|
||||||
@@ -609,10 +618,12 @@ func testAccountEquivalence(t *testing.T, expected, actual *types.Account) {
|
|||||||
assert.Len(t, actual.Groups, len(expected.Groups), "Groups maps should have the same number of elements")
|
assert.Len(t, actual.Groups, len(expected.Groups), "Groups maps should have the same number of elements")
|
||||||
for key, oldVal := range expected.Groups {
|
for key, oldVal := range expected.Groups {
|
||||||
newVal, ok := actual.Groups[key]
|
newVal, ok := actual.Groups[key]
|
||||||
|
if oldVal != nil && newVal != nil {
|
||||||
|
sort.Strings(oldVal.Peers)
|
||||||
|
sort.Strings(newVal.Peers)
|
||||||
|
assert.Equal(t, *oldVal, *newVal, "Group with ID '%s' should be equal", key)
|
||||||
|
}
|
||||||
assert.True(t, ok, "Group with ID '%s' should exist in new account", key)
|
assert.True(t, ok, "Group with ID '%s' should exist in new account", key)
|
||||||
sort.Strings(oldVal.Peers)
|
|
||||||
sort.Strings(newVal.Peers)
|
|
||||||
assert.Equal(t, *oldVal, *newVal, "Group with ID '%s' should be equal", key)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Len(t, actual.Routes, len(expected.Routes), "Routes maps should have the same number of elements")
|
assert.Len(t, actual.Routes, len(expected.Routes), "Routes maps should have the same number of elements")
|
||||||
@@ -900,7 +911,7 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty
|
|||||||
|
|
||||||
account.Users = make(map[string]*types.User, len(account.UsersG))
|
account.Users = make(map[string]*types.User, len(account.UsersG))
|
||||||
for i := range account.UsersG {
|
for i := range account.UsersG {
|
||||||
user := &account.UsersG[i]
|
user := account.UsersG[i]
|
||||||
user.PATs = make(map[string]*types.PersonalAccessToken)
|
user.PATs = make(map[string]*types.PersonalAccessToken)
|
||||||
if userPats, ok := patsByUserID[user.Id]; ok {
|
if userPats, ok := patsByUserID[user.Id]; ok {
|
||||||
for j := range userPats {
|
for j := range userPats {
|
||||||
|
|||||||
@@ -89,6 +89,8 @@ type Store interface {
|
|||||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||||
|
AddUserToGroup(ctx context.Context, accountID, userID, groupID string) error
|
||||||
|
RemoveUserFromGroup(ctx context.Context, userID, groupID string) error
|
||||||
|
|
||||||
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
|
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
|
||||||
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
|
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
|
||||||
@@ -353,6 +355,9 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc {
|
|||||||
func(db *gorm.DB) error {
|
func(db *gorm.DB) error {
|
||||||
return migration.RemoveDuplicatePeerKeys(ctx, db)
|
return migration.RemoveDuplicatePeerKeys(ctx, db)
|
||||||
},
|
},
|
||||||
|
func(db *gorm.DB) error {
|
||||||
|
return migration.CleanupOrphanedIDs[types.User, types.Group](ctx, db, "auto_groups")
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -392,6 +397,15 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
|
|||||||
func(db *gorm.DB) error {
|
func(db *gorm.DB) error {
|
||||||
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_peers_key_unique", "key")
|
return migration.CreateIndexIfNotExists[nbpeer.Peer](ctx, db, "idx_peers_key_unique", "key")
|
||||||
},
|
},
|
||||||
|
func(db *gorm.DB) error {
|
||||||
|
return migration.MigrateJsonToTable[types.User](ctx, db, "auto_groups", func(accountID, id, value string) any {
|
||||||
|
return &types.GroupUser{
|
||||||
|
AccountID: accountID,
|
||||||
|
GroupID: value,
|
||||||
|
UserID: id,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ type Account struct {
|
|||||||
Peers map[string]*nbpeer.Peer `gorm:"-"`
|
Peers map[string]*nbpeer.Peer `gorm:"-"`
|
||||||
PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||||
Users map[string]*User `gorm:"-"`
|
Users map[string]*User `gorm:"-"`
|
||||||
UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
UsersG []*User `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||||
Groups map[string]*Group `gorm:"-"`
|
Groups map[string]*Group `gorm:"-"`
|
||||||
GroupsG []*Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
GroupsG []*Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||||
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`
|
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ type Group struct {
|
|||||||
// Peers list of the group
|
// Peers list of the group
|
||||||
Peers []string `gorm:"-"` // Peers and GroupPeers list will be ignored when writing to the DB. Use AddPeerToGroup and RemovePeerFromGroup methods to modify group membership
|
Peers []string `gorm:"-"` // Peers and GroupPeers list will be ignored when writing to the DB. Use AddPeerToGroup and RemovePeerFromGroup methods to modify group membership
|
||||||
GroupPeers []GroupPeer `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"`
|
GroupPeers []GroupPeer `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"`
|
||||||
|
GroupUsers []GroupUser `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"`
|
||||||
|
|
||||||
// Resources contains a list of resources in that group
|
// Resources contains a list of resources in that group
|
||||||
Resources []Resource `gorm:"serializer:json"`
|
Resources []Resource `gorm:"serializer:json"`
|
||||||
@@ -41,6 +42,20 @@ type GroupPeer struct {
|
|||||||
PeerID string `gorm:"primaryKey"`
|
PeerID string `gorm:"primaryKey"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type GroupUser struct {
|
||||||
|
AccountID string `gorm:"index"`
|
||||||
|
GroupID string `gorm:"primaryKey"`
|
||||||
|
UserID string `gorm:"primaryKey"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GroupUser) Copy() *GroupUser {
|
||||||
|
return &GroupUser{
|
||||||
|
AccountID: g.AccountID,
|
||||||
|
GroupID: g.GroupID,
|
||||||
|
UserID: g.UserID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (g *Group) LoadGroupPeers() {
|
func (g *Group) LoadGroupPeers() {
|
||||||
g.Peers = make([]string, len(g.GroupPeers))
|
g.Peers = make([]string, len(g.GroupPeers))
|
||||||
for i, peer := range g.GroupPeers {
|
for i, peer := range g.GroupPeers {
|
||||||
@@ -78,11 +93,13 @@ func (g *Group) Copy() *Group {
|
|||||||
Issued: g.Issued,
|
Issued: g.Issued,
|
||||||
Peers: make([]string, len(g.Peers)),
|
Peers: make([]string, len(g.Peers)),
|
||||||
GroupPeers: make([]GroupPeer, len(g.GroupPeers)),
|
GroupPeers: make([]GroupPeer, len(g.GroupPeers)),
|
||||||
|
GroupUsers: make([]GroupUser, len(g.GroupUsers)),
|
||||||
Resources: make([]Resource, len(g.Resources)),
|
Resources: make([]Resource, len(g.Resources)),
|
||||||
IntegrationReference: g.IntegrationReference,
|
IntegrationReference: g.IntegrationReference,
|
||||||
}
|
}
|
||||||
copy(group.Peers, g.Peers)
|
copy(group.Peers, g.Peers)
|
||||||
copy(group.GroupPeers, g.GroupPeers)
|
copy(group.GroupPeers, g.GroupPeers)
|
||||||
|
copy(group.GroupUsers, g.GroupUsers)
|
||||||
copy(group.Resources, g.Resources)
|
copy(group.Resources, g.Resources)
|
||||||
return group
|
return group
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -85,9 +85,11 @@ type User struct {
|
|||||||
// ServiceUserName is only set if IsServiceUser is true
|
// ServiceUserName is only set if IsServiceUser is true
|
||||||
ServiceUserName string
|
ServiceUserName string
|
||||||
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
|
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
|
||||||
AutoGroups []string `gorm:"serializer:json"`
|
AutoGroups []string `gorm:"-"`
|
||||||
PATs map[string]*PersonalAccessToken `gorm:"-"`
|
// GroupUsers replaces old AutoGroups
|
||||||
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
|
Groups []*GroupUser `gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
|
||||||
|
PATs map[string]*PersonalAccessToken `gorm:"-"`
|
||||||
|
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
|
||||||
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
|
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
|
||||||
Blocked bool
|
Blocked bool
|
||||||
// PendingApproval indicates whether the user requires approval before being activated
|
// PendingApproval indicates whether the user requires approval before being activated
|
||||||
@@ -106,6 +108,26 @@ type User struct {
|
|||||||
Email string `gorm:"default:''"`
|
Email string `gorm:"default:''"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *User) LoadAutoGroups() {
|
||||||
|
u.AutoGroups = make([]string, 0, len(u.Groups))
|
||||||
|
for _, group := range u.Groups {
|
||||||
|
u.AutoGroups = append(u.AutoGroups, group.GroupID)
|
||||||
|
}
|
||||||
|
u.Groups = []*GroupUser{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) StoreAutoGroups() {
|
||||||
|
u.Groups = make([]*GroupUser, 0, len(u.AutoGroups))
|
||||||
|
for _, groupID := range u.AutoGroups {
|
||||||
|
u.Groups = append(u.Groups, &GroupUser{
|
||||||
|
AccountID: u.AccountID,
|
||||||
|
GroupID: groupID,
|
||||||
|
UserID: u.Id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
u.AutoGroups = []string{}
|
||||||
|
}
|
||||||
|
|
||||||
// IsBlocked returns true if the user is blocked, false otherwise
|
// IsBlocked returns true if the user is blocked, false otherwise
|
||||||
func (u *User) IsBlocked() bool {
|
func (u *User) IsBlocked() bool {
|
||||||
return u.Blocked
|
return u.Blocked
|
||||||
@@ -198,8 +220,20 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
|||||||
|
|
||||||
// Copy the user
|
// Copy the user
|
||||||
func (u *User) Copy() *User {
|
func (u *User) Copy() *User {
|
||||||
autoGroups := make([]string, len(u.AutoGroups))
|
var groupUsers []*GroupUser
|
||||||
copy(autoGroups, u.AutoGroups)
|
if u.Groups != nil {
|
||||||
|
groupUsers = make([]*GroupUser, len(u.Groups))
|
||||||
|
for i, groupUser := range u.Groups {
|
||||||
|
groupUsers[i] = groupUser.Copy()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var autoGroups []string
|
||||||
|
if u.AutoGroups != nil {
|
||||||
|
autoGroups = make([]string, len(u.AutoGroups))
|
||||||
|
copy(autoGroups, u.AutoGroups)
|
||||||
|
}
|
||||||
|
|
||||||
pats := make(map[string]*PersonalAccessToken, len(u.PATs))
|
pats := make(map[string]*PersonalAccessToken, len(u.PATs))
|
||||||
for k, v := range u.PATs {
|
for k, v := range u.PATs {
|
||||||
pats[k] = v.Copy()
|
pats[k] = v.Copy()
|
||||||
@@ -221,6 +255,7 @@ func (u *User) Copy() *User {
|
|||||||
IntegrationReference: u.IntegrationReference,
|
IntegrationReference: u.IntegrationReference,
|
||||||
Email: u.Email,
|
Email: u.Email,
|
||||||
Name: u.Name,
|
Name: u.Name,
|
||||||
|
Groups: groupUsers,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -45,7 +45,21 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI
|
|||||||
newUser.AccountID = accountID
|
newUser.AccountID = accountID
|
||||||
log.WithContext(ctx).Debugf("New User: %v", newUser)
|
log.WithContext(ctx).Debugf("New User: %v", newUser)
|
||||||
|
|
||||||
if err = am.Store.SaveUser(ctx, newUser); err != nil {
|
if err = am.Store.ExecuteInTransaction(ctx, func(tx store.Store) error {
|
||||||
|
err = tx.SaveUser(ctx, newUser)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, groupID := range autoGroups {
|
||||||
|
err = tx.AddUserToGroup(ctx, accountID, newUserID, groupID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add user to group %s: %w", groupID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,7 +133,6 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
|||||||
Id: idpUser.ID,
|
Id: idpUser.ID,
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
Role: types.StrRoleToUserRole(invite.Role),
|
Role: types.StrRoleToUserRole(invite.Role),
|
||||||
AutoGroups: invite.AutoGroups,
|
|
||||||
Issued: invite.Issued,
|
Issued: invite.Issued,
|
||||||
IntegrationReference: invite.IntegrationReference,
|
IntegrationReference: invite.IntegrationReference,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
@@ -127,6 +140,23 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
|||||||
Name: invite.Name,
|
Name: invite.Name,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(tx store.Store) error {
|
||||||
|
err = tx.SaveUser(ctx, newUser)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, group := range invite.AutoGroups {
|
||||||
|
err = tx.AddUserToGroup(ctx, accountID, userID, group)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add user to group %s: %w", group, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to save user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err = am.Store.SaveUser(ctx, newUser); err != nil {
|
if err = am.Store.SaveUser(ctx, newUser); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -715,6 +745,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
|||||||
updatedUser.Role = update.Role
|
updatedUser.Role = update.Role
|
||||||
updatedUser.Blocked = update.Blocked
|
updatedUser.Blocked = update.Blocked
|
||||||
updatedUser.AutoGroups = update.AutoGroups
|
updatedUser.AutoGroups = update.AutoGroups
|
||||||
|
updatedUser.StoreAutoGroups()
|
||||||
// these two fields can't be set via API, only via direct call to the method
|
// these two fields can't be set via API, only via direct call to the method
|
||||||
updatedUser.Issued = update.Issued
|
updatedUser.Issued = update.Issued
|
||||||
updatedUser.IntegrationReference = update.IntegrationReference
|
updatedUser.IntegrationReference = update.IntegrationReference
|
||||||
@@ -737,28 +768,58 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
|||||||
peersToExpire = userPeers
|
peersToExpire = userPeers
|
||||||
}
|
}
|
||||||
|
|
||||||
var removedGroups, addedGroups []string
|
updateAccountPeers, removedGroupsIDs, addedGroupsIDs, err := am.processUserGroupsUpdate(ctx, transaction, oldUser, updatedUser, userPeers, settings)
|
||||||
if update.AutoGroups != nil && settings.GroupsPropagationEnabled {
|
if err != nil {
|
||||||
removedGroups = util.Difference(oldUser.AutoGroups, update.AutoGroups)
|
return false, nil, nil, nil, err
|
||||||
addedGroups = util.Difference(update.AutoGroups, oldUser.AutoGroups)
|
}
|
||||||
|
|
||||||
|
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, isNewUser, removedGroupsIDs, addedGroupsIDs, transaction)
|
||||||
|
|
||||||
|
return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) processUserGroupsUpdate(ctx context.Context, transaction store.Store, oldUser *types.User, updatedUser *types.User, userPeers []*nbpeer.Peer, settings *types.Settings) (bool, []string, []string, error) {
|
||||||
|
removedGroups := util.Difference(oldUser.AutoGroups, updatedUser.AutoGroups)
|
||||||
|
addedGroups := util.Difference(updatedUser.AutoGroups, oldUser.AutoGroups)
|
||||||
|
|
||||||
|
updateAccountPeers := len(userPeers) > 0
|
||||||
|
|
||||||
|
removedGroupsIDs := make([]string, 0, len(removedGroups))
|
||||||
|
for _, id := range removedGroups {
|
||||||
|
err := transaction.RemoveUserFromGroup(ctx, updatedUser.Id, id)
|
||||||
|
if err != nil {
|
||||||
|
return false, nil, nil, fmt.Errorf("failed to remove user %s from group %s: %w", updatedUser.Id, id, err)
|
||||||
|
}
|
||||||
|
updateAccountPeers = true
|
||||||
|
removedGroupsIDs = append(removedGroupsIDs, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
addedGroupsIDs := make([]string, 0, len(addedGroups))
|
||||||
|
for _, id := range addedGroups {
|
||||||
|
err := transaction.AddUserToGroup(ctx, updatedUser.AccountID, updatedUser.Id, id)
|
||||||
|
if err != nil {
|
||||||
|
return false, nil, nil, fmt.Errorf("failed to add user %s to group %s: %w", updatedUser.Id, id, err)
|
||||||
|
}
|
||||||
|
updateAccountPeers = true
|
||||||
|
addedGroupsIDs = append(addedGroupsIDs, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
if updatedUser.Groups != nil && settings.GroupsPropagationEnabled {
|
||||||
for _, peer := range userPeers {
|
for _, peer := range userPeers {
|
||||||
for _, groupID := range removedGroups {
|
for _, id := range removedGroups {
|
||||||
if err := transaction.RemovePeerFromGroup(ctx, peer.ID, groupID); err != nil {
|
if err := transaction.RemovePeerFromGroup(ctx, peer.ID, id); err != nil {
|
||||||
return false, nil, nil, nil, fmt.Errorf("failed to remove peer %s from group %s: %w", peer.ID, groupID, err)
|
return false, nil, nil, fmt.Errorf("failed to remove peer %s from group %s: %w", peer.ID, id, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, groupID := range addedGroups {
|
for _, id := range addedGroups {
|
||||||
if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil {
|
if err := transaction.AddPeerToGroup(ctx, updatedUser.AccountID, peer.ID, id); err != nil {
|
||||||
return false, nil, nil, nil, fmt.Errorf("failed to add peer %s to group %s: %w", peer.ID, groupID, err)
|
return false, nil, nil, fmt.Errorf("failed to add peer %s to group %s: %w", peer.ID, id, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
updateAccountPeers := len(userPeers) > 0
|
return updateAccountPeers, removedGroupsIDs, addedGroupsIDs, nil
|
||||||
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, isNewUser, removedGroups, addedGroups, transaction)
|
|
||||||
|
|
||||||
return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.
|
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.
|
||||||
|
|||||||
@@ -345,6 +345,9 @@ func TestUser_Copy(t *testing.T) {
|
|||||||
IsServiceUser: true,
|
IsServiceUser: true,
|
||||||
ServiceUserName: "servicename",
|
ServiceUserName: "servicename",
|
||||||
AutoGroups: []string{"group1", "group2"},
|
AutoGroups: []string{"group1", "group2"},
|
||||||
|
Groups: []*types.GroupUser{
|
||||||
|
{AccountID: "accountId", GroupID: "groupId", UserID: "userId"},
|
||||||
|
},
|
||||||
PATs: map[string]*types.PersonalAccessToken{
|
PATs: map[string]*types.PersonalAccessToken{
|
||||||
"pat1": {
|
"pat1": {
|
||||||
ID: "pat1",
|
ID: "pat1",
|
||||||
@@ -413,6 +416,14 @@ func TestUser_CreateServiceUser(t *testing.T) {
|
|||||||
t.Cleanup(cleanup)
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||||
|
account.Groups["group1"] = &types.Group{
|
||||||
|
ID: "group1",
|
||||||
|
Name: "group1",
|
||||||
|
}
|
||||||
|
account.Groups["group2"] = &types.Group{
|
||||||
|
ID: "group2",
|
||||||
|
Name: "group2",
|
||||||
|
}
|
||||||
|
|
||||||
err = store.SaveAccount(context.Background(), account)
|
err = store.SaveAccount(context.Background(), account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -460,6 +471,14 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) {
|
|||||||
t.Cleanup(cleanup)
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||||
|
account.Groups["group1"] = &types.Group{
|
||||||
|
ID: "group1",
|
||||||
|
Name: "group1",
|
||||||
|
}
|
||||||
|
account.Groups["group2"] = &types.Group{
|
||||||
|
ID: "group2",
|
||||||
|
Name: "group2",
|
||||||
|
}
|
||||||
|
|
||||||
err = store.SaveAccount(context.Background(), account)
|
err = store.SaveAccount(context.Background(), account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -539,6 +558,14 @@ func TestUser_InviteNewUser(t *testing.T) {
|
|||||||
t.Cleanup(cleanup)
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "", "", "", false)
|
||||||
|
account.Groups["group1"] = &types.Group{
|
||||||
|
ID: "group1",
|
||||||
|
Name: "group1",
|
||||||
|
}
|
||||||
|
account.Groups["group2"] = &types.Group{
|
||||||
|
ID: "group2",
|
||||||
|
Name: "group2",
|
||||||
|
}
|
||||||
|
|
||||||
err = store.SaveAccount(context.Background(), account)
|
err = store.SaveAccount(context.Background(), account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1653,6 +1680,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
|||||||
LastLogin: time.Time{},
|
LastLogin: time.Time{},
|
||||||
Issued: "api",
|
Issued: "api",
|
||||||
IntegrationReference: integration_reference.IntegrationReference{},
|
IntegrationReference: integration_reference.IntegrationReference{},
|
||||||
|
AutoGroups: []string{},
|
||||||
},
|
},
|
||||||
Permissions: mergeRolePermissions(roles.User),
|
Permissions: mergeRolePermissions(roles.User),
|
||||||
},
|
},
|
||||||
@@ -1672,6 +1700,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
|||||||
LastLogin: time.Time{},
|
LastLogin: time.Time{},
|
||||||
Issued: "api",
|
Issued: "api",
|
||||||
IntegrationReference: integration_reference.IntegrationReference{},
|
IntegrationReference: integration_reference.IntegrationReference{},
|
||||||
|
AutoGroups: []string{},
|
||||||
},
|
},
|
||||||
Permissions: mergeRolePermissions(roles.Admin),
|
Permissions: mergeRolePermissions(roles.Admin),
|
||||||
},
|
},
|
||||||
@@ -1691,6 +1720,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
|||||||
LastLogin: time.Time{},
|
LastLogin: time.Time{},
|
||||||
Issued: "api",
|
Issued: "api",
|
||||||
IntegrationReference: integration_reference.IntegrationReference{},
|
IntegrationReference: integration_reference.IntegrationReference{},
|
||||||
|
AutoGroups: []string{},
|
||||||
},
|
},
|
||||||
Permissions: mergeRolePermissions(roles.User),
|
Permissions: mergeRolePermissions(roles.User),
|
||||||
Restricted: true,
|
Restricted: true,
|
||||||
@@ -1712,6 +1742,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
|||||||
LastLogin: time.Time{},
|
LastLogin: time.Time{},
|
||||||
Issued: "api",
|
Issued: "api",
|
||||||
IntegrationReference: integration_reference.IntegrationReference{},
|
IntegrationReference: integration_reference.IntegrationReference{},
|
||||||
|
AutoGroups: []string{},
|
||||||
},
|
},
|
||||||
Permissions: mergeRolePermissions(roles.User),
|
Permissions: mergeRolePermissions(roles.User),
|
||||||
Restricted: false,
|
Restricted: false,
|
||||||
|
|||||||
Reference in New Issue
Block a user