mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Refactor Azure IDP manager
This commit is contained in:
@@ -1,7 +1,6 @@
|
|||||||
package idp
|
package idp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -11,18 +10,12 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt"
|
"github.com/golang-jwt/jwt"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const profileFields = "id,displayName,mail,userPrincipalName"
|
||||||
// azure extension properties template
|
|
||||||
wtAccountIDTpl = "extension_%s_wt_account_id"
|
|
||||||
wtPendingInviteTpl = "extension_%s_wt_pending_invite"
|
|
||||||
|
|
||||||
profileFields = "id,displayName,mail,userPrincipalName"
|
|
||||||
extensionFields = "id,name,targetObjects"
|
|
||||||
)
|
|
||||||
|
|
||||||
// AzureManager azure manager client instance.
|
// AzureManager azure manager client instance.
|
||||||
type AzureManager struct {
|
type AzureManager struct {
|
||||||
@@ -58,21 +51,6 @@ type AzureCredentials struct {
|
|||||||
// azureProfile represents an azure user profile.
|
// azureProfile represents an azure user profile.
|
||||||
type azureProfile map[string]any
|
type azureProfile map[string]any
|
||||||
|
|
||||||
// passwordProfile represent authentication method for,
|
|
||||||
// newly created user profile.
|
|
||||||
type passwordProfile struct {
|
|
||||||
ForceChangePasswordNextSignIn bool `json:"forceChangePasswordNextSignIn"`
|
|
||||||
Password string `json:"password"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// azureExtension represent custom attribute,
|
|
||||||
// that can be added to user objects in Azure Active Directory (AD).
|
|
||||||
type azureExtension struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
DataType string `json:"dataType"`
|
|
||||||
TargetObjects []string `json:"targetObjects"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewAzureManager creates a new instance of the AzureManager.
|
// NewAzureManager creates a new instance of the AzureManager.
|
||||||
func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) (*AzureManager, error) {
|
func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) (*AzureManager, error) {
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
@@ -115,7 +93,7 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics)
|
|||||||
appMetrics: appMetrics,
|
appMetrics: appMetrics,
|
||||||
}
|
}
|
||||||
|
|
||||||
manager := &AzureManager{
|
return &AzureManager{
|
||||||
ObjectID: config.ObjectID,
|
ObjectID: config.ObjectID,
|
||||||
ClientID: config.ClientID,
|
ClientID: config.ClientID,
|
||||||
GraphAPIEndpoint: config.GraphAPIEndpoint,
|
GraphAPIEndpoint: config.GraphAPIEndpoint,
|
||||||
@@ -123,14 +101,7 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics)
|
|||||||
credentials: credentials,
|
credentials: credentials,
|
||||||
helper: helper,
|
helper: helper,
|
||||||
appMetrics: appMetrics,
|
appMetrics: appMetrics,
|
||||||
}
|
}, nil
|
||||||
|
|
||||||
err := manager.configureAppMetadata()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return manager, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from azure.
|
// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from azure.
|
||||||
@@ -236,44 +207,14 @@ func (ac *AzureCredentials) Authenticate() (JWTToken, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateUser creates a new user in azure AD Idp.
|
// CreateUser creates a new user in azure AD Idp.
|
||||||
func (am *AzureManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
|
func (am *AzureManager) CreateUser(_, _, _, _ string) (*UserData, error) {
|
||||||
payload, err := buildAzureCreateUserRequestPayload(email, name, accountID, am.ClientID)
|
return nil, fmt.Errorf("method CreateUser not implemented")
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := am.post("users", payload)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountCreateUser()
|
|
||||||
}
|
|
||||||
|
|
||||||
var profile azureProfile
|
|
||||||
err = am.helper.Unmarshal(body, &profile)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
|
|
||||||
profile[wtAccountIDField] = accountID
|
|
||||||
|
|
||||||
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
|
|
||||||
profile[wtPendingInviteField] = true
|
|
||||||
|
|
||||||
return profile.userData(am.ClientID), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserDataByID requests user data from keycloak via ID.
|
// GetUserDataByID requests user data from keycloak via ID.
|
||||||
func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) {
|
||||||
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
|
|
||||||
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
|
|
||||||
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
|
|
||||||
|
|
||||||
q := url.Values{}
|
q := url.Values{}
|
||||||
q.Add("$select", selectFields)
|
q.Add("$select", profileFields)
|
||||||
|
|
||||||
body, err := am.get("users/"+userID, q)
|
body, err := am.get("users/"+userID, q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -290,18 +231,17 @@ func (am *AzureManager) GetUserDataByID(userID string, appMetadata AppMetadata)
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return profile.userData(am.ClientID), nil
|
userData := profile.userData()
|
||||||
|
userData.AppMetadata = appMetadata
|
||||||
|
|
||||||
|
return userData, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserByEmail searches users with a given email.
|
// GetUserByEmail searches users with a given email.
|
||||||
// If no users have been found, this function returns an empty list.
|
// If no users have been found, this function returns an empty list.
|
||||||
func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
|
func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||||
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
|
|
||||||
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
|
|
||||||
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
|
|
||||||
|
|
||||||
q := url.Values{}
|
q := url.Values{}
|
||||||
q.Add("$select", selectFields)
|
q.Add("$select", profileFields)
|
||||||
|
|
||||||
body, err := am.get("users/"+email, q)
|
body, err := am.get("users/"+email, q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -319,20 +259,15 @@ func (am *AzureManager) GetUserByEmail(email string) ([]*UserData, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
users := make([]*UserData, 0)
|
users := make([]*UserData, 0)
|
||||||
users = append(users, profile.userData(am.ClientID))
|
users = append(users, profile.userData())
|
||||||
|
|
||||||
return users, nil
|
return users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccount returns all the users for a given profile.
|
// GetAccount returns all the users for a given profile.
|
||||||
func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
|
func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
|
||||||
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
|
|
||||||
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
|
|
||||||
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
|
|
||||||
|
|
||||||
q := url.Values{}
|
q := url.Values{}
|
||||||
q.Add("$select", selectFields)
|
q.Add("$select", profileFields)
|
||||||
q.Add("$filter", fmt.Sprintf("%s eq '%s'", wtAccountIDField, accountID))
|
|
||||||
|
|
||||||
body, err := am.get("users", q)
|
body, err := am.get("users", q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -351,7 +286,10 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
|
|||||||
|
|
||||||
users := make([]*UserData, 0)
|
users := make([]*UserData, 0)
|
||||||
for _, profile := range profiles.Value {
|
for _, profile := range profiles.Value {
|
||||||
users = append(users, profile.userData(am.ClientID))
|
userData := profile.userData()
|
||||||
|
userData.AppMetadata.WTAccountID = accountID
|
||||||
|
|
||||||
|
users = append(users, userData)
|
||||||
}
|
}
|
||||||
|
|
||||||
return users, nil
|
return users, nil
|
||||||
@@ -360,12 +298,8 @@ func (am *AzureManager) GetAccount(accountID string) ([]*UserData, error) {
|
|||||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||||
// It returns a list of users indexed by accountID.
|
// It returns a list of users indexed by accountID.
|
||||||
func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
|
func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||||
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
|
|
||||||
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
|
|
||||||
selectFields := strings.Join([]string{profileFields, wtAccountIDField, wtPendingInviteField}, ",")
|
|
||||||
|
|
||||||
q := url.Values{}
|
q := url.Values{}
|
||||||
q.Add("$select", selectFields)
|
q.Add("$select", profileFields)
|
||||||
|
|
||||||
body, err := am.get("users", q)
|
body, err := am.get("users", q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -384,67 +318,17 @@ func (am *AzureManager) GetAllAccounts() (map[string][]*UserData, error) {
|
|||||||
|
|
||||||
indexedUsers := make(map[string][]*UserData)
|
indexedUsers := make(map[string][]*UserData)
|
||||||
for _, profile := range profiles.Value {
|
for _, profile := range profiles.Value {
|
||||||
userData := profile.userData(am.ClientID)
|
userData := profile.userData()
|
||||||
|
|
||||||
accountID := userData.AppMetadata.WTAccountID
|
|
||||||
if accountID != "" {
|
|
||||||
if _, ok := indexedUsers[accountID]; !ok {
|
|
||||||
indexedUsers[accountID] = make([]*UserData, 0)
|
|
||||||
}
|
|
||||||
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
accountID := "unset"
|
||||||
|
indexedUsers[accountID] = append(indexedUsers[accountID], userData)
|
||||||
}
|
}
|
||||||
|
|
||||||
return indexedUsers, nil
|
return indexedUsers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateUserAppMetadata updates user app metadata based on userID.
|
// UpdateUserAppMetadata updates user app metadata based on userID.
|
||||||
func (am *AzureManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
|
func (am *AzureManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
|
||||||
jwtToken, err := am.credentials.Authenticate()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
|
|
||||||
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
|
|
||||||
|
|
||||||
data, err := am.helper.Marshal(map[string]any{
|
|
||||||
wtAccountIDField: appMetadata.WTAccountID,
|
|
||||||
wtPendingInviteField: appMetadata.WTPendingInvite,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
payload := strings.NewReader(string(data))
|
|
||||||
|
|
||||||
reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, userID)
|
|
||||||
req, err := http.NewRequest(http.MethodPatch, reqURL, payload)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
|
||||||
req.Header.Add("content-type", "application/json")
|
|
||||||
|
|
||||||
log.Debugf("updating idp metadata for user %s", userID)
|
|
||||||
|
|
||||||
resp, err := am.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountRequestError()
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if am.appMetrics != nil {
|
|
||||||
am.appMetrics.IDPMetrics().CountUpdateUserAppMetadata()
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusNoContent {
|
|
||||||
return fmt.Errorf("unable to update the appMetadata, statusCode %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -454,7 +338,7 @@ func (am *AzureManager) InviteUserByID(_ string) error {
|
|||||||
return fmt.Errorf("method InviteUserByID not implemented")
|
return fmt.Errorf("method InviteUserByID not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteUser from Azure
|
// DeleteUser from Azure.
|
||||||
func (am *AzureManager) DeleteUser(userID string) error {
|
func (am *AzureManager) DeleteUser(userID string) error {
|
||||||
jwtToken, err := am.credentials.Authenticate()
|
jwtToken, err := am.credentials.Authenticate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -491,81 +375,6 @@ func (am *AzureManager) DeleteUser(userID string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *AzureManager) getUserExtensions() ([]azureExtension, error) {
|
|
||||||
q := url.Values{}
|
|
||||||
q.Add("$select", extensionFields)
|
|
||||||
|
|
||||||
resource := fmt.Sprintf("applications/%s/extensionProperties", am.ObjectID)
|
|
||||||
body, err := am.get(resource, q)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var extensions struct{ Value []azureExtension }
|
|
||||||
err = am.helper.Unmarshal(body, &extensions)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return extensions.Value, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (am *AzureManager) createUserExtension(name string) (*azureExtension, error) {
|
|
||||||
extension := azureExtension{
|
|
||||||
Name: name,
|
|
||||||
DataType: "string",
|
|
||||||
TargetObjects: []string{"User"},
|
|
||||||
}
|
|
||||||
|
|
||||||
payload, err := am.helper.Marshal(extension)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resource := fmt.Sprintf("applications/%s/extensionProperties", am.ObjectID)
|
|
||||||
body, err := am.post(resource, string(payload))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var userExtension azureExtension
|
|
||||||
err = am.helper.Unmarshal(body, &userExtension)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &userExtension, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// configureAppMetadata sets up app metadata extensions if they do not exists.
|
|
||||||
func (am *AzureManager) configureAppMetadata() error {
|
|
||||||
wtAccountIDField := extensionName(wtAccountIDTpl, am.ClientID)
|
|
||||||
wtPendingInviteField := extensionName(wtPendingInviteTpl, am.ClientID)
|
|
||||||
|
|
||||||
extensions, err := am.getUserExtensions()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the wt_account_id extension does not already exist, create it.
|
|
||||||
if !hasExtension(extensions, wtAccountIDField) {
|
|
||||||
_, err = am.createUserExtension(wtAccountID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the wt_pending_invite extension does not already exist, create it.
|
|
||||||
if !hasExtension(extensions, wtPendingInviteField) {
|
|
||||||
_, err = am.createUserExtension(wtPendingInvite)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// get perform Get requests.
|
// get perform Get requests.
|
||||||
func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
|
func (am *AzureManager) get(resource string, q url.Values) ([]byte, error) {
|
||||||
jwtToken, err := am.credentials.Authenticate()
|
jwtToken, err := am.credentials.Authenticate()
|
||||||
@@ -639,7 +448,7 @@ func (am *AzureManager) post(resource string, body string) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// userData construct user data from keycloak profile.
|
// userData construct user data from keycloak profile.
|
||||||
func (ap azureProfile) userData(clientID string) *UserData {
|
func (ap azureProfile) userData() *UserData {
|
||||||
id, ok := ap["id"].(string)
|
id, ok := ap["id"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
id = ""
|
id = ""
|
||||||
@@ -655,66 +464,9 @@ func (ap azureProfile) userData(clientID string) *UserData {
|
|||||||
name = ""
|
name = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
accountIDField := extensionName(wtAccountIDTpl, clientID)
|
|
||||||
accountID, ok := ap[accountIDField].(string)
|
|
||||||
if !ok {
|
|
||||||
accountID = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
pendingInviteField := extensionName(wtPendingInviteTpl, clientID)
|
|
||||||
pendingInvite, ok := ap[pendingInviteField].(bool)
|
|
||||||
if !ok {
|
|
||||||
pendingInvite = false
|
|
||||||
}
|
|
||||||
|
|
||||||
return &UserData{
|
return &UserData{
|
||||||
Email: email,
|
Email: email,
|
||||||
Name: name,
|
Name: name,
|
||||||
ID: id,
|
ID: id,
|
||||||
AppMetadata: AppMetadata{
|
|
||||||
WTAccountID: accountID,
|
|
||||||
WTPendingInvite: &pendingInvite,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildAzureCreateUserRequestPayload(email, name, accountID, clientID string) (string, error) {
|
|
||||||
wtAccountIDField := extensionName(wtAccountIDTpl, clientID)
|
|
||||||
wtPendingInviteField := extensionName(wtPendingInviteTpl, clientID)
|
|
||||||
|
|
||||||
req := &azureProfile{
|
|
||||||
"accountEnabled": true,
|
|
||||||
"displayName": name,
|
|
||||||
"mailNickName": strings.Join(strings.Split(name, " "), ""),
|
|
||||||
"userPrincipalName": email,
|
|
||||||
"passwordProfile": passwordProfile{
|
|
||||||
ForceChangePasswordNextSignIn: true,
|
|
||||||
Password: GeneratePassword(8, 1, 1, 1),
|
|
||||||
},
|
|
||||||
wtAccountIDField: accountID,
|
|
||||||
wtPendingInviteField: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
str, err := json.Marshal(req)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(str), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func extensionName(extensionTpl, clientID string) string {
|
|
||||||
clientID = strings.ReplaceAll(clientID, "-", "")
|
|
||||||
return fmt.Sprintf(extensionTpl, clientID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// hasExtension checks whether a given extension by name,
|
|
||||||
// exists in an list of extensions.
|
|
||||||
func hasExtension(extensions []azureExtension, name string) bool {
|
|
||||||
for _, ext := range extensions {
|
|
||||||
if ext.Name == name {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -124,206 +124,63 @@ func TestAzureAuthenticate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAzureUpdateUserAppMetadata(t *testing.T) {
|
|
||||||
type updateUserAppMetadataTest struct {
|
|
||||||
name string
|
|
||||||
inputReqBody string
|
|
||||||
expectedReqBody string
|
|
||||||
appMetadata AppMetadata
|
|
||||||
statusCode int
|
|
||||||
helper ManagerHelper
|
|
||||||
managerCreds ManagerCredentials
|
|
||||||
assertErrFunc assert.ErrorAssertionFunc
|
|
||||||
assertErrFuncMessage string
|
|
||||||
}
|
|
||||||
|
|
||||||
appMetadata := AppMetadata{WTAccountID: "ok"}
|
|
||||||
|
|
||||||
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
|
|
||||||
name: "Bad Authentication",
|
|
||||||
expectedReqBody: "",
|
|
||||||
appMetadata: appMetadata,
|
|
||||||
statusCode: 400,
|
|
||||||
helper: JsonParser{},
|
|
||||||
managerCreds: &mockAzureCredentials{
|
|
||||||
jwtToken: JWTToken{},
|
|
||||||
err: fmt.Errorf("error"),
|
|
||||||
},
|
|
||||||
assertErrFunc: assert.Error,
|
|
||||||
assertErrFuncMessage: "should return error",
|
|
||||||
}
|
|
||||||
|
|
||||||
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
|
|
||||||
name: "Bad Status Code",
|
|
||||||
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":null}", appMetadata.WTAccountID),
|
|
||||||
appMetadata: appMetadata,
|
|
||||||
statusCode: 400,
|
|
||||||
helper: JsonParser{},
|
|
||||||
managerCreds: &mockAzureCredentials{
|
|
||||||
jwtToken: JWTToken{},
|
|
||||||
},
|
|
||||||
assertErrFunc: assert.Error,
|
|
||||||
assertErrFuncMessage: "should return error",
|
|
||||||
}
|
|
||||||
|
|
||||||
updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{
|
|
||||||
name: "Bad Response Parsing",
|
|
||||||
statusCode: 400,
|
|
||||||
helper: &mockJsonParser{marshalErrorString: "error"},
|
|
||||||
managerCreds: &mockAzureCredentials{
|
|
||||||
jwtToken: JWTToken{},
|
|
||||||
},
|
|
||||||
assertErrFunc: assert.Error,
|
|
||||||
assertErrFuncMessage: "should return error",
|
|
||||||
}
|
|
||||||
|
|
||||||
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
|
|
||||||
name: "Good request",
|
|
||||||
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":null}", appMetadata.WTAccountID),
|
|
||||||
appMetadata: appMetadata,
|
|
||||||
statusCode: 204,
|
|
||||||
helper: JsonParser{},
|
|
||||||
managerCreds: &mockAzureCredentials{
|
|
||||||
jwtToken: JWTToken{},
|
|
||||||
},
|
|
||||||
assertErrFunc: assert.NoError,
|
|
||||||
assertErrFuncMessage: "shouldn't return error",
|
|
||||||
}
|
|
||||||
|
|
||||||
invite := true
|
|
||||||
updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{
|
|
||||||
name: "Update Pending Invite",
|
|
||||||
expectedReqBody: fmt.Sprintf("{\"extension__wt_account_id\":\"%s\",\"extension__wt_pending_invite\":true}", appMetadata.WTAccountID),
|
|
||||||
appMetadata: AppMetadata{
|
|
||||||
WTAccountID: "ok",
|
|
||||||
WTPendingInvite: &invite,
|
|
||||||
},
|
|
||||||
statusCode: 204,
|
|
||||||
helper: JsonParser{},
|
|
||||||
managerCreds: &mockAzureCredentials{
|
|
||||||
jwtToken: JWTToken{},
|
|
||||||
},
|
|
||||||
assertErrFunc: assert.NoError,
|
|
||||||
assertErrFuncMessage: "shouldn't return error",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2,
|
|
||||||
updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
|
||||||
reqClient := mockHTTPClient{
|
|
||||||
resBody: testCase.inputReqBody,
|
|
||||||
code: testCase.statusCode,
|
|
||||||
}
|
|
||||||
|
|
||||||
manager := &AzureManager{
|
|
||||||
httpClient: &reqClient,
|
|
||||||
credentials: testCase.managerCreds,
|
|
||||||
helper: testCase.helper,
|
|
||||||
}
|
|
||||||
|
|
||||||
err := manager.UpdateUserAppMetadata("1", testCase.appMetadata)
|
|
||||||
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)
|
|
||||||
|
|
||||||
assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAzureProfile(t *testing.T) {
|
func TestAzureProfile(t *testing.T) {
|
||||||
type azureProfileTest struct {
|
type azureProfileTest struct {
|
||||||
name string
|
name string
|
||||||
clientID string
|
|
||||||
invite bool
|
invite bool
|
||||||
inputProfile azureProfile
|
inputProfile azureProfile
|
||||||
expectedUserData UserData
|
expectedUserData UserData
|
||||||
}
|
}
|
||||||
|
|
||||||
azureProfileTestCase1 := azureProfileTest{
|
azureProfileTestCase1 := azureProfileTest{
|
||||||
name: "Good Request",
|
name: "Good Request",
|
||||||
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
|
invite: false,
|
||||||
invite: false,
|
|
||||||
inputProfile: azureProfile{
|
inputProfile: azureProfile{
|
||||||
"id": "test1",
|
"id": "test1",
|
||||||
"displayName": "John Doe",
|
"displayName": "John Doe",
|
||||||
"userPrincipalName": "test1@test.com",
|
"userPrincipalName": "test1@test.com",
|
||||||
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
|
|
||||||
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": false,
|
|
||||||
},
|
},
|
||||||
expectedUserData: UserData{
|
expectedUserData: UserData{
|
||||||
Email: "test1@test.com",
|
Email: "test1@test.com",
|
||||||
Name: "John Doe",
|
Name: "John Doe",
|
||||||
ID: "test1",
|
ID: "test1",
|
||||||
AppMetadata: AppMetadata{
|
|
||||||
WTAccountID: "1",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
azureProfileTestCase2 := azureProfileTest{
|
azureProfileTestCase2 := azureProfileTest{
|
||||||
name: "Missing User ID",
|
name: "Missing User ID",
|
||||||
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
|
invite: true,
|
||||||
invite: true,
|
|
||||||
inputProfile: azureProfile{
|
inputProfile: azureProfile{
|
||||||
"displayName": "John Doe",
|
"displayName": "John Doe",
|
||||||
"userPrincipalName": "test2@test.com",
|
"userPrincipalName": "test2@test.com",
|
||||||
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
|
|
||||||
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": true,
|
|
||||||
},
|
},
|
||||||
expectedUserData: UserData{
|
expectedUserData: UserData{
|
||||||
Email: "test2@test.com",
|
Email: "test2@test.com",
|
||||||
Name: "John Doe",
|
Name: "John Doe",
|
||||||
AppMetadata: AppMetadata{
|
|
||||||
WTAccountID: "1",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
azureProfileTestCase3 := azureProfileTest{
|
azureProfileTestCase3 := azureProfileTest{
|
||||||
name: "Missing User Name",
|
name: "Missing User Name",
|
||||||
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
|
invite: false,
|
||||||
invite: false,
|
|
||||||
inputProfile: azureProfile{
|
inputProfile: azureProfile{
|
||||||
"id": "test3",
|
"id": "test3",
|
||||||
"userPrincipalName": "test3@test.com",
|
"userPrincipalName": "test3@test.com",
|
||||||
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_account_id": "1",
|
|
||||||
"extension_25d0b095048440d29fd303f8f4abbb3c_wt_pending_invite": false,
|
|
||||||
},
|
},
|
||||||
expectedUserData: UserData{
|
expectedUserData: UserData{
|
||||||
ID: "test3",
|
ID: "test3",
|
||||||
Email: "test3@test.com",
|
Email: "test3@test.com",
|
||||||
AppMetadata: AppMetadata{
|
|
||||||
WTAccountID: "1",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
azureProfileTestCase4 := azureProfileTest{
|
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2, azureProfileTestCase3} {
|
||||||
name: "Missing Extension Fields",
|
|
||||||
clientID: "25d0b095-0484-40d2-9fd3-03f8f4abbb3c",
|
|
||||||
invite: false,
|
|
||||||
inputProfile: azureProfile{
|
|
||||||
"id": "test4",
|
|
||||||
"displayName": "John Doe",
|
|
||||||
"userPrincipalName": "test4@test.com",
|
|
||||||
},
|
|
||||||
expectedUserData: UserData{
|
|
||||||
ID: "test4",
|
|
||||||
Name: "John Doe",
|
|
||||||
Email: "test4@test.com",
|
|
||||||
AppMetadata: AppMetadata{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range []azureProfileTest{azureProfileTestCase1, azureProfileTestCase2, azureProfileTestCase3, azureProfileTestCase4} {
|
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
|
testCase.expectedUserData.AppMetadata.WTPendingInvite = &testCase.invite
|
||||||
userData := testCase.inputProfile.userData(testCase.clientID)
|
userData := testCase.inputProfile.userData()
|
||||||
|
|
||||||
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
|
assert.Equal(t, testCase.expectedUserData.ID, userData.ID, "User id should match")
|
||||||
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
|
assert.Equal(t, testCase.expectedUserData.Email, userData.Email, "User email should match")
|
||||||
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
|
assert.Equal(t, testCase.expectedUserData.Name, userData.Name, "User name should match")
|
||||||
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTAccountID, userData.AppMetadata.WTAccountID, "Account id should match")
|
|
||||||
assert.Equal(t, testCase.expectedUserData.AppMetadata.WTPendingInvite, userData.AppMetadata.WTPendingInvite, "Pending invite should match")
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user