migrate auto groups to different table

This commit is contained in:
pascal
2026-01-08 15:45:48 +01:00
parent fb71b0d04b
commit f7ee019f26
7 changed files with 387 additions and 56 deletions

View File

@@ -1403,9 +1403,20 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
}
addNewGroups = util.Difference(updatedAutoGroups, user.AutoGroups)
for _, group := range addNewGroups {
err = transaction.AddUserToGroup(ctx, userAuth.AccountId, group, user.Id)
if err != nil {
return fmt.Errorf("error adding user to group: %w", err)
}
}
removeOldGroups = util.Difference(user.AutoGroups, updatedAutoGroups)
for _, group := range removeOldGroups {
err = transaction.RemoveUserFromGroup(ctx, user.Id, group)
if err != nil {
return fmt.Errorf("error removing user from group: %w", err)
}
}
user.AutoGroups = updatedAutoGroups
if err = transaction.SaveUser(ctx, user); err != nil {
return fmt.Errorf("error saving user: %w", err)
}

View File

@@ -487,3 +487,103 @@ func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName stri
log.WithContext(ctx).Infof("Migration of JSON field %s from table %s into separate table completed", columnName, tableName)
return nil
}
// CleanupOrphanedIDs removes non-existent IDs from the JSON array column.
// T is the type of the model that contains the list.
// This migration cleans up the lists field by removing IDs that no longer exist in the target table.
func CleanupOrphanedIDs[T, S any](ctx context.Context, db *gorm.DB, columnName string) error {
var sourceModel T
var fkModel S
if !db.Migrator().HasTable(&sourceModel) {
log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", sourceModel)
return nil
}
if !db.Migrator().HasTable(&fkModel) {
log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", fkModel)
return nil
}
stmt := &gorm.Statement{DB: db}
err := stmt.Parse(&sourceModel)
if err != nil {
return fmt.Errorf("parse model: %w", err)
}
tableName := stmt.Schema.Table
if !db.Migrator().HasColumn(&sourceModel, columnName) {
log.WithContext(ctx).Debugf("Column %s does not exist in table %s, no migration needed", columnName, tableName)
return nil
}
if err := db.Transaction(func(tx *gorm.DB) error {
var rows []map[string]any
if err := tx.Table(tableName).Select("id", columnName).Find(&rows).Error; err != nil {
return fmt.Errorf("find rows: %w", err)
}
// Get all valid IDs from the fk table
var validIDs []string
if err := tx.Model(fkModel).Select("id").Pluck("id", &validIDs).Error; err != nil {
return fmt.Errorf("fetch valid group IDs: %w", err)
}
validIDMap := make(map[string]bool, len(validIDs))
for _, id := range validIDs {
validIDMap[id] = true
}
updatedCount := 0
for _, row := range rows {
jsonValue, ok := row[columnName].(string)
if !ok || jsonValue == "" || jsonValue == "null" {
continue
}
var list []string
if err := json.Unmarshal([]byte(jsonValue), &list); err != nil {
log.WithContext(ctx).Warnf("Failed to unmarshal %s for id %v: %v", columnName, row["id"], err)
continue
}
if len(list) == 0 {
continue
}
// Filter out non-existent IDs
cleanedList := make([]string, 0, len(list))
for _, groupID := range list {
if validIDMap[groupID] {
cleanedList = append(cleanedList, groupID)
}
}
// Only update if there were orphaned ids removed
if len(cleanedList) != len(list) {
cleanedJSON, err := json.Marshal(cleanedList)
if err != nil {
return fmt.Errorf("marshal cleaned %s: %w", columnName, err)
}
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(columnName, cleanedJSON).Error; err != nil {
return fmt.Errorf("update row with id %v: %w", row["id"], err)
}
updatedCount++
}
}
if updatedCount > 0 {
log.WithContext(ctx).Infof("Cleaned up orphaned %s in %d rows from table %s", columnName, updatedCount, tableName)
} else {
log.WithContext(ctx).Debugf("No orphaned %s found in table %s", columnName, tableName)
}
return nil
}); err != nil {
return err
}
log.WithContext(ctx).Infof("Cleanup of orphaned auto_groups from table %s completed", tableName)
return nil
}

View File

@@ -119,7 +119,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
return nil, fmt.Errorf("migratePreAuto: %w", err)
}
err = db.AutoMigrate(
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, &types.GroupUser{},
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
@@ -185,6 +185,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
for _, group := range account.GroupsG {
group.StoreGroupPeers()
group.StoreGroupUsers()
}
err := s.transaction(func(tx *gorm.DB) error {
@@ -243,6 +244,7 @@ func generateAccountSQLTypes(account *types.Account) {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
user.LoadAutoGroups()
account.UsersG = append(account.UsersG, *user)
}
@@ -453,6 +455,7 @@ func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
userCopy := user.Copy()
userCopy.Email = user.Email
userCopy.Name = user.Name
userCopy.StoreAutoGroups()
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt user: %w", err)
}
@@ -472,6 +475,7 @@ func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error {
userCopy := user.Copy()
userCopy.Email = user.Email
userCopy.Name = user.Name
userCopy.StoreAutoGroups()
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt user: %w", err)
@@ -617,6 +621,7 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
var user types.User
result := tx.
Preload("Groups").
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
Where("personal_access_tokens.id = ?", patID).Take(&user)
if result.Error != nil {
@@ -631,6 +636,8 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
return nil, fmt.Errorf("decrypt user: %w", err)
}
user.LoadAutoGroups()
return &user, nil
}
@@ -641,7 +648,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
}
var user types.User
result := tx.Take(&user, idQueryCondition, userID)
result := tx.Preload("Groups").Take(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID)
@@ -653,6 +660,8 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
return nil, fmt.Errorf("decrypt user: %w", err)
}
user.LoadAutoGroups()
return &user, nil
}
@@ -680,7 +689,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
}
var users []*types.User
result := tx.Find(&users, accountIDCondition, accountID)
result := tx.Preload("Groups").Find(&users, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@@ -693,6 +702,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt user: %w", err)
}
user.LoadAutoGroups()
}
return users, nil
@@ -705,7 +715,7 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
}
var user types.User
result := tx.Take(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
result := tx.Preload("Groups").Take(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account owner not found: index lookup failed")
@@ -717,6 +727,8 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
return nil, fmt.Errorf("decrypt user: %w", err)
}
user.LoadAutoGroups()
return &user, nil
}
@@ -738,6 +750,7 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
for _, g := range groups {
g.LoadGroupPeers()
g.LoadGroupUsers()
}
return groups, nil
@@ -767,6 +780,7 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
for _, g := range groups {
g.LoadGroupPeers()
g.LoadGroupUsers()
}
return groups, nil
@@ -867,6 +881,8 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
Preload("SetupKeysG").
Preload("PeersG").
Preload("UsersG").
Preload("UsersG.GroupUser").
Preload("GroupsG").
Preload("GroupsG.GroupPeers").
Preload("RoutesG").
Preload("NameServerGroupsG").
@@ -908,9 +924,10 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
pat.UserID = ""
user.PATs[pat.ID] = &pat
}
if user.AutoGroups == nil {
user.AutoGroups = []string{}
if user.Groups == nil {
user.Groups = []*types.GroupUser{}
}
user.LoadAutoGroups()
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt user: %w", err)
}
@@ -1116,8 +1133,8 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
groupIDs = append(groupIDs, g.ID)
}
wg.Add(3)
errChan = make(chan error, 3)
wg.Add(4)
errChan = make(chan error, 4)
var pats []types.PersonalAccessToken
go func() {
@@ -1149,6 +1166,16 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
}
}()
var groupUsers []types.GroupUser
go func() {
defer wg.Done()
var err error
groupUsers, err = s.getGroupUsers(ctx, userIDs)
if err != nil {
errChan <- err
}
}()
wg.Wait()
close(errChan)
for e := range errChan {
@@ -1174,6 +1201,12 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID)
}
groupsByUserID := make(map[string][]*types.GroupUser)
for i := range groupUsers {
gu := &groupUsers[i]
groupsByUserID[gu.UserID] = append(groupsByUserID[gu.UserID], gu)
}
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
for i := range account.SetupKeysG {
key := &account.SetupKeysG[i]
@@ -1199,6 +1232,8 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
user.PATs[pat.ID] = pat
}
}
user.Groups = groupsByUserID[user.Id]
user.LoadAutoGroups()
account.Users[user.Id] = user
}
@@ -1596,43 +1631,40 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
}
func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) {
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1`
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
return nil, err
}
users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) {
var u types.User
var autoGroups []byte
var lastLogin, createdAt sql.NullTime
var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name)
if err == nil {
if lastLogin.Valid {
u.LastLogin = &lastLogin.Time
}
if createdAt.Valid {
u.CreatedAt = createdAt.Time
}
if isServiceUser.Valid {
u.IsServiceUser = isServiceUser.Bool
}
if nonDeletable.Valid {
u.NonDeletable = nonDeletable.Bool
}
if blocked.Valid {
u.Blocked = blocked.Bool
}
if pendingApproval.Valid {
u.PendingApproval = pendingApproval.Bool
}
if autoGroups != nil {
_ = json.Unmarshal(autoGroups, &u.AutoGroups)
} else {
u.AutoGroups = []string{}
}
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name)
if err != nil {
return u, err
}
return u, err
if lastLogin.Valid {
u.LastLogin = &lastLogin.Time
}
if createdAt.Valid {
u.CreatedAt = createdAt.Time
}
if isServiceUser.Valid {
u.IsServiceUser = isServiceUser.Bool
}
if nonDeletable.Valid {
u.NonDeletable = nonDeletable.Bool
}
if blocked.Valid {
u.Blocked = blocked.Bool
}
if pendingApproval.Valid {
u.PendingApproval = pendingApproval.Bool
}
return u, nil
})
if err != nil {
return nil, err
@@ -2038,6 +2070,22 @@ func (s *SqlStore) getGroupPeers(ctx context.Context, groupIDs []string) ([]type
return groupPeers, nil
}
func (s *SqlStore) getGroupUsers(ctx context.Context, userIDs []string) ([]types.GroupUser, error) {
if len(userIDs) == 0 {
return nil, nil
}
const query = `SELECT account_id, group_id, user_id FROM group_users WHERE user_id = ANY($1)`
rows, err := s.pool.Query(ctx, query, userIDs)
if err != nil {
return nil, err
}
groupUsers, err := pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupUser])
if err != nil {
return nil, err
}
return groupUsers, nil
}
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) {
var user types.User
result := s.db.Select("account_id").Take(&user, idQueryCondition, userID)
@@ -2659,6 +2707,41 @@ func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, group
return nil
}
func (s *SqlStore) AddUserToGroup(ctx context.Context, accountID, userID, groupID string) error {
user := &types.GroupUser{
AccountID: accountID,
GroupID: groupID,
UserID: userID,
}
err := s.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "group_id"}, {Name: "user_id"}},
DoNothing: true,
}).Create(user).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to add user %s to group %s for account %s: %v", userID, groupID, accountID, err)
return status.Errorf(status.Internal, "failed to add user to group")
}
return nil
}
func (s *SqlStore) RemoveUserFromGroup(ctx context.Context, userID, groupID string) error {
result := s.db.Delete(&types.GroupUser{}, "group_id = ? AND user_id = ?", groupID, userID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to remove user %s from group %s: %v", userID, groupID, result.Error)
return status.Errorf(status.Internal, "failed to remove user from group")
}
if result.RowsAffected == 0 {
log.WithContext(ctx).Warnf("user %s was not in group %s", userID, groupID)
}
return nil
}
// RemovePeerFromAllGroups removes a peer from all groups
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
err := s.db.
@@ -2745,6 +2828,7 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng
for _, group := range groups {
group.LoadGroupPeers()
group.LoadGroupUsers()
}
return groups, nil
@@ -3146,6 +3230,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
}
group.LoadGroupPeers()
group.LoadGroupUsers()
return group, nil
}
@@ -3177,6 +3262,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
}
group.LoadGroupPeers()
group.LoadGroupUsers()
return &group, nil
}
@@ -3198,6 +3284,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
groupsMap := make(map[string]*types.Group)
for _, group := range groups {
group.LoadGroupPeers()
group.LoadGroupUsers()
groupsMap[group.ID] = group
}

View File

@@ -89,6 +89,8 @@ type Store interface {
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error
AddUserToGroup(ctx context.Context, accountID, userID, groupID string) error
RemoveUserFromGroup(ctx context.Context, userID, groupID string) error
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
@@ -350,6 +352,9 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc {
func(db *gorm.DB) error {
return migration.MigrateNewField[types.User](ctx, db, "email", "")
},
func(db *gorm.DB) error {
return migration.CleanupOrphanedIDs[types.User, types.Group](ctx, db, "auto_groups")
},
}
} // migratePostAuto migrates the SQLite database to the latest schema
func migratePostAuto(ctx context.Context, db *gorm.DB) error {
@@ -381,6 +386,15 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
}
})
},
func(db *gorm.DB) error {
return migration.MigrateJsonToTable[types.User](ctx, db, "auto_groups", func(accountID, id, value string) any {
return &types.GroupUser{
AccountID: accountID,
GroupID: value,
UserID: id,
}
})
},
}
}

View File

@@ -28,6 +28,8 @@ type Group struct {
// Peers list of the group
Peers []string `gorm:"-"` // Peers and GroupPeers list will be ignored when writing to the DB. Use AddPeerToGroup and RemovePeerFromGroup methods to modify group membership
GroupPeers []GroupPeer `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"`
Users []string `gorm:"-"`
GroupUsers []GroupUser `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"`
// Resources contains a list of resources in that group
Resources []Resource `gorm:"serializer:json"`
@@ -41,6 +43,12 @@ type GroupPeer struct {
PeerID string `gorm:"primaryKey"`
}
type GroupUser struct {
AccountID string `gorm:"index"`
GroupID string `gorm:"primaryKey"`
UserID string `gorm:"primaryKey"`
}
func (g *Group) LoadGroupPeers() {
g.Peers = make([]string, len(g.GroupPeers))
for i, peer := range g.GroupPeers {
@@ -61,6 +69,26 @@ func (g *Group) StoreGroupPeers() {
g.Peers = []string{}
}
func (g *Group) LoadGroupUsers() {
g.Users = make([]string, len(g.GroupUsers))
for i, user := range g.GroupUsers {
g.Users[i] = user.UserID
}
g.GroupUsers = []GroupUser{}
}
func (g *Group) StoreGroupUsers() {
g.GroupUsers = make([]GroupUser, len(g.Users))
for i, user := range g.Users {
g.GroupUsers[i] = GroupUser{
AccountID: g.AccountID,
GroupID: g.ID,
UserID: user,
}
}
g.Users = []string{}
}
// EventMeta returns activity event meta related to the group
func (g *Group) EventMeta() map[string]any {
return map[string]any{"name": g.Name}
@@ -78,11 +106,13 @@ func (g *Group) Copy() *Group {
Issued: g.Issued,
Peers: make([]string, len(g.Peers)),
GroupPeers: make([]GroupPeer, len(g.GroupPeers)),
GroupUsers: make([]GroupUser, len(g.GroupUsers)),
Resources: make([]Resource, len(g.Resources)),
IntegrationReference: g.IntegrationReference,
}
copy(group.Peers, g.Peers)
copy(group.GroupPeers, g.GroupPeers)
copy(group.GroupUsers, g.GroupUsers)
copy(group.Resources, g.Resources)
return group
}

View File

@@ -85,9 +85,11 @@ type User struct {
// ServiceUserName is only set if IsServiceUser is true
ServiceUserName string
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
AutoGroups []string `gorm:"serializer:json"`
PATs map[string]*PersonalAccessToken `gorm:"-"`
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
AutoGroups []string `gorm:"-"`
// GroupUsers replaces old AutoGroups
Groups []*GroupUser `gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
PATs map[string]*PersonalAccessToken `gorm:"-"`
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
Blocked bool
// PendingApproval indicates whether the user requires approval before being activated
@@ -106,6 +108,24 @@ type User struct {
Email string `gorm:"default:''"`
}
func (u *User) LoadAutoGroups() {
u.AutoGroups = make([]string, 0, len(u.Groups))
for _, group := range u.Groups {
u.AutoGroups = append(u.AutoGroups, group.GroupID)
}
}
func (u *User) StoreAutoGroups() {
u.Groups = make([]*GroupUser, 0, len(u.Groups))
for _, groupID := range u.AutoGroups {
u.Groups = append(u.Groups, &GroupUser{
AccountID: u.AccountID,
GroupID: groupID,
UserID: u.Id,
})
}
}
// IsBlocked returns true if the user is blocked, false otherwise
func (u *User) IsBlocked() bool {
return u.Blocked
@@ -198,8 +218,11 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
// Copy the user
func (u *User) Copy() *User {
groupUsers := make([]*GroupUser, len(u.Groups))
copy(groupUsers, u.Groups)
autoGroups := make([]string, len(u.AutoGroups))
copy(autoGroups, u.AutoGroups)
pats := make(map[string]*PersonalAccessToken, len(u.PATs))
for k, v := range u.PATs {
pats[k] = v.Copy()
@@ -221,6 +244,7 @@ func (u *User) Copy() *User {
IntegrationReference: u.IntegrationReference,
Email: u.Email,
Name: u.Name,
Groups: groupUsers,
}
}

View File

@@ -45,7 +45,21 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI
newUser.AccountID = accountID
log.WithContext(ctx).Debugf("New User: %v", newUser)
if err = am.Store.SaveUser(ctx, newUser); err != nil {
if err = am.Store.ExecuteInTransaction(ctx, func(tx store.Store) error {
err = tx.SaveUser(ctx, newUser)
if err != nil {
return err
}
for _, groupID := range autoGroups {
err = tx.AddUserToGroup(ctx, accountID, newUserID, groupID)
if err != nil {
return fmt.Errorf("failed to add user to group %s: %w", groupID, err)
}
}
return nil
}); err != nil {
return nil, err
}
@@ -119,7 +133,6 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
Id: idpUser.ID,
AccountID: accountID,
Role: types.StrRoleToUserRole(invite.Role),
AutoGroups: invite.AutoGroups,
Issued: invite.Issued,
IntegrationReference: invite.IntegrationReference,
CreatedAt: time.Now().UTC(),
@@ -127,6 +140,23 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
Name: invite.Name,
}
err = am.Store.ExecuteInTransaction(ctx, func(tx store.Store) error {
err = tx.SaveUser(ctx, newUser)
if err != nil {
return err
}
for _, group := range invite.AutoGroups {
err = tx.AddUserToGroup(ctx, accountID, userID, group)
if err != nil {
return fmt.Errorf("failed to add user to group %s: %w", group, err)
}
}
return nil
})
if err != nil {
return nil, fmt.Errorf("failed to save user: %w", err)
}
if err = am.Store.SaveUser(ctx, newUser); err != nil {
return nil, err
}
@@ -737,28 +767,63 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
peersToExpire = userPeers
}
var removedGroups, addedGroups []string
if update.AutoGroups != nil && settings.GroupsPropagationEnabled {
removedGroups = util.Difference(oldUser.AutoGroups, update.AutoGroups)
addedGroups = util.Difference(update.AutoGroups, oldUser.AutoGroups)
updateAccountPeers, removedGroupsIDs, addedGroupsIDs, err := am.processUserGroupsUpdate(ctx, transaction, oldUser, updatedUser, userPeers, settings)
if err != nil {
return false, nil, nil, nil, err
}
updatedUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthNone, updatedUser.Id)
if err != nil {
return false, nil, nil, nil, err
}
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, isNewUser, removedGroupsIDs, addedGroupsIDs, transaction)
return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil
}
func (am *DefaultAccountManager) processUserGroupsUpdate(ctx context.Context, transaction store.Store, oldUser *types.User, updatedUser *types.User, userPeers []*nbpeer.Peer, settings *types.Settings) (bool, []string, []string, error) {
removedGroups := util.Difference(oldUser.AutoGroups, updatedUser.AutoGroups)
addedGroups := util.Difference(updatedUser.AutoGroups, oldUser.AutoGroups)
updateAccountPeers := len(userPeers) > 0
removedGroupsIDs := make([]string, 0, len(removedGroups))
for _, id := range removedGroups {
err := transaction.RemoveUserFromGroup(ctx, updatedUser.Id, id)
if err != nil {
return false, nil, nil, fmt.Errorf("failed to remove user %s from group %s: %w", updatedUser.Id, id, err)
}
updateAccountPeers = true
removedGroupsIDs = append(removedGroupsIDs, id)
}
addedGroupsIDs := make([]string, 0, len(addedGroups))
for _, id := range addedGroups {
err := transaction.AddUserToGroup(ctx, updatedUser.AccountID, updatedUser.Id, id)
if err != nil {
return false, nil, nil, fmt.Errorf("failed to add user %s to group %s: %w", updatedUser.Id, id, err)
}
updateAccountPeers = true
addedGroupsIDs = append(addedGroupsIDs, id)
}
if updatedUser.Groups != nil && settings.GroupsPropagationEnabled {
for _, peer := range userPeers {
for _, groupID := range removedGroups {
if err := transaction.RemovePeerFromGroup(ctx, peer.ID, groupID); err != nil {
return false, nil, nil, nil, fmt.Errorf("failed to remove peer %s from group %s: %w", peer.ID, groupID, err)
for _, id := range removedGroups {
if err := transaction.RemovePeerFromGroup(ctx, peer.ID, id); err != nil {
return false, nil, nil, fmt.Errorf("failed to remove peer %s from group %s: %w", peer.ID, id, err)
}
}
for _, groupID := range addedGroups {
if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil {
return false, nil, nil, nil, fmt.Errorf("failed to add peer %s to group %s: %w", peer.ID, groupID, err)
for _, id := range addedGroups {
if err := transaction.AddPeerToGroup(ctx, updatedUser.AccountID, peer.ID, id); err != nil {
return false, nil, nil, fmt.Errorf("failed to add peer %s to group %s: %w", peer.ID, id, err)
}
}
}
}
updateAccountPeers := len(userPeers) > 0
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, isNewUser, removedGroups, addedGroups, transaction)
return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil
return updateAccountPeers, removedGroupsIDs, addedGroupsIDs, nil
}
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.