migrate auto groups to different table

This commit is contained in:
pascal
2026-01-08 15:45:48 +01:00
parent fb71b0d04b
commit f7ee019f26
7 changed files with 387 additions and 56 deletions

View File

@@ -119,7 +119,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
return nil, fmt.Errorf("migratePreAuto: %w", err)
}
err = db.AutoMigrate(
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, &types.GroupUser{},
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
@@ -185,6 +185,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
for _, group := range account.GroupsG {
group.StoreGroupPeers()
group.StoreGroupUsers()
}
err := s.transaction(func(tx *gorm.DB) error {
@@ -243,6 +244,7 @@ func generateAccountSQLTypes(account *types.Account) {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
user.LoadAutoGroups()
account.UsersG = append(account.UsersG, *user)
}
@@ -453,6 +455,7 @@ func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
userCopy := user.Copy()
userCopy.Email = user.Email
userCopy.Name = user.Name
userCopy.StoreAutoGroups()
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt user: %w", err)
}
@@ -472,6 +475,7 @@ func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error {
userCopy := user.Copy()
userCopy.Email = user.Email
userCopy.Name = user.Name
userCopy.StoreAutoGroups()
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt user: %w", err)
@@ -617,6 +621,7 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
var user types.User
result := tx.
Preload("Groups").
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
Where("personal_access_tokens.id = ?", patID).Take(&user)
if result.Error != nil {
@@ -631,6 +636,8 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
return nil, fmt.Errorf("decrypt user: %w", err)
}
user.LoadAutoGroups()
return &user, nil
}
@@ -641,7 +648,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
}
var user types.User
result := tx.Take(&user, idQueryCondition, userID)
result := tx.Preload("Groups").Take(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID)
@@ -653,6 +660,8 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
return nil, fmt.Errorf("decrypt user: %w", err)
}
user.LoadAutoGroups()
return &user, nil
}
@@ -680,7 +689,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
}
var users []*types.User
result := tx.Find(&users, accountIDCondition, accountID)
result := tx.Preload("Groups").Find(&users, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@@ -693,6 +702,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt user: %w", err)
}
user.LoadAutoGroups()
}
return users, nil
@@ -705,7 +715,7 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
}
var user types.User
result := tx.Take(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
result := tx.Preload("Groups").Take(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account owner not found: index lookup failed")
@@ -717,6 +727,8 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
return nil, fmt.Errorf("decrypt user: %w", err)
}
user.LoadAutoGroups()
return &user, nil
}
@@ -738,6 +750,7 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
for _, g := range groups {
g.LoadGroupPeers()
g.LoadGroupUsers()
}
return groups, nil
@@ -767,6 +780,7 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
for _, g := range groups {
g.LoadGroupPeers()
g.LoadGroupUsers()
}
return groups, nil
@@ -867,6 +881,8 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
Preload("SetupKeysG").
Preload("PeersG").
Preload("UsersG").
Preload("UsersG.GroupUser").
Preload("GroupsG").
Preload("GroupsG.GroupPeers").
Preload("RoutesG").
Preload("NameServerGroupsG").
@@ -908,9 +924,10 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
pat.UserID = ""
user.PATs[pat.ID] = &pat
}
if user.AutoGroups == nil {
user.AutoGroups = []string{}
if user.Groups == nil {
user.Groups = []*types.GroupUser{}
}
user.LoadAutoGroups()
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt user: %w", err)
}
@@ -1116,8 +1133,8 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
groupIDs = append(groupIDs, g.ID)
}
wg.Add(3)
errChan = make(chan error, 3)
wg.Add(4)
errChan = make(chan error, 4)
var pats []types.PersonalAccessToken
go func() {
@@ -1149,6 +1166,16 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
}
}()
var groupUsers []types.GroupUser
go func() {
defer wg.Done()
var err error
groupUsers, err = s.getGroupUsers(ctx, userIDs)
if err != nil {
errChan <- err
}
}()
wg.Wait()
close(errChan)
for e := range errChan {
@@ -1174,6 +1201,12 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID)
}
groupsByUserID := make(map[string][]*types.GroupUser)
for i := range groupUsers {
gu := &groupUsers[i]
groupsByUserID[gu.UserID] = append(groupsByUserID[gu.UserID], gu)
}
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
for i := range account.SetupKeysG {
key := &account.SetupKeysG[i]
@@ -1199,6 +1232,8 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
user.PATs[pat.ID] = pat
}
}
user.Groups = groupsByUserID[user.Id]
user.LoadAutoGroups()
account.Users[user.Id] = user
}
@@ -1596,43 +1631,40 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
}
func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) {
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1`
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
return nil, err
}
users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) {
var u types.User
var autoGroups []byte
var lastLogin, createdAt sql.NullTime
var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name)
if err == nil {
if lastLogin.Valid {
u.LastLogin = &lastLogin.Time
}
if createdAt.Valid {
u.CreatedAt = createdAt.Time
}
if isServiceUser.Valid {
u.IsServiceUser = isServiceUser.Bool
}
if nonDeletable.Valid {
u.NonDeletable = nonDeletable.Bool
}
if blocked.Valid {
u.Blocked = blocked.Bool
}
if pendingApproval.Valid {
u.PendingApproval = pendingApproval.Bool
}
if autoGroups != nil {
_ = json.Unmarshal(autoGroups, &u.AutoGroups)
} else {
u.AutoGroups = []string{}
}
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name)
if err != nil {
return u, err
}
return u, err
if lastLogin.Valid {
u.LastLogin = &lastLogin.Time
}
if createdAt.Valid {
u.CreatedAt = createdAt.Time
}
if isServiceUser.Valid {
u.IsServiceUser = isServiceUser.Bool
}
if nonDeletable.Valid {
u.NonDeletable = nonDeletable.Bool
}
if blocked.Valid {
u.Blocked = blocked.Bool
}
if pendingApproval.Valid {
u.PendingApproval = pendingApproval.Bool
}
return u, nil
})
if err != nil {
return nil, err
@@ -2038,6 +2070,22 @@ func (s *SqlStore) getGroupPeers(ctx context.Context, groupIDs []string) ([]type
return groupPeers, nil
}
func (s *SqlStore) getGroupUsers(ctx context.Context, userIDs []string) ([]types.GroupUser, error) {
if len(userIDs) == 0 {
return nil, nil
}
const query = `SELECT account_id, group_id, user_id FROM group_users WHERE user_id = ANY($1)`
rows, err := s.pool.Query(ctx, query, userIDs)
if err != nil {
return nil, err
}
groupUsers, err := pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupUser])
if err != nil {
return nil, err
}
return groupUsers, nil
}
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) {
var user types.User
result := s.db.Select("account_id").Take(&user, idQueryCondition, userID)
@@ -2659,6 +2707,41 @@ func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, group
return nil
}
func (s *SqlStore) AddUserToGroup(ctx context.Context, accountID, userID, groupID string) error {
user := &types.GroupUser{
AccountID: accountID,
GroupID: groupID,
UserID: userID,
}
err := s.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "group_id"}, {Name: "user_id"}},
DoNothing: true,
}).Create(user).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to add user %s to group %s for account %s: %v", userID, groupID, accountID, err)
return status.Errorf(status.Internal, "failed to add user to group")
}
return nil
}
func (s *SqlStore) RemoveUserFromGroup(ctx context.Context, userID, groupID string) error {
result := s.db.Delete(&types.GroupUser{}, "group_id = ? AND user_id = ?", groupID, userID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to remove user %s from group %s: %v", userID, groupID, result.Error)
return status.Errorf(status.Internal, "failed to remove user from group")
}
if result.RowsAffected == 0 {
log.WithContext(ctx).Warnf("user %s was not in group %s", userID, groupID)
}
return nil
}
// RemovePeerFromAllGroups removes a peer from all groups
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
err := s.db.
@@ -2745,6 +2828,7 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng
for _, group := range groups {
group.LoadGroupPeers()
group.LoadGroupUsers()
}
return groups, nil
@@ -3146,6 +3230,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
}
group.LoadGroupPeers()
group.LoadGroupUsers()
return group, nil
}
@@ -3177,6 +3262,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
}
group.LoadGroupPeers()
group.LoadGroupUsers()
return &group, nil
}
@@ -3198,6 +3284,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
groupsMap := make(map[string]*types.Group)
for _, group := range groups {
group.LoadGroupPeers()
group.LoadGroupUsers()
groupsMap[group.ID] = group
}