mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 01:06:45 +00:00
migrate auto groups to different table
This commit is contained in:
@@ -119,7 +119,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
return nil, fmt.Errorf("migratePreAuto: %w", err)
|
||||
}
|
||||
err = db.AutoMigrate(
|
||||
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
|
||||
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, &types.GroupUser{},
|
||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||
@@ -185,6 +185,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
|
||||
|
||||
for _, group := range account.GroupsG {
|
||||
group.StoreGroupPeers()
|
||||
group.StoreGroupUsers()
|
||||
}
|
||||
|
||||
err := s.transaction(func(tx *gorm.DB) error {
|
||||
@@ -243,6 +244,7 @@ func generateAccountSQLTypes(account *types.Account) {
|
||||
pat.ID = id
|
||||
user.PATsG = append(user.PATsG, *pat)
|
||||
}
|
||||
user.LoadAutoGroups()
|
||||
account.UsersG = append(account.UsersG, *user)
|
||||
}
|
||||
|
||||
@@ -453,6 +455,7 @@ func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
|
||||
userCopy := user.Copy()
|
||||
userCopy.Email = user.Email
|
||||
userCopy.Name = user.Name
|
||||
userCopy.StoreAutoGroups()
|
||||
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return fmt.Errorf("encrypt user: %w", err)
|
||||
}
|
||||
@@ -472,6 +475,7 @@ func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error {
|
||||
userCopy := user.Copy()
|
||||
userCopy.Email = user.Email
|
||||
userCopy.Name = user.Name
|
||||
userCopy.StoreAutoGroups()
|
||||
|
||||
if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return fmt.Errorf("encrypt user: %w", err)
|
||||
@@ -617,6 +621,7 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
|
||||
|
||||
var user types.User
|
||||
result := tx.
|
||||
Preload("Groups").
|
||||
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
|
||||
Where("personal_access_tokens.id = ?", patID).Take(&user)
|
||||
if result.Error != nil {
|
||||
@@ -631,6 +636,8 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
|
||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||
}
|
||||
|
||||
user.LoadAutoGroups()
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
@@ -641,7 +648,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
}
|
||||
|
||||
var user types.User
|
||||
result := tx.Take(&user, idQueryCondition, userID)
|
||||
result := tx.Preload("Groups").Take(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewUserNotFoundError(userID)
|
||||
@@ -653,6 +660,8 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||
}
|
||||
|
||||
user.LoadAutoGroups()
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
@@ -680,7 +689,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
|
||||
}
|
||||
|
||||
var users []*types.User
|
||||
result := tx.Find(&users, accountIDCondition, accountID)
|
||||
result := tx.Preload("Groups").Find(&users, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
@@ -693,6 +702,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
|
||||
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||
}
|
||||
user.LoadAutoGroups()
|
||||
}
|
||||
|
||||
return users, nil
|
||||
@@ -705,7 +715,7 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
|
||||
}
|
||||
|
||||
var user types.User
|
||||
result := tx.Take(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
|
||||
result := tx.Preload("Groups").Take(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account owner not found: index lookup failed")
|
||||
@@ -717,6 +727,8 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
|
||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||
}
|
||||
|
||||
user.LoadAutoGroups()
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
@@ -738,6 +750,7 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
|
||||
|
||||
for _, g := range groups {
|
||||
g.LoadGroupPeers()
|
||||
g.LoadGroupUsers()
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
@@ -767,6 +780,7 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
|
||||
|
||||
for _, g := range groups {
|
||||
g.LoadGroupPeers()
|
||||
g.LoadGroupUsers()
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
@@ -867,6 +881,8 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
|
||||
Preload("SetupKeysG").
|
||||
Preload("PeersG").
|
||||
Preload("UsersG").
|
||||
Preload("UsersG.GroupUser").
|
||||
Preload("GroupsG").
|
||||
Preload("GroupsG.GroupPeers").
|
||||
Preload("RoutesG").
|
||||
Preload("NameServerGroupsG").
|
||||
@@ -908,9 +924,10 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
|
||||
pat.UserID = ""
|
||||
user.PATs[pat.ID] = &pat
|
||||
}
|
||||
if user.AutoGroups == nil {
|
||||
user.AutoGroups = []string{}
|
||||
if user.Groups == nil {
|
||||
user.Groups = []*types.GroupUser{}
|
||||
}
|
||||
user.LoadAutoGroups()
|
||||
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||
}
|
||||
@@ -1116,8 +1133,8 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
||||
groupIDs = append(groupIDs, g.ID)
|
||||
}
|
||||
|
||||
wg.Add(3)
|
||||
errChan = make(chan error, 3)
|
||||
wg.Add(4)
|
||||
errChan = make(chan error, 4)
|
||||
|
||||
var pats []types.PersonalAccessToken
|
||||
go func() {
|
||||
@@ -1149,6 +1166,16 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
||||
}
|
||||
}()
|
||||
|
||||
var groupUsers []types.GroupUser
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
var err error
|
||||
groupUsers, err = s.getGroupUsers(ctx, userIDs)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
for e := range errChan {
|
||||
@@ -1174,6 +1201,12 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
||||
peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID)
|
||||
}
|
||||
|
||||
groupsByUserID := make(map[string][]*types.GroupUser)
|
||||
for i := range groupUsers {
|
||||
gu := &groupUsers[i]
|
||||
groupsByUserID[gu.UserID] = append(groupsByUserID[gu.UserID], gu)
|
||||
}
|
||||
|
||||
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
|
||||
for i := range account.SetupKeysG {
|
||||
key := &account.SetupKeysG[i]
|
||||
@@ -1199,6 +1232,8 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
||||
user.PATs[pat.ID] = pat
|
||||
}
|
||||
}
|
||||
user.Groups = groupsByUserID[user.Id]
|
||||
user.LoadAutoGroups()
|
||||
account.Users[user.Id] = user
|
||||
}
|
||||
|
||||
@@ -1596,43 +1631,40 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
|
||||
}
|
||||
|
||||
func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) {
|
||||
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1`
|
||||
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) {
|
||||
var u types.User
|
||||
var autoGroups []byte
|
||||
var lastLogin, createdAt sql.NullTime
|
||||
var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool
|
||||
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name)
|
||||
if err == nil {
|
||||
if lastLogin.Valid {
|
||||
u.LastLogin = &lastLogin.Time
|
||||
}
|
||||
if createdAt.Valid {
|
||||
u.CreatedAt = createdAt.Time
|
||||
}
|
||||
if isServiceUser.Valid {
|
||||
u.IsServiceUser = isServiceUser.Bool
|
||||
}
|
||||
if nonDeletable.Valid {
|
||||
u.NonDeletable = nonDeletable.Bool
|
||||
}
|
||||
if blocked.Valid {
|
||||
u.Blocked = blocked.Bool
|
||||
}
|
||||
if pendingApproval.Valid {
|
||||
u.PendingApproval = pendingApproval.Bool
|
||||
}
|
||||
if autoGroups != nil {
|
||||
_ = json.Unmarshal(autoGroups, &u.AutoGroups)
|
||||
} else {
|
||||
u.AutoGroups = []string{}
|
||||
}
|
||||
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name)
|
||||
if err != nil {
|
||||
return u, err
|
||||
}
|
||||
return u, err
|
||||
|
||||
if lastLogin.Valid {
|
||||
u.LastLogin = &lastLogin.Time
|
||||
}
|
||||
if createdAt.Valid {
|
||||
u.CreatedAt = createdAt.Time
|
||||
}
|
||||
if isServiceUser.Valid {
|
||||
u.IsServiceUser = isServiceUser.Bool
|
||||
}
|
||||
if nonDeletable.Valid {
|
||||
u.NonDeletable = nonDeletable.Bool
|
||||
}
|
||||
if blocked.Valid {
|
||||
u.Blocked = blocked.Bool
|
||||
}
|
||||
if pendingApproval.Valid {
|
||||
u.PendingApproval = pendingApproval.Bool
|
||||
}
|
||||
|
||||
return u, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2038,6 +2070,22 @@ func (s *SqlStore) getGroupPeers(ctx context.Context, groupIDs []string) ([]type
|
||||
return groupPeers, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) getGroupUsers(ctx context.Context, userIDs []string) ([]types.GroupUser, error) {
|
||||
if len(userIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
const query = `SELECT account_id, group_id, user_id FROM group_users WHERE user_id = ANY($1)`
|
||||
rows, err := s.pool.Query(ctx, query, userIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
groupUsers, err := pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupUser])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return groupUsers, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) {
|
||||
var user types.User
|
||||
result := s.db.Select("account_id").Take(&user, idQueryCondition, userID)
|
||||
@@ -2659,6 +2707,41 @@ func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, group
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddUserToGroup(ctx context.Context, accountID, userID, groupID string) error {
|
||||
user := &types.GroupUser{
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
UserID: userID,
|
||||
}
|
||||
|
||||
err := s.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "group_id"}, {Name: "user_id"}},
|
||||
DoNothing: true,
|
||||
}).Create(user).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to add user %s to group %s for account %s: %v", userID, groupID, accountID, err)
|
||||
return status.Errorf(status.Internal, "failed to add user to group")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) RemoveUserFromGroup(ctx context.Context, userID, groupID string) error {
|
||||
result := s.db.Delete(&types.GroupUser{}, "group_id = ? AND user_id = ?", groupID, userID)
|
||||
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to remove user %s from group %s: %v", userID, groupID, result.Error)
|
||||
return status.Errorf(status.Internal, "failed to remove user from group")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
log.WithContext(ctx).Warnf("user %s was not in group %s", userID, groupID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePeerFromAllGroups removes a peer from all groups
|
||||
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
|
||||
err := s.db.
|
||||
@@ -2745,6 +2828,7 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng
|
||||
|
||||
for _, group := range groups {
|
||||
group.LoadGroupPeers()
|
||||
group.LoadGroupUsers()
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
@@ -3146,6 +3230,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
|
||||
}
|
||||
|
||||
group.LoadGroupPeers()
|
||||
group.LoadGroupUsers()
|
||||
|
||||
return group, nil
|
||||
}
|
||||
@@ -3177,6 +3262,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
||||
}
|
||||
|
||||
group.LoadGroupPeers()
|
||||
group.LoadGroupUsers()
|
||||
|
||||
return &group, nil
|
||||
}
|
||||
@@ -3198,6 +3284,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
|
||||
groupsMap := make(map[string]*types.Group)
|
||||
for _, group := range groups {
|
||||
group.LoadGroupPeers()
|
||||
group.LoadGroupUsers()
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user