diff --git a/management/server/account_test.go b/management/server/account_test.go index 3f2bf0f91..ce405b275 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1240,9 +1240,10 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { manager, account, peer1, peer2, _ := setupNetworkMapTest(t) group := types.Group{ - ID: "groupA", - Name: "GroupA", - Peers: []string{peer1.ID, peer2.ID}, + AccountID: account.Id, + ID: "groupA", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID}, } if err := manager.SaveGroup(context.Background(), account.Id, userID, &group, true); err != nil { t.Errorf("save group: %v", err) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index bdd04f7ca..5d7aded3b 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1459,7 +1459,6 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { } func Test_RegisterPeerRollbackOnFailure(t *testing.T) { - t.Setenv("NETBIRD_STORE_ENGINE", "postgres") engine := os.Getenv("NETBIRD_STORE_ENGINE") if engine == "sqlite" || engine == "" { t.Skip("Skipping test because sqlite test store is not respecting foreign keys") @@ -1766,47 +1765,47 @@ func TestPeerAccountPeersUpdate(t *testing.T) { }) // Adding peer to unlinked group should not update account peers and not send peer update - // t.Run("adding peer to unlinked group", func(t *testing.T) { - // done := make(chan struct{}) - // go func() { - // peerShouldNotReceiveUpdate(t, updMsg) - // close(done) - // }() - // - // key, err := wgtypes.GeneratePrivateKey() - // require.NoError(t, err) - // - // expectedPeerKey := key.PublicKey().String() - // peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{ - // Key: expectedPeerKey, - // Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - // }) - // require.NoError(t, err) - // - // select { - // case <-done: - // case <-time.After(time.Second): - // t.Error("timeout waiting for peerShouldNotReceiveUpdate") - // } - // }) + t.Run("adding peer to unlinked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) // + close(done) + }() + + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + expectedPeerKey := key.PublicKey().String() + peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{ + Key: expectedPeerKey, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + }) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) // Deleting peer with unlinked group should not update account peers and not send peer update - // t.Run("deleting peer with unlinked group", func(t *testing.T) { - // done := make(chan struct{}) - // go func() { - // peerShouldNotReceiveUpdate(t, updMsg) - // close(done) - // }() - // - // err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID) - // require.NoError(t, err) - // - // select { - // case <-done: - // case <-time.After(time.Second): - // t.Error("timeout waiting for peerShouldNotReceiveUpdate") - // } - // }) + t.Run("deleting peer with unlinked group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) // Updating peer label should update account peers and send peer update t.Run("updating peer label", func(t *testing.T) { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 390a6c220..c05c7e75c 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -208,7 +208,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro result = tx. Session(&gorm.Session{FullSaveAssociations: true}). - // Clauses(clause.OnConflict{UpdateAll: true}). + Clauses(clause.OnConflict{UpdateAll: true}). Create(account) if result.Error != nil { return result.Error @@ -459,6 +459,10 @@ func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, return nil } + for _, g := range groups { + g.StoreGroupPeers() + } + return s.db.Transaction(func(tx *gorm.DB) error { result := tx. Clauses( @@ -1773,6 +1777,15 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren // we may need to reconsider changing the types. query := tx.Preload(clause.Associations) + switch s.storeEngine { + case types.PostgresStoreEngine: + query = query.Order("json_array_length(peers::json) DESC") + case types.MysqlStoreEngine: + query = query.Order("JSON_LENGTH(JSON_EXTRACT(peers, \"$\")) DESC") + default: + query = query.Order("json_array_length(peers) DESC") + } + result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName) if err := result.Error; err != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 887f23ba3..bdeba4653 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -2615,7 +2615,6 @@ func TestSqlStore_AddPeerToAccount(t *testing.T) { } func TestSqlStore_GetPeerGroups(t *testing.T) { - t.Setenv("NETBIRD_STORE_ENGINE", "postgres") store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err)