mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-24 03:06:38 +00:00
refactor groups methods
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
@@ -378,17 +378,6 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
|
||||
Create(&usersToSave).Error
|
||||
}
|
||||
|
||||
// SaveGroups saves the given list of groups to the database.
|
||||
// It updates existing groups if a conflict occurs.
|
||||
func (s *SqlStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
|
||||
groupsToSave := make([]nbgroup.Group, 0, len(groups))
|
||||
for _, group := range groups {
|
||||
group.AccountID = accountID
|
||||
groupsToSave = append(groupsToSave, *group)
|
||||
}
|
||||
return s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&groupsToSave).Error
|
||||
}
|
||||
|
||||
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
|
||||
func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
|
||||
return nil
|
||||
@@ -500,6 +489,11 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetAccountUsers returns all users associated with the account.
|
||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) {
|
||||
return getRecords[User](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
var groups []*nbgroup.Group
|
||||
result := s.db.Find(&groups, accountIDCondition, accountID)
|
||||
@@ -1135,9 +1129,38 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
// SaveGroup saves a group to the database.
|
||||
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
|
||||
return saveRecord[nbgroup.Group](s.db.WithContext(ctx), lockStrength, group)
|
||||
}
|
||||
|
||||
// SaveGroups saves the given list of groups to the database.
|
||||
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error {
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGroup deletes a group from the database.
|
||||
func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) error {
|
||||
return deleteRecordByID[nbgroup.Group](s.db.WithContext(ctx), lockStrength, groupID, accountID)
|
||||
}
|
||||
|
||||
// DeleteGroups deletes groups from the database.
|
||||
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, groupIDs []string, accountID string) error {
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(strength)}).
|
||||
Where("account_id AND id IN ?", accountID, groupIDs).Delete(&nbgroup.Group{})
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountPolicies retrieves policies for an account.
|
||||
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
|
||||
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
|
||||
return getRecords[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetPolicyByID retrieves a policy by its ID and account ID.
|
||||
@@ -1159,7 +1182,7 @@ func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrengt
|
||||
|
||||
// GetAccountPostureChecks retrieves posture checks for an account.
|
||||
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
|
||||
return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
return getRecords[posture.Checks](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
|
||||
@@ -1188,7 +1211,7 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking
|
||||
|
||||
// GetAccountRoutes retrieves network routes for an account.
|
||||
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
|
||||
return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
return getRecords[route.Route](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetRouteByID retrieves a route by its ID and account ID.
|
||||
@@ -1209,7 +1232,7 @@ func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength
|
||||
|
||||
// GetAccountSetupKeys retrieves setup keys for an account.
|
||||
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
|
||||
return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
return getRecords[SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
|
||||
@@ -1231,7 +1254,7 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren
|
||||
|
||||
// GetAccountNameServerGroups retrieves name server groups for an account.
|
||||
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
|
||||
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
return getRecords[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
|
||||
@@ -1277,13 +1300,13 @@ func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength,
|
||||
|
||||
// GetAccountPeers retrieves peers for an account.
|
||||
func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
|
||||
return getRecords[*nbpeer.Peer](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
return getRecords[nbpeer.Peer](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetAccountPeersWithExpiration retrieves a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user.
|
||||
func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
|
||||
db := s.db.WithContext(ctx).Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true)
|
||||
return getRecords[*nbpeer.Peer](db, lockStrength, accountID)
|
||||
return getRecords[nbpeer.Peer](db, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetPeerByID retrieves a peer by its ID and account ID.
|
||||
@@ -1292,14 +1315,12 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength
|
||||
}
|
||||
|
||||
// getRecords retrieves records from the database based on the account ID.
|
||||
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
|
||||
var record []T
|
||||
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]*T, error) {
|
||||
var record []*T
|
||||
|
||||
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||
recordType := parts[len(parts)-1]
|
||||
|
||||
recordType := getRecordType(record)
|
||||
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
|
||||
}
|
||||
|
||||
@@ -1313,8 +1334,7 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a
|
||||
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&record, accountAndIDQueryCondition, accountID, recordID)
|
||||
if err := result.Error; err != nil {
|
||||
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||
recordType := parts[len(parts)-1]
|
||||
recordType := getRecordType(record)
|
||||
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "%s not found", recordType)
|
||||
@@ -1324,15 +1344,23 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
// saveRecord saves a record to the database.
|
||||
func saveRecord[T any](db *gorm.DB, lockStrength LockingStrength, record *T) error {
|
||||
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(record)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save %s to store: %v", getRecordType(record), result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteRecordByID deletes a record by its ID and account ID from the database.
|
||||
func deleteRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) error {
|
||||
var record T
|
||||
|
||||
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||
recordType := parts[len(parts)-1]
|
||||
recordType := getRecordType(record)
|
||||
|
||||
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&record, accountAndIDQueryCondition, accountID, recordID)
|
||||
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&record, accountAndIDQueryCondition, accountID, recordID)
|
||||
if err := result.Error; err != nil {
|
||||
return status.Errorf(status.Internal, "failed to delete %s from store: %v", recordType, err)
|
||||
}
|
||||
@@ -1343,3 +1371,8 @@ func deleteRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getRecordType(record any) string {
|
||||
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||
return parts[len(parts)-1]
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user