fix tests

This commit is contained in:
pascal
2026-01-08 20:32:24 +01:00
parent f7ee019f26
commit ea37d4b768
9 changed files with 78 additions and 60 deletions

View File

@@ -177,7 +177,8 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
generateAccountSQLTypes(account)
// Encrypt sensitive user data before saving
for i := range account.UsersG {
for i, user := range account.UsersG {
user.StoreAutoGroups()
if err := account.UsersG[i].EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt user: %w", err)
}
@@ -185,7 +186,6 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
for _, group := range account.GroupsG {
group.StoreGroupPeers()
group.StoreGroupUsers()
}
err := s.transaction(func(tx *gorm.DB) error {
@@ -213,6 +213,9 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
}
return nil
})
if err != nil {
return err
}
took := time.Since(start)
if s.metrics != nil {
@@ -220,7 +223,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro
}
log.WithContext(ctx).Debugf("took %d ms to persist an account to the store", took.Milliseconds())
return err
return nil
}
// generateAccountSQLTypes generates the GORM compatible types for the account
@@ -244,8 +247,8 @@ func generateAccountSQLTypes(account *types.Account) {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
user.LoadAutoGroups()
account.UsersG = append(account.UsersG, *user)
account.UsersG = append(account.UsersG, user)
account.Users = nil
}
for id, group := range account.Groups {
@@ -750,7 +753,6 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
for _, g := range groups {
g.LoadGroupPeers()
g.LoadGroupUsers()
}
return groups, nil
@@ -780,7 +782,6 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
for _, g := range groups {
g.LoadGroupPeers()
g.LoadGroupUsers()
}
return groups, nil
@@ -881,9 +882,10 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
Preload("SetupKeysG").
Preload("PeersG").
Preload("UsersG").
Preload("UsersG.GroupUser").
Preload("UsersG.Groups").
Preload("GroupsG").
Preload("GroupsG.GroupPeers").
Preload("GroupsG.GroupUsers").
Preload("RoutesG").
Preload("NameServerGroupsG").
Preload("PostureChecks").
@@ -931,7 +933,7 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt user: %w", err)
}
account.Users[user.Id] = &user
account.Users[user.Id] = user
user.PATsG = nil
}
account.UsersG = nil
@@ -1221,7 +1223,7 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
account.Users = make(map[string]*types.User, len(account.UsersG))
for i := range account.UsersG {
user := &account.UsersG[i]
user := account.UsersG[i]
if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt user: %w", err)
}
@@ -1630,19 +1632,19 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
return peers, nil
}
func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) {
func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]*types.User, error) {
const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
return nil, err
}
users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) {
users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.User, error) {
var u types.User
var lastLogin, createdAt sql.NullTime
var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool
err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name)
if err != nil {
return u, err
return &u, err
}
if lastLogin.Valid {
@@ -1664,7 +1666,7 @@ func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User
u.PendingApproval = pendingApproval.Bool
}
return u, nil
return &u, nil
})
if err != nil {
return nil, err
@@ -2828,7 +2830,6 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng
for _, group := range groups {
group.LoadGroupPeers()
group.LoadGroupUsers()
}
return groups, nil
@@ -3230,7 +3231,6 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
}
group.LoadGroupPeers()
group.LoadGroupUsers()
return group, nil
}
@@ -3262,7 +3262,6 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
}
group.LoadGroupPeers()
group.LoadGroupUsers()
return &group, nil
}
@@ -3284,7 +3283,6 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
groupsMap := make(map[string]*types.Group)
for _, group := range groups {
group.LoadGroupPeers()
group.LoadGroupUsers()
groupsMap[group.ID] = group
}