diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index fd168ef9b..9b2f25920 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -186,6 +186,10 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro generateAccountSQLTypes(account) + for _, group := range account.GroupsG { + group.StoreGroupPeers() + } + err := s.db.Transaction(func(tx *gorm.DB) error { result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) if result.Error != nil { @@ -204,7 +208,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro result = tx. Session(&gorm.Session{FullSaveAssociations: true}). - Clauses(clause.OnConflict{UpdateAll: true}). + // Clauses(clause.OnConflict{UpdateAll: true}). Create(account) if result.Error != nil { return result.Error @@ -247,7 +251,7 @@ func generateAccountSQLTypes(account *types.Account) { for id, group := range account.Groups { group.ID = id - account.GroupsG = append(account.GroupsG, *group) + account.GroupsG = append(account.GroupsG, group) } for id, route := range account.Routes { @@ -1436,20 +1440,30 @@ func (s *SqlStore) RemoveResourceFromGroup(ctx context.Context, accountId string // GetPeerGroups retrieves all groups assigned to a specific peer in a given account. func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) { - tx := s.db + tx := s.db.Debug() if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } - var groups []*types.Group - query := tx. - Joins("JOIN group_peers ON group_peers.group_id = groups.id"). - Where("group_peers.peer_id = ?", peerId). - Preload(clause.Associations). - Find(&groups) + var groupIDs []string + err := s.db. + Table("group_peers"). + Where("peer_id = ?", peerId). + Pluck("group_id", &groupIDs).Error + if err != nil { + return nil, err + } + if len(groupIDs) == 0 { + return []*types.Group{}, nil // no matches + } - if query.Error != nil { - return nil, query.Error + var groups []*types.Group + err = s.db. + Where("id IN ?", groupIDs). + Preload("GroupPeers"). + Find(&groups).Error + if err != nil { + return nil, err } for _, group := range groups { diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 9b5101c79..254062415 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -2611,6 +2611,7 @@ func TestSqlStore_AddPeerToAccount(t *testing.T) { } func TestSqlStore_GetPeerGroups(t *testing.T) { + t.Setenv("NETBIRD_STORE_ENGINE", "postgres") store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) diff --git a/management/server/types/account.go b/management/server/types/account.go index 5a62ee4c6..4215b8b07 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -73,7 +73,7 @@ type Account struct { Users map[string]*User `gorm:"-"` UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` Groups map[string]*Group `gorm:"-"` - GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"` + GroupsG []*Group `json:"-" gorm:"foreignKey:AccountID;references:id"` Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` Routes map[route.ID]*route.Route `gorm:"-"` RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"`