mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 09:46:40 +00:00
Revert "Merge branch 'main' into feature/remote-debug"
This reverts commit6d6333058c, reversing changes made to446aded1f7.
This commit is contained in:
@@ -52,6 +52,7 @@ const (
|
||||
// SqlStore represents an account storage backed by a Sql DB persisted to disk
|
||||
type SqlStore struct {
|
||||
db *gorm.DB
|
||||
resourceLocks sync.Map
|
||||
globalAccountLock sync.Mutex
|
||||
metrics telemetry.AppMetrics
|
||||
installationPK int
|
||||
@@ -218,6 +219,44 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
|
||||
return unlock
|
||||
}
|
||||
|
||||
// AcquireWriteLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
|
||||
func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring write lock for ID %s", uniqueID)
|
||||
|
||||
startWait := time.Now()
|
||||
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||
mtx := value.(*sync.RWMutex)
|
||||
mtx.Lock()
|
||||
log.WithContext(ctx).Tracef("waiting to acquire write lock for ID %s in %v", uniqueID, time.Since(startWait))
|
||||
startHold := time.Now()
|
||||
|
||||
unlock = func() {
|
||||
mtx.Unlock()
|
||||
log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(startHold))
|
||||
}
|
||||
|
||||
return unlock
|
||||
}
|
||||
|
||||
// AcquireReadLockByUID acquires an ID lock for writing to a resource and returns a function that releases the lock
|
||||
func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring read lock for ID %s", uniqueID)
|
||||
|
||||
startWait := time.Now()
|
||||
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||
mtx := value.(*sync.RWMutex)
|
||||
mtx.RLock()
|
||||
log.WithContext(ctx).Tracef("waiting to acquire read lock for ID %s in %v", uniqueID, time.Since(startWait))
|
||||
startHold := time.Now()
|
||||
|
||||
unlock = func() {
|
||||
mtx.RUnlock()
|
||||
log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(startHold))
|
||||
}
|
||||
|
||||
return unlock
|
||||
}
|
||||
|
||||
// Deprecated: Full account operations are no longer supported
|
||||
func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) error {
|
||||
start := time.Now()
|
||||
@@ -989,7 +1028,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
|
||||
|
||||
func (s *SqlStore) GetAnyAccountID(ctx context.Context) (string, error) {
|
||||
var account types.Account
|
||||
result := s.db.Select("id").Order("created_at desc").Limit(1).Find(&account)
|
||||
result := s.db.WithContext(ctx).Select("id").Order("created_at desc").Limit(1).Find(&account)
|
||||
if result.Error != nil {
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
@@ -1474,7 +1513,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupI
|
||||
PeerID: peerID,
|
||||
}
|
||||
|
||||
err := s.db.Clauses(clause.OnConflict{
|
||||
err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
|
||||
DoNothing: true,
|
||||
}).Create(peer).Error
|
||||
@@ -1489,7 +1528,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupI
|
||||
|
||||
// RemovePeerFromGroup removes a peer from a group
|
||||
func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error {
|
||||
err := s.db.
|
||||
err := s.db.WithContext(ctx).
|
||||
Delete(&types.GroupPeer{}, "group_id = ? AND peer_id = ?", groupID, peerID).Error
|
||||
|
||||
if err != nil {
|
||||
@@ -1502,7 +1541,7 @@ func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, group
|
||||
|
||||
// RemovePeerFromAllGroups removes a peer from all groups
|
||||
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
|
||||
err := s.db.
|
||||
err := s.db.WithContext(ctx).
|
||||
Delete(&types.GroupPeer{}, "peer_id = ?", peerID).Error
|
||||
|
||||
if err != nil {
|
||||
@@ -2090,7 +2129,7 @@ func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error {
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeletePolicy(ctx context.Context, accountID, policyID string) error {
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Where("policy_id = ?", policyID).Delete(&types.PolicyRule{}).Error; err != nil {
|
||||
return fmt.Errorf("delete policy rules: %w", err)
|
||||
}
|
||||
@@ -2781,7 +2820,7 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) {
|
||||
tx := s.db
|
||||
tx := s.db.WithContext(ctx)
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
@@ -2922,22 +2961,3 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return []*nbpeer.Peer{}, nil
|
||||
}
|
||||
|
||||
var peers []*nbpeer.Peer
|
||||
peerIDsSubquery := s.db.Model(&types.GroupPeer{}).
|
||||
Select("DISTINCT peer_id").
|
||||
Where("account_id = ? AND group_id IN ?", accountID, groupIDs)
|
||||
|
||||
result := s.db.Where("id IN (?)", peerIDsSubquery).Find(&peers)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers by group IDs: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get peers by group IDs")
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
@@ -3607,113 +3607,3 @@ func intToIPv4(n uint32) net.IP {
|
||||
binary.BigEndian.PutUint32(ip, n)
|
||||
return ip
|
||||
}
|
||||
|
||||
func TestSqlStore_GetPeersByGroupIDs(t *testing.T) {
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
group1ID := "test-group-1"
|
||||
group2ID := "test-group-2"
|
||||
emptyGroupID := "empty-group"
|
||||
|
||||
peer1 := "cfefqs706sqkneg59g4g"
|
||||
peer2 := "cfeg6sf06sqkneg59g50"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
groupIDs []string
|
||||
expectedPeers []string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "retrieve peers from single group with multiple peers",
|
||||
groupIDs: []string{group1ID},
|
||||
expectedPeers: []string{peer1, peer2},
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "retrieve peers from single group with one peer",
|
||||
groupIDs: []string{group2ID},
|
||||
expectedPeers: []string{peer1},
|
||||
expectedCount: 1,
|
||||
},
|
||||
{
|
||||
name: "retrieve peers from multiple groups (with overlap)",
|
||||
groupIDs: []string{group1ID, group2ID},
|
||||
expectedPeers: []string{peer1, peer2}, // should deduplicate
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "retrieve peers from existing 'All' group",
|
||||
groupIDs: []string{"cfefqs706sqkneg59g3g"}, // All group from test data
|
||||
expectedPeers: []string{peer1, peer2},
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "retrieve peers from empty group",
|
||||
groupIDs: []string{emptyGroupID},
|
||||
expectedPeers: []string{},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "retrieve peers from non-existing group",
|
||||
groupIDs: []string{"non-existing-group"},
|
||||
expectedPeers: []string{},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "empty group IDs list",
|
||||
groupIDs: []string{},
|
||||
expectedPeers: []string{},
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "mix of existing and non-existing groups",
|
||||
groupIDs: []string{group1ID, "non-existing-group"},
|
||||
expectedPeers: []string{peer1, peer2},
|
||||
expectedCount: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
groups := []*types.Group{
|
||||
{
|
||||
ID: group1ID,
|
||||
AccountID: accountID,
|
||||
},
|
||||
{
|
||||
ID: group2ID,
|
||||
AccountID: accountID,
|
||||
},
|
||||
}
|
||||
require.NoError(t, store.CreateGroups(ctx, accountID, groups))
|
||||
|
||||
require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group1ID))
|
||||
require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer2, group1ID))
|
||||
require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group2ID))
|
||||
|
||||
peers, err := store.GetPeersByGroupIDs(ctx, accountID, tt.groupIDs)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peers, tt.expectedCount)
|
||||
|
||||
if tt.expectedCount > 0 {
|
||||
actualPeerIDs := make([]string, len(peers))
|
||||
for i, peer := range peers {
|
||||
actualPeerIDs[i] = peer.ID
|
||||
}
|
||||
assert.ElementsMatch(t, tt.expectedPeers, actualPeerIDs)
|
||||
|
||||
// Verify all returned peers belong to the correct account
|
||||
for _, peer := range peers {
|
||||
assert.Equal(t, accountID, peer.AccountID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,7 +136,6 @@ type Store interface {
|
||||
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
|
||||
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
|
||||
GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error)
|
||||
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
|
||||
@@ -169,6 +168,10 @@ type Store interface {
|
||||
GetInstallationID() string
|
||||
SaveInstallationID(ctx context.Context, ID string) error
|
||||
|
||||
// AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
|
||||
AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
|
||||
// AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
|
||||
AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
|
||||
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
|
||||
AcquireGlobalLock(ctx context.Context) func()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user