mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-06 17:08:53 +00:00
fix tests
This commit is contained in:
@@ -1404,7 +1404,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
|
||||
addNewGroups = util.Difference(updatedAutoGroups, user.AutoGroups)
|
||||
for _, group := range addNewGroups {
|
||||
err = transaction.AddUserToGroup(ctx, userAuth.AccountId, group, user.Id)
|
||||
err = transaction.AddUserToGroup(ctx, userAuth.AccountId, user.Id, group)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error adding user to group: %w", err)
|
||||
}
|
||||
@@ -1416,6 +1416,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
return fmt.Errorf("error removing user from group: %w", err)
|
||||
}
|
||||
}
|
||||
user.AutoGroups = updatedAutoGroups
|
||||
|
||||
if err = transaction.SaveUser(ctx, user); err != nil {
|
||||
return fmt.Errorf("error saving user: %w", err)
|
||||
|
||||
@@ -1723,6 +1723,13 @@ func TestAccount_Copy(t *testing.T) {
|
||||
Id: "user1",
|
||||
Role: types.UserRoleAdmin,
|
||||
AutoGroups: []string{"group1"},
|
||||
Groups: []*types.GroupUser{
|
||||
{
|
||||
AccountID: "account1",
|
||||
UserID: "user1",
|
||||
GroupID: "group1",
|
||||
},
|
||||
},
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"pat1": {
|
||||
ID: "pat1",
|
||||
|
||||
@@ -177,7 +177,8 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
|
||||
generateAccountSQLTypes(account)
|
||||
|
||||
// Encrypt sensitive user data before saving
|
||||
for i := range account.UsersG {
|
||||
for i, user := range account.UsersG {
|
||||
user.StoreAutoGroups()
|
||||
if err := account.UsersG[i].EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return fmt.Errorf("encrypt user: %w", err)
|
||||
}
|
||||
@@ -185,7 +186,6 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
|
||||
|
||||
for _, group := range account.GroupsG {
|
||||
group.StoreGroupPeers()
|
||||
group.StoreGroupUsers()
|
||||
}
|
||||
|
||||
err := s.transaction(func(tx *gorm.DB) error {
|
||||
@@ -213,6 +213,9 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
took := time.Since(start)
|
||||
if s.metrics != nil {
|
||||
@@ -220,7 +223,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
|
||||
}
|
||||
log.WithContext(ctx).Debugf("took %d ms to persist an account to the store", took.Milliseconds())
|
||||
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateAccountSQLTypes generates the GORM compatible types for the account
|
||||
@@ -244,8 +247,8 @@ func generateAccountSQLTypes(account *types.Account) {
|
||||
pat.ID = id
|
||||
user.PATsG = append(user.PATsG, *pat)
|
||||
}
|
||||
user.LoadAutoGroups()
|
||||
account.UsersG = append(account.UsersG, *user)
|
||||
account.UsersG = append(account.UsersG, user)
|
||||
account.Users = nil
|
||||
}
|
||||
|
||||
for id, group := range account.Groups {
|
||||
@@ -750,7 +753,6 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
|
||||
|
||||
for _, g := range groups {
|
||||
g.LoadGroupPeers()
|
||||
g.LoadGroupUsers()
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
@@ -780,7 +782,6 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
|
||||
|
||||
for _, g := range groups {
|
||||
g.LoadGroupPeers()
|
||||
g.LoadGroupUsers()
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
@@ -881,9 +882,10 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
|
||||
Preload("SetupKeysG").
|
||||
Preload("PeersG").
|
||||
Preload("UsersG").
|
||||
Preload("UsersG.GroupUser").
|
||||
Preload("UsersG.Groups").
|
||||
Preload("GroupsG").
|
||||
Preload("GroupsG.GroupPeers").
|
||||
Preload("GroupsG.GroupUsers").
|
||||
Preload("RoutesG").
|
||||
Preload("NameServerGroupsG").
|
||||
Preload("PostureChecks").
|
||||
@@ -931,7 +933,7 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
|
||||
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||
}
|
||||
account.Users[user.Id] = &user
|
||||
account.Users[user.Id] = user
|
||||
user.PATsG = nil
|
||||
}
|
||||
account.UsersG = nil
|
||||
@@ -1221,7 +1223,7 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
|
||||
|
||||
account.Users = make(map[string]*types.User, len(account.UsersG))
|
||||
for i := range account.UsersG {
|
||||
user := &account.UsersG[i]
|
||||
user := account.UsersG[i]
|
||||
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt user: %w", err)
|
||||
}
|
||||
@@ -1630,19 +1632,19 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) {
|
||||
func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]*types.User, error) {
|
||||
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) {
|
||||
users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.User, error) {
|
||||
var u types.User
|
||||
var lastLogin, createdAt sql.NullTime
|
||||
var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool
|
||||
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name)
|
||||
if err != nil {
|
||||
return u, err
|
||||
return &u, err
|
||||
}
|
||||
|
||||
if lastLogin.Valid {
|
||||
@@ -1664,7 +1666,7 @@ func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User
|
||||
u.PendingApproval = pendingApproval.Bool
|
||||
}
|
||||
|
||||
return u, nil
|
||||
return &u, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2828,7 +2830,6 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng
|
||||
|
||||
for _, group := range groups {
|
||||
group.LoadGroupPeers()
|
||||
group.LoadGroupUsers()
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
@@ -3230,7 +3231,6 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
|
||||
}
|
||||
|
||||
group.LoadGroupPeers()
|
||||
group.LoadGroupUsers()
|
||||
|
||||
return group, nil
|
||||
}
|
||||
@@ -3262,7 +3262,6 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
||||
}
|
||||
|
||||
group.LoadGroupPeers()
|
||||
group.LoadGroupUsers()
|
||||
|
||||
return &group, nil
|
||||
}
|
||||
@@ -3284,7 +3283,6 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
|
||||
groupsMap := make(map[string]*types.Group)
|
||||
for _, group := range groups {
|
||||
group.LoadGroupPeers()
|
||||
group.LoadGroupUsers()
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
|
||||
@@ -178,7 +178,7 @@ func (s *SqlStore) GetAccountGormOpt(ctx context.Context, accountID string) (*ty
|
||||
if user.AutoGroups == nil {
|
||||
user.AutoGroups = []string{}
|
||||
}
|
||||
account.Users[user.Id] = &user
|
||||
account.Users[user.Id] = user
|
||||
user.PATsG = nil
|
||||
}
|
||||
account.UsersG = nil
|
||||
@@ -900,7 +900,7 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty
|
||||
|
||||
account.Users = make(map[string]*types.User, len(account.UsersG))
|
||||
for i := range account.UsersG {
|
||||
user := &account.UsersG[i]
|
||||
user := account.UsersG[i]
|
||||
user.PATs = make(map[string]*types.PersonalAccessToken)
|
||||
if userPats, ok := patsByUserID[user.Id]; ok {
|
||||
for j := range userPats {
|
||||
|
||||
@@ -87,7 +87,7 @@ type Account struct {
|
||||
Peers map[string]*nbpeer.Peer `gorm:"-"`
|
||||
PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
Users map[string]*User `gorm:"-"`
|
||||
UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
UsersG []*User `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
Groups map[string]*Group `gorm:"-"`
|
||||
GroupsG []*Group `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
Policies []*Policy `gorm:"foreignKey:AccountID;references:id"`
|
||||
|
||||
@@ -28,7 +28,6 @@ type Group struct {
|
||||
// Peers list of the group
|
||||
Peers []string `gorm:"-"` // Peers and GroupPeers list will be ignored when writing to the DB. Use AddPeerToGroup and RemovePeerFromGroup methods to modify group membership
|
||||
GroupPeers []GroupPeer `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"`
|
||||
Users []string `gorm:"-"`
|
||||
GroupUsers []GroupUser `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"`
|
||||
|
||||
// Resources contains a list of resources in that group
|
||||
@@ -69,26 +68,6 @@ func (g *Group) StoreGroupPeers() {
|
||||
g.Peers = []string{}
|
||||
}
|
||||
|
||||
func (g *Group) LoadGroupUsers() {
|
||||
g.Users = make([]string, len(g.GroupUsers))
|
||||
for i, user := range g.GroupUsers {
|
||||
g.Users[i] = user.UserID
|
||||
}
|
||||
g.GroupUsers = []GroupUser{}
|
||||
}
|
||||
|
||||
func (g *Group) StoreGroupUsers() {
|
||||
g.GroupUsers = make([]GroupUser, len(g.Users))
|
||||
for i, user := range g.Users {
|
||||
g.GroupUsers[i] = GroupUser{
|
||||
AccountID: g.AccountID,
|
||||
GroupID: g.ID,
|
||||
UserID: user,
|
||||
}
|
||||
}
|
||||
g.Users = []string{}
|
||||
}
|
||||
|
||||
// EventMeta returns activity event meta related to the group
|
||||
func (g *Group) EventMeta() map[string]any {
|
||||
return map[string]any{"name": g.Name}
|
||||
@@ -99,21 +78,41 @@ func (g *Group) EventMetaResource(resource *types.NetworkResource) map[string]an
|
||||
}
|
||||
|
||||
func (g *Group) Copy() *Group {
|
||||
var peers []string
|
||||
if g.Peers != nil {
|
||||
peers = make([]string, len(g.Peers))
|
||||
copy(peers, g.Peers)
|
||||
}
|
||||
|
||||
var groupPeers []GroupPeer
|
||||
if g.GroupPeers != nil {
|
||||
groupPeers = make([]GroupPeer, len(g.GroupPeers))
|
||||
copy(groupPeers, g.GroupPeers)
|
||||
}
|
||||
|
||||
var groupUsers []GroupUser
|
||||
if g.GroupUsers != nil {
|
||||
groupUsers = make([]GroupUser, len(g.GroupUsers))
|
||||
copy(groupUsers, g.GroupUsers)
|
||||
}
|
||||
|
||||
var resources []Resource
|
||||
if g.Resources != nil {
|
||||
resources = make([]Resource, len(g.Resources))
|
||||
copy(resources, g.Resources)
|
||||
}
|
||||
|
||||
group := &Group{
|
||||
ID: g.ID,
|
||||
AccountID: g.AccountID,
|
||||
Name: g.Name,
|
||||
Issued: g.Issued,
|
||||
Peers: make([]string, len(g.Peers)),
|
||||
GroupPeers: make([]GroupPeer, len(g.GroupPeers)),
|
||||
GroupUsers: make([]GroupUser, len(g.GroupUsers)),
|
||||
Resources: make([]Resource, len(g.Resources)),
|
||||
Peers: peers,
|
||||
GroupPeers: groupPeers,
|
||||
GroupUsers: groupUsers,
|
||||
Resources: resources,
|
||||
IntegrationReference: g.IntegrationReference,
|
||||
}
|
||||
copy(group.Peers, g.Peers)
|
||||
copy(group.GroupPeers, g.GroupPeers)
|
||||
copy(group.GroupUsers, g.GroupUsers)
|
||||
copy(group.Resources, g.Resources)
|
||||
return group
|
||||
}
|
||||
|
||||
|
||||
@@ -113,6 +113,7 @@ func (u *User) LoadAutoGroups() {
|
||||
for _, group := range u.Groups {
|
||||
u.AutoGroups = append(u.AutoGroups, group.GroupID)
|
||||
}
|
||||
u.Groups = []*GroupUser{}
|
||||
}
|
||||
|
||||
func (u *User) StoreAutoGroups() {
|
||||
@@ -124,6 +125,7 @@ func (u *User) StoreAutoGroups() {
|
||||
UserID: u.Id,
|
||||
})
|
||||
}
|
||||
u.AutoGroups = []string{}
|
||||
}
|
||||
|
||||
// IsBlocked returns true if the user is blocked, false otherwise
|
||||
@@ -218,10 +220,16 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
|
||||
// Copy the user
|
||||
func (u *User) Copy() *User {
|
||||
groupUsers := make([]*GroupUser, len(u.Groups))
|
||||
copy(groupUsers, u.Groups)
|
||||
autoGroups := make([]string, len(u.AutoGroups))
|
||||
copy(autoGroups, u.AutoGroups)
|
||||
var groupUsers []*GroupUser
|
||||
if u.Groups != nil {
|
||||
groupUsers = make([]*GroupUser, len(u.Groups))
|
||||
copy(groupUsers, u.Groups)
|
||||
}
|
||||
var autoGroups []string
|
||||
if u.AutoGroups != nil {
|
||||
autoGroups = make([]string, len(u.AutoGroups))
|
||||
copy(autoGroups, u.AutoGroups)
|
||||
}
|
||||
|
||||
pats := make(map[string]*PersonalAccessToken, len(u.PATs))
|
||||
for k, v := range u.PATs {
|
||||
|
||||
@@ -745,6 +745,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
||||
updatedUser.Role = update.Role
|
||||
updatedUser.Blocked = update.Blocked
|
||||
updatedUser.AutoGroups = update.AutoGroups
|
||||
updatedUser.StoreAutoGroups()
|
||||
// these two fields can't be set via API, only via direct call to the method
|
||||
updatedUser.Issued = update.Issued
|
||||
updatedUser.IntegrationReference = update.IntegrationReference
|
||||
@@ -772,11 +773,6 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
||||
return false, nil, nil, nil, err
|
||||
}
|
||||
|
||||
updatedUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthNone, updatedUser.Id)
|
||||
if err != nil {
|
||||
return false, nil, nil, nil, err
|
||||
}
|
||||
|
||||
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, isNewUser, removedGroupsIDs, addedGroupsIDs, transaction)
|
||||
|
||||
return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil
|
||||
|
||||
@@ -345,6 +345,9 @@ func TestUser_Copy(t *testing.T) {
|
||||
IsServiceUser: true,
|
||||
ServiceUserName: "servicename",
|
||||
AutoGroups: []string{"group1", "group2"},
|
||||
Groups: []*types.GroupUser{
|
||||
{AccountID: "accountId", GroupID: "groupId", UserID: "userId"},
|
||||
},
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"pat1": {
|
||||
ID: "pat1",
|
||||
@@ -1338,6 +1341,8 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account, err = manager.Store.GetAccount(context.Background(), account.Id)
|
||||
|
||||
updated, err := manager.SaveUser(context.Background(), account.Id, tc.initiatorID, tc.update)
|
||||
if tc.expectedErr {
|
||||
require.Errorf(t, err, "expecting SaveUser to throw an error")
|
||||
@@ -1653,6 +1658,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
LastLogin: time.Time{},
|
||||
Issued: "api",
|
||||
IntegrationReference: integration_reference.IntegrationReference{},
|
||||
AutoGroups: []string{},
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles.User),
|
||||
},
|
||||
@@ -1672,6 +1678,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
LastLogin: time.Time{},
|
||||
Issued: "api",
|
||||
IntegrationReference: integration_reference.IntegrationReference{},
|
||||
AutoGroups: []string{},
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles.Admin),
|
||||
},
|
||||
@@ -1691,6 +1698,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
LastLogin: time.Time{},
|
||||
Issued: "api",
|
||||
IntegrationReference: integration_reference.IntegrationReference{},
|
||||
AutoGroups: []string{},
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles.User),
|
||||
Restricted: true,
|
||||
@@ -1712,6 +1720,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) {
|
||||
LastLogin: time.Time{},
|
||||
Issued: "api",
|
||||
IntegrationReference: integration_reference.IntegrationReference{},
|
||||
AutoGroups: []string{},
|
||||
},
|
||||
Permissions: mergeRolePermissions(roles.User),
|
||||
Restricted: false,
|
||||
|
||||
Reference in New Issue
Block a user