mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-27 20:56:44 +00:00
Merge branch 'main' into chore/benchmark-with-large-runner
This commit is contained in:
@@ -2,11 +2,8 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
b64 "encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -24,14 +21,13 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/base62"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
@@ -77,13 +73,10 @@ type AccountManager interface {
|
||||
GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error)
|
||||
AccountExists(ctx context.Context, accountID string) (bool, error)
|
||||
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
|
||||
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||
GetPATInfo(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, string, string, error)
|
||||
GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
||||
DeleteAccount(ctx context.Context, accountID, userID string) error
|
||||
MarkPATUsed(ctx context.Context, tokenID string) error
|
||||
GetUserByID(ctx context.Context, id string) (*types.User, error)
|
||||
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error)
|
||||
GetUserFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error)
|
||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
|
||||
@@ -150,6 +143,7 @@ type AccountManager interface {
|
||||
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
|
||||
UpdateAccountPeers(ctx context.Context, accountID string)
|
||||
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error
|
||||
}
|
||||
|
||||
type DefaultAccountManager struct {
|
||||
@@ -954,11 +948,11 @@ func (am *DefaultAccountManager) removeUserFromCache(ctx context.Context, accoun
|
||||
}
|
||||
|
||||
// updateAccountDomainAttributesIfNotUpToDate updates the account domain attributes if they are not up to date and then, saves the account changes
|
||||
func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims,
|
||||
func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx context.Context, accountID string, userAuth nbcontext.UserAuth,
|
||||
primaryDomain bool,
|
||||
) error {
|
||||
if claims.Domain == "" {
|
||||
log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims)
|
||||
if userAuth.Domain == "" {
|
||||
log.WithContext(ctx).Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", userAuth)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -971,11 +965,11 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
|
||||
return err
|
||||
}
|
||||
|
||||
if domainIsUpToDate(accountDomain, domainCategory, claims) {
|
||||
if domainIsUpToDate(accountDomain, domainCategory, userAuth) {
|
||||
return nil
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error getting user: %v", err)
|
||||
return err
|
||||
@@ -984,13 +978,13 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
|
||||
newDomain := accountDomain
|
||||
newCategoty := domainCategory
|
||||
|
||||
lowerDomain := strings.ToLower(claims.Domain)
|
||||
lowerDomain := strings.ToLower(userAuth.Domain)
|
||||
if accountDomain != lowerDomain && user.HasAdminPower() {
|
||||
newDomain = lowerDomain
|
||||
}
|
||||
|
||||
if accountDomain == lowerDomain {
|
||||
newCategoty = claims.DomainCategory
|
||||
newCategoty = userAuth.DomainCategory
|
||||
}
|
||||
|
||||
return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain)
|
||||
@@ -1006,16 +1000,16 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||
ctx context.Context,
|
||||
userAccountID string,
|
||||
domainAccountID string,
|
||||
claims jwtclaims.AuthorizationClaims,
|
||||
userAuth nbcontext.UserAuth,
|
||||
) error {
|
||||
primaryDomain := domainAccountID == "" || userAccountID == domainAccountID
|
||||
err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, claims, primaryDomain)
|
||||
err := am.updateAccountDomainAttributesIfNotUpToDate(ctx, userAccountID, userAuth, primaryDomain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// we should register the account ID to this user's metadata in our IDP manager
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, userAccountID)
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, userAccountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1025,20 +1019,20 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||
|
||||
// addNewPrivateAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
|
||||
// otherwise it will create a new account and make it primary account for the domain.
|
||||
func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
if claims.UserId == "" {
|
||||
func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
|
||||
if userAuth.UserId == "" {
|
||||
return "", fmt.Errorf("user ID is empty")
|
||||
}
|
||||
|
||||
lowerDomain := strings.ToLower(claims.Domain)
|
||||
lowerDomain := strings.ToLower(userAuth.Domain)
|
||||
|
||||
newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain)
|
||||
newAccount, err := am.newAccount(ctx, userAuth.UserId, lowerDomain)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
newAccount.Domain = lowerDomain
|
||||
newAccount.DomainCategory = claims.DomainCategory
|
||||
newAccount.DomainCategory = userAuth.DomainCategory
|
||||
newAccount.IsDomainPrimaryAccount = true
|
||||
|
||||
err = am.Store.SaveAccount(ctx, newAccount)
|
||||
@@ -1046,33 +1040,33 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id)
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, newAccount.Id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil)
|
||||
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, newAccount.Id, activity.UserJoined, nil)
|
||||
|
||||
return newAccount.Id, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, userAuth nbcontext.UserAuth) (string, error) {
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
||||
defer unlockAccount()
|
||||
|
||||
newUser := types.NewRegularUser(claims.UserId)
|
||||
newUser := types.NewRegularUser(userAuth.UserId)
|
||||
newUser.AccountID = domainAccountID
|
||||
err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, domainAccountID)
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, userAuth.UserId, domainAccountID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, claims.UserId, claims.UserId, domainAccountID, activity.UserJoined, nil)
|
||||
am.StoreEvent(ctx, userAuth.UserId, userAuth.UserId, domainAccountID, activity.UserJoined, nil)
|
||||
|
||||
return domainAccountID, nil
|
||||
}
|
||||
@@ -1112,76 +1106,11 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkPATUsed marks a personal access token as used
|
||||
func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error {
|
||||
return am.Store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID)
|
||||
}
|
||||
|
||||
// GetAccount returns an account associated with this account ID.
|
||||
func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*types.Account, error) {
|
||||
return am.Store.GetAccount(ctx, accountID)
|
||||
}
|
||||
|
||||
// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token.
|
||||
func (am *DefaultAccountManager) GetPATInfo(ctx context.Context, token string) (user *types.User, pat *types.PersonalAccessToken, domain string, category string, err error) {
|
||||
user, pat, err = am.extractPATFromToken(ctx, token)
|
||||
if err != nil {
|
||||
return nil, nil, "", "", err
|
||||
}
|
||||
|
||||
domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID)
|
||||
if err != nil {
|
||||
return nil, nil, "", "", err
|
||||
}
|
||||
|
||||
return user, pat, domain, category, nil
|
||||
}
|
||||
|
||||
// extractPATFromToken validates the token structure and retrieves associated User and PAT.
|
||||
func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*types.User, *types.PersonalAccessToken, error) {
|
||||
if len(token) != types.PATLength {
|
||||
return nil, nil, fmt.Errorf("token has incorrect length")
|
||||
}
|
||||
|
||||
prefix := token[:len(types.PATPrefix)]
|
||||
if prefix != types.PATPrefix {
|
||||
return nil, nil, fmt.Errorf("token has wrong prefix")
|
||||
}
|
||||
secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength]
|
||||
encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength]
|
||||
|
||||
verificationChecksum, err := base62.Decode(encodedChecksum)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
|
||||
}
|
||||
|
||||
secretChecksum := crc32.ChecksumIEEE([]byte(secret))
|
||||
if secretChecksum != verificationChecksum {
|
||||
return nil, nil, fmt.Errorf("token checksum does not match")
|
||||
}
|
||||
|
||||
hashedToken := sha256.Sum256([]byte(token))
|
||||
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
|
||||
|
||||
var user *types.User
|
||||
var pat *types.PersonalAccessToken
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return user, pat, nil
|
||||
}
|
||||
|
||||
// GetAccountByID returns an account associated with this account ID.
|
||||
func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
@@ -1196,58 +1125,56 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s
|
||||
return am.Store.GetAccount(ctx, accountID)
|
||||
}
|
||||
|
||||
// GetAccountIDFromToken returns an account ID associated with this token.
|
||||
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
if claims.UserId == "" {
|
||||
func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
if userAuth.UserId == "" {
|
||||
return "", "", errors.New(emptyUserID)
|
||||
}
|
||||
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
||||
// This section is mostly related to self-hosted installations.
|
||||
// We override incoming domain claims to group users under a single account.
|
||||
claims.Domain = am.singleAccountModeDomain
|
||||
claims.DomainCategory = types.PrivateCategory
|
||||
userAuth.Domain = am.singleAccountModeDomain
|
||||
userAuth.DomainCategory = types.PrivateCategory
|
||||
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
||||
}
|
||||
|
||||
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims)
|
||||
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId)
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
||||
if err != nil {
|
||||
// this is not really possible because we got an account by user ID
|
||||
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
||||
return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId)
|
||||
}
|
||||
|
||||
if userAuth.IsChild {
|
||||
return accountID, user.Id, nil
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", claims.UserId, accountID)
|
||||
return "", "", status.Errorf(status.PermissionDenied, "user %s is not part of the account %s", userAuth.UserId, accountID)
|
||||
}
|
||||
|
||||
if !user.IsServiceUser && claims.Invited {
|
||||
if !user.IsServiceUser && userAuth.Invited {
|
||||
err = am.redeemInvite(ctx, accountID, user.Id)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
if err = am.syncJWTGroups(ctx, accountID, claims); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return accountID, user.Id, nil
|
||||
}
|
||||
|
||||
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
||||
// and propagates changes to peers if group propagation is enabled.
|
||||
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error {
|
||||
if claim, exists := claims.Raw[jwtclaims.IsToken]; exists {
|
||||
if isToken, ok := claim.(bool); ok && isToken {
|
||||
return nil
|
||||
}
|
||||
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager
|
||||
func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
if userAuth.IsChild || userAuth.IsPAT {
|
||||
return nil
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1261,9 +1188,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
return nil
|
||||
}
|
||||
|
||||
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAuth.AccountId)
|
||||
defer func() {
|
||||
if unlockAccount != nil {
|
||||
unlockAccount()
|
||||
@@ -1275,17 +1200,17 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
var hasChanges bool
|
||||
var user *types.User
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId)
|
||||
user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting user: %w", err)
|
||||
}
|
||||
|
||||
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting account groups: %w", err)
|
||||
}
|
||||
|
||||
changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, jwtGroupsNames)
|
||||
changed, updatedAutoGroups, newGroupsToCreate, err := am.getJWTGroupsChanges(user, groups, userAuth.Groups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting JWT groups changes: %w", err)
|
||||
}
|
||||
@@ -1310,7 +1235,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
|
||||
// Propagate changes to peers if group propagation is enabled
|
||||
if settings.GroupsPropagationEnabled {
|
||||
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
|
||||
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting account groups: %w", err)
|
||||
}
|
||||
@@ -1320,7 +1245,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, claims.UserId)
|
||||
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting user peers: %w", err)
|
||||
}
|
||||
@@ -1334,7 +1259,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
return fmt.Errorf("error saving groups: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, userAuth.AccountId); err != nil {
|
||||
return fmt.Errorf("error incrementing network serial: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -1352,45 +1277,45 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
}
|
||||
|
||||
for _, g := range addNewGroups {
|
||||
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, g)
|
||||
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
|
||||
} else {
|
||||
meta := map[string]any{
|
||||
"group": group.Name, "group_id": group.ID,
|
||||
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
|
||||
}
|
||||
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupAddedToUser, meta)
|
||||
am.StoreEvent(ctx, user.Id, user.Id, userAuth.AccountId, activity.GroupAddedToUser, meta)
|
||||
}
|
||||
}
|
||||
|
||||
for _, g := range removeOldGroups {
|
||||
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, g)
|
||||
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
|
||||
} else {
|
||||
meta := map[string]any{
|
||||
"group": group.Name, "group_id": group.ID,
|
||||
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName,
|
||||
}
|
||||
am.StoreEvent(ctx, user.Id, user.Id, accountID, activity.GroupRemovedFromUser, meta)
|
||||
am.StoreEvent(ctx, user.Id, user.Id, userAuth.AccountId, activity.GroupRemovedFromUser, meta)
|
||||
}
|
||||
}
|
||||
|
||||
if settings.GroupsPropagationEnabled {
|
||||
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups)
|
||||
removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, removeOldGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups)
|
||||
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, userAuth.AccountId, addNewGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if removedGroupAffectsPeers || newGroupsAffectsPeers {
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
|
||||
am.UpdateAccountPeers(ctx, userAuth.AccountId)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1415,24 +1340,34 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
|
||||
//
|
||||
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
|
||||
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
//
|
||||
// UserAuth IsChild -> checks that account exists
|
||||
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
|
||||
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
||||
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
||||
userAuth.UserId, userAuth.AccountId, userAuth.Domain, userAuth.DomainCategory)
|
||||
|
||||
if claims.UserId == "" {
|
||||
if userAuth.UserId == "" {
|
||||
return "", errors.New(emptyUserID)
|
||||
}
|
||||
|
||||
if claims.DomainCategory != types.PrivateCategory || !isDomainValid(claims.Domain) {
|
||||
return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain)
|
||||
if userAuth.IsChild {
|
||||
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
||||
if err != nil || !exists {
|
||||
return "", err
|
||||
}
|
||||
return userAuth.AccountId, nil
|
||||
}
|
||||
|
||||
if claims.AccountId != "" {
|
||||
return am.handlePrivateAccountWithIDFromClaim(ctx, claims)
|
||||
if userAuth.DomainCategory != types.PrivateCategory || !isDomainValid(userAuth.Domain) {
|
||||
return am.GetAccountIDByUserID(ctx, userAuth.UserId, userAuth.Domain)
|
||||
}
|
||||
|
||||
if userAuth.AccountId != "" {
|
||||
return am.handlePrivateAccountWithIDFromClaim(ctx, userAuth)
|
||||
}
|
||||
|
||||
// We checked if the domain has a primary account already
|
||||
domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, claims.Domain)
|
||||
domainAccountID, cancel, err := am.getPrivateDomainWithGlobalLock(ctx, userAuth.Domain)
|
||||
if cancel != nil {
|
||||
defer cancel()
|
||||
}
|
||||
@@ -1440,14 +1375,14 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
|
||||
return "", err
|
||||
}
|
||||
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId)
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
||||
if handleNotFound(err) != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if userAccountID != "" {
|
||||
if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, claims); err != nil {
|
||||
if err = am.handleExistingUserAccount(ctx, userAccountID, domainAccountID, userAuth); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -1455,10 +1390,10 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
|
||||
}
|
||||
|
||||
if domainAccountID != "" {
|
||||
return am.addNewUserToDomainAccount(ctx, domainAccountID, claims)
|
||||
return am.addNewUserToDomainAccount(ctx, domainAccountID, userAuth)
|
||||
}
|
||||
|
||||
return am.addNewPrivateAccount(ctx, domainAccountID, claims)
|
||||
return am.addNewPrivateAccount(ctx, domainAccountID, userAuth)
|
||||
}
|
||||
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain)
|
||||
@@ -1486,40 +1421,40 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
|
||||
return domainAccountID, cancel, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId)
|
||||
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if userAccountID != claims.AccountId {
|
||||
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||
if userAccountID != userAuth.AccountId {
|
||||
return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId)
|
||||
}
|
||||
|
||||
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, claims.AccountId)
|
||||
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, userAuth.AccountId)
|
||||
if handleNotFound(err) != nil {
|
||||
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if domainIsUpToDate(accountDomain, domainCategory, claims) {
|
||||
return claims.AccountId, nil
|
||||
if domainIsUpToDate(accountDomain, domainCategory, userAuth) {
|
||||
return userAuth.AccountId, nil
|
||||
}
|
||||
|
||||
// We checked if the domain has a primary account already
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, claims.Domain)
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, userAuth.Domain)
|
||||
if handleNotFound(err) != nil {
|
||||
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = am.handleExistingUserAccount(ctx, claims.AccountId, domainAccountID, claims)
|
||||
err = am.handleExistingUserAccount(ctx, userAuth.AccountId, domainAccountID, userAuth)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return claims.AccountId, nil
|
||||
return userAuth.AccountId, nil
|
||||
}
|
||||
|
||||
func handleNotFound(err error) error {
|
||||
@@ -1534,8 +1469,8 @@ func handleNotFound(err error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool {
|
||||
return domainCategory == types.PrivateCategory || claims.DomainCategory != types.PrivateCategory || domain != claims.Domain
|
||||
func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.UserAuth) bool {
|
||||
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
@@ -1618,34 +1553,6 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
|
||||
return am.dnsDomain
|
||||
}
|
||||
|
||||
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
|
||||
// group propagation and set the list of groups with access permissions.
|
||||
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
||||
accountID, _, err := am.GetAccountIDFromToken(ctx, claims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensures JWT group synchronization to the management is enabled before,
|
||||
// filtering access based on the allowed groups.
|
||||
if settings != nil && settings.JWTGroupsEnabled {
|
||||
if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 {
|
||||
userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||
|
||||
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
|
||||
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) {
|
||||
log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID)
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
@@ -1804,39 +1711,6 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty
|
||||
return acc
|
||||
}
|
||||
|
||||
// extractJWTGroups extracts the group names from a JWT token's claims.
|
||||
func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string {
|
||||
userJWTGroups := make([]string, 0)
|
||||
|
||||
if claim, ok := claims.Raw[claimName]; ok {
|
||||
if claimGroups, ok := claim.([]interface{}); ok {
|
||||
for _, g := range claimGroups {
|
||||
if group, ok := g.(string); ok {
|
||||
userJWTGroups = append(userJWTGroups, group)
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("JWT claim %q is not a string array", claimName)
|
||||
}
|
||||
|
||||
return userJWTGroups
|
||||
}
|
||||
|
||||
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
|
||||
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
|
||||
for _, userGroup := range userGroups {
|
||||
for _, allowedGroup := range allowedGroups {
|
||||
if userGroup == allowedGroup {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// separateGroups separates user's auto groups into non-JWT and JWT groups.
|
||||
// Returns the list of standard auto groups and a map of JWT auto groups,
|
||||
// where the keys are the group names and the values are the group IDs.
|
||||
|
||||
Reference in New Issue
Block a user