From 4ad14cb46b3e3f0b3ebac54e568bff82e8d50c2d Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 11 Oct 2023 17:09:30 +0300 Subject: [PATCH] Add Pagination for IdP Users Fetch (#1210) * Retrieve all workspace users via pagination, excluding custom user attributes * Retrieve all authentik users via pagination * Retrieve all Azure AD users via pagination * Simplify user data appending operation Reduced unnecessary iteration and used an efficient way to append all users to 'indexedUsers' * Fix ineffectual assignment to reqURL * Retrieve all Okta users via pagination * Add missing GetAccount metrics * Refactor * minimize memory allocation Refactored the memory allocation for the 'users' slice in the Okta IDP code. Previously, the slice was only initialized but not given a size. Now the size of userList is utilized to optimize memory allocation, reducing potential slice resizing and memory re-allocation costs while appending users. * Add logging for entries received from IdP management Added informative and debug logging statements in account.go file. Logging has been added to identify the number of entries received from Identity Provider (IdP) management. This will aid in tracking and debugging any potential data ingestion issues. --- management/server/account.go | 2 + management/server/idp/authentik.go | 84 ++++++++++++----------- management/server/idp/azure.go | 84 ++++++++++++++--------- management/server/idp/google_workspace.go | 78 ++++++++++++++------- management/server/idp/okta.go | 65 +++++++++++------- 5 files changed, 191 insertions(+), 122 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 1f8fc497b..d35ad2566 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -946,6 +946,7 @@ func (am *DefaultAccountManager) warmupIDPCache() error { if err != nil { return err } + log.Infof("%d entries received from IdP management", len(userData)) // If the Identity Provider does not support writing AppMetadata, // in cases like this, we expect it to return all users in an "unset" field. @@ -1045,6 +1046,7 @@ func (am *DefaultAccountManager) loadAccount(_ context.Context, accountID interf if err != nil { return nil, err } + log.Debugf("%d entries received from IdP management", len(userData)) dataMap := make(map[string]*idp.UserData, len(userData)) for _, datum := range userData { diff --git a/management/server/idp/authentik.go b/management/server/idp/authentik.go index ca995b299..4bbf09404 100644 --- a/management/server/idp/authentik.go +++ b/management/server/idp/authentik.go @@ -251,34 +251,18 @@ func (am *AuthentikManager) GetUserDataByID(userID string, appMetadata AppMetada // GetAccount returns all the users for a given profile. func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) { - ctx, err := am.authenticationContext() + users, err := am.getAllUsers() if err != nil { return nil, err } - userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Execute() - if err != nil { - return nil, err - } - defer resp.Body.Close() - if am.appMetrics != nil { am.appMetrics.IDPMetrics().CountGetAccount() } - if resp.StatusCode != http.StatusOK { - if am.appMetrics != nil { - am.appMetrics.IDPMetrics().CountRequestStatusError() - } - return nil, fmt.Errorf("unable to get account %s users, statusCode %d", accountID, resp.StatusCode) - } - - users := make([]*UserData, 0) - for _, user := range userList.Results { - userData := parseAuthentikUser(user) - userData.AppMetadata.WTAccountID = accountID - - users = append(users, userData) + for index, user := range users { + user.AppMetadata.WTAccountID = accountID + users[index] = user } return users, nil @@ -287,37 +271,59 @@ func (am *AuthentikManager) GetAccount(accountID string) ([]*UserData, error) { // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. func (am *AuthentikManager) GetAllAccounts() (map[string][]*UserData, error) { - ctx, err := am.authenticationContext() + users, err := am.getAllUsers() if err != nil { return nil, err } - userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Execute() - if err != nil { - return nil, err - } - defer resp.Body.Close() + indexedUsers := make(map[string][]*UserData) + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...) if am.appMetrics != nil { am.appMetrics.IDPMetrics().CountGetAllAccounts() } - if resp.StatusCode != http.StatusOK { - if am.appMetrics != nil { - am.appMetrics.IDPMetrics().CountRequestStatusError() - } - return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode) - } - - indexedUsers := make(map[string][]*UserData) - for _, user := range userList.Results { - userData := parseAuthentikUser(user) - indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) - } - return indexedUsers, nil } +// getAllUsers returns all users in a Authentik account. +func (am *AuthentikManager) getAllUsers() ([]*UserData, error) { + users := make([]*UserData, 0) + + page := int32(1) + for { + ctx, err := am.authenticationContext() + if err != nil { + return nil, err + } + + userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Page(page).Execute() + if err != nil { + return nil, err + } + _ = resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if am.appMetrics != nil { + am.appMetrics.IDPMetrics().CountRequestStatusError() + } + return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode) + } + + for _, user := range userList.Results { + users = append(users, parseAuthentikUser(user)) + } + + page = int32(userList.GetPagination().Next) + if userList.GetPagination().Next == 0 { + break + } + + } + + return users, nil +} + // CreateUser creates a new user in authentik Idp and sends an invitation. func (am *AuthentikManager) CreateUser(_, _, _, _ string) (*UserData, error) { return nil, fmt.Errorf("method CreateUser not implemented") diff --git a/management/server/idp/azure.go b/management/server/idp/azure.go index e4224c26d..706e4d330 100644 --- a/management/server/idp/azure.go +++ b/management/server/idp/azure.go @@ -266,10 +266,7 @@ func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) { // GetAccount returns all the users for a given profile. func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) { - q := url.Values{} - q.Add("$select", profileFields) - - body, err := am.get("users", q) + users, err := am.getAllUsers() if err != nil { return nil, err } @@ -278,18 +275,9 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) { am.appMetrics.IDPMetrics().CountGetAccount() } - var profiles struct{ Value []azureProfile } - err = am.helper.Unmarshal(body, &profiles) - if err != nil { - return nil, err - } - - users := make([]*UserData, 0) - for _, profile := range profiles.Value { - userData := profile.userData() - userData.AppMetadata.WTAccountID = accountID - - users = append(users, userData) + for index, user := range users { + user.AppMetadata.WTAccountID = accountID + users[index] = user } return users, nil @@ -298,28 +286,16 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) { // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) { - q := url.Values{} - q.Add("$select", profileFields) - - body, err := am.get("users", q) - if err != nil { - return nil, err - } - - if am.appMetrics != nil { - am.appMetrics.IDPMetrics().CountGetAllAccounts() - } - - var profiles struct{ Value []azureProfile } - err = am.helper.Unmarshal(body, &profiles) + users, err := am.getAllUsers() if err != nil { return nil, err } indexedUsers := make(map[string][]*UserData) - for _, profile := range profiles.Value { - userData := profile.userData() - indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...) + + if am.appMetrics != nil { + am.appMetrics.IDPMetrics().CountGetAllAccounts() } return indexedUsers, nil @@ -373,6 +349,39 @@ func (am *AzureManager) DeleteUser(userID string) error { return nil } +// getAllUsers returns all users in an Azure AD account. +func (am *AzureManager) getAllUsers() ([]*UserData, error) { + users := make([]*UserData, 0) + + q := url.Values{} + q.Add("$select", profileFields) + q.Add("$top", "500") + + for nextLink := "users"; nextLink != ""; { + body, err := am.get(nextLink, q) + if err != nil { + return nil, err + } + + var profiles struct { + Value []azureProfile + NextLink string `json:"@odata.nextLink"` + } + err = am.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + for _, profile := range profiles.Value { + users = append(users, profile.userData()) + } + + nextLink = profiles.NextLink + } + + return users, nil +} + // get perform Get requests. func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) { jwtToken, err := am.credentials.Authenticate() @@ -380,7 +389,14 @@ func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) { return nil, err } - reqURL := fmt.Sprintf("%s/%s?%s", am.GraphAPIEndpoint, resource, q.Encode()) + var reqURL string + if strings.HasPrefix(resource, "https") { + // Already an absolute URL for paging + reqURL = resource + } else { + reqURL = fmt.Sprintf("%s/%s?%s", am.GraphAPIEndpoint, resource, q.Encode()) + } + req, err := http.NewRequest(http.MethodGet, reqURL, nil) if err != nil { return nil, err diff --git a/management/server/idp/google_workspace.go b/management/server/idp/google_workspace.go index ed2de9a42..896fb707b 100644 --- a/management/server/idp/google_workspace.go +++ b/management/server/idp/google_workspace.go @@ -96,7 +96,7 @@ func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ string, _ AppMetadata) // GetUserDataByID requests user data from Google Workspace via ID. func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { - user, err := gm.usersService.Get(userID).Projection("full").Do() + user, err := gm.usersService.Get(userID).Do() if err != nil { return nil, err } @@ -113,43 +113,69 @@ func (gm *GoogleWorkspaceManager) GetUserDataByID(userID string, appMetadata App // GetAccount returns all the users for a given profile. func (gm *GoogleWorkspaceManager) GetAccount(accountID string) ([]*UserData, error) { - usersList, err := gm.usersService.List().Customer(gm.CustomerID).Projection("full").Do() - if err != nil { - return nil, err - } - - usersData := make([]*UserData, 0) - for _, user := range usersList.Users { - userData := parseGoogleWorkspaceUser(user) - userData.AppMetadata.WTAccountID = accountID - - usersData = append(usersData, userData) - } - - return usersData, nil -} - -// GetAllAccounts gets all registered accounts with corresponding user data. -// It returns a list of users indexed by accountID. -func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, error) { - usersList, err := gm.usersService.List().Customer(gm.CustomerID).Projection("full").Do() + users, err := gm.getAllUsers() if err != nil { return nil, err } if gm.appMetrics != nil { - gm.appMetrics.IDPMetrics().CountGetAllAccounts() + gm.appMetrics.IDPMetrics().CountGetAccount() + } + + for index, user := range users { + user.AppMetadata.WTAccountID = accountID + users[index] = user + } + + return users, nil +} + +// GetAllAccounts gets all registered accounts with corresponding user data. +// It returns a list of users indexed by accountID. +func (gm *GoogleWorkspaceManager) GetAllAccounts() (map[string][]*UserData, error) { + users, err := gm.getAllUsers() + if err != nil { + return nil, err } indexedUsers := make(map[string][]*UserData) - for _, user := range usersList.Users { - userData := parseGoogleWorkspaceUser(user) - indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...) + + if gm.appMetrics != nil { + gm.appMetrics.IDPMetrics().CountGetAllAccounts() } return indexedUsers, nil } +// getAllUsers returns all users in a Google Workspace account filtered by customer ID. +func (gm *GoogleWorkspaceManager) getAllUsers() ([]*UserData, error) { + users := make([]*UserData, 0) + pageToken := "" + for { + call := gm.usersService.List().Customer(gm.CustomerID).MaxResults(500) + if pageToken != "" { + call.PageToken(pageToken) + } + + resp, err := call.Do() + if err != nil { + return nil, err + } + + for _, user := range resp.Users { + users = append(users, parseGoogleWorkspaceUser(user)) + } + + pageToken = resp.NextPageToken + if pageToken == "" { + break + } + } + + return users, nil +} + // CreateUser creates a new user in Google Workspace and sends an invitation. func (gm *GoogleWorkspaceManager) CreateUser(_, _, _, _ string) (*UserData, error) { return nil, fmt.Errorf("method CreateUser not implemented") @@ -158,7 +184,7 @@ func (gm *GoogleWorkspaceManager) CreateUser(_, _, _, _ string) (*UserData, erro // GetUserByEmail searches users with a given email. // If no users have been found, this function returns an empty list. func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, error) { - user, err := gm.usersService.Get(email).Projection("full").Do() + user, err := gm.usersService.Get(email).Do() if err != nil { return nil, err } diff --git a/management/server/idp/okta.go b/management/server/idp/okta.go index 3e7b9357e..67341a26f 100644 --- a/management/server/idp/okta.go +++ b/management/server/idp/okta.go @@ -9,6 +9,7 @@ import ( "time" "github.com/okta/okta-sdk-golang/v2/okta" + "github.com/okta/okta-sdk-golang/v2/okta/query" "github.com/netbirdio/netbird/management/server/telemetry" ) @@ -160,7 +161,7 @@ func (om *OktaManager) GetUserByEmail(email string) ([]*UserData, error) { // GetAccount returns all the users for a given profile. func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) { - users, resp, err := om.client.User.ListUsers(context.Background(), nil) + users, err := om.getAllUsers() if err != nil { return nil, err } @@ -169,39 +170,40 @@ func (om *OktaManager) GetAccount(accountID string) ([]*UserData, error) { om.appMetrics.IDPMetrics().CountGetAccount() } - if resp.StatusCode != http.StatusOK { - if om.appMetrics != nil { - om.appMetrics.IDPMetrics().CountRequestStatusError() - } - return nil, fmt.Errorf("unable to get account, statusCode %d", resp.StatusCode) + for index, user := range users { + user.AppMetadata.WTAccountID = accountID + users[index] = user } - list := make([]*UserData, 0) - for _, user := range users { - userData, err := parseOktaUser(user) - if err != nil { - return nil, err - } - userData.AppMetadata.WTAccountID = accountID - - list = append(list, userData) - } - - return list, nil + return users, nil } // GetAllAccounts gets all registered accounts with corresponding user data. // It returns a list of users indexed by accountID. func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) { - users, resp, err := om.client.User.ListUsers(context.Background(), nil) + users, err := om.getAllUsers() if err != nil { return nil, err } + indexedUsers := make(map[string][]*UserData) + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...) + if om.appMetrics != nil { om.appMetrics.IDPMetrics().CountGetAllAccounts() } + return indexedUsers, nil +} + +// getAllUsers returns all users in an Okta account. +func (om *OktaManager) getAllUsers() ([]*UserData, error) { + qp := query.NewQueryParams(query.WithLimit(200)) + userList, resp, err := om.client.User.ListUsers(context.Background(), qp) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { if om.appMetrics != nil { om.appMetrics.IDPMetrics().CountRequestStatusError() @@ -209,17 +211,34 @@ func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) { return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode) } - indexedUsers := make(map[string][]*UserData) - for _, user := range users { + for resp.HasNextPage() { + paginatedUsers := make([]*okta.User, 0) + resp, err = resp.Next(context.Background(), &paginatedUsers) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + if om.appMetrics != nil { + om.appMetrics.IDPMetrics().CountRequestStatusError() + } + return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode) + } + + userList = append(userList, paginatedUsers...) + } + + users := make([]*UserData, 0, len(userList)) + for _, user := range userList { userData, err := parseOktaUser(user) if err != nil { return nil, err } - indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) + users = append(users, userData) } - return indexedUsers, nil + return users, nil } // UpdateUserAppMetadata updates user app metadata based on userID and metadata map.