mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 17:56:39 +00:00
update transaction logic
This commit is contained in:
@@ -20,6 +20,11 @@ import (
|
|||||||
cacheStore "github.com/eko/gocache/v3/store"
|
cacheStore "github.com/eko/gocache/v3/store"
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
gocache "github.com/patrickmn/go-cache"
|
||||||
|
"github.com/rs/xid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/base62"
|
"github.com/netbirdio/netbird/base62"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
@@ -36,10 +41,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
gocache "github.com/patrickmn/go-cache"
|
|
||||||
"github.com/rs/xid"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -846,17 +847,7 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer {
|
|||||||
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
|
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
|
||||||
// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups,
|
// Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups,
|
||||||
// newly groups to create and an error if any occurred.
|
// newly groups to create and an error if any occurred.
|
||||||
func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, userID, accountID string, groupNames []string) (bool, []string, []*nbgroup.Group, error) {
|
func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, user *User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) {
|
||||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
|
||||||
if err != nil {
|
|
||||||
return false, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
groups, err := am.Store.GetAccountGroups(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return false, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
existedGroupsByName := make(map[string]*nbgroup.Group)
|
existedGroupsByName := make(map[string]*nbgroup.Group)
|
||||||
for _, group := range groups {
|
for _, group := range groups {
|
||||||
existedGroupsByName[group.Name] = group
|
existedGroupsByName[group.Name] = group
|
||||||
@@ -880,7 +871,7 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(ctx context.Context, userID
|
|||||||
if !exists {
|
if !exists {
|
||||||
group = &nbgroup.Group{
|
group = &nbgroup.Group{
|
||||||
ID: xid.New().String(),
|
ID: xid.New().String(),
|
||||||
AccountID: accountID,
|
AccountID: user.AccountID,
|
||||||
Name: name,
|
Name: name,
|
||||||
Issued: nbgroup.GroupIssuedJWT,
|
Issued: nbgroup.GroupIssuedJWT,
|
||||||
}
|
}
|
||||||
@@ -1836,16 +1827,6 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
|
|
||||||
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||||
|
|
||||||
hasChanges, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, claims.UserId, accountID, jwtGroupsNames)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// skip update if no changes
|
|
||||||
if !hasChanges {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer func() {
|
defer func() {
|
||||||
if unlockPeer != nil {
|
if unlockPeer != nil {
|
||||||
@@ -1853,19 +1834,39 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err = am.Store.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
|
var addNewGroups []string
|
||||||
return fmt.Errorf("error saving groups: %w", err)
|
var removeOldGroups []string
|
||||||
}
|
var hasChanges bool
|
||||||
|
var user *User
|
||||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthUpdate, claims.UserId)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error getting user: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
addNewGroups := difference(updatedAutoGroups, user.AutoGroups)
|
|
||||||
removeOldGroups := difference(user.AutoGroups, updatedAutoGroups)
|
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
user, err = am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
groups, err := am.Store.GetAccountGroups(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting account groups: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(ctx, user, groups, jwtGroupsNames)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting JWT groups changes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hasChanges = changed
|
||||||
|
// skip update if no changes
|
||||||
|
if !changed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
|
||||||
|
return fmt.Errorf("error saving groups: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
addNewGroups = difference(updatedAutoGroups, user.AutoGroups)
|
||||||
|
removeOldGroups = difference(user.AutoGroups, updatedAutoGroups)
|
||||||
|
|
||||||
user.AutoGroups = updatedAutoGroups
|
user.AutoGroups = updatedAutoGroups
|
||||||
if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil {
|
if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil {
|
||||||
return fmt.Errorf("error saving user: %w", err)
|
return fmt.Errorf("error saving user: %w", err)
|
||||||
@@ -1873,7 +1874,22 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
|
|
||||||
// Propagate changes to peers if group propagation is enabled
|
// Propagate changes to peers if group propagation is enabled
|
||||||
if settings.GroupsPropagationEnabled {
|
if settings.GroupsPropagationEnabled {
|
||||||
updatedGroups, err := am.updateUserPeersInGroups(ctx, accountID, claims.UserId, addNewGroups, removeOldGroups)
|
groups, err = transaction.GetAccountGroups(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting account groups: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
groupsMap := make(map[string]*nbgroup.Group, len(groups))
|
||||||
|
for _, group := range groups {
|
||||||
|
groupsMap[group.ID] = group
|
||||||
|
}
|
||||||
|
|
||||||
|
peers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, claims.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting user peers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
updatedGroups, err := am.updateUserPeersInGroups(ctx, groupsMap, peers, addNewGroups, removeOldGroups)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error modifying user peers in groups: %w", err)
|
return fmt.Errorf("error modifying user peers in groups: %w", err)
|
||||||
}
|
}
|
||||||
@@ -1895,6 +1911,10 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !hasChanges {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
for _, g := range addNewGroups {
|
for _, g := range addNewGroups {
|
||||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1185,3 +1185,37 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 2, setupKey.UsedTimes)
|
assert.Equal(t, 2, setupKey.UsedTimes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSqlite_CreateAndGetObjcetInTransaction(t *testing.T) {
|
||||||
|
store, cleanup, err := NewSqliteTestStore(context.Background(), t.TempDir(), "testdata/extended-store.sqlite")
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
group := &nbgroup.Group{
|
||||||
|
ID: "group-id",
|
||||||
|
AccountID: "account-id",
|
||||||
|
Name: "group-name",
|
||||||
|
Issued: "api",
|
||||||
|
Peers: nil,
|
||||||
|
}
|
||||||
|
store.ExecuteInTransaction(context.Background(), func(transaction Store) error {
|
||||||
|
err := transaction.SaveGroup(context.Background(), LockingStrengthUpdate, group)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to save group")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("failed to get group")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("group: %v", group)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
@@ -15,7 +17,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -1255,36 +1256,31 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun
|
|||||||
}
|
}
|
||||||
|
|
||||||
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
|
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
|
||||||
func (am *DefaultAccountManager) updateUserPeersInGroups(ctx context.Context, accountID, userID string, groupsToAdd,
|
func (am *DefaultAccountManager) updateUserPeersInGroups(ctx context.Context, accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd,
|
||||||
groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) {
|
groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) {
|
||||||
|
|
||||||
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
|
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
peers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, accountID, userID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
userPeerIDMap := make(map[string]struct{}, len(peers))
|
userPeerIDMap := make(map[string]struct{}, len(peers))
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
userPeerIDMap[peer.ID] = struct{}{}
|
userPeerIDMap[peer.ID] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, gid := range groupsToAdd {
|
for _, gid := range groupsToAdd {
|
||||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
|
group, ok := accountGroups[gid]
|
||||||
if err != nil {
|
if !ok {
|
||||||
return nil, err
|
return nil, errors.New("group not found")
|
||||||
}
|
}
|
||||||
addUserPeersToGroup(userPeerIDMap, group)
|
addUserPeersToGroup(userPeerIDMap, group)
|
||||||
groupsToUpdate = append(groupsToUpdate, group)
|
groupsToUpdate = append(groupsToUpdate, group)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, gid := range groupsToRemove {
|
for _, gid := range groupsToRemove {
|
||||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
|
group, ok := accountGroups[gid]
|
||||||
if err != nil {
|
if !ok {
|
||||||
return nil, err
|
return nil, errors.New("group not found")
|
||||||
}
|
}
|
||||||
removeUserPeersFromGroup(userPeerIDMap, group)
|
removeUserPeersFromGroup(userPeerIDMap, group)
|
||||||
groupsToUpdate = append(groupsToUpdate, group)
|
groupsToUpdate = append(groupsToUpdate, group)
|
||||||
|
|||||||
Reference in New Issue
Block a user