mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-20 23:59:55 +00:00
refactor getAccountFromToken
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
@@ -1625,13 +1625,21 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai
|
|||||||
}
|
}
|
||||||
|
|
||||||
// redeemInvite checks whether user has been invited and redeems the invite
|
// redeemInvite checks whether user has been invited and redeems the invite
|
||||||
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error {
|
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID, userID string) error {
|
||||||
// only possible with the enabled IdP manager
|
// only possible with the enabled IdP manager
|
||||||
if am.idpManager == nil {
|
if am.idpManager == nil {
|
||||||
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
|
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(ctx, accountID)
|
||||||
|
unlock()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
user, err := am.lookupUserInCache(ctx, userID, account)
|
user, err := am.lookupUserInCache(ctx, userID, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -1751,94 +1759,112 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
|
|||||||
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
newAcc, err := am.getAccountWithAuthorizationClaims(ctx, claims)
|
accountID, err := am.getAccountWithAuthorizationClaims(ctx, claims)
|
||||||
if err != nil {
|
|
||||||
return "", "", err
|
|
||||||
}
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id)
|
|
||||||
alreadyUnlocked := false
|
|
||||||
defer func() {
|
|
||||||
if !alreadyUnlocked {
|
|
||||||
unlock()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, newAcc.Id)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
user := account.Users[claims.UserId]
|
user, err := am.Store.GetUserByUserID(ctx, claims.UserId)
|
||||||
if user == nil {
|
if err != nil {
|
||||||
// this is not really possible because we got an account by user ID
|
// this is not really possible because we got an account by user ID
|
||||||
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !user.IsServiceUser && claims.Invited {
|
if !user.IsServiceUser && claims.Invited {
|
||||||
err = am.redeemInvite(ctx, account, claims.UserId)
|
err = am.redeemInvite(ctx, accountID, user.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if account.Settings.JWTGroupsEnabled {
|
if err = am.syncJWTGroups(ctx, claims, accountID); err != nil {
|
||||||
if account.Settings.JWTGroupsClaimName == "" {
|
return "", "", err
|
||||||
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
|
}
|
||||||
|
|
||||||
return account.Id, user.Id, nil
|
return accountID, user.Id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
||||||
|
// and propagates changes to peers if group propagation is enabled.
|
||||||
|
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims, accountID string) error {
|
||||||
|
settings, err := am.Store.GetAccountSettings(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !settings.JWTGroupsEnabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if settings.JWTGroupsClaimName == "" {
|
||||||
|
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||||
|
|
||||||
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
|
defer unlock()
|
||||||
|
|
||||||
|
account, err := am.Store.GetAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := account.FindUser(claims.UserId)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
oldGroups := make([]string, len(user.AutoGroups))
|
||||||
|
copy(oldGroups, user.AutoGroups)
|
||||||
|
|
||||||
|
// Update the account if group membership changes
|
||||||
|
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) {
|
||||||
|
addNewGroups := difference(user.AutoGroups, oldGroups)
|
||||||
|
removeOldGroups := difference(oldGroups, user.AutoGroups)
|
||||||
|
|
||||||
|
if settings.GroupsPropagationEnabled {
|
||||||
|
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
|
||||||
|
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
|
||||||
|
account.Network.IncSerial()
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtGroupsNames := extractJWTGroups(ctx, account.Settings.JWTGroupsClaimName, claims)
|
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to save account: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
oldGroups := make([]string, len(user.AutoGroups))
|
// Propagate changes to peers if group propagation is enabled
|
||||||
copy(oldGroups, user.AutoGroups)
|
if settings.GroupsPropagationEnabled {
|
||||||
// if groups were added or modified, save the account
|
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||||
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) {
|
am.updateAccountPeers(ctx, account)
|
||||||
if account.Settings.GroupsPropagationEnabled {
|
}
|
||||||
if user, err := account.FindUser(claims.UserId); err == nil {
|
|
||||||
addNewGroups := difference(user.AutoGroups, oldGroups)
|
for _, g := range addNewGroups {
|
||||||
removeOldGroups := difference(oldGroups, user.AutoGroups)
|
if group := account.GetGroup(g); group != nil {
|
||||||
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
|
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
|
||||||
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
|
map[string]any{
|
||||||
account.Network.IncSerial()
|
"group": group.Name,
|
||||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
"group_id": group.ID,
|
||||||
log.WithContext(ctx).Errorf("failed to save account: %v", err)
|
"is_service_user": user.IsServiceUser,
|
||||||
} else {
|
"user_name": user.ServiceUserName})
|
||||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
}
|
||||||
am.updateAccountPeers(ctx, account)
|
}
|
||||||
unlock()
|
|
||||||
alreadyUnlocked = true
|
for _, g := range removeOldGroups {
|
||||||
for _, g := range addNewGroups {
|
if group := account.GetGroup(g); group != nil {
|
||||||
if group := account.GetGroup(g); group != nil {
|
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
|
||||||
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
|
map[string]any{
|
||||||
map[string]any{
|
"group": group.Name,
|
||||||
"group": group.Name,
|
"group_id": group.ID,
|
||||||
"group_id": group.ID,
|
"is_service_user": user.IsServiceUser,
|
||||||
"is_service_user": user.IsServiceUser,
|
"user_name": user.ServiceUserName})
|
||||||
"user_name": user.ServiceUserName})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, g := range removeOldGroups {
|
|
||||||
if group := account.GetGroup(g); group != nil {
|
|
||||||
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
|
|
||||||
map[string]any{
|
|
||||||
"group": group.Name,
|
|
||||||
"group_id": group.ID,
|
|
||||||
"is_service_user": user.IsServiceUser,
|
|
||||||
"user_name": user.ServiceUserName})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to save account: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return account.Id, user.Id, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
|
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
|
||||||
@@ -1858,27 +1884,31 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
|
|||||||
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
|
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
|
||||||
//
|
//
|
||||||
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
|
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
|
||||||
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||||
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
||||||
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
||||||
if claims.UserId == "" {
|
if claims.UserId == "" {
|
||||||
return nil, fmt.Errorf("user ID is empty")
|
return "", fmt.Errorf("user ID is empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
// if Account ID is part of the claims
|
// if Account ID is part of the claims
|
||||||
// it means that we've already classified the domain and user has an account
|
// it means that we've already classified the domain and user has an account
|
||||||
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
||||||
return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
|
account, err := am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return account.Id, nil
|
||||||
} else if claims.AccountId != "" {
|
} else if claims.AccountId != "" {
|
||||||
accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId)
|
accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
if _, ok := accountFromID.Users[claims.UserId]; !ok {
|
if _, ok := accountFromID.Users[claims.UserId]; !ok {
|
||||||
return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||||
}
|
}
|
||||||
if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain {
|
if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain {
|
||||||
return accountFromID, nil
|
return accountFromID.Id, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1893,7 +1923,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
|||||||
// if NotFound we are good to continue, otherwise return error
|
// if NotFound we are good to continue, otherwise return error
|
||||||
e, ok := status.FromError(err)
|
e, ok := status.FromError(err)
|
||||||
if !ok || e.Type() != status.NotFound {
|
if !ok || e.Type() != status.NotFound {
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1903,7 +1933,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
|||||||
defer unlockAccount()
|
defer unlockAccount()
|
||||||
account, err = am.Store.GetAccountByUser(ctx, claims.UserId)
|
account, err = am.Store.GetAccountByUser(ctx, claims.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
||||||
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
||||||
@@ -1914,22 +1944,27 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
|||||||
|
|
||||||
err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims)
|
err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
return account, nil
|
return account.Id, nil
|
||||||
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||||
if domainAccount != nil {
|
if domainAccount != nil {
|
||||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id)
|
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id)
|
||||||
defer unlockAccount()
|
defer unlockAccount()
|
||||||
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return am.handleNewUserAccount(ctx, domainAccount, claims)
|
|
||||||
|
account, err = am.handleNewUserAccount(ctx, domainAccount, claims)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return account.Id, nil
|
||||||
} else {
|
} else {
|
||||||
// other error
|
// other error
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -645,8 +645,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
|||||||
testCase.inputClaims.AccountId = initAccount.Id
|
testCase.inputClaims.AccountId = initAccount.Id
|
||||||
}
|
}
|
||||||
|
|
||||||
account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
|
accountID, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
|
||||||
require.NoError(t, err, "support function failed")
|
require.NoError(t, err, "support function failed")
|
||||||
|
|
||||||
|
account, err := manager.GetAccountByUserOrAccountID(context.Background(), "", accountID, "")
|
||||||
|
require.NoError(t, err, "get account by account id")
|
||||||
|
|
||||||
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
|
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
|
||||||
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
|
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
|
||||||
|
|
||||||
@@ -685,8 +689,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.Run("JWT groups disabled", func(t *testing.T) {
|
t.Run("JWT groups disabled", func(t *testing.T) {
|
||||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
accountID, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
||||||
require.NoError(t, err, "get account by token failed")
|
require.NoError(t, err, "get account by token failed")
|
||||||
|
|
||||||
|
account, err := manager.GetAccountByUserOrAccountID(context.Background(), "", accountID, "")
|
||||||
|
require.NoError(t, err, "get account by account id")
|
||||||
|
|
||||||
require.Len(t, account.Groups, 1, "only ALL group should exists")
|
require.Len(t, account.Groups, 1, "only ALL group should exists")
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -696,8 +704,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
|||||||
require.NoError(t, err, "save account failed")
|
require.NoError(t, err, "save account failed")
|
||||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||||
|
|
||||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
accountID, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
||||||
require.NoError(t, err, "get account by token failed")
|
require.NoError(t, err, "get account by token failed")
|
||||||
|
|
||||||
|
account, err := manager.GetAccountByUserOrAccountID(context.Background(), "", accountID, "")
|
||||||
|
require.NoError(t, err, "get account by account id")
|
||||||
|
|
||||||
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
|
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -708,8 +720,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
|||||||
require.NoError(t, err, "save account failed")
|
require.NoError(t, err, "save account failed")
|
||||||
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
|
||||||
|
|
||||||
account, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
accountID, _, err := manager.GetAccountFromToken(context.Background(), claims)
|
||||||
require.NoError(t, err, "get account by token failed")
|
require.NoError(t, err, "get account by token failed")
|
||||||
|
|
||||||
|
account, err := manager.GetAccountByUserOrAccountID(context.Background(), "", accountID, "")
|
||||||
|
require.NoError(t, err, "get account by account id")
|
||||||
|
|
||||||
require.Len(t, account.Groups, 3, "groups should be added to the account")
|
require.Len(t, account.Groups, 3, "groups should be added to the account")
|
||||||
|
|
||||||
groupsByNames := map[string]*group.Group{}
|
groupsByNames := map[string]*group.Group{}
|
||||||
|
|||||||
Reference in New Issue
Block a user