wip refactor peer methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-10-01 01:07:48 +03:00
parent f43a006c34
commit f9ed25f8b1
4 changed files with 261 additions and 23 deletions

View File

@@ -782,6 +782,21 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
return accountSettings.Settings, nil
}
// SaveAccountSettings stores the account settings in DB.
func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error {
result := s.db.WithContext(ctx).Debug().Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Select("*").Where(idQueryCondition, accountID).Updates(&AccountSettings{Settings: settings})
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save account settings to store: %v", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "account not found")
}
return nil
}
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
var user User
@@ -1055,6 +1070,19 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
return &accountDNSSettings.DNSSettings, nil
}
// SaveDNSSettings saves the DNS settings to the store.
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings})
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save dns settings to store: %v", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "account not found")
}
return nil
}
// AccountExists checks whether an account exists by the given ID.
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
var accountID string
@@ -1120,13 +1148,13 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng
// SavePolicy saves a policy to the database.
func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
return s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&policy).Error
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy).Error
}
// DeletePolicy deletes a policy from the database.
func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, policyID string) error {
func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, policyID, accountID string) error {
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&Policy{}, idQueryCondition, policyID).Error
Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID).Error
}
// GetAccountPostureChecks retrieves posture checks for an account.
@@ -1141,23 +1169,21 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin
// SavePostureChecks saves a posture checks to the database.
func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error {
result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&postureCheck)
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrDuplicatedKey) {
return status.Errorf(status.InvalidArgument, "name should be unique")
}
return result.Error
return status.Errorf(status.Internal, "failed to save posture checks to store: %v", result.Error)
}
return nil
}
// DeletePostureChecks deletes a posture checks from the database.
func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, postureChecksID string) error {
func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, postureChecksID, accountID string) error {
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&posture.Checks{}, idQueryCondition, postureChecksID).Error
Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID).Error
}
// GetAccountRoutes retrieves network routes for an account.
@@ -1170,6 +1196,17 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt
return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID)
}
// SaveRoute saves a route to the database.
func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error {
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route).Error
}
// DeleteRoute deletes a route from the database.
func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, routeID, accountID string) error {
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID).Error
}
// GetAccountSetupKeys retrieves setup keys for an account.
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
@@ -1187,9 +1224,9 @@ func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrengt
}
// DeleteSetupKey deletes a setup key from the database.
func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, setupKeyID string) error {
func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, setupKeyID, accountID string) error {
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&posture.Checks{}, idQueryCondition, setupKeyID).Error
Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, setupKeyID).Error
}
// GetAccountNameServerGroups retrieves name server groups for an account.
@@ -1202,6 +1239,58 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock
return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID)
}
// SaveNameServerGroup saves a name server group to the database.
func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error {
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup).Error
}
// DeleteNameServerGroup deletes a name server group from the database.
func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroupID, accountID string) error {
return deleteRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nameServerGroupID, accountID)
}
// GetPATByID retrieves a personal access token by its ID and user ID.
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, patID string, userID string) (*PersonalAccessToken, error) {
var pat PersonalAccessToken
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&pat, "user_id = ? and id = ?", userID, patID)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "PAT not found")
}
return nil, status.Errorf(status.Internal, "failed to get PAT from store")
}
return &pat, nil
}
// SavePAT saves a personal access token to the database.
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *PersonalAccessToken) error {
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat).Error
}
// DeletePAT deletes a personal access token from the database.
func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, patID, userID string) error {
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&PersonalAccessToken{}, "user_id = ? and id = ?", userID, patID).Error
}
// GetAccountPeers retrieves peers for an account.
func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
return getRecords[*nbpeer.Peer](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetAccountPeersWithExpiration retrieves a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user.
func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
db := s.db.WithContext(ctx).Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true)
return getRecords[*nbpeer.Peer](db, lockStrength, accountID)
}
// GetPeerByID retrieves a peer by its ID and account ID.
func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, peerID string, accountID string) (*nbpeer.Peer, error) {
return getRecordByID[nbpeer.Peer](s.db.WithContext(ctx), lockStrength, peerID, accountID)
}
// getRecords retrieves records from the database based on the account ID.
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
var record []T
@@ -1234,3 +1323,23 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a
}
return &record, nil
}
// deleteRecordByID deletes a record by its ID and account ID from the database.
func deleteRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) error {
var record T
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&record, accountAndIDQueryCondition, accountID, recordID)
if err := result.Error; err != nil {
return status.Errorf(status.Internal, "failed to delete %s from store: %v", recordType, err)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "%s not found", recordType)
}
return nil
}