mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-25 19:56:46 +00:00
migrate auto groups to different table
This commit is contained in:
@@ -1403,9 +1403,20 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
}
|
||||
|
||||
addNewGroups = util.Difference(updatedAutoGroups, user.AutoGroups)
|
||||
for _, group := range addNewGroups {
|
||||
err = transaction.AddUserToGroup(ctx, userAuth.AccountId, group, user.Id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error adding user to group: %w", err)
|
||||
}
|
||||
}
|
||||
removeOldGroups = util.Difference(user.AutoGroups, updatedAutoGroups)
|
||||
for _, group := range removeOldGroups {
|
||||
err = transaction.RemoveUserFromGroup(ctx, user.Id, group)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error removing user from group: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
user.AutoGroups = updatedAutoGroups
|
||||
if err = transaction.SaveUser(ctx, user); err != nil {
|
||||
return fmt.Errorf("error saving user: %w", err)
|
||||
}
|
||||
|
||||
@@ -487,3 +487,103 @@ func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName stri
|
||||
log.WithContext(ctx).Infof("Migration of JSON field %s from table %s into separate table completed", columnName, tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupOrphanedIDs removes non-existent IDs from the JSON array column.
|
||||
// T is the type of the model that contains the list.
|
||||
// This migration cleans up the lists field by removing IDs that no longer exist in the target table.
|
||||
func CleanupOrphanedIDs[T, S any](ctx context.Context, db *gorm.DB, columnName string) error {
|
||||
var sourceModel T
|
||||
var fkModel S
|
||||
|
||||
if !db.Migrator().HasTable(&sourceModel) {
|
||||
log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", sourceModel)
|
||||
return nil
|
||||
}
|
||||
|
||||
if !db.Migrator().HasTable(&fkModel) {
|
||||
log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", fkModel)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmt := &gorm.Statement{DB: db}
|
||||
err := stmt.Parse(&sourceModel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse model: %w", err)
|
||||
}
|
||||
tableName := stmt.Schema.Table
|
||||
|
||||
if !db.Migrator().HasColumn(&sourceModel, columnName) {
|
||||
log.WithContext(ctx).Debugf("Column %s does not exist in table %s, no migration needed", columnName, tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := db.Transaction(func(tx *gorm.DB) error {
|
||||
var rows []map[string]any
|
||||
if err := tx.Table(tableName).Select("id", columnName).Find(&rows).Error; err != nil {
|
||||
return fmt.Errorf("find rows: %w", err)
|
||||
}
|
||||
|
||||
// Get all valid IDs from the fk table
|
||||
var validIDs []string
|
||||
if err := tx.Model(fkModel).Select("id").Pluck("id", &validIDs).Error; err != nil {
|
||||
return fmt.Errorf("fetch valid group IDs: %w", err)
|
||||
}
|
||||
|
||||
validIDMap := make(map[string]bool, len(validIDs))
|
||||
for _, id := range validIDs {
|
||||
validIDMap[id] = true
|
||||
}
|
||||
|
||||
updatedCount := 0
|
||||
for _, row := range rows {
|
||||
jsonValue, ok := row[columnName].(string)
|
||||
if !ok || jsonValue == "" || jsonValue == "null" {
|
||||
continue
|
||||
}
|
||||
|
||||
var list []string
|
||||
if err := json.Unmarshal([]byte(jsonValue), &list); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to unmarshal %s for id %v: %v", columnName, row["id"], err)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(list) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Filter out non-existent IDs
|
||||
cleanedList := make([]string, 0, len(list))
|
||||
for _, groupID := range list {
|
||||
if validIDMap[groupID] {
|
||||
cleanedList = append(cleanedList, groupID)
|
||||
}
|
||||
}
|
||||
|
||||
// Only update if there were orphaned ids removed
|
||||
if len(cleanedList) != len(list) {
|
||||
cleanedJSON, err := json.Marshal(cleanedList)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal cleaned %s: %w", columnName, err)
|
||||
}
|
||||
|
||||
if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(columnName, cleanedJSON).Error; err != nil {
|
||||
return fmt.Errorf("update row with id %v: %w", row["id"], err)
|
||||
}
|
||||
updatedCount++
|
||||
}
|
||||
}
|
||||
|
||||
if updatedCount > 0 {
|
||||
log.WithContext(ctx).Infof("Cleaned up orphaned %s in %d rows from table %s", columnName, updatedCount, tableName)
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("No orphaned %s found in table %s", columnName, tableName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("Cleanup of orphaned auto_groups from table %s completed", tableName)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -119,7 +119,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
return nil, fmt.Errorf("migratePreAuto: %w", err)
|
||||
}
|
||||
err = db.AutoMigrate(
|
||||
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
|
||||
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, &types.GroupUser{},
|
||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||
@@ -185,6 +185,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
|
||||
|
||||
for _, group := range account.GroupsG {
|
||||
group.StoreGroupPeers()
|
||||
group.StoreGroupUsers()
|
||||
}
|
||||
|
||||
err := s.transaction(func(tx *gorm.DB) error {
|
||||
@@ -243,6 +244,7 @@ func generateAccountSQLTypes(account *types.Account) {
|
||||
pat.ID = id
|
||||
user.PATsG = append(user.PATsG, *pat)
|
||||
}
|
||||
user.LoadAutoGroups()
|
||||
account.UsersG = append(account.UsersG, *user)
|
||||
}
|
||||
|
||||
@@ -453,6 +455,7 @@ func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
|
||||
userCopy := user.Copy()
|
||||
userCopy.Email = user.Email
|
||||
userCopy.Name = user.Name
|
||||
userCopy.StoreAutoGroups()
|
||||
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return fmt.Errorf("encrypt user: %w", err)
|
||||
}
|
||||
@@ -472,6 +475,7 @@ func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error {
|
||||
userCopy := user.Copy()
|
||||
userCopy.Email = user.Email
|
||||
userCopy.Name = user.Name
|
||||
userCopy.StoreAutoGroups()
|
||||
|
||||
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return fmt.Errorf("encrypt user: %w", err)
|
||||
@@ -617,6 +621,7 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
|
||||
|
||||
var user types.User
|
||||
result := tx.
|
||||
Preload("Groups").
|
||||
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
|
||||
Where("personal_access_tokens.id = ?", patID).Take(&user)
|
||||
if result.Error != nil {
|
||||
@@ -631,6 +636,8 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
|
||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||
}
|
||||
|
||||
user.LoadAutoGroups()
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
@@ -641,7 +648,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
}
|
||||
|
||||
var user types.User
|
||||
result := tx.Take(&user, idQueryCondition, userID)
|
||||
result := tx.Preload("Groups").Take(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewUserNotFoundError(userID)
|
||||
@@ -653,6 +660,8 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||
}
|
||||
|
||||
user.LoadAutoGroups()
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
@@ -680,7 +689,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
|
||||
}
|
||||
|
||||
var users []*types.User
|
||||
result := tx.Find(&users, accountIDCondition, accountID)
|
||||
result := tx.Preload("Groups").Find(&users, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
@@ -693,6 +702,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
|
||||
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||
}
|
||||
user.LoadAutoGroups()
|
||||
}
|
||||
|
||||
return users, nil
|
||||
@@ -705,7 +715,7 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
|
||||
}
|
||||
|
||||
var user types.User
|
||||
result := tx.Take(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
|
||||
result := tx.Preload("Groups").Take(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account owner not found: index lookup failed")
|
||||
@@ -717,6 +727,8 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
|
||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||
}
|
||||
|
||||
user.LoadAutoGroups()
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
@@ -738,6 +750,7 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
|
||||
|
||||
for _, g := range groups {
|
||||
g.LoadGroupPeers()
|
||||
g.LoadGroupUsers()
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
@@ -767,6 +780,7 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
|
||||
|
||||
for _, g := range groups {
|
||||
g.LoadGroupPeers()
|
||||
g.LoadGroupUsers()
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
@@ -867,6 +881,8 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
|
||||
Preload("SetupKeysG").
|
||||
Preload("PeersG").
|
||||
Preload("UsersG").
|
||||
Preload("UsersG.GroupUser").
|
||||
Preload("GroupsG").
|
||||
Preload("GroupsG.GroupPeers").
|
||||
Preload("RoutesG").
|
||||
Preload("NameServerGroupsG").
|
||||
@@ -908,9 +924,10 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
|
||||
pat.UserID = ""
|
||||
user.PATs[pat.ID] = &pat
|
||||
}
|
||||
if user.AutoGroups == nil {
|
||||
user.AutoGroups = []string{}
|
||||
if user.Groups == nil {
|
||||
user.Groups = []*types.GroupUser{}
|
||||
}
|
||||
user.LoadAutoGroups()
|
||||
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||
}
|
||||
@@ -1116,8 +1133,8 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
||||
groupIDs = append(groupIDs, g.ID)
|
||||
}
|
||||
|
||||
wg.Add(3)
|
||||
errChan = make(chan error, 3)
|
||||
wg.Add(4)
|
||||
errChan = make(chan error, 4)
|
||||
|
||||
var pats []types.PersonalAccessToken
|
||||
go func() {
|
||||
@@ -1149,6 +1166,16 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
||||
}
|
||||
}()
|
||||
|
||||
var groupUsers []types.GroupUser
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
var err error
|
||||
groupUsers, err = s.getGroupUsers(ctx, userIDs)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
for e := range errChan {
|
||||
@@ -1174,6 +1201,12 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
||||
peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID)
|
||||
}
|
||||
|
||||
groupsByUserID := make(map[string][]*types.GroupUser)
|
||||
for i := range groupUsers {
|
||||
gu := &groupUsers[i]
|
||||
groupsByUserID[gu.UserID] = append(groupsByUserID[gu.UserID], gu)
|
||||
}
|
||||
|
||||
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
|
||||
for i := range account.SetupKeysG {
|
||||
key := &account.SetupKeysG[i]
|
||||
@@ -1199,6 +1232,8 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
||||
user.PATs[pat.ID] = pat
|
||||
}
|
||||
}
|
||||
user.Groups = groupsByUserID[user.Id]
|
||||
user.LoadAutoGroups()
|
||||
account.Users[user.Id] = user
|
||||
}
|
||||
|
||||
@@ -1596,43 +1631,40 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
|
||||
}
|
||||
|
||||
func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) {
|
||||
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) {
|
||||
var u types.User
|
||||
var autoGroups []byte
|
||||
var lastLogin, createdAt sql.NullTime
|
||||
var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool
|
||||
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name)
|
||||
if err == nil {
|
||||
if lastLogin.Valid {
|
||||
u.LastLogin = &lastLogin.Time
|
||||
}
|
||||
if createdAt.Valid {
|
||||
u.CreatedAt = createdAt.Time
|
||||
}
|
||||
if isServiceUser.Valid {
|
||||
u.IsServiceUser = isServiceUser.Bool
|
||||
}
|
||||
if nonDeletable.Valid {
|
||||
u.NonDeletable = nonDeletable.Bool
|
||||
}
|
||||
if blocked.Valid {
|
||||
u.Blocked = blocked.Bool
|
||||
}
|
||||
if pendingApproval.Valid {
|
||||
u.PendingApproval = pendingApproval.Bool
|
||||
}
|
||||
if autoGroups != nil {
|
||||
_ = json.Unmarshal(autoGroups, &u.AutoGroups)
|
||||
} else {
|
||||
u.AutoGroups = []string{}
|
||||
}
|
||||
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name)
|
||||
if err != nil {
|
||||
return u, err
|
||||
}
|
||||
return u, err
|
||||
|
||||
if lastLogin.Valid {
|
||||
u.LastLogin = &lastLogin.Time
|
||||
}
|
||||
if createdAt.Valid {
|
||||
u.CreatedAt = createdAt.Time
|
||||
}
|
||||
if isServiceUser.Valid {
|
||||
u.IsServiceUser = isServiceUser.Bool
|
||||
}
|
||||
if nonDeletable.Valid {
|
||||
u.NonDeletable = nonDeletable.Bool
|
||||
}
|
||||
if blocked.Valid {
|
||||
u.Blocked = blocked.Bool
|
||||
}
|
||||
if pendingApproval.Valid {
|
||||
u.PendingApproval = pendingApproval.Bool
|
||||
}
|
||||
|
||||
return u, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2038,6 +2070,22 @@ func (s *SqlStore) getGroupPeers(ctx context.Context, groupIDs []string) ([]type
|
||||
return groupPeers, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) getGroupUsers(ctx context.Context, userIDs []string) ([]types.GroupUser, error) {
|
||||
if len(userIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
const query = `SELECT account_id, group_id, user_id FROM group_users WHERE user_id = ANY($1)`
|
||||
rows, err := s.pool.Query(ctx, query, userIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
groupUsers, err := pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupUser])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return groupUsers, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) {
|
||||
var user types.User
|
||||
result := s.db.Select("account_id").Take(&user, idQueryCondition, userID)
|
||||
@@ -2659,6 +2707,41 @@ func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, group
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddUserToGroup(ctx context.Context, accountID, userID, groupID string) error {
|
||||
user := &types.GroupUser{
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
UserID: userID,
|
||||
}
|
||||
|
||||
err := s.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "group_id"}, {Name: "user_id"}},
|
||||
DoNothing: true,
|
||||
}).Create(user).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to add user %s to group %s for account %s: %v", userID, groupID, accountID, err)
|
||||
return status.Errorf(status.Internal, "failed to add user to group")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) RemoveUserFromGroup(ctx context.Context, userID, groupID string) error {
|
||||
result := s.db.Delete(&types.GroupUser{}, "group_id = ? AND user_id = ?", groupID, userID)
|
||||
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to remove user %s from group %s: %v", userID, groupID, result.Error)
|
||||
return status.Errorf(status.Internal, "failed to remove user from group")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
log.WithContext(ctx).Warnf("user %s was not in group %s", userID, groupID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePeerFromAllGroups removes a peer from all groups
|
||||
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
|
||||
err := s.db.
|
||||
@@ -2745,6 +2828,7 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng
|
||||
|
||||
for _, group := range groups {
|
||||
group.LoadGroupPeers()
|
||||
group.LoadGroupUsers()
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
@@ -3146,6 +3230,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
|
||||
}
|
||||
|
||||
group.LoadGroupPeers()
|
||||
group.LoadGroupUsers()
|
||||
|
||||
return group, nil
|
||||
}
|
||||
@@ -3177,6 +3262,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
||||
}
|
||||
|
||||
group.LoadGroupPeers()
|
||||
group.LoadGroupUsers()
|
||||
|
||||
return &group, nil
|
||||
}
|
||||
@@ -3198,6 +3284,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
|
||||
groupsMap := make(map[string]*types.Group)
|
||||
for _, group := range groups {
|
||||
group.LoadGroupPeers()
|
||||
group.LoadGroupUsers()
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
|
||||
@@ -89,6 +89,8 @@ type Store interface {
|
||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||
AddUserToGroup(ctx context.Context, accountID, userID, groupID string) error
|
||||
RemoveUserFromGroup(ctx context.Context, userID, groupID string) error
|
||||
|
||||
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
|
||||
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
|
||||
@@ -350,6 +352,9 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc {
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNewField[types.User](ctx, db, "email", "")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.CleanupOrphanedIDs[types.User, types.Group](ctx, db, "auto_groups")
|
||||
},
|
||||
}
|
||||
} // migratePostAuto migrates the SQLite database to the latest schema
|
||||
func migratePostAuto(ctx context.Context, db *gorm.DB) error {
|
||||
@@ -381,6 +386,15 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
|
||||
}
|
||||
})
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateJsonToTable[types.User](ctx, db, "auto_groups", func(accountID, id, value string) any {
|
||||
return &types.GroupUser{
|
||||
AccountID: accountID,
|
||||
GroupID: value,
|
||||
UserID: id,
|
||||
}
|
||||
})
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -28,6 +28,8 @@ type Group struct {
|
||||
// Peers list of the group
|
||||
Peers []string `gorm:"-"` // Peers and GroupPeers list will be ignored when writing to the DB. Use AddPeerToGroup and RemovePeerFromGroup methods to modify group membership
|
||||
GroupPeers []GroupPeer `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"`
|
||||
Users []string `gorm:"-"`
|
||||
GroupUsers []GroupUser `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"`
|
||||
|
||||
// Resources contains a list of resources in that group
|
||||
Resources []Resource `gorm:"serializer:json"`
|
||||
@@ -41,6 +43,12 @@ type GroupPeer struct {
|
||||
PeerID string `gorm:"primaryKey"`
|
||||
}
|
||||
|
||||
type GroupUser struct {
|
||||
AccountID string `gorm:"index"`
|
||||
GroupID string `gorm:"primaryKey"`
|
||||
UserID string `gorm:"primaryKey"`
|
||||
}
|
||||
|
||||
func (g *Group) LoadGroupPeers() {
|
||||
g.Peers = make([]string, len(g.GroupPeers))
|
||||
for i, peer := range g.GroupPeers {
|
||||
@@ -61,6 +69,26 @@ func (g *Group) StoreGroupPeers() {
|
||||
g.Peers = []string{}
|
||||
}
|
||||
|
||||
func (g *Group) LoadGroupUsers() {
|
||||
g.Users = make([]string, len(g.GroupUsers))
|
||||
for i, user := range g.GroupUsers {
|
||||
g.Users[i] = user.UserID
|
||||
}
|
||||
g.GroupUsers = []GroupUser{}
|
||||
}
|
||||
|
||||
func (g *Group) StoreGroupUsers() {
|
||||
g.GroupUsers = make([]GroupUser, len(g.Users))
|
||||
for i, user := range g.Users {
|
||||
g.GroupUsers[i] = GroupUser{
|
||||
AccountID: g.AccountID,
|
||||
GroupID: g.ID,
|
||||
UserID: user,
|
||||
}
|
||||
}
|
||||
g.Users = []string{}
|
||||
}
|
||||
|
||||
// EventMeta returns activity event meta related to the group
|
||||
func (g *Group) EventMeta() map[string]any {
|
||||
return map[string]any{"name": g.Name}
|
||||
@@ -78,11 +106,13 @@ func (g *Group) Copy() *Group {
|
||||
Issued: g.Issued,
|
||||
Peers: make([]string, len(g.Peers)),
|
||||
GroupPeers: make([]GroupPeer, len(g.GroupPeers)),
|
||||
GroupUsers: make([]GroupUser, len(g.GroupUsers)),
|
||||
Resources: make([]Resource, len(g.Resources)),
|
||||
IntegrationReference: g.IntegrationReference,
|
||||
}
|
||||
copy(group.Peers, g.Peers)
|
||||
copy(group.GroupPeers, g.GroupPeers)
|
||||
copy(group.GroupUsers, g.GroupUsers)
|
||||
copy(group.Resources, g.Resources)
|
||||
return group
|
||||
}
|
||||
|
||||
@@ -85,9 +85,11 @@ type User struct {
|
||||
// ServiceUserName is only set if IsServiceUser is true
|
||||
ServiceUserName string
|
||||
// AutoGroups is a list of Group IDs to auto-assign to peers registered by this user
|
||||
AutoGroups []string `gorm:"serializer:json"`
|
||||
PATs map[string]*PersonalAccessToken `gorm:"-"`
|
||||
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
|
||||
AutoGroups []string `gorm:"-"`
|
||||
// GroupUsers replaces old AutoGroups
|
||||
Groups []*GroupUser `gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
|
||||
PATs map[string]*PersonalAccessToken `gorm:"-"`
|
||||
PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"`
|
||||
// Blocked indicates whether the user is blocked. Blocked users can't use the system.
|
||||
Blocked bool
|
||||
// PendingApproval indicates whether the user requires approval before being activated
|
||||
@@ -106,6 +108,24 @@ type User struct {
|
||||
Email string `gorm:"default:''"`
|
||||
}
|
||||
|
||||
func (u *User) LoadAutoGroups() {
|
||||
u.AutoGroups = make([]string, 0, len(u.Groups))
|
||||
for _, group := range u.Groups {
|
||||
u.AutoGroups = append(u.AutoGroups, group.GroupID)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *User) StoreAutoGroups() {
|
||||
u.Groups = make([]*GroupUser, 0, len(u.Groups))
|
||||
for _, groupID := range u.AutoGroups {
|
||||
u.Groups = append(u.Groups, &GroupUser{
|
||||
AccountID: u.AccountID,
|
||||
GroupID: groupID,
|
||||
UserID: u.Id,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// IsBlocked returns true if the user is blocked, false otherwise
|
||||
func (u *User) IsBlocked() bool {
|
||||
return u.Blocked
|
||||
@@ -198,8 +218,11 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
|
||||
// Copy the user
|
||||
func (u *User) Copy() *User {
|
||||
groupUsers := make([]*GroupUser, len(u.Groups))
|
||||
copy(groupUsers, u.Groups)
|
||||
autoGroups := make([]string, len(u.AutoGroups))
|
||||
copy(autoGroups, u.AutoGroups)
|
||||
|
||||
pats := make(map[string]*PersonalAccessToken, len(u.PATs))
|
||||
for k, v := range u.PATs {
|
||||
pats[k] = v.Copy()
|
||||
@@ -221,6 +244,7 @@ func (u *User) Copy() *User {
|
||||
IntegrationReference: u.IntegrationReference,
|
||||
Email: u.Email,
|
||||
Name: u.Name,
|
||||
Groups: groupUsers,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -45,7 +45,21 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI
|
||||
newUser.AccountID = accountID
|
||||
log.WithContext(ctx).Debugf("New User: %v", newUser)
|
||||
|
||||
if err = am.Store.SaveUser(ctx, newUser); err != nil {
|
||||
if err = am.Store.ExecuteInTransaction(ctx, func(tx store.Store) error {
|
||||
err = tx.SaveUser(ctx, newUser)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, groupID := range autoGroups {
|
||||
err = tx.AddUserToGroup(ctx, accountID, newUserID, groupID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add user to group %s: %w", groupID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -119,7 +133,6 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
||||
Id: idpUser.ID,
|
||||
AccountID: accountID,
|
||||
Role: types.StrRoleToUserRole(invite.Role),
|
||||
AutoGroups: invite.AutoGroups,
|
||||
Issued: invite.Issued,
|
||||
IntegrationReference: invite.IntegrationReference,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
@@ -127,6 +140,23 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
||||
Name: invite.Name,
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(tx store.Store) error {
|
||||
err = tx.SaveUser(ctx, newUser)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, group := range invite.AutoGroups {
|
||||
err = tx.AddUserToGroup(ctx, accountID, userID, group)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add user to group %s: %w", group, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save user: %w", err)
|
||||
}
|
||||
|
||||
if err = am.Store.SaveUser(ctx, newUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -737,28 +767,63 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
||||
peersToExpire = userPeers
|
||||
}
|
||||
|
||||
var removedGroups, addedGroups []string
|
||||
if update.AutoGroups != nil && settings.GroupsPropagationEnabled {
|
||||
removedGroups = util.Difference(oldUser.AutoGroups, update.AutoGroups)
|
||||
addedGroups = util.Difference(update.AutoGroups, oldUser.AutoGroups)
|
||||
updateAccountPeers, removedGroupsIDs, addedGroupsIDs, err := am.processUserGroupsUpdate(ctx, transaction, oldUser, updatedUser, userPeers, settings)
|
||||
if err != nil {
|
||||
return false, nil, nil, nil, err
|
||||
}
|
||||
|
||||
updatedUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthNone, updatedUser.Id)
|
||||
if err != nil {
|
||||
return false, nil, nil, nil, err
|
||||
}
|
||||
|
||||
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, isNewUser, removedGroupsIDs, addedGroupsIDs, transaction)
|
||||
|
||||
return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) processUserGroupsUpdate(ctx context.Context, transaction store.Store, oldUser *types.User, updatedUser *types.User, userPeers []*nbpeer.Peer, settings *types.Settings) (bool, []string, []string, error) {
|
||||
removedGroups := util.Difference(oldUser.AutoGroups, updatedUser.AutoGroups)
|
||||
addedGroups := util.Difference(updatedUser.AutoGroups, oldUser.AutoGroups)
|
||||
|
||||
updateAccountPeers := len(userPeers) > 0
|
||||
|
||||
removedGroupsIDs := make([]string, 0, len(removedGroups))
|
||||
for _, id := range removedGroups {
|
||||
err := transaction.RemoveUserFromGroup(ctx, updatedUser.Id, id)
|
||||
if err != nil {
|
||||
return false, nil, nil, fmt.Errorf("failed to remove user %s from group %s: %w", updatedUser.Id, id, err)
|
||||
}
|
||||
updateAccountPeers = true
|
||||
removedGroupsIDs = append(removedGroupsIDs, id)
|
||||
}
|
||||
|
||||
addedGroupsIDs := make([]string, 0, len(addedGroups))
|
||||
for _, id := range addedGroups {
|
||||
err := transaction.AddUserToGroup(ctx, updatedUser.AccountID, updatedUser.Id, id)
|
||||
if err != nil {
|
||||
return false, nil, nil, fmt.Errorf("failed to add user %s to group %s: %w", updatedUser.Id, id, err)
|
||||
}
|
||||
updateAccountPeers = true
|
||||
addedGroupsIDs = append(addedGroupsIDs, id)
|
||||
}
|
||||
|
||||
if updatedUser.Groups != nil && settings.GroupsPropagationEnabled {
|
||||
for _, peer := range userPeers {
|
||||
for _, groupID := range removedGroups {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peer.ID, groupID); err != nil {
|
||||
return false, nil, nil, nil, fmt.Errorf("failed to remove peer %s from group %s: %w", peer.ID, groupID, err)
|
||||
for _, id := range removedGroups {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peer.ID, id); err != nil {
|
||||
return false, nil, nil, fmt.Errorf("failed to remove peer %s from group %s: %w", peer.ID, id, err)
|
||||
}
|
||||
}
|
||||
for _, groupID := range addedGroups {
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil {
|
||||
return false, nil, nil, nil, fmt.Errorf("failed to add peer %s to group %s: %w", peer.ID, groupID, err)
|
||||
for _, id := range addedGroups {
|
||||
if err := transaction.AddPeerToGroup(ctx, updatedUser.AccountID, peer.ID, id); err != nil {
|
||||
return false, nil, nil, fmt.Errorf("failed to add peer %s to group %s: %w", peer.ID, id, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
updateAccountPeers := len(userPeers) > 0
|
||||
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, isNewUser, removedGroups, addedGroups, transaction)
|
||||
|
||||
return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil
|
||||
return updateAccountPeers, removedGroupsIDs, addedGroupsIDs, nil
|
||||
}
|
||||
|
||||
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.
|
||||
|
||||
Reference in New Issue
Block a user