Refactor ephemeral peers and mark PAT as used

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-10-31 21:50:05 +03:00
parent b7525d9fe8
commit 6b94f6e4e7
7 changed files with 99 additions and 60 deletions

View File

@@ -39,6 +39,7 @@ const (
accountAndIDQueryCondition = "account_id = ? and id = ?"
accountIDCondition = "account_id = ?"
peerNotFoundFMT = "peer %s not found"
batchSize = 500
)
// SqlStore represents an account storage backed by a Sql DB persisted to disk
@@ -592,7 +593,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
return nil, status.NewAccountNotFoundError()
}
return nil, status.NewGetAccountFromStoreError(result.Error)
}
@@ -708,7 +709,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
return "", status.NewAccountNotFoundError()
}
return "", status.NewGetAccountFromStoreError(result.Error)
}
@@ -719,6 +720,21 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
var accountID string
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.NewAccountNotFoundError()
}
return "", status.NewGetAccountFromStoreError(result.Error)
}
return accountID, nil
}
func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) {
var accountID string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Select("account_id").Where(idQueryCondition, peerID).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
@@ -798,7 +814,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
First(&accountNetwork, idQueryCondition, accountID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
return nil, status.NewAccountNotFoundError()
}
log.WithContext(ctx).Errorf("error when getting network from the store: %s", err)
return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
@@ -1132,9 +1148,27 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength
return peer, nil
}
// GetAllEphemeralPeers retrieves all peers with Ephemeral set to true across all accounts, optimized for batch processing.
func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) {
var allEphemeralPeers, batchPeers []*nbpeer.Peer
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Where("ephemeral = ?", true).
FindInBatches(&batchPeers, batchSize, func(tx *gorm.DB, batch int) error {
allEphemeralPeers = append(allEphemeralPeers, batchPeers...)
return nil
})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to retrieve ephemeral peers: %s", result.Error)
return nil, fmt.Errorf("failed to retrieve ephemeral peers")
}
return allEphemeralPeers, nil
}
func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&Policy{}, accountAndIDQueryCondition, accountID, peerID)
Delete(&nbpeer.Peer{}, accountAndIDQueryCondition, accountID, peerID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete peer from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete peer from store")
@@ -1629,6 +1663,27 @@ func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength
return pats, nil
}
// MarkPATUsed marks a personal access token as used.
func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error {
patCopy := PersonalAccessToken{
LastUsed: time.Now().UTC(),
}
fieldsToUpdate := []string{"last_used"}
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Select(fieldsToUpdate).Where(idQueryCondition, patID).Updates(&patCopy)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to mark PAT as used: %s", result.Error)
return status.Errorf(status.Internal, "failed to mark PAT as used")
}
if result.RowsAffected == 0 {
return status.NewPATNotFoundError()
}
return nil
}
// SavePAT saves a personal access token to the database.
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *PersonalAccessToken) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat)