migrate auto groups to different table

This commit is contained in:
pascal
2026-01-08 15:45:48 +01:00
parent fb71b0d04b
commit f7ee019f26
7 changed files with 387 additions and 56 deletions

View File

@@ -119,7 +119,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
return nil, fmt.Errorf("migratePreAuto: %w", err)
}
err = db.AutoMigrate(
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, &types.GroupUser{},
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
@@ -185,6 +185,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
for _, group := range account.GroupsG {
group.StoreGroupPeers()
group.StoreGroupUsers()
}
err := s.transaction(func(tx *gorm.DB) error {
@@ -243,6 +244,7 @@ func generateAccountSQLTypes(account *types.Account) {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
user.LoadAutoGroups()
account.UsersG = append(account.UsersG, *user)
}
@@ -453,6 +455,7 @@ func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
userCopy := user.Copy()
userCopy.Email = user.Email
userCopy.Name = user.Name
userCopy.StoreAutoGroups()
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt user: %w", err)
}
@@ -472,6 +475,7 @@ func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error {
userCopy := user.Copy()
userCopy.Email = user.Email
userCopy.Name = user.Name
userCopy.StoreAutoGroups()
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt user: %w", err)
@@ -617,6 +621,7 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
var user types.User
result := tx.
Preload("Groups").
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
Where("personal_access_tokens.id = ?", patID).Take(&user)
if result.Error != nil {
@@ -631,6 +636,8 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
return nil, fmt.Errorf("decrypt user: %w", err)
}
user.LoadAutoGroups()
return &user, nil
}
@@ -641,7 +648,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
}
var user types.User
result := tx.Take(&user, idQueryCondition, userID)
result := tx.Preload("Groups").Take(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID)
@@ -653,6 +660,8 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
return nil, fmt.Errorf("decrypt user: %w", err)
}
user.LoadAutoGroups()
return &user, nil
}
@@ -680,7 +689,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
}
var users []*types.User
result := tx.Find(&users, accountIDCondition, accountID)
result := tx.Preload("Groups").Find(&users, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@@ -693,6 +702,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt user: %w", err)
}
user.LoadAutoGroups()
}
return users, nil
@@ -705,7 +715,7 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
}
var user types.User
result := tx.Take(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
result := tx.Preload("Groups").Take(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account owner not found: index lookup failed")
@@ -717,6 +727,8 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
return nil, fmt.Errorf("decrypt user: %w", err)
}
user.LoadAutoGroups()
return &user, nil
}
@@ -738,6 +750,7 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
for _, g := range groups {
g.LoadGroupPeers()
g.LoadGroupUsers()
}
return groups, nil
@@ -767,6 +780,7 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
for _, g := range groups {
g.LoadGroupPeers()
g.LoadGroupUsers()
}
return groups, nil
@@ -867,6 +881,8 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
Preload("SetupKeysG").
Preload("PeersG").
Preload("UsersG").
Preload("UsersG.GroupUser").
Preload("GroupsG").
Preload("GroupsG.GroupPeers").
Preload("RoutesG").
Preload("NameServerGroupsG").
@@ -908,9 +924,10 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
pat.UserID = ""
user.PATs[pat.ID] = &pat
}
if user.AutoGroups == nil {
user.AutoGroups = []string{}
if user.Groups == nil {
user.Groups = []*types.GroupUser{}
}
user.LoadAutoGroups()
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt user: %w", err)
}
@@ -1116,8 +1133,8 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
groupIDs = append(groupIDs, g.ID)
}
wg.Add(3)
errChan = make(chan error, 3)
wg.Add(4)
errChan = make(chan error, 4)
var pats []types.PersonalAccessToken
go func() {
@@ -1149,6 +1166,16 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
}
}()
var groupUsers []types.GroupUser
go func() {
defer wg.Done()
var err error
groupUsers, err = s.getGroupUsers(ctx, userIDs)
if err != nil {
errChan <- err
}
}()
wg.Wait()
close(errChan)
for e := range errChan {
@@ -1174,6 +1201,12 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID)
}
groupsByUserID := make(map[string][]*types.GroupUser)
for i := range groupUsers {
gu := &groupUsers[i]
groupsByUserID[gu.UserID] = append(groupsByUserID[gu.UserID], gu)
}
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
for i := range account.SetupKeysG {
key := &account.SetupKeysG[i]
@@ -1199,6 +1232,8 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
user.PATs[pat.ID] = pat
}
}
user.Groups = groupsByUserID[user.Id]
user.LoadAutoGroups()
account.Users[user.Id] = user
}
@@ -1596,43 +1631,40 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
}
func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) {
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1`
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
return nil, err
}
users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) {
var u types.User
var autoGroups []byte
var lastLogin, createdAt sql.NullTime
var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name)
if err == nil {
if lastLogin.Valid {
u.LastLogin = &lastLogin.Time
}
if createdAt.Valid {
u.CreatedAt = createdAt.Time
}
if isServiceUser.Valid {
u.IsServiceUser = isServiceUser.Bool
}
if nonDeletable.Valid {
u.NonDeletable = nonDeletable.Bool
}
if blocked.Valid {
u.Blocked = blocked.Bool
}
if pendingApproval.Valid {
u.PendingApproval = pendingApproval.Bool
}
if autoGroups != nil {
_ = json.Unmarshal(autoGroups, &u.AutoGroups)
} else {
u.AutoGroups = []string{}
}
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name)
if err != nil {
return u, err
}
return u, err
if lastLogin.Valid {
u.LastLogin = &lastLogin.Time
}
if createdAt.Valid {
u.CreatedAt = createdAt.Time
}
if isServiceUser.Valid {
u.IsServiceUser = isServiceUser.Bool
}
if nonDeletable.Valid {
u.NonDeletable = nonDeletable.Bool
}
if blocked.Valid {
u.Blocked = blocked.Bool
}
if pendingApproval.Valid {
u.PendingApproval = pendingApproval.Bool
}
return u, nil
})
if err != nil {
return nil, err
@@ -2038,6 +2070,22 @@ func (s *SqlStore) getGroupPeers(ctx context.Context, groupIDs []string) ([]type
return groupPeers, nil
}
func (s *SqlStore) getGroupUsers(ctx context.Context, userIDs []string) ([]types.GroupUser, error) {
if len(userIDs) == 0 {
return nil, nil
}
const query = `SELECT account_id, group_id, user_id FROM group_users WHERE user_id = ANY($1)`
rows, err := s.pool.Query(ctx, query, userIDs)
if err != nil {
return nil, err
}
groupUsers, err := pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupUser])
if err != nil {
return nil, err
}
return groupUsers, nil
}
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) {
var user types.User
result := s.db.Select("account_id").Take(&user, idQueryCondition, userID)
@@ -2659,6 +2707,41 @@ func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, group
return nil
}
func (s *SqlStore) AddUserToGroup(ctx context.Context, accountID, userID, groupID string) error {
user := &types.GroupUser{
AccountID: accountID,
GroupID: groupID,
UserID: userID,
}
err := s.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "group_id"}, {Name: "user_id"}},
DoNothing: true,
}).Create(user).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to add user %s to group %s for account %s: %v", userID, groupID, accountID, err)
return status.Errorf(status.Internal, "failed to add user to group")
}
return nil
}
func (s *SqlStore) RemoveUserFromGroup(ctx context.Context, userID, groupID string) error {
result := s.db.Delete(&types.GroupUser{}, "group_id = ? AND user_id = ?", groupID, userID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to remove user %s from group %s: %v", userID, groupID, result.Error)
return status.Errorf(status.Internal, "failed to remove user from group")
}
if result.RowsAffected == 0 {
log.WithContext(ctx).Warnf("user %s was not in group %s", userID, groupID)
}
return nil
}
// RemovePeerFromAllGroups removes a peer from all groups
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
err := s.db.
@@ -2745,6 +2828,7 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng
for _, group := range groups {
group.LoadGroupPeers()
group.LoadGroupUsers()
}
return groups, nil
@@ -3146,6 +3230,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
}
group.LoadGroupPeers()
group.LoadGroupUsers()
return group, nil
}
@@ -3177,6 +3262,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
}
group.LoadGroupPeers()
group.LoadGroupUsers()
return &group, nil
}
@@ -3198,6 +3284,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
groupsMap := make(map[string]*types.Group)
for _, group := range groups {
group.LoadGroupPeers()
group.LoadGroupUsers()
groupsMap[group.ID] = group
}

View File

@@ -89,6 +89,8 @@ type Store interface {
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error
AddUserToGroup(ctx context.Context, accountID, userID, groupID string) error
RemoveUserFromGroup(ctx context.Context, userID, groupID string) error
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
@@ -350,6 +352,9 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc {
func(db *gorm.DB) error {
return migration.MigrateNewField[types.User](ctx, db, "email", "")
},
func(db *gorm.DB) error {
return migration.CleanupOrphanedIDs[types.User, types.Group](ctx, db, "auto_groups")
},
}
} // migratePostAuto migrates the SQLite database to the latest schema
func migratePostAuto(ctx context.Context, db *gorm.DB) error {
@@ -381,6 +386,15 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc {
}
})
},
func(db *gorm.DB) error {
return migration.MigrateJsonToTable[types.User](ctx, db, "auto_groups", func(accountID, id, value string) any {
return &types.GroupUser{
AccountID: accountID,
GroupID: value,
UserID: id,
}
})
},
}
}