[management] Refactor peers to use store methods (#2893)

This commit is contained in:
Bethuel Mmbaga
2025-01-20 20:41:46 +03:00
committed by GitHub
parent c619bf5b0c
commit 1ad2cb5582
30 changed files with 1614 additions and 857 deletions

View File

@@ -313,12 +313,12 @@ func (s *SqlStore) GetInstallationID() string {
return installation.InstallationIDValue
}
func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error {
func (s *SqlStore) SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error {
// To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields.
peerCopy := peer.Copy()
peerCopy.AccountID = accountID
err := s.db.Transaction(func(tx *gorm.DB) error {
err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Transaction(func(tx *gorm.DB) error {
// check if peer exists before saving
var peerID string
result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID)
@@ -332,7 +332,7 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
result = tx.Model(&nbpeer.Peer{}).Where(accountAndIDQueryCondition, accountID, peer.ID).Save(peerCopy)
if result.Error != nil {
return result.Error
return status.Errorf(status.Internal, "failed to save peer to store: %v", result.Error)
}
return nil
@@ -358,7 +358,7 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID
Where(idQueryCondition, accountID).
Updates(&accountCopy)
if result.Error != nil {
return result.Error
return status.Errorf(status.Internal, "failed to update account domain attributes to store: %v", result.Error)
}
if result.RowsAffected == 0 {
@@ -368,7 +368,7 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID
return nil
}
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
func (s *SqlStore) SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
var peerCopy nbpeer.Peer
peerCopy.Status = &peerStatus
@@ -376,12 +376,12 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe
"peer_status_last_seen", "peer_status_connected",
"peer_status_login_expired", "peer_status_required_approval",
}
result := s.db.Model(&nbpeer.Peer{}).
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Select(fieldsToUpdate).
Where(accountAndIDQueryCondition, accountID, peerID).
Updates(&peerCopy)
if result.Error != nil {
return result.Error
return status.Errorf(status.Internal, "failed to save peer status to store: %v", result.Error)
}
if result.RowsAffected == 0 {
@@ -391,22 +391,22 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe
return nil
}
func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peerWithLocation *nbpeer.Peer) error {
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
var peerCopy nbpeer.Peer
// Since the location field has been migrated to JSON serialization,
// updating the struct ensures the correct data format is inserted into the database.
peerCopy.Location = peerWithLocation.Location
result := s.db.Model(&nbpeer.Peer{}).
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID).
Updates(peerCopy)
if result.Error != nil {
return result.Error
return status.Errorf(status.Internal, "failed to save peer locations to store: %v", result.Error)
}
if result.RowsAffected == 0 && s.storeEngine != MysqlStoreEngine {
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID)
}
@@ -773,9 +773,10 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
return accountID, nil
}
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) {
var accountID string
result := s.db.Model(&types.User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.User{}).
Select("account_id").Where(idQueryCondition, userID).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
@@ -786,6 +787,20 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
return accountID, nil
}
func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) {
var accountID string
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Select("account_id").Where(idQueryCondition, peerID).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "peer %s account not found", peerID)
}
return "", status.NewGetAccountFromStoreError(result.Error)
}
return accountID, nil
}
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
var accountID string
result := s.db.Model(&types.SetupKey{}).Select("account_id").Where(GetKeyQueryCondition(s), setupKey).First(&accountID)
@@ -865,7 +880,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
return nil, status.NewPeerNotFoundError(peerKey)
}
return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error)
}
@@ -1096,9 +1111,10 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
}
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error {
var group types.Group
result := s.db.Where("account_id = ? AND name = ?", accountID, "All").First(&group)
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&group, "account_id = ? AND name = ?", accountID, "All")
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account")
@@ -1114,7 +1130,7 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
group.Peers = append(group.Peers, peerID)
if err := s.db.Save(&group).Error; err != nil {
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
}
@@ -1122,9 +1138,10 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
}
// AddPeerToGroup adds a peer to a group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
func (s *SqlStore) AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error {
var group types.Group
result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountAndIDQueryCondition, accountId, groupID).
First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewGroupNotFoundError(groupID)
@@ -1141,7 +1158,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
group.Peers = append(group.Peers, peerId)
if err := s.db.Save(&group).Error; err != nil {
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group: %s", err)
}
@@ -1201,13 +1218,52 @@ func (s *SqlStore) RemoveResourceFromGroup(ctx context.Context, accountId string
return nil
}
// GetUserPeers retrieves peers for a user.
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
return getRecords[*nbpeer.Peer](s.db.Where("user_id = ?", userID), lockStrength, accountID)
// GetPeerGroups retrieves all groups assigned to a specific peer in a given account.
func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) {
var groups []*types.Group
query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId))
if query.Error != nil {
return nil, query.Error
}
return groups, nil
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
if err := s.db.Create(peer).Error; err != nil {
// GetAccountPeers retrieves peers for an account.
func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
var peers []*nbpeer.Peer
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get peers from store")
}
return peers, nil
}
// GetUserPeers retrieves peers for a user.
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
var peers []*nbpeer.Peer
// Exclude peers added via setup keys, as they are not user-specific and have an empty user_id.
if userID == "" {
return peers, nil
}
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Find(&peers, "account_id = ? AND user_id = ?", accountID, userID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get peers from store")
}
return peers, nil
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error {
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
}
@@ -1221,7 +1277,7 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength
First(&peer, accountAndIDQueryCondition, accountID, peerID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
return nil, status.NewPeerNotFoundError(peerID)
}
log.WithContext(ctx).Errorf("failed to get peer from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get peer from store")
@@ -1247,6 +1303,68 @@ func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStreng
return peersMap, nil
}
// GetAccountPeersWithExpiration retrieves a list of peers that have login expiration enabled and added by a user.
func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
var peers []*nbpeer.Peer
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
Find(&peers, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get peers with expiration from store")
}
return peers, nil
}
// GetAccountPeersWithInactivity retrieves a list of peers that have login expiration enabled and added by a user.
func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
var peers []*nbpeer.Peer
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Where("inactivity_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
Find(&peers, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get peers with inactivity from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get peers with inactivity from store")
}
return peers, nil
}
// GetAllEphemeralPeers retrieves all peers with Ephemeral set to true across all accounts, optimized for batch processing.
func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) {
var allEphemeralPeers, batchPeers []*nbpeer.Peer
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Where("ephemeral = ?", true).
FindInBatches(&batchPeers, 1000, func(tx *gorm.DB, batch int) error {
allEphemeralPeers = append(allEphemeralPeers, batchPeers...)
return nil
})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to retrieve ephemeral peers: %s", result.Error)
return nil, fmt.Errorf("failed to retrieve ephemeral peers")
}
return allEphemeralPeers, nil
}
// DeletePeer removes a peer from the store.
func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&nbpeer.Peer{}, accountAndIDQueryCondition, accountID, peerID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete peer from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete peer from store")
}
if result.RowsAffected == 0 {
return status.NewPeerNotFoundError(peerID)
}
return nil
}
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
@@ -1638,7 +1756,7 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren
// GetAccountNameServerGroups retrieves name server groups for an account.
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
var nsGroups []*nbdns.NameServerGroup
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&nsGroups, accountIDCondition, accountID)
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&nsGroups, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get name server groups from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get name server groups from store")
@@ -1650,7 +1768,7 @@ func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
var nsGroup *nbdns.NameServerGroup
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -1665,7 +1783,7 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock
// SaveNameServerGroup saves a name server group to the database.
func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup)
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err)
return status.Errorf(status.Internal, "failed to save name server group to store")