add groups migration

This commit is contained in:
Pascal Fischer
2025-07-02 19:09:42 +02:00
parent 57961afe95
commit e23282b92c
12 changed files with 530 additions and 140 deletions

View File

@@ -96,7 +96,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
return nil, fmt.Errorf("migratePreAuto: %w", err)
}
err = db.AutoMigrate(
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{},
&types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
@@ -455,19 +455,34 @@ func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength,
return nil
}
result := s.db.
Clauses(
clause.Locking{Strength: string(lockStrength)},
clause.OnConflict{
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
UpdateAll: true,
},
).
Create(&groups)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
}
return nil
return s.db.Transaction(func(tx *gorm.DB) error {
for _, g := range groups {
g.StoreGroupPeers()
if err := tx.Model(&g).
Association("GroupPeers").
Replace(g.GroupPeers); err != nil {
log.WithContext(ctx).Errorf("failed to save group peers to store: %s", err)
return status.Errorf(status.Internal, "failed to save group peers to store")
}
}
result := tx.Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(
clause.Locking{Strength: string(lockStrength)},
clause.OnConflict{
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
UpdateAll: true,
},
).
Save(&groups)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save groups to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to save groups to store")
}
return nil
})
}
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
@@ -646,7 +661,7 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
}
var groups []*types.Group
result := tx.Find(&groups, accountIDCondition, accountID)
result := tx.Preload(clause.Associations).Find(&groups, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@@ -655,6 +670,10 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
return nil, status.Errorf(status.Internal, "failed to get account groups from the store")
}
for _, g := range groups {
g.LoadGroupPeers()
}
return groups, nil
}
@@ -669,6 +688,7 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
likePattern := `%"ID":"` + resourceID + `"%`
result := tx.
Preload(clause.Associations).
Where("resources LIKE ?", likePattern).
Find(&groups)
@@ -679,6 +699,10 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
return nil, result.Error
}
for _, g := range groups {
g.LoadGroupPeers()
}
return groups, nil
}
@@ -739,7 +763,8 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
var account types.Account
result := s.db.Model(&account).
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
Preload("GroupsG.GroupPeers"). // have to be specifies as this is nester reference
Preload(clause.Associations).
First(&account, idQueryCondition, accountID)
if result.Error != nil {
@@ -784,6 +809,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
for _, group := range account.GroupsG {
group.LoadGroupPeers()
account.Groups[group.ID] = group.Copy()
}
account.GroupsG = nil
@@ -1285,55 +1311,71 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
}
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error {
var group types.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&group, "account_id = ? AND name = ?", accountID, "All")
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account")
}
return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
var groupID string
_ = s.db.Model(types.Group{}).
Select("id").
Where("account_id = ? AND name = ?", accountID, "All").
Limit(1).
Scan(&groupID)
if groupID == "" {
return status.Errorf(status.NotFound, "group 'All' not found for account %s", accountID)
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerID {
return nil
}
}
err := s.db.Create(&types.GroupPeer{
GroupID: groupID,
PeerID: peerID,
}).Error
group.Peers = append(group.Peers, peerID)
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
if err != nil {
return status.Errorf(status.Internal, "error adding peer to group 'All': %v", err)
}
return nil
}
// AddPeerToGroup adds a peer to a group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error {
var group types.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountAndIDQueryCondition, accountId, groupID).
First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewGroupNotFoundError(groupID)
}
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
// AddPeerToGroup adds a peer to a group
func (s *SqlStore) AddPeerToGroup(ctx context.Context, peerID string, groupID string) error {
peer := &types.GroupPeer{
GroupID: groupID,
PeerID: peerID,
}
for _, existingPeerID := range group.Peers {
if existingPeerID == peerId {
return nil
}
err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
DoNothing: true,
}).Create(peer).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to add peer %s to group %s: %v", peerID, groupID, err)
return status.Errorf(status.Internal, "failed to add peer to group")
}
group.Peers = append(group.Peers, peerId)
return nil
}
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil {
return status.Errorf(status.Internal, "issue updating group: %s", err)
// RemovePeerFromGroup removes a peer from a group
func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, groupID string) error {
err := s.db.WithContext(ctx).
Delete(&types.GroupPeer{}, "group_id = ? AND peer_id = ?", groupID, peerID).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to remove peer %s from group %s: %v", peerID, groupID, err)
return status.Errorf(status.Internal, "failed to remove peer from group")
}
return nil
}
// RemovePeerFromAllGroups removes a peer from all groups
func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error {
err := s.db.WithContext(ctx).
Delete(&types.GroupPeer{}, "peer_id = ?", peerID).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to remove peer %s from all groups: %v", peerID, err)
return status.Errorf(status.Internal, "failed to remove peer from all groups")
}
return nil
@@ -1401,12 +1443,19 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng
var groups []*types.Group
query := tx.
Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId))
Joins("JOIN group_peers ON group_peers.group_id = groups.id").
Where("group_peers.peer_id = ?", peerId).
Preload(clause.Associations).
Find(&groups)
if query.Error != nil {
return nil, query.Error
}
for _, group := range groups {
group.LoadGroupPeers()
}
return groups, nil
}
@@ -1455,7 +1504,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error {
if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(peer).Error; err != nil {
if err := s.db.Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
}
@@ -1692,7 +1741,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
}
var group *types.Group
result := tx.First(&group, accountAndIDQueryCondition, accountID, groupID)
result := tx.Preload(clause.Associations).First(&group, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewGroupNotFoundError(groupID)
@@ -1701,6 +1750,8 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
return nil, status.Errorf(status.Internal, "failed to get group from store")
}
group.LoadGroupPeers()
return group, nil
}
@@ -1717,15 +1768,6 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
// we may need to reconsider changing the types.
query := tx.Preload(clause.Associations)
switch s.storeEngine {
case types.PostgresStoreEngine:
query = query.Order("json_array_length(peers::json) DESC")
case types.MysqlStoreEngine:
query = query.Order("JSON_LENGTH(JSON_EXTRACT(peers, \"$\")) DESC")
default:
query = query.Order("json_array_length(peers) DESC")
}
result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -1745,7 +1787,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
}
var groups []*types.Group
result := tx.Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
result := tx.Preload(clause.Associations).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store")
@@ -1753,6 +1795,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
groupsMap := make(map[string]*types.Group)
for _, group := range groups {
group.LoadGroupPeers()
groupsMap[group.ID] = group
}
@@ -1761,17 +1804,34 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
// SaveGroup saves a group to the store.
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error)
if group == nil {
return status.Errorf(status.InvalidArgument, "group is nil")
}
group.StoreGroupPeers()
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
if err := tx.Model(group).Association("Peers").Replace(group.Peers); err != nil {
log.WithContext(ctx).Errorf("failed to replace peers for group %s: %v", group.ID, err)
return status.Errorf(status.Internal, "failed to sync group peers")
}
if err := tx.Save(group).Error; err != nil {
log.WithContext(ctx).Errorf("failed to save group to store: %v", err)
return status.Errorf(status.Internal, "failed to save group to store")
}
return nil
}
// DeleteGroup deletes a group from the database.
func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Select(clause.Associations).
Delete(&types.Group{}, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error)
@@ -1788,6 +1848,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength
// DeleteGroups deletes groups from the database.
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error {
result := s.db.Clauses(clause.Locking{Strength: string(strength)}).
Select(clause.Associations).
Delete(&types.Group{}, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error)