[management] Refactor User JWT group sync (#2690)

* Refactor GetAccountIDByUserOrAccountID

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* sync user jwt group changes

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* propagate jwt group changes to peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix no jwt groups synced

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests and lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Move the account peer update outside the transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* move updateUserPeersInGroups to account manager

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* move event store outside of transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* get user with update lock

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Run jwt sync in transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
Bethuel Mmbaga
2024-10-04 17:17:01 +03:00
committed by GitHub
parent 158936fb15
commit 7f09b39769
8 changed files with 520 additions and 177 deletions

View File

@@ -10,6 +10,7 @@ import (
"path/filepath"
"runtime"
"runtime/debug"
"slices"
"strings"
"sync"
"time"
@@ -378,15 +379,26 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
Create(&usersToSave).Error
}
// SaveGroups saves the given list of groups to the database.
// It updates existing groups if a conflict occurs.
func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
groupsToSave := make([]nbgroup.Group, 0, len(groups))
for _, group := range groups {
group.AccountID = accountID
groupsToSave = append(groupsToSave, *group)
// SaveUser saves the given user to the database.
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
}
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error
return nil
}
// SaveGroups saves the given list of groups to the database.
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error {
if len(groups) == 0 {
return nil
}
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
}
return nil
}
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
@@ -1021,6 +1033,89 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
return nil
}
// AddUserPeersToGroups adds the user's peers to specified groups in database.
func (s *SqlStore) AddUserPeersToGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error {
if len(groupIDs) == 0 {
return nil
}
var userPeerIDs []string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id").
Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs)
if result.Error != nil {
return status.Errorf(status.Internal, "issue finding user peers")
}
groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs))
for _, gid := range groupIDs {
group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
if err != nil {
return err
}
groupPeers := make(map[string]struct{})
for _, pid := range group.Peers {
groupPeers[pid] = struct{}{}
}
for _, pid := range userPeerIDs {
groupPeers[pid] = struct{}{}
}
group.Peers = group.Peers[:0]
for pid := range groupPeers {
group.Peers = append(group.Peers, pid)
}
groupsToUpdate = append(groupsToUpdate, group)
}
return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate)
}
// RemoveUserPeersFromGroups removes the user's peers from specified groups in database.
func (s *SqlStore) RemoveUserPeersFromGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error {
if len(groupIDs) == 0 {
return nil
}
var userPeerIDs []string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(LockingStrengthShare)}).Select("id").
Where("account_id = ? AND user_id = ?", accountID, userID).Model(&nbpeer.Peer{}).Find(&userPeerIDs)
if result.Error != nil {
return status.Errorf(status.Internal, "issue finding user peers")
}
groupsToUpdate := make([]*nbgroup.Group, 0, len(groupIDs))
for _, gid := range groupIDs {
group, err := s.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
if err != nil {
return err
}
if group.Name == "All" {
continue
}
update := make([]string, 0, len(group.Peers))
for _, pid := range group.Peers {
if !slices.Contains(userPeerIDs, pid) {
update = append(update, pid)
}
}
group.Peers = update
groupsToUpdate = append(groupsToUpdate, group)
}
return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate)
}
// GetUserPeers retrieves peers for a user.
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID)
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account")
@@ -1127,6 +1222,15 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
return &group, nil
}
// SaveGroup saves a group to the store.
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error)
}
return nil
}
// GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)