mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 00:36:38 +00:00
refactor(idp): make NetBird single source of truth for authorization
Remove duplicate authorization data from Zitadel IdP. NetBird now stores all authorization data (account membership, invite status, roles) locally, while Zitadel only stores identity information (email, name, credentials). Changes: - Add PendingInvite field to User struct to track invite status locally - Simplify IdP Manager interface: remove metadata methods, add GetAllUsers - Update cache warming to match IdP users against NetBird DB - Remove addAccountIDToIDPAppMeta and all wt_* metadata writes - Delete legacy IdP managers (Auth0, Azure, Keycloak, Okta, Google Workspace, JumpCloud, Authentik, PocketId) - only Zitadel supported
This commit is contained in:
@@ -1,959 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Auth0Manager auth0 manager client instance
|
||||
type Auth0Manager struct {
|
||||
authIssuer string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// Auth0ClientConfig auth0 manager client configurations
|
||||
type Auth0ClientConfig struct {
|
||||
Audience string
|
||||
AuthIssuer string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
GrantType string
|
||||
}
|
||||
|
||||
// auth0JWTRequest payload struct to request a JWT Token
|
||||
type auth0JWTRequest struct {
|
||||
Audience string `json:"audience"`
|
||||
AuthIssuer string `json:"auth_issuer"`
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
GrantType string `json:"grant_type"`
|
||||
}
|
||||
|
||||
// Auth0Credentials auth0 authentication information
|
||||
type Auth0Credentials struct {
|
||||
clientConfig Auth0ClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
jwtToken JWTToken
|
||||
mux sync.Mutex
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// createUserRequest is a user create request
|
||||
type createUserRequest struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
AppMeta AppMetadata `json:"app_metadata"`
|
||||
Connection string `json:"connection"`
|
||||
Password string `json:"password"`
|
||||
VerifyEmail bool `json:"verify_email"`
|
||||
}
|
||||
|
||||
// userExportJobRequest is a user export request struct
|
||||
type userExportJobRequest struct {
|
||||
Format string `json:"format"`
|
||||
Fields []map[string]string `json:"fields"`
|
||||
}
|
||||
|
||||
// userExportJobResponse is a user export response struct
|
||||
type userExportJobResponse struct {
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
ConnectionID string `json:"connection_id"`
|
||||
Format string `json:"format"`
|
||||
Limit int `json:"limit"`
|
||||
Connection string `json:"connection"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// userExportJobStatusResponse is a user export status response struct
|
||||
type userExportJobStatusResponse struct {
|
||||
Type string `json:"type"`
|
||||
Status string `json:"status"`
|
||||
ConnectionID string `json:"connection_id"`
|
||||
Format string `json:"format"`
|
||||
Limit int `json:"limit"`
|
||||
Location string `json:"location"`
|
||||
Connection string `json:"connection"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// userVerificationJobRequest is a user verification request struct
|
||||
type userVerificationJobRequest struct {
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
// auth0Profile represents an Auth0 user profile response
|
||||
type auth0Profile struct {
|
||||
AccountID string `json:"wt_account_id"`
|
||||
PendingInvite bool `json:"wt_pending_invite"`
|
||||
UserID string `json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
LastLogin string `json:"last_login"`
|
||||
}
|
||||
|
||||
// Connections represents a single Auth0 connection
|
||||
// https://auth0.com/docs/api/management/v2/connections/get-connections
|
||||
type Connection struct {
|
||||
Id string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"display_name"`
|
||||
IsDomainConnection bool `json:"is_domain_connection"`
|
||||
Realms []string `json:"realms"`
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
Options ConnectionOptions `json:"options"`
|
||||
}
|
||||
|
||||
type ConnectionOptions struct {
|
||||
DomainAliases []string `json:"domain_aliases"`
|
||||
}
|
||||
|
||||
// NewAuth0Manager creates a new instance of the Auth0Manager
|
||||
func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics) (*Auth0Manager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.AuthIssuer == "" {
|
||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, AuthIssuer is missing")
|
||||
}
|
||||
|
||||
if config.ClientID == "" {
|
||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, ClientID is missing")
|
||||
}
|
||||
|
||||
if config.ClientSecret == "" {
|
||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, ClientSecret is missing")
|
||||
}
|
||||
|
||||
if config.Audience == "" {
|
||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, Audience is missing")
|
||||
}
|
||||
|
||||
if config.GrantType == "" {
|
||||
return nil, fmt.Errorf("auth0 IdP configuration is incomplete, GrantType is missing")
|
||||
}
|
||||
|
||||
credentials := &Auth0Credentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &Auth0Manager{
|
||||
authIssuer: config.AuthIssuer,
|
||||
credentials: credentials,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from Auth0
|
||||
func (c *Auth0Credentials) jwtStillValid() bool {
|
||||
return !c.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(c.jwtToken.expiresInTime)
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token
|
||||
func (c *Auth0Credentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
var res *http.Response
|
||||
reqURL := c.clientConfig.AuthIssuer + "/oauth/token"
|
||||
|
||||
p, err := c.helper.Marshal(auth0JWTRequest(c.clientConfig))
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
payload := strings.NewReader(string(p))
|
||||
|
||||
req, err := http.NewRequest("POST", reqURL, payload)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for idp manager")
|
||||
|
||||
res, err = c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if c.appMetrics != nil {
|
||||
c.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return res, err
|
||||
}
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
return res, fmt.Errorf("unable to get token, statusCode %d", res.StatusCode)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
|
||||
func (c *Auth0Credentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
||||
jwtToken := JWTToken{}
|
||||
body, err := io.ReadAll(rawBody)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
err = c.helper.Unmarshal(body, &jwtToken)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||
}
|
||||
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
// Exp maps into exp from jwt token
|
||||
var IssuedAt struct{ Exp int64 }
|
||||
err = json.Unmarshal(data, &IssuedAt)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
||||
|
||||
return jwtToken, nil
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the Auth0 Management API
|
||||
func (c *Auth0Credentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
if c.appMetrics != nil {
|
||||
c.appMetrics.IDPMetrics().CountAuthenticate()
|
||||
}
|
||||
|
||||
// If jwtToken has an expires time and we have enough time to do a request return immediately
|
||||
if c.jwtStillValid() {
|
||||
return c.jwtToken, nil
|
||||
}
|
||||
|
||||
res, err := c.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return c.jwtToken, err
|
||||
}
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while closing get jwt token response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
jwtToken, err := c.parseRequestJWTResponse(res.Body)
|
||||
if err != nil {
|
||||
return c.jwtToken, err
|
||||
}
|
||||
|
||||
c.jwtToken = jwtToken
|
||||
|
||||
return c.jwtToken, nil
|
||||
}
|
||||
|
||||
func batchRequestUsersURL(authIssuer, accountID string, page int, perPage int) (string, url.Values, error) {
|
||||
u, err := url.Parse(authIssuer + "/api/v2/users")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("page", strconv.Itoa(page))
|
||||
q.Set("search_engine", "v3")
|
||||
q.Set("per_page", strconv.Itoa(perPage))
|
||||
q.Set("q", "app_metadata.wt_account_id:"+accountID)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String(), q, nil
|
||||
}
|
||||
|
||||
func requestByUserIDURL(authIssuer, userID string) string {
|
||||
return authIssuer + "/api/v2/users/" + userID
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile. Calls Auth0 API.
|
||||
func (am *Auth0Manager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var list []*UserData
|
||||
|
||||
// https://auth0.com/docs/manage-users/user-search/retrieve-users-with-get-users-endpoint#limitations
|
||||
// auth0 limitation of 1000 users via this endpoint
|
||||
resultsPerPage := 50
|
||||
for page := 0; page < 20; page++ {
|
||||
reqURL, query, err := batchRequestUsersURL(am.authIssuer, accountID, page, resultsPerPage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, reqURL, strings.NewReader(query.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
res, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("failed requesting user data from IdP %s", string(body))
|
||||
}
|
||||
|
||||
var batch []UserData
|
||||
err = json.Unmarshal(body, &batch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("returned user batch for accountID %s on page %d, batch length %d", accountID, page, len(batch))
|
||||
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for user := range batch {
|
||||
list = append(list, &batch[user])
|
||||
}
|
||||
|
||||
if len(batch) == 0 || len(batch) < resultsPerPage {
|
||||
log.WithContext(ctx).Debugf("finished loading users for accountID %s", accountID)
|
||||
return list, nil
|
||||
}
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from auth0 via ID
|
||||
func (am *Auth0Manager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqURL := requestByUserIDURL(am.authIssuer, userID)
|
||||
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
res, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var userData UserData
|
||||
err = json.Unmarshal(body, &userData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("unable to get UserData, statusCode %d", res.StatusCode)
|
||||
}
|
||||
|
||||
return &userData, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userId and metadata map
|
||||
func (am *Auth0Manager) UpdateUserAppMetadata(ctx context.Context, userID string, appMetadata AppMetadata) error {
|
||||
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reqURL := am.authIssuer + "/api/v2/users/" + userID
|
||||
|
||||
data, err := am.helper.Marshal(map[string]any{"app_metadata": appMetadata})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
payload := strings.NewReader(string(data))
|
||||
|
||||
req, err := http.NewRequest("PATCH", reqURL, payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
log.WithContext(ctx).Debugf("updating IdP metadata for user %s", userID)
|
||||
|
||||
res, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
return fmt.Errorf("unable to update the appMetadata, statusCode %d", res.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildCreateUserRequestPayload(email, name, accountID, invitedByEmail string) (string, error) {
|
||||
invite := true
|
||||
req := &createUserRequest{
|
||||
Email: email,
|
||||
Name: name,
|
||||
AppMeta: AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
WTPendingInvite: &invite,
|
||||
WTInvitedBy: invitedByEmail,
|
||||
},
|
||||
Connection: "Username-Password-Authentication",
|
||||
Password: GeneratePassword(8, 1, 1, 1),
|
||||
VerifyEmail: true,
|
||||
}
|
||||
|
||||
str, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(str), nil
|
||||
}
|
||||
|
||||
func buildUserExportRequest() (string, error) {
|
||||
req := &userExportJobRequest{}
|
||||
fields := make([]map[string]string, 0)
|
||||
|
||||
for _, field := range []string{"created_at", "last_login", "user_id", "email", "name"} {
|
||||
fields = append(fields, map[string]string{"name": field})
|
||||
}
|
||||
|
||||
fields = append(fields, map[string]string{
|
||||
"name": "app_metadata.wt_account_id",
|
||||
"export_as": "wt_account_id",
|
||||
})
|
||||
|
||||
fields = append(fields, map[string]string{
|
||||
"name": "app_metadata.wt_pending_invite",
|
||||
"export_as": "wt_pending_invite",
|
||||
})
|
||||
|
||||
req.Format = "json"
|
||||
req.Fields = fields
|
||||
|
||||
str, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(str), nil
|
||||
}
|
||||
|
||||
func (am *Auth0Manager) createRequest(
|
||||
ctx context.Context, method string, endpoint string, body io.Reader,
|
||||
) (*http.Request, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqURL := am.authIssuer + endpoint
|
||||
|
||||
req, err := http.NewRequest(method, reqURL, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (am *Auth0Manager) createPostRequest(ctx context.Context, endpoint string, payloadStr string) (*http.Request, error) {
|
||||
req, err := am.createRequest(ctx, "POST", endpoint, strings.NewReader(payloadStr))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (am *Auth0Manager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
payloadString, err := buildUserExportRequest()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
exportJobReq, err := am.createPostRequest(ctx, "/api/v2/jobs/users-exports", payloadString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
jobResp, err := am.httpClient.Do(exportJobReq)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = jobResp.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while closing update user app metadata response body: %v", err)
|
||||
}
|
||||
}()
|
||||
if jobResp.StatusCode != 200 {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to update the appMetadata, statusCode %d", jobResp.StatusCode)
|
||||
}
|
||||
|
||||
var exportJobResp userExportJobResponse
|
||||
|
||||
body, err := io.ReadAll(jobResp.Body)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.helper.Unmarshal(body, &exportJobResp)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if exportJobResp.ID == "" {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("couldn't get an batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("batch id status %d, %s, response body: %v", jobResp.StatusCode, jobResp.Status, exportJobResp)
|
||||
|
||||
done, downloadLink, err := am.checkExportJobStatus(ctx, exportJobResp.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Failed at getting status checks from exportJob; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if done {
|
||||
return am.downloadProfileExport(ctx, downloadLink)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed extracting user profiles from auth0")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email. If no users have been found, this function returns an empty list.
|
||||
// This function can return multiple users. This is due to the Auth0 internals - there could be multiple users with
|
||||
// the same email but different connections that are considered as separate accounts (e.g., Google and username/password).
|
||||
func (am *Auth0Manager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + url.QueryEscape(email)
|
||||
body, err := doGetReq(ctx, am.httpClient, reqURL, jwtToken.AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
userResp := []*UserData{}
|
||||
|
||||
err = am.helper.Unmarshal(body, &userResp)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userResp, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in Auth0 Idp and sends an invite
|
||||
func (am *Auth0Manager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
|
||||
payloadString, err := buildCreateUserRequestPayload(email, name, accountID, invitedByEmail)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := am.createPostRequest(ctx, "/api/v2/users", payloadString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while closing create user response body: %v", err)
|
||||
}
|
||||
}()
|
||||
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var createResp UserData
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't read export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.helper.Unmarshal(body, &createResp)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't unmarshal export job response; %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if createResp.ID == "" {
|
||||
return nil, fmt.Errorf("couldn't create user: response %v", resp)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("created user %s in account %s", createResp.ID, accountID)
|
||||
|
||||
return &createResp, nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (am *Auth0Manager) InviteUserByID(ctx context.Context, userID string) error {
|
||||
userVerificationReq := userVerificationJobRequest{
|
||||
UserID: userID,
|
||||
}
|
||||
|
||||
payload, err := am.helper.Marshal(userVerificationReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := am.createPostRequest(ctx, "/api/v2/jobs/verification-email", string(payload))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't get job response %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while closing invite user response body: %v", err)
|
||||
}
|
||||
}()
|
||||
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to invite user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUser from Auth0
|
||||
func (am *Auth0Manager) DeleteUser(ctx context.Context, userID string) error {
|
||||
req, err := am.createRequest(ctx, http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("execute delete request: %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("close delete request body: %v", err)
|
||||
}
|
||||
}()
|
||||
if resp.StatusCode != 204 {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllConnections returns detailed list of all connections filtered by given params.
|
||||
// Note this method is not part of the IDP Manager interface as this is Auth0 specific.
|
||||
func (am *Auth0Manager) GetAllConnections(ctx context.Context, strategy []string) ([]Connection, error) {
|
||||
var connections []Connection
|
||||
|
||||
q := make(url.Values)
|
||||
q.Set("strategy", strings.Join(strategy, ","))
|
||||
|
||||
req, err := am.createRequest(ctx, http.MethodGet, "/api/v2/connections?"+q.Encode(), nil)
|
||||
if err != nil {
|
||||
return connections, err
|
||||
}
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("execute get connections request: %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return connections, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("close get connections request body: %v", err)
|
||||
}
|
||||
}()
|
||||
if resp.StatusCode != 200 {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return connections, fmt.Errorf("unable to get connections, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't read get connections response; %v", err)
|
||||
return connections, err
|
||||
}
|
||||
|
||||
err = am.helper.Unmarshal(body, &connections)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("Couldn't unmarshal get connection response; %v", err)
|
||||
return connections, err
|
||||
}
|
||||
|
||||
return connections, err
|
||||
}
|
||||
|
||||
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
|
||||
// If the status is "completed", then return the downloadLink
|
||||
func (am *Auth0Manager) checkExportJobStatus(ctx context.Context, jobID string) (bool, string, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 90*time.Second)
|
||||
defer cancel()
|
||||
retry := time.NewTicker(10 * time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.WithContext(ctx).Debugf("Export job status stopped...\n")
|
||||
return false, "", ctx.Err()
|
||||
case <-retry.C:
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
statusURL := am.authIssuer + "/api/v2/jobs/" + jobID
|
||||
body, err := doGetReq(ctx, am.httpClient, statusURL, jwtToken.AccessToken)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
var status userExportJobStatusResponse
|
||||
err = am.helper.Unmarshal(body, &status)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("current export job status is %v", status.Status)
|
||||
|
||||
if status.Status != "completed" {
|
||||
continue
|
||||
}
|
||||
|
||||
return true, status.Location, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// downloadProfileExport downloads user profiles from auth0 batch job
|
||||
func (am *Auth0Manager) downloadProfileExport(ctx context.Context, location string) (map[string][]*UserData, error) {
|
||||
body, err := doGetReq(ctx, am.httpClient, location, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bodyReader := bytes.NewReader(body)
|
||||
|
||||
gzipReader, err := gzip.NewReader(bodyReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(gzipReader)
|
||||
|
||||
res := make(map[string][]*UserData)
|
||||
for decoder.More() {
|
||||
profile := auth0Profile{}
|
||||
err = decoder.Decode(&profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if profile.AccountID != "" {
|
||||
if _, ok := res[profile.AccountID]; !ok {
|
||||
res[profile.AccountID] = []*UserData{}
|
||||
}
|
||||
res[profile.AccountID] = append(res[profile.AccountID],
|
||||
&UserData{
|
||||
ID: profile.UserID,
|
||||
Name: profile.Name,
|
||||
Email: profile.Email,
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: profile.AccountID,
|
||||
WTPendingInvite: &profile.PendingInvite,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Boilerplate implementation for Get Requests.
|
||||
func doGetReq(ctx context.Context, client ManagerHTTPClient, url, accessToken string) ([]byte, error) {
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if accessToken != "" {
|
||||
req.Header.Add("authorization", "Bearer "+accessToken)
|
||||
}
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = res.Body.Close()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error while closing body for url %s: %v", url, err)
|
||||
}
|
||||
}()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
return nil, fmt.Errorf("unable to get %s, statusCode %d", url, res.StatusCode)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
@@ -1,473 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
type mockHTTPClient struct {
|
||||
code int
|
||||
resBody string
|
||||
reqBody string
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||
if req.Body != nil {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err == nil {
|
||||
c.reqBody = string(body)
|
||||
}
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: c.code,
|
||||
Body: io.NopCloser(strings.NewReader(c.resBody)),
|
||||
}, c.err
|
||||
}
|
||||
|
||||
type mockJsonParser struct {
|
||||
jsonParser JsonParser
|
||||
marshalErrorString string
|
||||
unmarshalErrorString string
|
||||
}
|
||||
|
||||
func (m *mockJsonParser) Marshal(v interface{}) ([]byte, error) {
|
||||
if m.marshalErrorString != "" {
|
||||
return nil, errors.New(m.marshalErrorString)
|
||||
}
|
||||
return m.jsonParser.Marshal(v)
|
||||
}
|
||||
|
||||
func (m *mockJsonParser) Unmarshal(data []byte, v interface{}) error {
|
||||
if m.unmarshalErrorString != "" {
|
||||
return errors.New(m.unmarshalErrorString)
|
||||
}
|
||||
return m.jsonParser.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
type mockAuth0Credentials struct {
|
||||
jwtToken JWTToken
|
||||
err error
|
||||
}
|
||||
|
||||
func (mc *mockAuth0Credentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return mc.jwtToken, mc.err
|
||||
}
|
||||
|
||||
func newTestJWT(t *testing.T, expInt int) string {
|
||||
t.Helper()
|
||||
now := time.Now()
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"iat": now.Unix(),
|
||||
"exp": now.Add(time.Duration(expInt) * time.Second).Unix(),
|
||||
})
|
||||
var hmacSampleSecret []byte
|
||||
tokenString, err := token.SignedString(hmacSampleSecret)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return tokenString
|
||||
}
|
||||
|
||||
func TestAuth0_RequestJWTToken(t *testing.T) {
|
||||
|
||||
type requestJWTTokenTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputResBody string
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedCode int
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
requestJWTTokenTesttCase1 := requestJWTTokenTest{
|
||||
name: "Get Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedCode: 200,
|
||||
expectedToken: token,
|
||||
}
|
||||
requestJWTTokenTestCase2 := requestJWTTokenTest{
|
||||
name: "Request Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputResBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get token, statusCode 400"),
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputResBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := Auth0ClientConfig{}
|
||||
|
||||
creds := Auth0Credentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
res, err := creds.requestJWTToken(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
assert.NoError(t, err, "unable to read the response body")
|
||||
|
||||
jwtToken := JWTToken{}
|
||||
err = json.Unmarshal(body, &jwtToken)
|
||||
assert.NoError(t, err, "unable to parse the json input")
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuth0_ParseRequestJWTResponse(t *testing.T) {
|
||||
type parseRequestJWTResponseTest struct {
|
||||
name string
|
||||
inputResBody string
|
||||
helper ManagerHelper
|
||||
expectedToken string
|
||||
expectedExpiresIn int
|
||||
assertErrFunc func(t assert.TestingT, err error, msgAndArgs ...interface{}) bool
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
exp := 100
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
|
||||
name: "Parse Good JWT Body",
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedToken: token,
|
||||
expectedExpiresIn: exp,
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "no error was expected",
|
||||
}
|
||||
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
|
||||
name: "Parse Bad json JWT Body",
|
||||
inputResBody: "",
|
||||
helper: JsonParser{},
|
||||
expectedToken: "",
|
||||
expectedExpiresIn: 0,
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "json error was expected",
|
||||
}
|
||||
|
||||
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
rawBody := io.NopCloser(strings.NewReader(testCase.inputResBody))
|
||||
|
||||
config := Auth0ClientConfig{}
|
||||
|
||||
creds := Auth0Credentials{
|
||||
clientConfig: config,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuth0_JwtStillValid(t *testing.T) {
|
||||
|
||||
type jwtStillValidTest struct {
|
||||
name string
|
||||
inputTime time.Time
|
||||
expectedResult bool
|
||||
message string
|
||||
}
|
||||
jwtStillValidTestCase1 := jwtStillValidTest{
|
||||
name: "JWT still valid",
|
||||
inputTime: time.Now().Add(10 * time.Second),
|
||||
expectedResult: true,
|
||||
message: "should be true",
|
||||
}
|
||||
jwtStillValidTestCase2 := jwtStillValidTest{
|
||||
name: "JWT is invalid",
|
||||
inputTime: time.Now(),
|
||||
expectedResult: false,
|
||||
message: "should be false",
|
||||
}
|
||||
|
||||
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
config := Auth0ClientConfig{}
|
||||
|
||||
creds := Auth0Credentials{
|
||||
clientConfig: config,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputTime
|
||||
|
||||
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuth0_Authenticate(t *testing.T) {
|
||||
type authenticateTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputResBody string
|
||||
inputExpireToken time.Time
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedCode int
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
authenticateTestCase1 := authenticateTest{
|
||||
name: "Get Cached token",
|
||||
inputExpireToken: time.Now().Add(30 * time.Second),
|
||||
helper: JsonParser{},
|
||||
// expectedFuncExitErrDiff: fmt.Errorf("unable to get token, statusCode 400"),
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
authenticateTestCase2 := authenticateTest{
|
||||
name: "Get Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedCode: 200,
|
||||
expectedToken: token,
|
||||
}
|
||||
authenticateTestCase3 := authenticateTest{
|
||||
name: "Get Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputResBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get token, statusCode 400"),
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputResBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := Auth0ClientConfig{}
|
||||
|
||||
creds := Auth0Credentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
|
||||
|
||||
type updateUserAppMetadataTest struct {
|
||||
name string
|
||||
inputReqBody string
|
||||
expectedReqBody string
|
||||
appMetadata AppMetadata
|
||||
statusCode int
|
||||
helper ManagerHelper
|
||||
managerCreds ManagerCredentials
|
||||
assertErrFunc func(t assert.TestingT, err error, msgAndArgs ...interface{}) bool
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
exp := 15
|
||||
token := newTestJWT(t, exp)
|
||||
appMetadata := AppMetadata{WTAccountID: "ok"}
|
||||
|
||||
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
|
||||
name: "Bad Authentication",
|
||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
expectedReqBody: "",
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 400,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockAuth0Credentials{
|
||||
jwtToken: JWTToken{},
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
|
||||
name: "Bad Status Code",
|
||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\"}}", appMetadata.WTAccountID),
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 400,
|
||||
helper: JsonParser{},
|
||||
managerCreds: &mockAuth0Credentials{
|
||||
jwtToken: JWTToken{},
|
||||
},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
|
||||
name: "Bad Response Parsing",
|
||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
statusCode: 400,
|
||||
helper: &mockJsonParser{marshalErrorString: "error"},
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "should return error",
|
||||
}
|
||||
|
||||
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
|
||||
name: "Good request",
|
||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\"}}", appMetadata.WTAccountID),
|
||||
appMetadata: appMetadata,
|
||||
statusCode: 200,
|
||||
helper: JsonParser{},
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
invite := true
|
||||
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
|
||||
name: "Update Pending Invite",
|
||||
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":true}}", appMetadata.WTAccountID),
|
||||
appMetadata: AppMetadata{
|
||||
WTAccountID: "ok",
|
||||
WTPendingInvite: &invite,
|
||||
},
|
||||
statusCode: 200,
|
||||
helper: JsonParser{},
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
|
||||
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputReqBody,
|
||||
code: testCase.statusCode,
|
||||
}
|
||||
config := Auth0ClientConfig{}
|
||||
|
||||
var creds ManagerCredentials
|
||||
if testCase.managerCreds != nil {
|
||||
creds = testCase.managerCreds
|
||||
} else {
|
||||
creds = &Auth0Credentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
}
|
||||
|
||||
manager := Auth0Manager{
|
||||
httpClient: &jwtReqClient,
|
||||
credentials: creds,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
err := manager.UpdateUserAppMetadata(context.Background(), "1", testCase.appMetadata)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
assert.Equal(t, testCase.expectedReqBody, jwtReqClient.reqBody, "request body should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAuth0Manager(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputConfig Auth0ClientConfig
|
||||
assertErrFunc require.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
defaultTestConfig := Auth0ClientConfig{
|
||||
AuthIssuer: "https://abc-auth0.eu.auth0.com",
|
||||
Audience: "https://abc-auth0.eu.auth0.com/api/v2/",
|
||||
ClientID: "abcdefg",
|
||||
ClientSecret: "supersecret",
|
||||
GrantType: "client_credentials",
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "Good Scenario With Config",
|
||||
inputConfig: defaultTestConfig,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
testCase2Config := defaultTestConfig
|
||||
testCase2Config.ClientID = ""
|
||||
|
||||
testCase2 := test{
|
||||
name: "Missing Configuration",
|
||||
inputConfig: testCase2Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "shouldn't return error when field empty",
|
||||
}
|
||||
|
||||
testCase3Config := defaultTestConfig
|
||||
testCase3Config.AuthIssuer = "abc-auth0.eu.auth0.com"
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
_, err := NewAuth0Manager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,428 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/api/v3"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
// AuthentikManager authentik manager client instance.
|
||||
type AuthentikManager struct {
|
||||
apiClient *api.APIClient
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// AuthentikClientConfig authentik manager client configurations.
|
||||
type AuthentikClientConfig struct {
|
||||
Issuer string
|
||||
ClientID string
|
||||
Username string
|
||||
Password string
|
||||
TokenEndpoint string
|
||||
GrantType string
|
||||
}
|
||||
|
||||
// AuthentikCredentials authentik authentication information.
|
||||
type AuthentikCredentials struct {
|
||||
clientConfig AuthentikClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
jwtToken JWTToken
|
||||
mux sync.Mutex
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// NewAuthentikManager creates a new instance of the AuthentikManager.
|
||||
func NewAuthentikManager(config AuthentikClientConfig,
|
||||
appMetrics telemetry.AppMetrics) (*AuthentikManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.ClientID == "" {
|
||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, clientID is missing")
|
||||
}
|
||||
|
||||
if config.Username == "" {
|
||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, Username is missing")
|
||||
}
|
||||
|
||||
if config.Password == "" {
|
||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, Password is missing")
|
||||
}
|
||||
|
||||
if config.TokenEndpoint == "" {
|
||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, TokenEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.Issuer == "" {
|
||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, Issuer is missing")
|
||||
}
|
||||
|
||||
if config.GrantType == "" {
|
||||
return nil, fmt.Errorf("authentik IdP configuration is incomplete, GrantType is missing")
|
||||
}
|
||||
|
||||
// authentik client configuration
|
||||
issuerURL, err := url.Parse(config.Issuer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authentikConfig := api.NewConfiguration()
|
||||
authentikConfig.HTTPClient = httpClient
|
||||
authentikConfig.Host = issuerURL.Host
|
||||
authentikConfig.Scheme = issuerURL.Scheme
|
||||
|
||||
credentials := &AuthentikCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &AuthentikManager{
|
||||
apiClient: api.NewAPIClient(authentikConfig),
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from authentik.
|
||||
func (ac *AuthentikCredentials) jwtStillValid() bool {
|
||||
return !ac.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(ac.jwtToken.expiresInTime)
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token.
|
||||
func (ac *AuthentikCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", ac.clientConfig.ClientID)
|
||||
data.Set("username", ac.clientConfig.Username)
|
||||
data.Set("password", ac.clientConfig.Password)
|
||||
data.Set("grant_type", ac.clientConfig.GrantType)
|
||||
data.Set("scope", "goauthentik.io/api")
|
||||
|
||||
payload := strings.NewReader(data.Encode())
|
||||
req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for authentik idp manager")
|
||||
|
||||
resp, err := ac.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if ac.appMetrics != nil {
|
||||
ac.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unable to get authentik token, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
|
||||
func (ac *AuthentikCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
||||
jwtToken := JWTToken{}
|
||||
body, err := io.ReadAll(rawBody)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
err = ac.helper.Unmarshal(body, &jwtToken)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||
}
|
||||
|
||||
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
// Exp maps into exp from jwt token
|
||||
var IssuedAt struct{ Exp int64 }
|
||||
err = ac.helper.Unmarshal(data, &IssuedAt)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
||||
|
||||
return jwtToken, nil
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the authentik management API.
|
||||
func (ac *AuthentikCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
ac.mux.Lock()
|
||||
defer ac.mux.Unlock()
|
||||
|
||||
if ac.appMetrics != nil {
|
||||
ac.appMetrics.IDPMetrics().CountAuthenticate()
|
||||
}
|
||||
|
||||
// reuse the token without requesting a new one if it is not expired,
|
||||
// and if expiry time is sufficient time available to make a request.
|
||||
if ac.jwtStillValid() {
|
||||
return ac.jwtToken, nil
|
||||
}
|
||||
|
||||
resp, err := ac.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return ac.jwtToken, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
jwtToken, err := ac.parseRequestJWTResponse(resp.Body)
|
||||
if err != nil {
|
||||
return ac.jwtToken, err
|
||||
}
|
||||
|
||||
ac.jwtToken = jwtToken
|
||||
|
||||
return ac.jwtToken, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (am *AuthentikManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from authentik via ID.
|
||||
func (am *AuthentikManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
ctx, err := am.authenticationContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userPk, err := strconv.ParseInt(userID, 10, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, resp, err := am.apiClient.CoreApi.CoreUsersRetrieve(ctx, int32(userPk)).Execute()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
|
||||
}
|
||||
|
||||
userData := parseAuthentikUser(*user)
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (am *AuthentikManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
users, err := am.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
for index, user := range users {
|
||||
user.AppMetadata.WTAccountID = accountID
|
||||
users[index] = user
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (am *AuthentikManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
users, err := am.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// getAllUsers returns all users in a Authentik account.
|
||||
func (am *AuthentikManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
|
||||
users := make([]*UserData, 0)
|
||||
|
||||
page := int32(1)
|
||||
for {
|
||||
ctx, err := am.authenticationContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Page(page).Execute()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
for _, user := range userList.Results {
|
||||
users = append(users, parseAuthentikUser(user))
|
||||
}
|
||||
|
||||
page = int32(userList.GetPagination().Next)
|
||||
if userList.GetPagination().Next == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in authentik Idp and sends an invitation.
|
||||
func (am *AuthentikManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (am *AuthentikManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
ctx, err := am.authenticationContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userList, resp, err := am.apiClient.CoreApi.CoreUsersList(ctx).Email(email).Execute()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode)
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, user := range userList.Results {
|
||||
users = append(users, parseAuthentikUser(user))
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (am *AuthentikManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Authentik
|
||||
func (am *AuthentikManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
ctx, err := am.authenticationContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userPk, err := strconv.ParseInt(userID, 10, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := am.apiClient.CoreApi.CoreUsersDestroy(ctx, int32(userPk)).Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close() // nolint
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to delete user %s, statusCode %d", userID, resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *AuthentikManager) authenticationContext(ctx context.Context) (context.Context, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
value := map[string]api.APIKey{
|
||||
"authentik": {
|
||||
Key: jwtToken.AccessToken,
|
||||
Prefix: jwtToken.TokenType,
|
||||
},
|
||||
}
|
||||
return context.WithValue(context.Background(), api.ContextAPIKeys, value), nil
|
||||
}
|
||||
|
||||
func parseAuthentikUser(user api.User) *UserData {
|
||||
return &UserData{
|
||||
Email: *user.Email,
|
||||
Name: user.Name,
|
||||
ID: strconv.FormatInt(int64(user.Pk), 10),
|
||||
}
|
||||
}
|
||||
@@ -1,320 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
func TestNewAuthentikManager(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputConfig AuthentikClientConfig
|
||||
assertErrFunc require.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
defaultTestConfig := AuthentikClientConfig{
|
||||
ClientID: "client_id",
|
||||
Username: "username",
|
||||
Password: "password",
|
||||
TokenEndpoint: "https://localhost:8080/application/o/token/",
|
||||
Issuer: "https://localhost:8080/application/o/netbird/",
|
||||
GrantType: "client_credentials",
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "Good Configuration",
|
||||
inputConfig: defaultTestConfig,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
testCase2Config := defaultTestConfig
|
||||
testCase2Config.ClientID = ""
|
||||
|
||||
testCase2 := test{
|
||||
name: "Missing ClientID Configuration",
|
||||
inputConfig: testCase2Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase3Config := defaultTestConfig
|
||||
testCase3Config.Username = ""
|
||||
|
||||
testCase3 := test{
|
||||
name: "Missing Username Configuration",
|
||||
inputConfig: testCase3Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase4Config := defaultTestConfig
|
||||
testCase4Config.Password = ""
|
||||
|
||||
testCase4 := test{
|
||||
name: "Missing Password Configuration",
|
||||
inputConfig: testCase4Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase5Config := defaultTestConfig
|
||||
testCase5Config.GrantType = ""
|
||||
|
||||
testCase5 := test{
|
||||
name: "Missing GrantType Configuration",
|
||||
inputConfig: testCase5Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase6Config := defaultTestConfig
|
||||
testCase6Config.Issuer = ""
|
||||
|
||||
testCase6 := test{
|
||||
name: "Missing Issuer Configuration",
|
||||
inputConfig: testCase6Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
_, err := NewAuthentikManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthentikRequestJWTToken(t *testing.T) {
|
||||
type requestJWTTokenTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputRespBody string
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
requestJWTTokenTesttCase1 := requestJWTTokenTest{
|
||||
name: "Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedToken: token,
|
||||
}
|
||||
requestJWTTokenTestCase2 := requestJWTTokenTest{
|
||||
name: "Request Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputRespBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get authentik token, statusCode 400"),
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputRespBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := AuthentikClientConfig{}
|
||||
|
||||
creds := AuthentikCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
resp, err := creds.requestJWTToken(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
assert.NoError(t, err, "unable to read the response body")
|
||||
|
||||
jwtToken := JWTToken{}
|
||||
err = testCase.helper.Unmarshal(body, &jwtToken)
|
||||
assert.NoError(t, err, "unable to parse the json input")
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthentikParseRequestJWTResponse(t *testing.T) {
|
||||
type parseRequestJWTResponseTest struct {
|
||||
name string
|
||||
inputRespBody string
|
||||
helper ManagerHelper
|
||||
expectedToken string
|
||||
expectedExpiresIn int
|
||||
assertErrFunc assert.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
exp := 100
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
|
||||
name: "Parse Good JWT Body",
|
||||
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedToken: token,
|
||||
expectedExpiresIn: exp,
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "no error was expected",
|
||||
}
|
||||
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
|
||||
name: "Parse Bad json JWT Body",
|
||||
inputRespBody: "",
|
||||
helper: JsonParser{},
|
||||
expectedToken: "",
|
||||
expectedExpiresIn: 0,
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "json error was expected",
|
||||
}
|
||||
|
||||
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
rawBody := io.NopCloser(strings.NewReader(testCase.inputRespBody))
|
||||
config := AuthentikClientConfig{}
|
||||
|
||||
creds := AuthentikCredentials{
|
||||
clientConfig: config,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthentikJwtStillValid(t *testing.T) {
|
||||
type jwtStillValidTest struct {
|
||||
name string
|
||||
inputTime time.Time
|
||||
expectedResult bool
|
||||
message string
|
||||
}
|
||||
|
||||
jwtStillValidTestCase1 := jwtStillValidTest{
|
||||
name: "JWT still valid",
|
||||
inputTime: time.Now().Add(10 * time.Second),
|
||||
expectedResult: true,
|
||||
message: "should be true",
|
||||
}
|
||||
jwtStillValidTestCase2 := jwtStillValidTest{
|
||||
name: "JWT is invalid",
|
||||
inputTime: time.Now(),
|
||||
expectedResult: false,
|
||||
message: "should be false",
|
||||
}
|
||||
|
||||
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
config := AuthentikClientConfig{}
|
||||
|
||||
creds := AuthentikCredentials{
|
||||
clientConfig: config,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputTime
|
||||
|
||||
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthentikAuthenticate(t *testing.T) {
|
||||
type authenticateTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputResBody string
|
||||
inputExpireToken time.Time
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedCode int
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
authenticateTestCase1 := authenticateTest{
|
||||
name: "Get Cached token",
|
||||
inputExpireToken: time.Now().Add(30 * time.Second),
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: nil,
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
authenticateTestCase2 := authenticateTest{
|
||||
name: "Get Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedCode: 200,
|
||||
expectedToken: token,
|
||||
}
|
||||
|
||||
authenticateTestCase3 := authenticateTest{
|
||||
name: "Get Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputResBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get authentik token, statusCode 400"),
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputResBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := AuthentikClientConfig{}
|
||||
|
||||
creds := AuthentikCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,459 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const profileFields = "id,displayName,mail,userPrincipalName"
|
||||
|
||||
// AzureManager azure manager client instance.
|
||||
type AzureManager struct {
|
||||
ClientID string
|
||||
ObjectID string
|
||||
GraphAPIEndpoint string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// AzureClientConfig azure manager client configurations.
|
||||
type AzureClientConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
ObjectID string
|
||||
GraphAPIEndpoint string
|
||||
TokenEndpoint string
|
||||
GrantType string
|
||||
}
|
||||
|
||||
// AzureCredentials azure authentication information.
|
||||
type AzureCredentials struct {
|
||||
clientConfig AzureClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
jwtToken JWTToken
|
||||
mux sync.Mutex
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// azureProfile represents an azure user profile.
|
||||
type azureProfile map[string]any
|
||||
|
||||
// NewAzureManager creates a new instance of the AzureManager.
|
||||
func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) (*AzureManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.ClientID == "" {
|
||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, clientID is missing")
|
||||
}
|
||||
|
||||
if config.ClientSecret == "" {
|
||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, ClientSecret is missing")
|
||||
}
|
||||
|
||||
if config.TokenEndpoint == "" {
|
||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, TokenEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.GraphAPIEndpoint == "" {
|
||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, GraphAPIEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.ObjectID == "" {
|
||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, ObjectID is missing")
|
||||
}
|
||||
|
||||
if config.GrantType == "" {
|
||||
return nil, fmt.Errorf("azure IdP configuration is incomplete, GrantType is missing")
|
||||
}
|
||||
|
||||
credentials := &AzureCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &AzureManager{
|
||||
ObjectID: config.ObjectID,
|
||||
ClientID: config.ClientID,
|
||||
GraphAPIEndpoint: config.GraphAPIEndpoint,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from azure.
|
||||
func (ac *AzureCredentials) jwtStillValid() bool {
|
||||
return !ac.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(ac.jwtToken.expiresInTime)
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token.
|
||||
func (ac *AzureCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", ac.clientConfig.ClientID)
|
||||
data.Set("client_secret", ac.clientConfig.ClientSecret)
|
||||
data.Set("grant_type", ac.clientConfig.GrantType)
|
||||
parsedURL, err := url.Parse(ac.clientConfig.GraphAPIEndpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// get base url and add "/.default" as scope
|
||||
baseURL := parsedURL.Scheme + "://" + parsedURL.Host
|
||||
scopeURL := baseURL + "/.default"
|
||||
data.Set("scope", scopeURL)
|
||||
|
||||
payload := strings.NewReader(data.Encode())
|
||||
req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for azure idp manager")
|
||||
|
||||
resp, err := ac.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if ac.appMetrics != nil {
|
||||
ac.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unable to get azure token, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
|
||||
func (ac *AzureCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
||||
jwtToken := JWTToken{}
|
||||
body, err := io.ReadAll(rawBody)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
err = ac.helper.Unmarshal(body, &jwtToken)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||
}
|
||||
|
||||
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
// Exp maps into exp from jwt token
|
||||
var IssuedAt struct{ Exp int64 }
|
||||
err = ac.helper.Unmarshal(data, &IssuedAt)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
||||
|
||||
return jwtToken, nil
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the azure Management API.
|
||||
func (ac *AzureCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
ac.mux.Lock()
|
||||
defer ac.mux.Unlock()
|
||||
|
||||
if ac.appMetrics != nil {
|
||||
ac.appMetrics.IDPMetrics().CountAuthenticate()
|
||||
}
|
||||
|
||||
// reuse the token without requesting a new one if it is not expired,
|
||||
// and if expiry time is sufficient time available to make a request.
|
||||
if ac.jwtStillValid() {
|
||||
return ac.jwtToken, nil
|
||||
}
|
||||
|
||||
resp, err := ac.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return ac.jwtToken, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
jwtToken, err := ac.parseRequestJWTResponse(resp.Body)
|
||||
if err != nil {
|
||||
return ac.jwtToken, err
|
||||
}
|
||||
|
||||
ac.jwtToken = jwtToken
|
||||
|
||||
return ac.jwtToken, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in azure AD Idp.
|
||||
func (am *AzureManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from keycloak via ID.
|
||||
func (am *AzureManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
q := url.Values{}
|
||||
q.Add("$select", profileFields)
|
||||
|
||||
body, err := am.get(ctx, "users/"+userID, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
var profile azureProfile
|
||||
err = am.helper.Unmarshal(body, &profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userData := profile.userData()
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (am *AzureManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
q := url.Values{}
|
||||
q.Add("$select", profileFields)
|
||||
|
||||
body, err := am.get(ctx, "users/"+email, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
var profile azureProfile
|
||||
err = am.helper.Unmarshal(body, &profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
users = append(users, profile.userData())
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (am *AzureManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
users, err := am.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
for index, user := range users {
|
||||
user.AppMetadata.WTAccountID = accountID
|
||||
users[index] = user
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (am *AzureManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
users, err := am.getAllUsers(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID.
|
||||
func (am *AzureManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (am *AzureManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Azure.
|
||||
func (am *AzureManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, url.QueryEscape(userID))
|
||||
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
log.WithContext(ctx).Debugf("delete idp user %s", userID)
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAllUsers returns all users in an Azure AD account.
|
||||
func (am *AzureManager) getAllUsers(ctx context.Context) ([]*UserData, error) {
|
||||
users := make([]*UserData, 0)
|
||||
|
||||
q := url.Values{}
|
||||
q.Add("$select", profileFields)
|
||||
q.Add("$top", "500")
|
||||
|
||||
for nextLink := "users"; nextLink != ""; {
|
||||
body, err := am.get(ctx, nextLink, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var profiles struct {
|
||||
Value []azureProfile
|
||||
NextLink string `json:"@odata.nextLink"`
|
||||
}
|
||||
err = am.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, profile := range profiles.Value {
|
||||
users = append(users, profile.userData())
|
||||
}
|
||||
|
||||
nextLink = profiles.NextLink
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// get perform Get requests.
|
||||
func (am *AzureManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := am.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var reqURL string
|
||||
if strings.HasPrefix(resource, "https") {
|
||||
// Already an absolute URL for paging
|
||||
reqURL = resource
|
||||
} else {
|
||||
reqURL = fmt.Sprintf("%s/%s?%s", am.GraphAPIEndpoint, resource, q.Encode())
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// userData construct user data from keycloak profile.
|
||||
func (ap azureProfile) userData() *UserData {
|
||||
id, ok := ap["id"].(string)
|
||||
if !ok {
|
||||
id = ""
|
||||
}
|
||||
|
||||
email, ok := ap["userPrincipalName"].(string)
|
||||
if !ok {
|
||||
email = ""
|
||||
}
|
||||
|
||||
name, ok := ap["displayName"].(string)
|
||||
if !ok {
|
||||
name = ""
|
||||
}
|
||||
|
||||
return &UserData{
|
||||
Email: email,
|
||||
Name: name,
|
||||
ID: id,
|
||||
}
|
||||
}
|
||||
@@ -1,178 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAzureJwtStillValid(t *testing.T) {
|
||||
type jwtStillValidTest struct {
|
||||
name string
|
||||
inputTime time.Time
|
||||
expectedResult bool
|
||||
message string
|
||||
}
|
||||
|
||||
jwtStillValidTestCase1 := jwtStillValidTest{
|
||||
name: "JWT still valid",
|
||||
inputTime: time.Now().Add(10 * time.Second),
|
||||
expectedResult: true,
|
||||
message: "should be true",
|
||||
}
|
||||
jwtStillValidTestCase2 := jwtStillValidTest{
|
||||
name: "JWT is invalid",
|
||||
inputTime: time.Now(),
|
||||
expectedResult: false,
|
||||
message: "should be false",
|
||||
}
|
||||
|
||||
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
config := AzureClientConfig{}
|
||||
|
||||
creds := AzureCredentials{
|
||||
clientConfig: config,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputTime
|
||||
|
||||
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzureAuthenticate(t *testing.T) {
|
||||
type authenticateTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputResBody string
|
||||
inputExpireToken time.Time
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedCode int
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
authenticateTestCase1 := authenticateTest{
|
||||
name: "Get Cached token",
|
||||
inputExpireToken: time.Now().Add(30 * time.Second),
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: nil,
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
authenticateTestCase2 := authenticateTest{
|
||||
name: "Get Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedCode: 200,
|
||||
expectedToken: token,
|
||||
}
|
||||
|
||||
authenticateTestCase3 := authenticateTest{
|
||||
name: "Get Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputResBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get azure token, statusCode 400"),
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputResBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := AzureClientConfig{}
|
||||
|
||||
creds := AzureCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzureProfile(t *testing.T) {
|
||||
type azureProfileTest struct {
|
||||
name string
|
||||
invite bool
|
||||
inputProfile azureProfile
|
||||
expectedUserData UserData
|
||||
}
|
||||
|
||||
azureProfileTestCase1 := azureProfileTest{
|
||||
name: "Good Request",
|
||||
invite: false,
|
||||
inputProfile: azureProfile{
|
||||
"id": "test1",
|
||||
"displayName": "John Doe",
|
||||
"userPrincipalName": "test1@test.com",
|
||||
},
|
||||
expectedUserData: UserData{
|
||||
Email: "test1@test.com",
|
||||
Name: "John Doe",
|
||||
ID: "test1",
|
||||
},
|
||||
}
|
||||
|
||||
azureProfileTestCase2 := azureProfileTest{
|
||||
name: "Missing User ID",
|
||||
invite: true,
|
||||
inputProfile: azureProfile{
|
||||
"displayName": "John Doe",
|
||||
"userPrincipalName": "test2@test.com",
|
||||
},
|
||||
expectedUserData: UserData{
|
||||
Email: "test2@test.com",
|
||||
Name: "John Doe",
|
||||
},
|
||||
}
|
||||
|
||||
azureProfileTestCase3 := azureProfileTest{
|
||||
name: "Missing User Name",
|
||||
invite: false,
|
||||
inputProfile: azureProfile{
|
||||
"id": "test3",
|
||||
"userPrincipalName": "test3@test.com",
|
||||
},
|
||||
expectedUserData: UserData{
|
||||
ID: "test3",
|
||||
Email: "test3@test.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2, azureProfileTestCase3} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
|
||||
userData := testCase.inputProfile.userData()
|
||||
|
||||
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
|
||||
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
|
||||
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,263 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2/google"
|
||||
admin "google.golang.org/api/admin/directory/v1"
|
||||
"google.golang.org/api/option"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
// GoogleWorkspaceManager Google Workspace manager client instance.
|
||||
type GoogleWorkspaceManager struct {
|
||||
usersService *admin.UsersService
|
||||
CustomerID string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// GoogleWorkspaceClientConfig Google Workspace manager client configurations.
|
||||
type GoogleWorkspaceClientConfig struct {
|
||||
ServiceAccountKey string
|
||||
CustomerID string
|
||||
}
|
||||
|
||||
// GoogleWorkspaceCredentials Google Workspace authentication information.
|
||||
type GoogleWorkspaceCredentials struct {
|
||||
clientConfig GoogleWorkspaceClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
func (gc *GoogleWorkspaceCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return JWTToken{}, nil
|
||||
}
|
||||
|
||||
// NewGoogleWorkspaceManager creates a new instance of the GoogleWorkspaceManager.
|
||||
func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClientConfig, appMetrics telemetry.AppMetrics) (*GoogleWorkspaceManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.CustomerID == "" {
|
||||
return nil, fmt.Errorf("google IdP configuration is incomplete, CustomerID is missing")
|
||||
}
|
||||
|
||||
credentials := &GoogleWorkspaceCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
// Create a new Admin SDK Directory service client
|
||||
adminCredentials, err := getGoogleCredentials(ctx, config.ServiceAccountKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
service, err := admin.NewService(context.Background(),
|
||||
option.WithScopes(admin.AdminDirectoryUserReadonlyScope),
|
||||
option.WithCredentials(adminCredentials),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GoogleWorkspaceManager{
|
||||
usersService: service.Users,
|
||||
CustomerID: config.CustomerID,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (gm *GoogleWorkspaceManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from Google Workspace via ID.
|
||||
func (gm *GoogleWorkspaceManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
user, err := gm.usersService.Get(userID).Do()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if gm.appMetrics != nil {
|
||||
gm.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
userData := parseGoogleWorkspaceUser(user)
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (gm *GoogleWorkspaceManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
|
||||
users, err := gm.getAllUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if gm.appMetrics != nil {
|
||||
gm.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
for index, user := range users {
|
||||
user.AppMetadata.WTAccountID = accountID
|
||||
users[index] = user
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (gm *GoogleWorkspaceManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
|
||||
users, err := gm.getAllUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
|
||||
|
||||
if gm.appMetrics != nil {
|
||||
gm.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// getAllUsers returns all users in a Google Workspace account filtered by customer ID.
|
||||
func (gm *GoogleWorkspaceManager) getAllUsers() ([]*UserData, error) {
|
||||
users := make([]*UserData, 0)
|
||||
pageToken := ""
|
||||
for {
|
||||
call := gm.usersService.List().Customer(gm.CustomerID).MaxResults(500)
|
||||
if pageToken != "" {
|
||||
call.PageToken(pageToken)
|
||||
}
|
||||
|
||||
resp, err := call.Do()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, user := range resp.Users {
|
||||
users = append(users, parseGoogleWorkspaceUser(user))
|
||||
}
|
||||
|
||||
pageToken = resp.NextPageToken
|
||||
if pageToken == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in Google Workspace and sends an invitation.
|
||||
func (gm *GoogleWorkspaceManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (gm *GoogleWorkspaceManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
|
||||
user, err := gm.usersService.Get(email).Do()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if gm.appMetrics != nil {
|
||||
gm.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
users = append(users, parseGoogleWorkspaceUser(user))
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (gm *GoogleWorkspaceManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from GoogleWorkspace.
|
||||
func (gm *GoogleWorkspaceManager) DeleteUser(_ context.Context, userID string) error {
|
||||
if err := gm.usersService.Delete(userID).Do(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if gm.appMetrics != nil {
|
||||
gm.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey.
|
||||
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it.
|
||||
// If that fails, it falls back to using the default Google credentials path.
|
||||
// It returns the retrieved credentials or an error if unsuccessful.
|
||||
func getGoogleCredentials(ctx context.Context, serviceAccountKey string) (*google.Credentials, error) {
|
||||
log.WithContext(ctx).Debug("retrieving google credentials from the base64 encoded service account key")
|
||||
decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode service account key: %w", err)
|
||||
}
|
||||
|
||||
creds, err := google.CredentialsFromJSON(
|
||||
context.Background(),
|
||||
decodeKey,
|
||||
admin.AdminDirectoryUserReadonlyScope,
|
||||
)
|
||||
if err == nil {
|
||||
// No need to fallback to the default Google credentials path
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err)
|
||||
log.WithContext(ctx).Debug("falling back to default google credentials location")
|
||||
|
||||
creds, err = google.FindDefaultCredentials(
|
||||
context.Background(),
|
||||
admin.AdminDirectoryUserReadonlyScope,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
// parseGoogleWorkspaceUser parse google user to UserData.
|
||||
func parseGoogleWorkspaceUser(user *admin.User) *UserData {
|
||||
return &UserData{
|
||||
ID: user.Id,
|
||||
Email: user.PrimaryEmail,
|
||||
Name: user.Name.FullName,
|
||||
}
|
||||
}
|
||||
@@ -11,24 +11,25 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const (
|
||||
// UnsetAccountID is a special key to map users without an account ID
|
||||
UnsetAccountID = "unset"
|
||||
)
|
||||
|
||||
// Manager idp manager interface
|
||||
// Note: NetBird is the single source of truth for authorization data (roles, account membership, invite status).
|
||||
// The IdP only stores identity information (email, name, credentials).
|
||||
type Manager interface {
|
||||
UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error
|
||||
GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error)
|
||||
GetAccount(ctx context.Context, accountId string) ([]*UserData, error)
|
||||
GetAllAccounts(ctx context.Context) (map[string][]*UserData, error)
|
||||
CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error)
|
||||
// CreateUser creates a new user in the IdP. Returns basic user data (ID, email, name).
|
||||
CreateUser(ctx context.Context, email, name string) (*UserData, error)
|
||||
// GetUserDataByID retrieves user identity data from the IdP by user ID.
|
||||
GetUserDataByID(ctx context.Context, userId string) (*UserData, error)
|
||||
// GetUserByEmail searches for users by email address.
|
||||
GetUserByEmail(ctx context.Context, email string) ([]*UserData, error)
|
||||
// GetAllUsers returns all users from the IdP for cache warming.
|
||||
GetAllUsers(ctx context.Context) ([]*UserData, error)
|
||||
// InviteUserByID resends an invitation to a user who hasn't completed signup.
|
||||
InviteUserByID(ctx context.Context, userID string) error
|
||||
// DeleteUser removes a user from the IdP.
|
||||
DeleteUser(ctx context.Context, userID string) error
|
||||
}
|
||||
|
||||
// ClientConfig defines common client configuration for all IdP manager
|
||||
// ClientConfig defines common client configuration for the IdP manager
|
||||
type ClientConfig struct {
|
||||
Issuer string
|
||||
TokenEndpoint string
|
||||
@@ -42,13 +43,10 @@ type ExtraConfig map[string]string
|
||||
|
||||
// Config an idp configuration struct to be loaded from management server's config file
|
||||
type Config struct {
|
||||
ManagerType string
|
||||
ClientConfig *ClientConfig
|
||||
ExtraConfig ExtraConfig
|
||||
Auth0ClientCredentials *Auth0ClientConfig
|
||||
AzureClientCredentials *AzureClientConfig
|
||||
KeycloakClientCredentials *KeycloakClientConfig
|
||||
ZitadelClientCredentials *ZitadelClientConfig
|
||||
ManagerType string
|
||||
ClientConfig *ClientConfig
|
||||
ExtraConfig ExtraConfig
|
||||
ZitadelClientCredentials *ZitadelClientConfig
|
||||
}
|
||||
|
||||
// ManagerCredentials interface that authenticates using the credential of each type of idp
|
||||
@@ -67,11 +65,12 @@ type ManagerHelper interface {
|
||||
Unmarshal(data []byte, v interface{}) error
|
||||
}
|
||||
|
||||
// UserData represents identity information from the IdP.
|
||||
// Note: Authorization data (account membership, roles, invite status) is stored in NetBird's DB.
|
||||
type UserData struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
ID string `json:"user_id"`
|
||||
AppMetadata AppMetadata `json:"app_metadata"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
ID string `json:"user_id"`
|
||||
}
|
||||
|
||||
func (u *UserData) MarshalBinary() (data []byte, err error) {
|
||||
@@ -91,15 +90,6 @@ func (u *UserData) Unmarshal(data []byte) (err error) {
|
||||
return json.Unmarshal(data, &u)
|
||||
}
|
||||
|
||||
// AppMetadata user app metadata to associate with a profile
|
||||
type AppMetadata struct {
|
||||
// WTAccountID is a NetBird (previously Wiretrustee) account id to update in the IDP
|
||||
// maps to wt_account_id when json.marshal
|
||||
WTAccountID string `json:"wt_account_id,omitempty"`
|
||||
WTPendingInvite *bool `json:"wt_pending_invite,omitempty"`
|
||||
WTInvitedBy string `json:"wt_invited_by_email,omitempty"`
|
||||
}
|
||||
|
||||
// JWTToken a JWT object that holds information of a token
|
||||
type JWTToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
@@ -109,7 +99,8 @@ type JWTToken struct {
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
// NewManager returns a new idp manager based on the configuration that it receives
|
||||
// NewManager returns a new idp manager based on the configuration that it receives.
|
||||
// Only Zitadel is supported as the IdP manager.
|
||||
func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetrics) (Manager, error) {
|
||||
if config.ClientConfig != nil {
|
||||
config.ClientConfig.Issuer = strings.TrimSuffix(config.ClientConfig.Issuer, "/")
|
||||
@@ -118,46 +109,6 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
|
||||
switch strings.ToLower(config.ManagerType) {
|
||||
case "none", "":
|
||||
return nil, nil //nolint:nilnil
|
||||
case "auth0":
|
||||
auth0ClientConfig := config.Auth0ClientCredentials
|
||||
if config.ClientConfig != nil {
|
||||
auth0ClientConfig = &Auth0ClientConfig{
|
||||
Audience: config.ExtraConfig["Audience"],
|
||||
AuthIssuer: config.ClientConfig.Issuer,
|
||||
ClientID: config.ClientConfig.ClientID,
|
||||
ClientSecret: config.ClientConfig.ClientSecret,
|
||||
GrantType: config.ClientConfig.GrantType,
|
||||
}
|
||||
}
|
||||
|
||||
return NewAuth0Manager(*auth0ClientConfig, appMetrics)
|
||||
case "azure":
|
||||
azureClientConfig := config.AzureClientCredentials
|
||||
if config.ClientConfig != nil {
|
||||
azureClientConfig = &AzureClientConfig{
|
||||
ClientID: config.ClientConfig.ClientID,
|
||||
ClientSecret: config.ClientConfig.ClientSecret,
|
||||
GrantType: config.ClientConfig.GrantType,
|
||||
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
||||
ObjectID: config.ExtraConfig["ObjectId"],
|
||||
GraphAPIEndpoint: config.ExtraConfig["GraphApiEndpoint"],
|
||||
}
|
||||
}
|
||||
|
||||
return NewAzureManager(*azureClientConfig, appMetrics)
|
||||
case "keycloak":
|
||||
keycloakClientConfig := config.KeycloakClientCredentials
|
||||
if config.ClientConfig != nil {
|
||||
keycloakClientConfig = &KeycloakClientConfig{
|
||||
ClientID: config.ClientConfig.ClientID,
|
||||
ClientSecret: config.ClientConfig.ClientSecret,
|
||||
GrantType: config.ClientConfig.GrantType,
|
||||
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
||||
AdminEndpoint: config.ExtraConfig["AdminEndpoint"],
|
||||
}
|
||||
}
|
||||
|
||||
return NewKeycloakManager(*keycloakClientConfig, appMetrics)
|
||||
case "zitadel":
|
||||
zitadelClientConfig := config.ZitadelClientCredentials
|
||||
if config.ClientConfig != nil {
|
||||
@@ -172,42 +123,7 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
|
||||
}
|
||||
|
||||
return NewZitadelManager(*zitadelClientConfig, appMetrics)
|
||||
case "authentik":
|
||||
authentikConfig := AuthentikClientConfig{
|
||||
Issuer: config.ClientConfig.Issuer,
|
||||
ClientID: config.ClientConfig.ClientID,
|
||||
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
||||
GrantType: config.ClientConfig.GrantType,
|
||||
Username: config.ExtraConfig["Username"],
|
||||
Password: config.ExtraConfig["Password"],
|
||||
}
|
||||
return NewAuthentikManager(authentikConfig, appMetrics)
|
||||
case "okta":
|
||||
oktaClientConfig := OktaClientConfig{
|
||||
Issuer: config.ClientConfig.Issuer,
|
||||
TokenEndpoint: config.ClientConfig.TokenEndpoint,
|
||||
GrantType: config.ClientConfig.GrantType,
|
||||
APIToken: config.ExtraConfig["ApiToken"],
|
||||
}
|
||||
return NewOktaManager(oktaClientConfig, appMetrics)
|
||||
case "google":
|
||||
googleClientConfig := GoogleWorkspaceClientConfig{
|
||||
ServiceAccountKey: config.ExtraConfig["ServiceAccountKey"],
|
||||
CustomerID: config.ExtraConfig["CustomerId"],
|
||||
}
|
||||
return NewGoogleWorkspaceManager(ctx, googleClientConfig, appMetrics)
|
||||
case "jumpcloud":
|
||||
jumpcloudConfig := JumpCloudClientConfig{
|
||||
APIToken: config.ExtraConfig["ApiToken"],
|
||||
}
|
||||
return NewJumpCloudManager(jumpcloudConfig, appMetrics)
|
||||
case "pocketid":
|
||||
pocketidConfig := PocketIdClientConfig{
|
||||
APIToken: config.ExtraConfig["ApiToken"],
|
||||
ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"],
|
||||
}
|
||||
return NewPocketIdManager(pocketidConfig, appMetrics)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
|
||||
return nil, fmt.Errorf("unsupported IdP manager type: %s (only 'zitadel' is supported)", config.ManagerType)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,257 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
v1 "github.com/TheJumpCloud/jcapi-go/v1"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const (
|
||||
contentType = "application/json"
|
||||
accept = "application/json"
|
||||
)
|
||||
|
||||
// JumpCloudManager JumpCloud manager client instance.
|
||||
type JumpCloudManager struct {
|
||||
client *v1.APIClient
|
||||
apiToken string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// JumpCloudClientConfig JumpCloud manager client configurations.
|
||||
type JumpCloudClientConfig struct {
|
||||
APIToken string
|
||||
}
|
||||
|
||||
// JumpCloudCredentials JumpCloud authentication information.
|
||||
type JumpCloudCredentials struct {
|
||||
clientConfig JumpCloudClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// NewJumpCloudManager creates a new instance of the JumpCloudManager.
|
||||
func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppMetrics) (*JumpCloudManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.APIToken == "" {
|
||||
return nil, fmt.Errorf("jumpCloud IdP configuration is incomplete, ApiToken is missing")
|
||||
}
|
||||
|
||||
client := v1.NewAPIClient(v1.NewConfiguration())
|
||||
credentials := &JumpCloudCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &JumpCloudManager{
|
||||
client: client,
|
||||
apiToken: config.APIToken,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the JumpCloud user API.
|
||||
func (jc *JumpCloudCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return JWTToken{}, nil
|
||||
}
|
||||
|
||||
func (jm *JumpCloudManager) authenticationContext() context.Context {
|
||||
return context.WithValue(context.Background(), v1.ContextAPIKey, v1.APIKey{
|
||||
Key: jm.apiToken,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (jm *JumpCloudManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from JumpCloud via ID.
|
||||
func (jm *JumpCloudManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
authCtx := jm.authenticationContext()
|
||||
user, resp, err := jm.client.SystemusersApi.SystemusersGet(authCtx, userID, contentType, accept, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
userData := parseJumpCloudUser(user)
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (jm *JumpCloudManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
|
||||
authCtx := jm.authenticationContext()
|
||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get account %s users, statusCode %d", accountID, resp.StatusCode)
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, user := range userList.Results {
|
||||
userData := parseJumpCloudUser(user)
|
||||
userData.AppMetadata.WTAccountID = accountID
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (jm *JumpCloudManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
|
||||
authCtx := jm.authenticationContext()
|
||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, user := range userList.Results {
|
||||
userData := parseJumpCloudUser(user)
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in JumpCloud Idp and sends an invitation.
|
||||
func (jm *JumpCloudManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (jm *JumpCloudManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
|
||||
searchFilter := map[string]interface{}{
|
||||
"searchFilter": map[string]interface{}{
|
||||
"filter": []string{email},
|
||||
"fields": []string{"email"},
|
||||
},
|
||||
}
|
||||
|
||||
authCtx := jm.authenticationContext()
|
||||
userList, resp, err := jm.client.SearchApi.SearchSystemusersPost(authCtx, contentType, accept, searchFilter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode)
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
usersData := make([]*UserData, 0)
|
||||
for _, user := range userList.Results {
|
||||
usersData = append(usersData, parseJumpCloudUser(user))
|
||||
}
|
||||
|
||||
return usersData, nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (jm *JumpCloudManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from jumpCloud directory
|
||||
func (jm *JumpCloudManager) DeleteUser(_ context.Context, userID string) error {
|
||||
authCtx := jm.authenticationContext()
|
||||
_, resp, err := jm.client.SystemusersApi.SystemusersDelete(authCtx, userID, contentType, accept, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if jm.appMetrics != nil {
|
||||
jm.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseJumpCloudUser parse JumpCloud system user returned from API V1 to UserData.
|
||||
func parseJumpCloudUser(user v1.Systemuserreturn) *UserData {
|
||||
names := []string{user.Firstname, user.Middlename, user.Lastname}
|
||||
return &UserData{
|
||||
Email: user.Email,
|
||||
Name: strings.Join(names, " "),
|
||||
ID: user.Id,
|
||||
}
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
func TestNewJumpCloudManager(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputConfig JumpCloudClientConfig
|
||||
assertErrFunc require.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
defaultTestConfig := JumpCloudClientConfig{
|
||||
APIToken: "test123",
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "Good Configuration",
|
||||
inputConfig: defaultTestConfig,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
testCase2Config := defaultTestConfig
|
||||
testCase2Config.APIToken = ""
|
||||
|
||||
testCase2 := test{
|
||||
name: "Missing APIToken Configuration",
|
||||
inputConfig: testCase2Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
_, err := NewJumpCloudManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,439 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
// KeycloakManager keycloak manager client instance.
|
||||
type KeycloakManager struct {
|
||||
adminEndpoint string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// KeycloakClientConfig keycloak manager client configurations.
|
||||
type KeycloakClientConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
AdminEndpoint string
|
||||
TokenEndpoint string
|
||||
GrantType string
|
||||
}
|
||||
|
||||
// KeycloakCredentials keycloak authentication information.
|
||||
type KeycloakCredentials struct {
|
||||
clientConfig KeycloakClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
jwtToken JWTToken
|
||||
mux sync.Mutex
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// keycloakUserAttributes holds additional user data fields.
|
||||
type keycloakUserAttributes map[string][]string
|
||||
|
||||
// keycloakProfile represents a keycloak user profile response.
|
||||
type keycloakProfile struct {
|
||||
ID string `json:"id"`
|
||||
CreatedTimestamp int64 `json:"createdTimestamp"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
Attributes keycloakUserAttributes `json:"attributes"`
|
||||
}
|
||||
|
||||
// NewKeycloakManager creates a new instance of the KeycloakManager.
|
||||
func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMetrics) (*KeycloakManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.ClientID == "" {
|
||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, clientID is missing")
|
||||
}
|
||||
|
||||
if config.ClientSecret == "" {
|
||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, ClientSecret is missing")
|
||||
}
|
||||
|
||||
if config.TokenEndpoint == "" {
|
||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, TokenEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.AdminEndpoint == "" {
|
||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, AdminEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.GrantType == "" {
|
||||
return nil, fmt.Errorf("keycloak IdP configuration is incomplete, GrantType is missing")
|
||||
}
|
||||
|
||||
credentials := &KeycloakCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &KeycloakManager{
|
||||
adminEndpoint: config.AdminEndpoint,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from keycloak.
|
||||
func (kc *KeycloakCredentials) jwtStillValid() bool {
|
||||
return !kc.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(kc.jwtToken.expiresInTime)
|
||||
}
|
||||
|
||||
// requestJWTToken performs request to get jwt token.
|
||||
func (kc *KeycloakCredentials) requestJWTToken(ctx context.Context) (*http.Response, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", kc.clientConfig.ClientID)
|
||||
data.Set("client_secret", kc.clientConfig.ClientSecret)
|
||||
data.Set("grant_type", kc.clientConfig.GrantType)
|
||||
|
||||
payload := strings.NewReader(data.Encode())
|
||||
req, err := http.NewRequest(http.MethodPost, kc.clientConfig.TokenEndpoint, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("content-type", "application/x-www-form-urlencoded")
|
||||
|
||||
log.WithContext(ctx).Debug("requesting new jwt token for keycloak idp manager")
|
||||
|
||||
resp, err := kc.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if kc.appMetrics != nil {
|
||||
kc.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unable to get keycloak token, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
|
||||
func (kc *KeycloakCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
|
||||
jwtToken := JWTToken{}
|
||||
body, err := io.ReadAll(rawBody)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
err = kc.helper.Unmarshal(body, &jwtToken)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" {
|
||||
return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken)
|
||||
}
|
||||
|
||||
data, err := base64.RawURLEncoding.DecodeString(strings.Split(jwtToken.AccessToken, ".")[1])
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
|
||||
// Exp maps into exp from jwt token
|
||||
var IssuedAt struct{ Exp int64 }
|
||||
err = kc.helper.Unmarshal(data, &IssuedAt)
|
||||
if err != nil {
|
||||
return jwtToken, err
|
||||
}
|
||||
jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0)
|
||||
|
||||
return jwtToken, nil
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the keycloak Management API.
|
||||
func (kc *KeycloakCredentials) Authenticate(ctx context.Context) (JWTToken, error) {
|
||||
kc.mux.Lock()
|
||||
defer kc.mux.Unlock()
|
||||
|
||||
if kc.appMetrics != nil {
|
||||
kc.appMetrics.IDPMetrics().CountAuthenticate()
|
||||
}
|
||||
|
||||
// reuse the token without requesting a new one if it is not expired,
|
||||
// and if expiry time is sufficient time available to make a request.
|
||||
if kc.jwtStillValid() {
|
||||
return kc.jwtToken, nil
|
||||
}
|
||||
|
||||
resp, err := kc.requestJWTToken(ctx)
|
||||
if err != nil {
|
||||
return kc.jwtToken, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
jwtToken, err := kc.parseRequestJWTResponse(resp.Body)
|
||||
if err != nil {
|
||||
return kc.jwtToken, err
|
||||
}
|
||||
|
||||
kc.jwtToken = jwtToken
|
||||
|
||||
return kc.jwtToken, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in keycloak Idp and sends an invite.
|
||||
func (km *KeycloakManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (km *KeycloakManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
q := url.Values{}
|
||||
q.Add("email", email)
|
||||
q.Add("exact", "true")
|
||||
|
||||
body, err := km.get(ctx, "users", q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
profiles := make([]keycloakProfile, 0)
|
||||
err = km.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range profiles {
|
||||
users = append(users, profile.userData())
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from keycloak via ID.
|
||||
func (km *KeycloakManager) GetUserDataByID(ctx context.Context, userID string, _ AppMetadata) (*UserData, error) {
|
||||
body, err := km.get(ctx, "users/"+userID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
var profile keycloakProfile
|
||||
err = km.helper.Unmarshal(body, &profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return profile.userData(), nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given account profile.
|
||||
func (km *KeycloakManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
profiles, err := km.fetchAllUserProfiles(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range profiles {
|
||||
userData := profile.userData()
|
||||
userData.AppMetadata.WTAccountID = accountID
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (km *KeycloakManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
profiles, err := km.fetchAllUserProfiles(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, profile := range profiles {
|
||||
userData := profile.userData()
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (km *KeycloakManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (km *KeycloakManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Keycloak by user ID.
|
||||
func (km *KeycloakManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
jwtToken, err := km.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, url.QueryEscape(userID))
|
||||
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
resp, err := km.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close() // nolint
|
||||
|
||||
// In the docs, they specified 200, but in the endpoints, they return 204
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (km *KeycloakManager) fetchAllUserProfiles(ctx context.Context) ([]keycloakProfile, error) {
|
||||
totalUsers, err := km.totalUsersCount(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := url.Values{}
|
||||
q.Add("max", fmt.Sprint(*totalUsers))
|
||||
|
||||
body, err := km.get(ctx, "users", q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
profiles := make([]keycloakProfile, 0)
|
||||
err = km.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
// get perform Get requests.
|
||||
func (km *KeycloakManager) get(ctx context.Context, resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := km.credentials.Authenticate(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/%s?%s", km.adminEndpoint, resource, q.Encode())
|
||||
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
resp, err := km.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// totalUsersCount returns the total count of all user created.
|
||||
// Used when fetching all registered accounts with pagination.
|
||||
func (km *KeycloakManager) totalUsersCount(ctx context.Context) (*int, error) {
|
||||
body, err := km.get(ctx, "users/count", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
count, err := strconv.Atoi(string(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &count, nil
|
||||
}
|
||||
|
||||
// userData construct user data from keycloak profile.
|
||||
func (kp keycloakProfile) userData() *UserData {
|
||||
return &UserData{
|
||||
Email: kp.Email,
|
||||
Name: kp.Username,
|
||||
ID: kp.ID,
|
||||
}
|
||||
}
|
||||
@@ -1,310 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
func TestNewKeycloakManager(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputConfig KeycloakClientConfig
|
||||
assertErrFunc require.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
defaultTestConfig := KeycloakClientConfig{
|
||||
ClientID: "client_id",
|
||||
ClientSecret: "client_secret",
|
||||
AdminEndpoint: "https://localhost:8080/auth/admin/realms/test123",
|
||||
TokenEndpoint: "https://localhost:8080/auth/realms/test123/protocol/openid-connect/token",
|
||||
GrantType: "client_credentials",
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "Good Configuration",
|
||||
inputConfig: defaultTestConfig,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
}
|
||||
|
||||
testCase2Config := defaultTestConfig
|
||||
testCase2Config.ClientID = ""
|
||||
|
||||
testCase2 := test{
|
||||
name: "Missing ClientID Configuration",
|
||||
inputConfig: testCase2Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase3Config := defaultTestConfig
|
||||
testCase3Config.ClientSecret = ""
|
||||
|
||||
testCase3 := test{
|
||||
name: "Missing ClientSecret Configuration",
|
||||
inputConfig: testCase3Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase4Config := defaultTestConfig
|
||||
testCase4Config.TokenEndpoint = ""
|
||||
|
||||
testCase4 := test{
|
||||
name: "Missing TokenEndpoint Configuration",
|
||||
inputConfig: testCase3Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
testCase5Config := defaultTestConfig
|
||||
testCase5Config.GrantType = ""
|
||||
|
||||
testCase5 := test{
|
||||
name: "Missing GrantType Configuration",
|
||||
inputConfig: testCase3Config,
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
_, err := NewKeycloakManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeycloakRequestJWTToken(t *testing.T) {
|
||||
|
||||
type requestJWTTokenTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputRespBody string
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
requestJWTTokenTesttCase1 := requestJWTTokenTest{
|
||||
name: "Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedToken: token,
|
||||
}
|
||||
requestJWTTokenTestCase2 := requestJWTTokenTest{
|
||||
name: "Request Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputRespBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get keycloak token, statusCode 400"),
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputRespBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := KeycloakClientConfig{}
|
||||
|
||||
creds := KeycloakCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
|
||||
resp, err := creds.requestJWTToken(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
assert.NoError(t, err, "unable to read the response body")
|
||||
|
||||
jwtToken := JWTToken{}
|
||||
err = testCase.helper.Unmarshal(body, &jwtToken)
|
||||
assert.NoError(t, err, "unable to parse the json input")
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeycloakParseRequestJWTResponse(t *testing.T) {
|
||||
type parseRequestJWTResponseTest struct {
|
||||
name string
|
||||
inputRespBody string
|
||||
helper ManagerHelper
|
||||
expectedToken string
|
||||
expectedExpiresIn int
|
||||
assertErrFunc assert.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
exp := 100
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{
|
||||
name: "Parse Good JWT Body",
|
||||
inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedToken: token,
|
||||
expectedExpiresIn: exp,
|
||||
assertErrFunc: assert.NoError,
|
||||
assertErrFuncMessage: "no error was expected",
|
||||
}
|
||||
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
|
||||
name: "Parse Bad json JWT Body",
|
||||
inputRespBody: "",
|
||||
helper: JsonParser{},
|
||||
expectedToken: "",
|
||||
expectedExpiresIn: 0,
|
||||
assertErrFunc: assert.Error,
|
||||
assertErrFuncMessage: "json error was expected",
|
||||
}
|
||||
|
||||
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
rawBody := io.NopCloser(strings.NewReader(testCase.inputRespBody))
|
||||
config := KeycloakClientConfig{}
|
||||
|
||||
creds := KeycloakCredentials{
|
||||
clientConfig: config,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
jwtToken, err := creds.parseRequestJWTResponse(rawBody)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same")
|
||||
assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeycloakJwtStillValid(t *testing.T) {
|
||||
type jwtStillValidTest struct {
|
||||
name string
|
||||
inputTime time.Time
|
||||
expectedResult bool
|
||||
message string
|
||||
}
|
||||
|
||||
jwtStillValidTestCase1 := jwtStillValidTest{
|
||||
name: "JWT still valid",
|
||||
inputTime: time.Now().Add(10 * time.Second),
|
||||
expectedResult: true,
|
||||
message: "should be true",
|
||||
}
|
||||
jwtStillValidTestCase2 := jwtStillValidTest{
|
||||
name: "JWT is invalid",
|
||||
inputTime: time.Now(),
|
||||
expectedResult: false,
|
||||
message: "should be false",
|
||||
}
|
||||
|
||||
for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
config := KeycloakClientConfig{}
|
||||
|
||||
creds := KeycloakCredentials{
|
||||
clientConfig: config,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputTime
|
||||
|
||||
assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeycloakAuthenticate(t *testing.T) {
|
||||
type authenticateTest struct {
|
||||
name string
|
||||
inputCode int
|
||||
inputResBody string
|
||||
inputExpireToken time.Time
|
||||
helper ManagerHelper
|
||||
expectedFuncExitErrDiff error
|
||||
expectedCode int
|
||||
expectedToken string
|
||||
}
|
||||
exp := 5
|
||||
token := newTestJWT(t, exp)
|
||||
|
||||
authenticateTestCase1 := authenticateTest{
|
||||
name: "Get Cached token",
|
||||
inputExpireToken: time.Now().Add(30 * time.Second),
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: nil,
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
authenticateTestCase2 := authenticateTest{
|
||||
name: "Get Good JWT Response",
|
||||
inputCode: 200,
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
|
||||
helper: JsonParser{},
|
||||
expectedCode: 200,
|
||||
expectedToken: token,
|
||||
}
|
||||
|
||||
authenticateTestCase3 := authenticateTest{
|
||||
name: "Get Bad Status Code",
|
||||
inputCode: 400,
|
||||
inputResBody: "{}",
|
||||
helper: JsonParser{},
|
||||
expectedFuncExitErrDiff: fmt.Errorf("unable to get keycloak token, statusCode 400"),
|
||||
expectedCode: 200,
|
||||
expectedToken: "",
|
||||
}
|
||||
|
||||
for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
jwtReqClient := mockHTTPClient{
|
||||
resBody: testCase.inputResBody,
|
||||
code: testCase.inputCode,
|
||||
}
|
||||
config := KeycloakClientConfig{}
|
||||
|
||||
creds := KeycloakCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: &jwtReqClient,
|
||||
helper: testCase.helper,
|
||||
}
|
||||
creds.jwtToken.expiresInTime = testCase.inputExpireToken
|
||||
|
||||
_, err := creds.Authenticate(context.Background())
|
||||
if err != nil {
|
||||
if testCase.expectedFuncExitErrDiff != nil {
|
||||
assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same")
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,52 +4,26 @@ import "context"
|
||||
|
||||
// MockIDP is a mock implementation of the IDP interface
|
||||
type MockIDP struct {
|
||||
UpdateUserAppMetadataFunc func(ctx context.Context, userId string, appMetadata AppMetadata) error
|
||||
GetUserDataByIDFunc func(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error)
|
||||
GetAccountFunc func(ctx context.Context, accountId string) ([]*UserData, error)
|
||||
GetAllAccountsFunc func(ctx context.Context) (map[string][]*UserData, error)
|
||||
CreateUserFunc func(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error)
|
||||
GetUserByEmailFunc func(ctx context.Context, email string) ([]*UserData, error)
|
||||
InviteUserByIDFunc func(ctx context.Context, userID string) error
|
||||
DeleteUserFunc func(ctx context.Context, userID string) error
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata is a mock implementation of the IDP interface UpdateUserAppMetadata method
|
||||
func (m *MockIDP) UpdateUserAppMetadata(ctx context.Context, userId string, appMetadata AppMetadata) error {
|
||||
if m.UpdateUserAppMetadataFunc != nil {
|
||||
return m.UpdateUserAppMetadataFunc(ctx, userId, appMetadata)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserDataByID is a mock implementation of the IDP interface GetUserDataByID method
|
||||
func (m *MockIDP) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) {
|
||||
if m.GetUserDataByIDFunc != nil {
|
||||
return m.GetUserDataByIDFunc(ctx, userId, appMetadata)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetAccount is a mock implementation of the IDP interface GetAccount method
|
||||
func (m *MockIDP) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) {
|
||||
if m.GetAccountFunc != nil {
|
||||
return m.GetAccountFunc(ctx, accountId)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts is a mock implementation of the IDP interface GetAllAccounts method
|
||||
func (m *MockIDP) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
if m.GetAllAccountsFunc != nil {
|
||||
return m.GetAllAccountsFunc(ctx)
|
||||
}
|
||||
return nil, nil
|
||||
CreateUserFunc func(ctx context.Context, email, name string) (*UserData, error)
|
||||
GetUserDataByIDFunc func(ctx context.Context, userId string) (*UserData, error)
|
||||
GetUserByEmailFunc func(ctx context.Context, email string) ([]*UserData, error)
|
||||
GetAllUsersFunc func(ctx context.Context) ([]*UserData, error)
|
||||
InviteUserByIDFunc func(ctx context.Context, userID string) error
|
||||
DeleteUserFunc func(ctx context.Context, userID string) error
|
||||
}
|
||||
|
||||
// CreateUser is a mock implementation of the IDP interface CreateUser method
|
||||
func (m *MockIDP) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
func (m *MockIDP) CreateUser(ctx context.Context, email, name string) (*UserData, error) {
|
||||
if m.CreateUserFunc != nil {
|
||||
return m.CreateUserFunc(ctx, email, name, accountID, invitedByEmail)
|
||||
return m.CreateUserFunc(ctx, email, name)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetUserDataByID is a mock implementation of the IDP interface GetUserDataByID method
|
||||
func (m *MockIDP) GetUserDataByID(ctx context.Context, userId string) (*UserData, error) {
|
||||
if m.GetUserDataByIDFunc != nil {
|
||||
return m.GetUserDataByIDFunc(ctx, userId)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
@@ -62,6 +36,14 @@ func (m *MockIDP) GetUserByEmail(ctx context.Context, email string) ([]*UserData
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetAllUsers is a mock implementation of the IDP interface GetAllUsers method
|
||||
func (m *MockIDP) GetAllUsers(ctx context.Context) ([]*UserData, error) {
|
||||
if m.GetAllUsersFunc != nil {
|
||||
return m.GetAllUsersFunc(ctx)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// InviteUserByID is a mock implementation of the IDP interface InviteUserByID method
|
||||
func (m *MockIDP) InviteUserByID(ctx context.Context, userID string) error {
|
||||
if m.InviteUserByIDFunc != nil {
|
||||
|
||||
@@ -1,306 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/okta/okta-sdk-golang/v2/okta"
|
||||
"github.com/okta/okta-sdk-golang/v2/okta/query"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
// OktaManager okta manager client instance.
|
||||
type OktaManager struct {
|
||||
client *okta.Client
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// OktaClientConfig okta manager client configurations.
|
||||
type OktaClientConfig struct {
|
||||
APIToken string
|
||||
Issuer string
|
||||
TokenEndpoint string
|
||||
GrantType string
|
||||
}
|
||||
|
||||
// OktaCredentials okta authentication information.
|
||||
type OktaCredentials struct {
|
||||
clientConfig OktaClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// NewOktaManager creates a new instance of the OktaManager.
|
||||
func NewOktaManager(config OktaClientConfig, appMetrics telemetry.AppMetrics) (*OktaManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
|
||||
helper := JsonParser{}
|
||||
config.Issuer = baseURL(config.Issuer)
|
||||
|
||||
if config.APIToken == "" {
|
||||
return nil, fmt.Errorf("okta IdP configuration is incomplete, APIToken is missing")
|
||||
}
|
||||
|
||||
if config.Issuer == "" {
|
||||
return nil, fmt.Errorf("okta IdP configuration is incomplete, Issuer is missing")
|
||||
}
|
||||
|
||||
if config.TokenEndpoint == "" {
|
||||
return nil, fmt.Errorf("okta IdP configuration is incomplete, TokenEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.GrantType == "" {
|
||||
return nil, fmt.Errorf("okta IdP configuration is incomplete, GrantType is missing")
|
||||
}
|
||||
|
||||
_, client, err := okta.NewClient(context.Background(),
|
||||
okta.WithOrgUrl(config.Issuer),
|
||||
okta.WithToken(config.APIToken),
|
||||
okta.WithHttpClientPtr(httpClient),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
credentials := &OktaCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &OktaManager{
|
||||
client: client,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Authenticate retrieves access token to use the okta user API.
|
||||
func (oc *OktaCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return JWTToken{}, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in okta Idp and sends an invitation.
|
||||
func (om *OktaManager) CreateUser(_ context.Context, _, _, _, _ string) (*UserData, error) {
|
||||
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from keycloak via ID.
|
||||
func (om *OktaManager) GetUserDataByID(_ context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
user, resp, err := om.client.User.GetUser(context.Background(), userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", userID, resp.StatusCode)
|
||||
}
|
||||
|
||||
userData, err := parseOktaUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
// If no users have been found, this function returns an empty list.
|
||||
func (om *OktaManager) GetUserByEmail(_ context.Context, email string) ([]*UserData, error) {
|
||||
user, resp, err := om.client.User.GetUser(context.Background(), url.QueryEscape(email))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get user %s, statusCode %d", email, resp.StatusCode)
|
||||
}
|
||||
|
||||
userData, err := parseOktaUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users := make([]*UserData, 0)
|
||||
users = append(users, userData)
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (om *OktaManager) GetAccount(_ context.Context, accountID string) ([]*UserData, error) {
|
||||
users, err := om.getAllUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
for index, user := range users {
|
||||
user.AppMetadata.WTAccountID = accountID
|
||||
users[index] = user
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (om *OktaManager) GetAllAccounts(_ context.Context) (map[string][]*UserData, error) {
|
||||
users, err := om.getAllUsers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], users...)
|
||||
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// getAllUsers returns all users in an Okta account.
|
||||
func (om *OktaManager) getAllUsers() ([]*UserData, error) {
|
||||
qp := query.NewQueryParams(query.WithLimit(200))
|
||||
userList, resp, err := om.client.User.ListUsers(context.Background(), qp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
for resp.HasNextPage() {
|
||||
paginatedUsers := make([]*okta.User, 0)
|
||||
resp, err = resp.Next(context.Background(), &paginatedUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return nil, fmt.Errorf("unable to get all accounts, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
userList = append(userList, paginatedUsers...)
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0, len(userList))
|
||||
for _, user := range userList {
|
||||
userData, err := parseOktaUser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (om *OktaManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (om *OktaManager) InviteUserByID(_ context.Context, _ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Okta
|
||||
func (om *OktaManager) DeleteUser(_ context.Context, userID string) error {
|
||||
resp, err := om.client.User.DeactivateOrDeleteUser(context.Background(), userID, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseOktaUser parse okta user to UserData.
|
||||
func parseOktaUser(user *okta.User) (*UserData, error) {
|
||||
var oktaUser struct {
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"firstName"`
|
||||
LastName string `json:"lastName"`
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return nil, fmt.Errorf("invalid okta user")
|
||||
}
|
||||
|
||||
if user.Profile != nil {
|
||||
helper := JsonParser{}
|
||||
buf, err := helper.Marshal(*user.Profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = helper.Unmarshal(buf, &oktaUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &UserData{
|
||||
Email: oktaUser.Email,
|
||||
Name: strings.Join([]string{oktaUser.FirstName, oktaUser.LastName}, " "),
|
||||
ID: user.Id,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/okta/okta-sdk-golang/v2/okta"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseOktaUser(t *testing.T) {
|
||||
type parseOktaUserTest struct {
|
||||
name string
|
||||
inputProfile *okta.User
|
||||
expectedUserData *UserData
|
||||
assertErrFunc assert.ErrorAssertionFunc
|
||||
}
|
||||
|
||||
parseOktaTestCase1 := parseOktaUserTest{
|
||||
name: "Good Request",
|
||||
inputProfile: &okta.User{
|
||||
Id: "123",
|
||||
Profile: &okta.UserProfile{
|
||||
"email": "test@example.com",
|
||||
"firstName": "John",
|
||||
"lastName": "Doe",
|
||||
},
|
||||
},
|
||||
expectedUserData: &UserData{
|
||||
Email: "test@example.com",
|
||||
Name: "John Doe",
|
||||
ID: "123",
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: "456",
|
||||
},
|
||||
},
|
||||
assertErrFunc: assert.NoError,
|
||||
}
|
||||
|
||||
parseOktaTestCase2 := parseOktaUserTest{
|
||||
name: "Invalid okta user",
|
||||
inputProfile: nil,
|
||||
expectedUserData: nil,
|
||||
assertErrFunc: assert.Error,
|
||||
}
|
||||
|
||||
for _, testCase := range []parseOktaUserTest{parseOktaTestCase1, parseOktaTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
userData, err := parseOktaUser(testCase.inputProfile)
|
||||
testCase.assertErrFunc(t, err, testCase.assertErrFunc)
|
||||
|
||||
if err == nil {
|
||||
assert.True(t, userDataEqual(testCase.expectedUserData, userData), "user data should match")
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// userDataEqual helper function to compare UserData structs for equality.
|
||||
func userDataEqual(a, b *UserData) bool {
|
||||
if a.Email != b.Email || a.Name != b.Name || a.ID != b.ID {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -1,384 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
type PocketIdManager struct {
|
||||
managementEndpoint string
|
||||
apiToken string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
type pocketIdCustomClaimDto struct {
|
||||
Key string `json:"key"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
type pocketIdUserDto struct {
|
||||
CustomClaims []pocketIdCustomClaimDto `json:"customClaims"`
|
||||
Disabled bool `json:"disabled"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"firstName"`
|
||||
ID string `json:"id"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
LastName string `json:"lastName"`
|
||||
LdapID string `json:"ldapId"`
|
||||
Locale string `json:"locale"`
|
||||
UserGroups []pocketIdUserGroupDto `json:"userGroups"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
type pocketIdUserCreateDto struct {
|
||||
Disabled bool `json:"disabled,omitempty"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"firstName"`
|
||||
IsAdmin bool `json:"isAdmin,omitempty"`
|
||||
LastName string `json:"lastName,omitempty"`
|
||||
Locale string `json:"locale,omitempty"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
type pocketIdPaginatedUserDto struct {
|
||||
Data []pocketIdUserDto `json:"data"`
|
||||
Pagination pocketIdPaginationDto `json:"pagination"`
|
||||
}
|
||||
|
||||
type pocketIdPaginationDto struct {
|
||||
CurrentPage int `json:"currentPage"`
|
||||
ItemsPerPage int `json:"itemsPerPage"`
|
||||
TotalItems int `json:"totalItems"`
|
||||
TotalPages int `json:"totalPages"`
|
||||
}
|
||||
|
||||
func (p *pocketIdUserDto) userData() *UserData {
|
||||
return &UserData{
|
||||
Email: p.Email,
|
||||
Name: p.DisplayName,
|
||||
ID: p.ID,
|
||||
AppMetadata: AppMetadata{},
|
||||
}
|
||||
}
|
||||
|
||||
type pocketIdUserGroupDto struct {
|
||||
CreatedAt string `json:"createdAt"`
|
||||
CustomClaims []pocketIdCustomClaimDto `json:"customClaims"`
|
||||
FriendlyName string `json:"friendlyName"`
|
||||
ID string `json:"id"`
|
||||
LdapID string `json:"ldapId"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMetrics) (*PocketIdManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.ManagementEndpoint == "" {
|
||||
return nil, fmt.Errorf("pocketId IdP configuration is incomplete, ManagementEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.APIToken == "" {
|
||||
return nil, fmt.Errorf("pocketId IdP configuration is incomplete, APIToken is missing")
|
||||
}
|
||||
|
||||
credentials := &PocketIdCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &PocketIdManager{
|
||||
managementEndpoint: config.ManagementEndpoint,
|
||||
apiToken: config.APIToken,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) request(ctx context.Context, method, resource string, query *url.Values, body string) ([]byte, error) {
|
||||
var MethodsWithBody = []string{http.MethodPost, http.MethodPut}
|
||||
if !slices.Contains(MethodsWithBody, method) && body != "" {
|
||||
return nil, fmt.Errorf("Body provided to unsupported method: %s", method)
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/api/%s", p.managementEndpoint, resource)
|
||||
if query != nil {
|
||||
reqURL = fmt.Sprintf("%s?%s", reqURL, query.Encode())
|
||||
}
|
||||
var req *http.Request
|
||||
var err error
|
||||
if body != "" {
|
||||
req, err = http.NewRequestWithContext(ctx, method, reqURL, strings.NewReader(body))
|
||||
} else {
|
||||
req, err = http.NewRequestWithContext(ctx, method, reqURL, nil)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Add("X-API-KEY", p.apiToken)
|
||||
|
||||
if body != "" {
|
||||
req.Header.Add("content-type", "application/json")
|
||||
req.Header.Add("content-length", fmt.Sprintf("%d", req.ContentLength))
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("received unexpected status code from PocketID API: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// getAllUsersPaginated fetches all users from PocketID API using pagination
|
||||
func (p *PocketIdManager) getAllUsersPaginated(ctx context.Context, searchParams url.Values) ([]pocketIdUserDto, error) {
|
||||
var allUsers []pocketIdUserDto
|
||||
currentPage := 1
|
||||
|
||||
for {
|
||||
params := url.Values{}
|
||||
// Copy existing search parameters
|
||||
for key, values := range searchParams {
|
||||
params[key] = values
|
||||
}
|
||||
|
||||
params.Set("pagination[limit]", "100")
|
||||
params.Set("pagination[page]", fmt.Sprintf("%d", currentPage))
|
||||
|
||||
body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var profiles pocketIdPaginatedUserDto
|
||||
err = p.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allUsers = append(allUsers, profiles.Data...)
|
||||
|
||||
// Check if we've reached the last page
|
||||
if currentPage >= profiles.Pagination.TotalPages {
|
||||
break
|
||||
}
|
||||
|
||||
currentPage++
|
||||
}
|
||||
|
||||
return allUsers, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) {
|
||||
body, err := p.request(ctx, http.MethodGet, "users/"+userId, nil, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
var user pocketIdUserDto
|
||||
err = p.helper.Unmarshal(body, &user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userData := user.userData()
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) {
|
||||
// Get all users using pagination
|
||||
allUsers, err := p.getAllUsersPaginated(ctx, url.Values{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range allUsers {
|
||||
userData := profile.userData()
|
||||
userData.AppMetadata.WTAccountID = accountId
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
// Get all users using pagination
|
||||
allUsers, err := p.getAllUsersPaginated(ctx, url.Values{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, profile := range allUsers {
|
||||
userData := profile.userData()
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
firstLast := strings.Split(name, " ")
|
||||
|
||||
createUser := pocketIdUserCreateDto{
|
||||
Disabled: false,
|
||||
DisplayName: name,
|
||||
Email: email,
|
||||
FirstName: firstLast[0],
|
||||
LastName: firstLast[1],
|
||||
Username: firstLast[0] + "." + firstLast[1],
|
||||
}
|
||||
payload, err := p.helper.Marshal(createUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := p.request(ctx, http.MethodPost, "users", nil, string(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var newUser pocketIdUserDto
|
||||
err = p.helper.Unmarshal(body, &newUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
var pending bool = true
|
||||
ret := &UserData{
|
||||
Email: email,
|
||||
Name: name,
|
||||
ID: newUser.ID,
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
WTPendingInvite: &pending,
|
||||
WTInvitedBy: invitedByEmail,
|
||||
},
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
params := url.Values{
|
||||
// This value a
|
||||
"search": []string{email},
|
||||
}
|
||||
body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
var profiles struct{ data []pocketIdUserDto }
|
||||
err = p.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range profiles.data {
|
||||
users = append(users, profile.userData())
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) InviteUserByID(ctx context.Context, userID string) error {
|
||||
_, err := p.request(ctx, http.MethodPut, "users/"+userID+"/one-time-access-email", nil, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
_, err := p.request(ctx, http.MethodDelete, "users/"+userID, nil, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ Manager = (*PocketIdManager)(nil)
|
||||
|
||||
type PocketIdClientConfig struct {
|
||||
APIToken string
|
||||
ManagementEndpoint string
|
||||
}
|
||||
|
||||
type PocketIdCredentials struct {
|
||||
clientConfig PocketIdClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
var _ ManagerCredentials = (*PocketIdCredentials)(nil)
|
||||
|
||||
func (p PocketIdCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return JWTToken{}, nil
|
||||
}
|
||||
@@ -1,137 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
func TestNewPocketIdManager(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputConfig PocketIdClientConfig
|
||||
assertErrFunc require.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
defaultTestConfig := PocketIdClientConfig{
|
||||
APIToken: "api_token",
|
||||
ManagementEndpoint: "http://localhost",
|
||||
}
|
||||
|
||||
tests := []test{
|
||||
{
|
||||
name: "Good Configuration",
|
||||
inputConfig: defaultTestConfig,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
},
|
||||
{
|
||||
name: "Missing ManagementEndpoint",
|
||||
inputConfig: PocketIdClientConfig{
|
||||
APIToken: defaultTestConfig.APIToken,
|
||||
ManagementEndpoint: "",
|
||||
},
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
},
|
||||
{
|
||||
name: "Missing APIToken",
|
||||
inputConfig: PocketIdClientConfig{
|
||||
APIToken: "",
|
||||
ManagementEndpoint: defaultTestConfig.ManagementEndpoint,
|
||||
},
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{})
|
||||
tc.assertErrFunc(t, err, tc.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPocketID_GetUserDataByID(t *testing.T) {
|
||||
client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
md := AppMetadata{WTAccountID: "acc1"}
|
||||
got, err := mgr.GetUserDataByID(context.Background(), "u1", md)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "u1", got.ID)
|
||||
assert.Equal(t, "user1@example.com", got.Email)
|
||||
assert.Equal(t, "User One", got.Name)
|
||||
assert.Equal(t, "acc1", got.AppMetadata.WTAccountID)
|
||||
}
|
||||
|
||||
func TestPocketID_GetAccount_WithPagination(t *testing.T) {
|
||||
// Single page response with two users
|
||||
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
users, err := mgr.GetAccount(context.Background(), "accX")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, users, 2)
|
||||
assert.Equal(t, "u1", users[0].ID)
|
||||
assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID)
|
||||
assert.Equal(t, "u2", users[1].ID)
|
||||
}
|
||||
|
||||
func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) {
|
||||
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
accounts, err := mgr.GetAllAccounts(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, accounts[UnsetAccountID], 2)
|
||||
}
|
||||
|
||||
func TestPocketID_CreateUser(t *testing.T) {
|
||||
client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "newid", ud.ID)
|
||||
assert.Equal(t, "new@example.com", ud.Email)
|
||||
assert.Equal(t, "New User", ud.Name)
|
||||
assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID)
|
||||
if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) {
|
||||
assert.True(t, *ud.AppMetadata.WTPendingInvite)
|
||||
}
|
||||
assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy)
|
||||
}
|
||||
|
||||
func TestPocketID_InviteAndDeleteUser(t *testing.T) {
|
||||
// Same mock for both calls; returns OK with empty JSON
|
||||
client := &mockHTTPClient{code: 200, resBody: `{}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
err = mgr.InviteUserByID(context.Background(), "u1")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mgr.DeleteUser(context.Background(), "u1")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
47
management/server/idp/test_util_test.go
Normal file
47
management/server/idp/test_util_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockHTTPClient is a mock implementation of ManagerHTTPClient for testing
|
||||
type mockHTTPClient struct {
|
||||
code int
|
||||
resBody string
|
||||
reqBody string
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||
if c.err != nil {
|
||||
return nil, c.err
|
||||
}
|
||||
|
||||
if req.Body != nil {
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
c.reqBody = string(body)
|
||||
}
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: c.code,
|
||||
Body: io.NopCloser(bytes.NewReader([]byte(c.resBody))),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newTestJWT creates a test JWT token with the given expiration time in seconds
|
||||
func newTestJWT(t *testing.T, expiresIn int) string {
|
||||
t.Helper()
|
||||
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`))
|
||||
exp := time.Now().Add(time.Duration(expiresIn) * time.Second).Unix()
|
||||
payload := base64.RawURLEncoding.EncodeToString([]byte(fmt.Sprintf(`{"exp":%d}`, exp)))
|
||||
signature := base64.RawURLEncoding.EncodeToString([]byte("test-signature"))
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", header, payload, signature)
|
||||
}
|
||||
@@ -319,14 +319,19 @@ func (zc *ZitadelCredentials) Authenticate(ctx context.Context) (JWTToken, error
|
||||
}
|
||||
|
||||
// CreateUser creates a new user in zitadel Idp and sends an invite via Zitadel.
|
||||
func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
// Note: Authorization data (account membership, invite status) is stored in NetBird's DB, not in the IdP.
|
||||
func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name string) (*UserData, error) {
|
||||
firstLast := strings.SplitN(name, " ", 2)
|
||||
lastName := firstLast[0]
|
||||
if len(firstLast) > 1 {
|
||||
lastName = firstLast[1]
|
||||
}
|
||||
|
||||
var addUser = map[string]any{
|
||||
"userName": email,
|
||||
"profile": map[string]string{
|
||||
"firstName": firstLast[0],
|
||||
"lastName": firstLast[0],
|
||||
"lastName": lastName,
|
||||
"displayName": name,
|
||||
},
|
||||
"email": map[string]any{
|
||||
@@ -357,18 +362,11 @@ func (zm *ZitadelManager) CreateUser(ctx context.Context, email, name, accountID
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var pending bool = true
|
||||
ret := &UserData{
|
||||
return &UserData{
|
||||
Email: email,
|
||||
Name: name,
|
||||
ID: newUser.UserId,
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
WTPendingInvite: &pending,
|
||||
WTInvitedBy: invitedByEmail,
|
||||
},
|
||||
}
|
||||
return ret, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail searches users with a given email.
|
||||
@@ -413,7 +411,7 @@ func (zm *ZitadelManager) GetUserByEmail(ctx context.Context, email string) ([]*
|
||||
}
|
||||
|
||||
// GetUserDataByID requests user data from zitadel via ID.
|
||||
func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||
func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string) (*UserData, error) {
|
||||
body, err := zm.get(ctx, "users/"+userID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -429,43 +427,12 @@ func (zm *ZitadelManager) GetUserDataByID(ctx context.Context, userID string, ap
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userData := profile.User.userData()
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
return profile.User.userData(), nil
|
||||
}
|
||||
|
||||
// GetAccount returns all the users for a given profile.
|
||||
func (zm *ZitadelManager) GetAccount(ctx context.Context, accountID string) ([]*UserData, error) {
|
||||
body, err := zm.post(ctx, "users/_search", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if zm.appMetrics != nil {
|
||||
zm.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
var profiles struct{ Result []zitadelProfile }
|
||||
err = zm.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range profiles.Result {
|
||||
userData := profile.userData()
|
||||
userData.AppMetadata.WTAccountID = accountID
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
// It returns a list of users indexed by accountID.
|
||||
func (zm *ZitadelManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
// GetAllUsers returns all users from the IdP.
|
||||
// Used for cache warming - NetBird matches these against its own user database.
|
||||
func (zm *ZitadelManager) GetAllUsers(ctx context.Context) ([]*UserData, error) {
|
||||
body, err := zm.post(ctx, "users/_search", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -481,19 +448,12 @@ func (zm *ZitadelManager) GetAllAccounts(ctx context.Context) (map[string][]*Use
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
users := make([]*UserData, 0, len(profiles.Result))
|
||||
for _, profile := range profiles.Result {
|
||||
userData := profile.userData()
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
||||
users = append(users, profile.userData())
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
// Metadata values are base64 encoded.
|
||||
func (zm *ZitadelManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
return users, nil
|
||||
}
|
||||
|
||||
type inviteUserRequest struct {
|
||||
|
||||
@@ -288,16 +288,14 @@ func TestZitadelAuthenticate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestZitadelProfile(t *testing.T) {
|
||||
type azureProfileTest struct {
|
||||
type zitadelProfileTest struct {
|
||||
name string
|
||||
invite bool
|
||||
inputProfile zitadelProfile
|
||||
expectedUserData UserData
|
||||
}
|
||||
|
||||
azureProfileTestCase1 := azureProfileTest{
|
||||
name: "User Request",
|
||||
invite: false,
|
||||
zitadelProfileTestCase1 := zitadelProfileTest{
|
||||
name: "User Request",
|
||||
inputProfile: zitadelProfile{
|
||||
ID: "test1",
|
||||
State: "USER_STATE_ACTIVE",
|
||||
@@ -322,15 +320,11 @@ func TestZitadelProfile(t *testing.T) {
|
||||
ID: "test1",
|
||||
Name: "ZITADEL Admin",
|
||||
Email: "test1@mail.com",
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
azureProfileTestCase2 := azureProfileTest{
|
||||
name: "Service User Request",
|
||||
invite: true,
|
||||
zitadelProfileTestCase2 := zitadelProfileTest{
|
||||
name: "Service User Request",
|
||||
inputProfile: zitadelProfile{
|
||||
ID: "test2",
|
||||
State: "USER_STATE_ACTIVE",
|
||||
@@ -345,15 +339,11 @@ func TestZitadelProfile(t *testing.T) {
|
||||
ID: "test2",
|
||||
Name: "machine",
|
||||
Email: "machine",
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: "1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2} {
|
||||
for _, testCase := range []zitadelProfileTest{zitadelProfileTestCase1, zitadelProfileTestCase2} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
|
||||
userData := testCase.inputProfile.userData()
|
||||
|
||||
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
|
||||
|
||||
Reference in New Issue
Block a user