update transaction logic

This commit is contained in:
Pascal Fischer
2024-10-04 15:17:28 +02:00
parent adf521a9d9
commit e3f3d2c1bd
3 changed files with 102 additions and 52 deletions

View File

@@ -20,6 +20,11 @@ import (
cacheStore "github.com/eko/gocache/v3/store" cacheStore "github.com/eko/gocache/v3/store"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/miekg/dns" "github.com/miekg/dns"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/base62" "github.com/netbirdio/netbird/base62"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
@@ -36,10 +41,6 @@ import (
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
) )
const ( const (
@@ -846,17 +847,7 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer {
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups. // getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups, // Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups,
// newly groups to create and an error if any occurred. // newly groups to create and an error if any occurred.
func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, userID, accountID string, groupNames []string) (bool, []string, []*nbgroup.Group, error) { func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, user *User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return false, nil, nil, err
}
groups, err := am.Store.GetAccountGroups(ctx, accountID)
if err != nil {
return false, nil, nil, err
}
existedGroupsByName := make(map[string]*nbgroup.Group) existedGroupsByName := make(map[string]*nbgroup.Group)
for _, group := range groups { for _, group := range groups {
existedGroupsByName[group.Name] = group existedGroupsByName[group.Name] = group
@@ -880,7 +871,7 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, userID
if !exists { if !exists {
group = &nbgroup.Group{ group = &nbgroup.Group{
ID: xid.New().String(), ID: xid.New().String(),
AccountID: accountID, AccountID: user.AccountID,
Name: name, Name: name,
Issued: nbgroup.GroupIssuedJWT, Issued: nbgroup.GroupIssuedJWT,
} }
@@ -1836,16 +1827,6 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
hasChanges, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, claims.UserId, accountID, jwtGroupsNames)
if err != nil {
return err
}
// skip update if no changes
if !hasChanges {
return nil
}
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, accountID) unlockPeer := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer func() { defer func() {
if unlockPeer != nil { if unlockPeer != nil {
@@ -1853,19 +1834,39 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
} }
}() }()
if err = am.Store.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil { var addNewGroups []string
return fmt.Errorf("error saving groups: %w", err) var removeOldGroups []string
} var hasChanges bool
var user *User
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthUpdate, claims.UserId)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
addNewGroups := difference(updatedAutoGroups, user.AutoGroups)
removeOldGroups := difference(user.AutoGroups, updatedAutoGroups)
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
user, err = am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
if err != nil {
return fmt.Errorf("error getting user: %w", err)
}
groups, err := am.Store.GetAccountGroups(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, user, groups, jwtGroupsNames)
if err != nil {
return fmt.Errorf("error getting JWT groups changes: %w", err)
}
hasChanges = changed
// skip update if no changes
if !changed {
return nil
}
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}
addNewGroups = difference(updatedAutoGroups, user.AutoGroups)
removeOldGroups = difference(user.AutoGroups, updatedAutoGroups)
user.AutoGroups = updatedAutoGroups user.AutoGroups = updatedAutoGroups
if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil { if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil {
return fmt.Errorf("error saving user: %w", err) return fmt.Errorf("error saving user: %w", err)
@@ -1873,7 +1874,22 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
// Propagate changes to peers if group propagation is enabled // Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled { if settings.GroupsPropagationEnabled {
updatedGroups, err := am.updateUserPeersInGroups(ctx, accountID, claims.UserId, addNewGroups, removeOldGroups) groups, err = transaction.GetAccountGroups(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account groups: %w", err)
}
groupsMap := make(map[string]*nbgroup.Group, len(groups))
for _, group := range groups {
groupsMap[group.ID] = group
}
peers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, claims.UserId)
if err != nil {
return fmt.Errorf("error getting user peers: %w", err)
}
updatedGroups, err := am.updateUserPeersInGroups(ctx, groupsMap, peers, addNewGroups, removeOldGroups)
if err != nil { if err != nil {
return fmt.Errorf("error modifying user peers in groups: %w", err) return fmt.Errorf("error modifying user peers in groups: %w", err)
} }
@@ -1895,6 +1911,10 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return err return err
} }
if !hasChanges {
return nil
}
for _, g := range addNewGroups { for _, g := range addNewGroups {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
if err != nil { if err != nil {

View File

@@ -1185,3 +1185,37 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 2, setupKey.UsedTimes) assert.Equal(t, 2, setupKey.UsedTimes)
} }
func TestSqlite_CreateAndGetObjcetInTransaction(t *testing.T) {
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
t.Cleanup(cleanup)
if err != nil {
t.Fatal(err)
}
group := &nbgroup.Group{
ID: "group-id",
AccountID: "account-id",
Name: "group-name",
Issued: "api",
Peers: nil,
}
store.ExecuteInTransaction(context.Background(), func(transaction Store) error {
err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group)
if err != nil {
t.Fatal("failed to save group")
return err
}
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
if err != nil {
t.Fatal("failed to get group")
return err
}
t.Logf("group: %v", group)
return nil
})
}

View File

@@ -8,6 +8,8 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
@@ -15,7 +17,6 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
) )
const ( const (
@@ -1255,36 +1256,31 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun
} }
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. // updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
func (am *DefaultAccountManager) updateUserPeersInGroups(ctx context.Context, accountID, userID string, groupsToAdd, func (am *DefaultAccountManager) updateUserPeersInGroups(ctx context.Context, accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd,
groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) { groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) {
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
return return
} }
peers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, accountID, userID)
if err != nil {
return nil, err
}
userPeerIDMap := make(map[string]struct{}, len(peers)) userPeerIDMap := make(map[string]struct{}, len(peers))
for _, peer := range peers { for _, peer := range peers {
userPeerIDMap[peer.ID] = struct{}{} userPeerIDMap[peer.ID] = struct{}{}
} }
for _, gid := range groupsToAdd { for _, gid := range groupsToAdd {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, gid, accountID) group, ok := accountGroups[gid]
if err != nil { if !ok {
return nil, err return nil, errors.New("group not found")
} }
addUserPeersToGroup(userPeerIDMap, group) addUserPeersToGroup(userPeerIDMap, group)
groupsToUpdate = append(groupsToUpdate, group) groupsToUpdate = append(groupsToUpdate, group)
} }
for _, gid := range groupsToRemove { for _, gid := range groupsToRemove {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, gid, accountID) group, ok := accountGroups[gid]
if err != nil { if !ok {
return nil, err return nil, errors.New("group not found")
} }
removeUserPeersFromGroup(userPeerIDMap, group) removeUserPeersFromGroup(userPeerIDMap, group)
groupsToUpdate = append(groupsToUpdate, group) groupsToUpdate = append(groupsToUpdate, group)