mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[management] Remove redundant get account calls in GetAccountFromToken (#2615)
* refactor access control middleware and user access by JWT groups Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor jwt groups extractor Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor handlers to get account when necessary Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor getAccountFromToken Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor getAccountWithAuthorizationClaims Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix merge Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * revert handles change Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * remove GetUserByID from account manager Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor getAccountWithAuthorizationClaims to return account id Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor handlers to use GetAccountIDFromToken Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * remove locks Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * refactor Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add GetGroupByName from store Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add GetGroupByID from store and refactor Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor retrieval of policy and posture checks Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor user permissions and retrieves PAT Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor route, setupkey, nameserver and dns to get record(s) from store Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * Refactor store Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix lint Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix add missing policy source posture checks Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add store lock Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * fix tests Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> * add get account Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com> --------- Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
@@ -20,11 +20,6 @@ import (
|
||||
cacheStore "github.com/eko/gocache/v3/store"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/base62"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
@@ -41,6 +36,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -63,6 +62,7 @@ func cacheEntryExpiration() time.Duration {
|
||||
|
||||
type AccountManager interface {
|
||||
GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error)
|
||||
GetAccount(ctx context.Context, accountID string) (*Account, error)
|
||||
CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
|
||||
autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error)
|
||||
SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error)
|
||||
@@ -75,12 +75,14 @@ type AccountManager interface {
|
||||
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
|
||||
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
|
||||
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
|
||||
GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error)
|
||||
GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
|
||||
GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error)
|
||||
GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error)
|
||||
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
|
||||
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
|
||||
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
|
||||
DeleteAccount(ctx context.Context, accountID, userID string) error
|
||||
MarkPATUsed(ctx context.Context, tokenID string) error
|
||||
GetUserByID(ctx context.Context, id string) (*User, error)
|
||||
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
|
||||
ListUsers(ctx context.Context, accountID string) ([]*User, error)
|
||||
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
@@ -107,7 +109,7 @@ type AccountManager interface {
|
||||
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
|
||||
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error
|
||||
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
|
||||
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
|
||||
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
|
||||
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
|
||||
@@ -145,6 +147,7 @@ type AccountManager interface {
|
||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error)
|
||||
}
|
||||
|
||||
type DefaultAccountManager struct {
|
||||
@@ -268,6 +271,11 @@ type AccountNetwork struct {
|
||||
Network *Network `gorm:"embedded;embeddedPrefix:network_"`
|
||||
}
|
||||
|
||||
// AccountDNSSettings used in gorm to only load dns settings and not whole account
|
||||
type AccountDNSSettings struct {
|
||||
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
|
||||
}
|
||||
|
||||
type UserPermissions struct {
|
||||
DashboardView string `json:"dashboard_view"`
|
||||
}
|
||||
@@ -1252,25 +1260,37 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and
|
||||
// userID doesn't have an account associated with it, one account is created
|
||||
// domain is used to create a new account if no account is found
|
||||
func (am *DefaultAccountManager) GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) {
|
||||
// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided.
|
||||
// If an accountID is provided, it checks if the account exists and returns it.
|
||||
// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID.
|
||||
// If the user doesn't have an account, it creates one using the provided domain.
|
||||
// Returns the account ID or an error if none is found or created.
|
||||
func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
|
||||
if accountID != "" {
|
||||
return am.Store.GetAccount(ctx, accountID)
|
||||
} else if userID != "" {
|
||||
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
|
||||
exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID)
|
||||
return "", err
|
||||
}
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, userID, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if !exists {
|
||||
return "", status.Errorf(status.NotFound, "account %s does not exist", accountID)
|
||||
}
|
||||
return account, nil
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.NotFound, "no valid user or account Id provided")
|
||||
if userID != "" {
|
||||
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
|
||||
if err != nil {
|
||||
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
||||
}
|
||||
|
||||
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return account.Id, nil
|
||||
}
|
||||
|
||||
return "", status.Errorf(status.NotFound, "no valid userID or accountID provided")
|
||||
}
|
||||
|
||||
func isNil(i idp.Manager) bool {
|
||||
@@ -1613,13 +1633,18 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai
|
||||
}
|
||||
|
||||
// redeemInvite checks whether user has been invited and redeems the invite
|
||||
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error {
|
||||
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID string, userID string) error {
|
||||
// only possible with the enabled IdP manager
|
||||
if am.idpManager == nil {
|
||||
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
|
||||
return nil
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := am.lookupUserInCache(ctx, userID, account)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1678,6 +1703,11 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string
|
||||
return am.Store.SaveAccount(ctx, account)
|
||||
}
|
||||
|
||||
// GetAccount returns an account associated with this account ID.
|
||||
func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*Account, error) {
|
||||
return am.Store.GetAccount(ctx, accountID)
|
||||
}
|
||||
|
||||
// GetAccountFromPAT returns Account and User associated with a personal access token
|
||||
func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) {
|
||||
if len(token) != PATLength {
|
||||
@@ -1726,10 +1756,24 @@ func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token st
|
||||
return account, user, pat, nil
|
||||
}
|
||||
|
||||
// GetAccountFromToken returns an account associated with this token
|
||||
func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {
|
||||
// GetAccountByID returns an account associated with this account ID.
|
||||
func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
|
||||
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
|
||||
}
|
||||
|
||||
return am.Store.GetAccount(ctx, accountID)
|
||||
}
|
||||
|
||||
// GetAccountIDFromToken returns an account ID associated with this token.
|
||||
func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
if claims.UserId == "" {
|
||||
return nil, nil, fmt.Errorf("user ID is empty")
|
||||
return "", "", fmt.Errorf("user ID is empty")
|
||||
}
|
||||
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
||||
// This section is mostly related to self-hosted installations.
|
||||
@@ -1739,110 +1783,111 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
|
||||
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
||||
}
|
||||
|
||||
newAcc, err := am.getAccountWithAuthorizationClaims(ctx, claims)
|
||||
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id)
|
||||
alreadyUnlocked := false
|
||||
defer func() {
|
||||
if !alreadyUnlocked {
|
||||
unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, newAcc.Id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
user := account.Users[claims.UserId]
|
||||
if user == nil {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
|
||||
if err != nil {
|
||||
// this is not really possible because we got an account by user ID
|
||||
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
||||
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
|
||||
}
|
||||
|
||||
if !user.IsServiceUser && claims.Invited {
|
||||
err = am.redeemInvite(ctx, account, claims.UserId)
|
||||
err = am.redeemInvite(ctx, accountID, user.Id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
if account.Settings.JWTGroupsEnabled {
|
||||
if account.Settings.JWTGroupsClaimName == "" {
|
||||
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
|
||||
return account, user, nil
|
||||
}
|
||||
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
|
||||
if slice, ok := claim.([]interface{}); ok {
|
||||
var groupsNames []string
|
||||
for _, item := range slice {
|
||||
if g, ok := item.(string); ok {
|
||||
groupsNames = append(groupsNames, g)
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item)
|
||||
}
|
||||
}
|
||||
|
||||
oldGroups := make([]string, len(user.AutoGroups))
|
||||
copy(oldGroups, user.AutoGroups)
|
||||
// if groups were added or modified, save the account
|
||||
if account.SetJWTGroups(claims.UserId, groupsNames) {
|
||||
if account.Settings.GroupsPropagationEnabled {
|
||||
if user, err := account.FindUser(claims.UserId); err == nil {
|
||||
addNewGroups := difference(user.AutoGroups, oldGroups)
|
||||
removeOldGroups := difference(oldGroups, user.AutoGroups)
|
||||
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
|
||||
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
|
||||
account.Network.IncSerial()
|
||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save account: %v", err)
|
||||
} else {
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
unlock()
|
||||
alreadyUnlocked = true
|
||||
for _, g := range addNewGroups {
|
||||
if group := account.GetGroup(g); group != nil {
|
||||
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
|
||||
map[string]any{
|
||||
"group": group.Name,
|
||||
"group_id": group.ID,
|
||||
"is_service_user": user.IsServiceUser,
|
||||
"user_name": user.ServiceUserName})
|
||||
}
|
||||
}
|
||||
for _, g := range removeOldGroups {
|
||||
if group := account.GetGroup(g); group != nil {
|
||||
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
|
||||
map[string]any{
|
||||
"group": group.Name,
|
||||
"group_id": group.ID,
|
||||
"is_service_user": user.IsServiceUser,
|
||||
"user_name": user.ServiceUserName})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save account: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName)
|
||||
}
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName)
|
||||
}
|
||||
if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return account, user, nil
|
||||
return accountID, user.Id, nil
|
||||
}
|
||||
|
||||
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
|
||||
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
||||
// and propagates changes to peers if group propagation is enabled.
|
||||
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error {
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if settings == nil || !settings.JWTGroupsEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if settings.JWTGroupsClaimName == "" {
|
||||
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: Remove GetAccount after refactoring account peer's update
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||
|
||||
oldGroups := make([]string, len(user.AutoGroups))
|
||||
copy(oldGroups, user.AutoGroups)
|
||||
|
||||
// Update the account if group membership changes
|
||||
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) {
|
||||
addNewGroups := difference(user.AutoGroups, oldGroups)
|
||||
removeOldGroups := difference(oldGroups, user.AutoGroups)
|
||||
|
||||
if settings.GroupsPropagationEnabled {
|
||||
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
|
||||
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
|
||||
account.Network.IncSerial()
|
||||
}
|
||||
|
||||
if err := am.Store.SaveAccount(ctx, account); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save account: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Propagate changes to peers if group propagation is enabled
|
||||
if settings.GroupsPropagationEnabled {
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
for _, g := range addNewGroups {
|
||||
if group := account.GetGroup(g); group != nil {
|
||||
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
|
||||
map[string]any{
|
||||
"group": group.Name,
|
||||
"group_id": group.ID,
|
||||
"is_service_user": user.IsServiceUser,
|
||||
"user_name": user.ServiceUserName})
|
||||
}
|
||||
}
|
||||
|
||||
for _, g := range removeOldGroups {
|
||||
if group := account.GetGroup(g); group != nil {
|
||||
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
|
||||
map[string]any{
|
||||
"group": group.Name,
|
||||
"group_id": group.ID,
|
||||
"is_service_user": user.IsServiceUser,
|
||||
"user_name": user.ServiceUserName})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
|
||||
// if domain is of the PrivateCategory category, it will evaluate
|
||||
// if account is new, existing or if there is another account with the same domain
|
||||
//
|
||||
@@ -1859,26 +1904,34 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
|
||||
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
|
||||
//
|
||||
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
|
||||
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) {
|
||||
func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
|
||||
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
|
||||
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
|
||||
if claims.UserId == "" {
|
||||
return nil, fmt.Errorf("user ID is empty")
|
||||
return "", fmt.Errorf("user ID is empty")
|
||||
}
|
||||
|
||||
// if Account ID is part of the claims
|
||||
// it means that we've already classified the domain and user has an account
|
||||
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
|
||||
return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
|
||||
return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
|
||||
} else if claims.AccountId != "" {
|
||||
accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId)
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
if _, ok := accountFromID.Users[claims.UserId]; !ok {
|
||||
return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||
|
||||
if userAccountID != claims.AccountId {
|
||||
return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
|
||||
}
|
||||
if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain {
|
||||
return accountFromID, nil
|
||||
|
||||
domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain {
|
||||
return userAccountID, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1888,48 +1941,53 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
|
||||
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
|
||||
|
||||
// We checked if the domain has a primary account already
|
||||
domainAccount, err := am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
||||
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
|
||||
if err != nil {
|
||||
// if NotFound we are good to continue, otherwise return error
|
||||
e, ok := status.FromError(err)
|
||||
if !ok || e.Type() != status.NotFound {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
|
||||
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
|
||||
if err == nil {
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID)
|
||||
defer unlockAccount()
|
||||
account, err = am.Store.GetAccountByUser(ctx, claims.UserId)
|
||||
account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
|
||||
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
|
||||
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
|
||||
// was previously unclassified or classified as public so N users that logged int that time, has they own account
|
||||
// and peers that shouldn't be lost.
|
||||
primaryDomain := domainAccount == nil || account.Id == domainAccount.Id
|
||||
|
||||
err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
primaryDomain := domainAccountID == "" || account.Id == domainAccountID
|
||||
if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return account, nil
|
||||
|
||||
return account.Id, nil
|
||||
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
if domainAccount != nil {
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id)
|
||||
var domainAccount *Account
|
||||
if domainAccountID != "" {
|
||||
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
|
||||
defer unlockAccount()
|
||||
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return am.handleNewUserAccount(ctx, domainAccount, claims)
|
||||
|
||||
account, err := am.handleNewUserAccount(ctx, domainAccount, claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return account.Id, nil
|
||||
} else {
|
||||
// other error
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2022,26 +2080,21 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
|
||||
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
|
||||
// group propagation and set the list of groups with access permissions.
|
||||
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
|
||||
account, _, err := am.GetAccountFromToken(ctx, claims)
|
||||
accountID, _, err := am.GetAccountIDFromToken(ctx, claims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensures JWT group synchronization to the management is enabled before,
|
||||
// filtering access based on the allowed groups.
|
||||
if account.Settings != nil && account.Settings.JWTGroupsEnabled {
|
||||
if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 {
|
||||
userJWTGroups := make([]string, 0)
|
||||
|
||||
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
|
||||
if claimGroups, ok := claim.([]interface{}); ok {
|
||||
for _, g := range claimGroups {
|
||||
if group, ok := g.(string); ok {
|
||||
userJWTGroups = append(userJWTGroups, group)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if settings != nil && settings.JWTGroupsEnabled {
|
||||
if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 {
|
||||
userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
|
||||
|
||||
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
|
||||
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
|
||||
@@ -2111,6 +2164,19 @@ func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Stor
|
||||
return newLabel, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
|
||||
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
|
||||
}
|
||||
|
||||
return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
// addAllGroup to account object if it doesn't exist
|
||||
func addAllGroup(account *Account) error {
|
||||
if len(account.Groups) == 0 {
|
||||
@@ -2193,6 +2259,27 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
|
||||
return acc
|
||||
}
|
||||
|
||||
// extractJWTGroups extracts the group names from a JWT token's claims.
|
||||
func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string {
|
||||
userJWTGroups := make([]string, 0)
|
||||
|
||||
if claim, ok := claims.Raw[claimName]; ok {
|
||||
if claimGroups, ok := claim.([]interface{}); ok {
|
||||
for _, g := range claimGroups {
|
||||
if group, ok := g.(string); ok {
|
||||
userJWTGroups = append(userJWTGroups, group)
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("JWT claim %q is not a string array", claimName)
|
||||
}
|
||||
|
||||
return userJWTGroups
|
||||
}
|
||||
|
||||
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
|
||||
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
|
||||
for _, userGroup := range userGroups {
|
||||
|
||||
Reference in New Issue
Block a user