diff --git a/management/server/account.go b/management/server/account.go index 984139a12..cb32378ed 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -11,16 +11,13 @@ import ( "net/netip" "reflect" "regexp" + "slices" "strings" "sync" "time" "github.com/eko/gocache/v3/cache" cacheStore "github.com/eko/gocache/v3/store" - gocache "github.com/patrickmn/go-cache" - "github.com/rs/xid" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" @@ -35,6 +32,10 @@ import ( "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" + gocache "github.com/patrickmn/go-cache" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" ) const ( @@ -758,8 +759,13 @@ func (a *Account) GetPeer(peerID string) *nbpeer.Peer { return a.Peers[peerID] } -// SetJWTGroups to account and to user autoassigned groups +// SetJWTGroups updates the user's auto groups by synchronizing JWT groups. +// Returns true if there are changes in the JWT group membership. func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { + if len(groupsNames) == 0 { + return false + } + user, ok := a.Users[userID] if !ok { return false @@ -770,23 +776,19 @@ func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { existedGroupsByName[group.Name] = group } - // remove JWT groups from the autogroups, to sync them again - removed := 0 - jwtAutoGroups := make(map[string]struct{}) - for i, id := range user.AutoGroups { - if group, ok := a.Groups[id]; ok && group.Issued == nbgroup.GroupIssuedJWT { - jwtAutoGroups[group.Name] = struct{}{} - user.AutoGroups = append(user.AutoGroups[:i-removed], user.AutoGroups[i-removed+1:]...) - removed++ - } + newAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, a.Groups) + groupsToAdd := difference(groupsNames, maps.Keys(jwtGroupsMap)) + groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupsNames) + + // If no groups are added or removed, we should not sync account + if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { + return false } - // create JWT groups if they doesn't exist - // and all of them to the autogroups var modified bool - for _, name := range groupsNames { - group, ok := existedGroupsByName[name] - if !ok { + for _, name := range groupsToAdd { + group, exists := existedGroupsByName[name] + if !exists { group = &nbgroup.Group{ ID: xid.New().String(), Name: name, @@ -794,20 +796,20 @@ func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { } a.Groups[group.ID] = group } - // only JWT groups will be synced if group.Issued == nbgroup.GroupIssuedJWT { - user.AutoGroups = append(user.AutoGroups, group.ID) - if _, ok := jwtAutoGroups[name]; !ok { - modified = true - } - delete(jwtAutoGroups, name) + newAutoGroups = append(newAutoGroups, group.ID) + modified = true } } - // if not empty it means we removed some groups - if len(jwtAutoGroups) > 0 { + for name, id := range jwtGroupsMap { + if !slices.Contains(groupsToRemove, name) { + newAutoGroups = append(newAutoGroups, id) + continue + } modified = true } + user.AutoGroups = newAutoGroups return modified } @@ -1714,6 +1716,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat if err := am.Store.SaveAccount(account); err != nil { log.Errorf("failed to save account: %v", err) } else { + log.Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) am.updateAccountPeers(account) unlock() alreadyUnlocked = true @@ -2064,3 +2067,22 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { } return false } + +// separateGroups separates user's auto groups into non-JWT and JWT groups. +// Returns the list of standard auto groups and a map of JWT auto groups, +// where the keys are the group names and the values are the group IDs. +func separateGroups(autoGroups []string, allGroups map[string]*nbgroup.Group) ([]string, map[string]string) { + newAutoGroups := make([]string, 0) + jwtAutoGroups := make(map[string]string) // map of group name to group ID + + for _, id := range autoGroups { + if group, ok := allGroups[id]; ok { + if group.Issued == nbgroup.GroupIssuedJWT { + jwtAutoGroups[group.Name] = id + } else { + newAutoGroups = append(newAutoGroups, id) + } + } + } + return newAutoGroups, jwtAutoGroups +} diff --git a/management/server/account_test.go b/management/server/account_test.go index 38c9fabbc..a5ba8fdcf 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2175,17 +2175,33 @@ func TestAccount_SetJWTGroups(t *testing.T) { }, } - t.Run("api group already exists", func(t *testing.T) { - updated := account.SetJWTGroups("user1", []string{"group1"}) + t.Run("empty jwt groups", func(t *testing.T) { + updated := account.SetJWTGroups("user1", []string{}) assert.False(t, updated, "account should not be updated") assert.Empty(t, account.Users["user1"].AutoGroups, "auto groups must be empty") }) + t.Run("jwt match existing api group", func(t *testing.T) { + updated := account.SetJWTGroups("user1", []string{"group1"}) + assert.False(t, updated, "account should not be updated") + assert.Equal(t, 0, len(account.Users["user1"].AutoGroups)) + assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued") + }) + + t.Run("jwt match existing api group in user auto groups", func(t *testing.T) { + account.Users["user1"].AutoGroups = []string{"group1"} + + updated := account.SetJWTGroups("user1", []string{"group1"}) + assert.False(t, updated, "account should not be updated") + assert.Equal(t, 1, len(account.Users["user1"].AutoGroups)) + assert.Equal(t, account.Groups["group1"].Issued, group.GroupIssuedAPI, "group should be api issued") + }) + t.Run("add jwt group", func(t *testing.T) { updated := account.SetJWTGroups("user1", []string{"group1", "group2"}) assert.True(t, updated, "account should be updated") assert.Len(t, account.Groups, 2, "new group should be added") - assert.Len(t, account.Users["user1"].AutoGroups, 1, "new group should be added") + assert.Len(t, account.Users["user1"].AutoGroups, 2, "new group should be added") assert.Contains(t, account.Groups, account.Users["user1"].AutoGroups[0], "groups must contain group2 from user groups") })