mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-24 03:06:38 +00:00
wip refactor peer methods
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
@@ -782,6 +782,21 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
|
||||
return accountSettings.Settings, nil
|
||||
}
|
||||
|
||||
// SaveAccountSettings stores the account settings in DB.
|
||||
func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error {
|
||||
result := s.db.WithContext(ctx).Debug().Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
Select("*").Where(idQueryCondition, accountID).Updates(&AccountSettings{Settings: settings})
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save account settings to store: %v", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||
var user User
|
||||
@@ -1055,6 +1070,19 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
|
||||
return &accountDNSSettings.DNSSettings, nil
|
||||
}
|
||||
|
||||
// SaveDNSSettings saves the DNS settings to the store.
|
||||
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error {
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings})
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save dns settings to store: %v", result.Error)
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AccountExists checks whether an account exists by the given ID.
|
||||
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
|
||||
var accountID string
|
||||
@@ -1120,13 +1148,13 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng
|
||||
// SavePolicy saves a policy to the database.
|
||||
func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
|
||||
return s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&policy).Error
|
||||
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy).Error
|
||||
}
|
||||
|
||||
// DeletePolicy deletes a policy from the database.
|
||||
func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, policyID string) error {
|
||||
func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, policyID, accountID string) error {
|
||||
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&Policy{}, idQueryCondition, policyID).Error
|
||||
Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID).Error
|
||||
}
|
||||
|
||||
// GetAccountPostureChecks retrieves posture checks for an account.
|
||||
@@ -1141,23 +1169,21 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin
|
||||
|
||||
// SavePostureChecks saves a posture checks to the database.
|
||||
func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error {
|
||||
result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&postureCheck)
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrDuplicatedKey) {
|
||||
return status.Errorf(status.InvalidArgument, "name should be unique")
|
||||
}
|
||||
return result.Error
|
||||
return status.Errorf(status.Internal, "failed to save posture checks to store: %v", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePostureChecks deletes a posture checks from the database.
|
||||
func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, postureChecksID string) error {
|
||||
func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, postureChecksID, accountID string) error {
|
||||
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&posture.Checks{}, idQueryCondition, postureChecksID).Error
|
||||
Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID).Error
|
||||
}
|
||||
|
||||
// GetAccountRoutes retrieves network routes for an account.
|
||||
@@ -1170,6 +1196,17 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt
|
||||
return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID)
|
||||
}
|
||||
|
||||
// SaveRoute saves a route to the database.
|
||||
func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error {
|
||||
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route).Error
|
||||
}
|
||||
|
||||
// DeleteRoute deletes a route from the database.
|
||||
func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, routeID, accountID string) error {
|
||||
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID).Error
|
||||
}
|
||||
|
||||
// GetAccountSetupKeys retrieves setup keys for an account.
|
||||
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
|
||||
return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
@@ -1187,9 +1224,9 @@ func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrengt
|
||||
}
|
||||
|
||||
// DeleteSetupKey deletes a setup key from the database.
|
||||
func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, setupKeyID string) error {
|
||||
func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, setupKeyID, accountID string) error {
|
||||
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&posture.Checks{}, idQueryCondition, setupKeyID).Error
|
||||
Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, setupKeyID).Error
|
||||
}
|
||||
|
||||
// GetAccountNameServerGroups retrieves name server groups for an account.
|
||||
@@ -1202,6 +1239,58 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock
|
||||
return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID)
|
||||
}
|
||||
|
||||
// SaveNameServerGroup saves a name server group to the database.
|
||||
func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error {
|
||||
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup).Error
|
||||
}
|
||||
|
||||
// DeleteNameServerGroup deletes a name server group from the database.
|
||||
func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroupID, accountID string) error {
|
||||
return deleteRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nameServerGroupID, accountID)
|
||||
}
|
||||
|
||||
// GetPATByID retrieves a personal access token by its ID and user ID.
|
||||
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, patID string, userID string) (*PersonalAccessToken, error) {
|
||||
var pat PersonalAccessToken
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&pat, "user_id = ? and id = ?", userID, patID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "PAT not found")
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get PAT from store")
|
||||
}
|
||||
|
||||
return &pat, nil
|
||||
}
|
||||
|
||||
// SavePAT saves a personal access token to the database.
|
||||
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *PersonalAccessToken) error {
|
||||
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat).Error
|
||||
}
|
||||
|
||||
// DeletePAT deletes a personal access token from the database.
|
||||
func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, patID, userID string) error {
|
||||
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&PersonalAccessToken{}, "user_id = ? and id = ?", userID, patID).Error
|
||||
}
|
||||
|
||||
// GetAccountPeers retrieves peers for an account.
|
||||
func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
|
||||
return getRecords[*nbpeer.Peer](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetAccountPeersWithExpiration retrieves a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user.
|
||||
func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
|
||||
db := s.db.WithContext(ctx).Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true)
|
||||
return getRecords[*nbpeer.Peer](db, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetPeerByID retrieves a peer by its ID and account ID.
|
||||
func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, peerID string, accountID string) (*nbpeer.Peer, error) {
|
||||
return getRecordByID[nbpeer.Peer](s.db.WithContext(ctx), lockStrength, peerID, accountID)
|
||||
}
|
||||
|
||||
// getRecords retrieves records from the database based on the account ID.
|
||||
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
|
||||
var record []T
|
||||
@@ -1234,3 +1323,23 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
// deleteRecordByID deletes a record by its ID and account ID from the database.
|
||||
func deleteRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) error {
|
||||
var record T
|
||||
|
||||
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||
recordType := parts[len(parts)-1]
|
||||
|
||||
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&record, accountAndIDQueryCondition, accountID, recordID)
|
||||
if err := result.Error; err != nil {
|
||||
return status.Errorf(status.Internal, "failed to delete %s from store: %v", recordType, err)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "%s not found", recordType)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user