diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index 724a3541d..4dd950369 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -2,10 +2,11 @@ package idp import ( "fmt" - "github.com/netbirdio/netbird/management/server/telemetry" "net/http" "strings" "time" + + "github.com/netbirdio/netbird/management/server/telemetry" ) // Manager idp manager interface @@ -20,8 +21,9 @@ type Manager interface { // Config an idp configuration struct to be loaded from management server's config file type Config struct { - ManagerType string - Auth0ClientCredentials Auth0ClientConfig + ManagerType string + Auth0ClientCredentials Auth0ClientConfig + KeycloakClientCredentials KeycloakClientConfig } // ManagerCredentials interface that authenticates using the credential of each type of idp @@ -71,6 +73,8 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) return nil, nil case "auth0": return NewAuth0Manager(config.Auth0ClientCredentials, appMetrics) + case "keycloak": + return NewKeycloakManager(config.KeycloakClientCredentials, appMetrics) default: return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType) } diff --git a/management/server/idp/keycloak.go b/management/server/idp/keycloak.go new file mode 100644 index 000000000..f9fc94ae7 --- /dev/null +++ b/management/server/idp/keycloak.go @@ -0,0 +1,581 @@ +package idp + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "path" + "strconv" + "strings" + "sync" + "time" + + "github.com/golang-jwt/jwt" + "github.com/netbirdio/netbird/management/server/telemetry" + log "github.com/sirupsen/logrus" +) + +const ( + wtAccountID = "wt_account_id" + wtPendingInvite = "wt_pending_invite" +) + +// KeycloakManager keycloak manager client instance. +type KeycloakManager struct { + adminEndpoint string + httpClient ManagerHTTPClient + credentials ManagerCredentials + helper ManagerHelper + appMetrics telemetry.AppMetrics +} + +// KeycloakClientConfig keycloak manager client configurations. +type KeycloakClientConfig struct { + ClientID string + ClientSecret string + AdminEndpoint string + TokenEndpoint string + GrantType string +} + +// KeycloakCredentials keycloak authentication information. +type KeycloakCredentials struct { + clientConfig KeycloakClientConfig + helper ManagerHelper + httpClient ManagerHTTPClient + jwtToken JWTToken + mux sync.Mutex + appMetrics telemetry.AppMetrics +} + +// keycloakUserCredential describe the authentication method for, +// newly created user profile. +type keycloakUserCredential struct { + Type string `json:"type"` + Value string `json:"value"` + Temporary bool `json:"temporary"` +} + +// keycloakUserAttributes holds additional user data fields. +type keycloakUserAttributes map[string][]string + +// createUserRequest is a user create request. +type keycloakCreateUserRequest struct { + Email string `json:"email"` + Username string `json:"username"` + Enabled bool `json:"enabled"` + EmailVerified bool `json:"emailVerified"` + Credentials []keycloakUserCredential `json:"credentials"` + Attributes keycloakUserAttributes `json:"attributes"` +} + +// keycloakProfile represents an keycloak user profile response. +type keycloakProfile struct { + ID string `json:"id"` + CreatedTimestamp int64 `json:"createdTimestamp"` + Username string `json:"username"` + Email string `json:"email"` + Attributes keycloakUserAttributes `json:"attributes"` +} + +// NewKeycloakManager creates a new instance of the KeycloakManager. +func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMetrics) (*KeycloakManager, error) { + httpTransport := http.DefaultTransport.(*http.Transport).Clone() + httpTransport.MaxIdleConns = 5 + + httpClient := &http.Client{ + Timeout: 10 * time.Second, + Transport: httpTransport, + } + + helper := JsonParser{} + + if config.ClientID == "" || config.ClientSecret == "" || config.GrantType == "" || config.AdminEndpoint == "" || config.TokenEndpoint == "" { + return nil, fmt.Errorf("keycloak idp configuration is not complete") + } + + if config.GrantType != "client_credentials" { + return nil, fmt.Errorf("keycloak idp configuration failed. Grant Type should be client_credentials") + } + + credentials := &KeycloakCredentials{ + clientConfig: config, + httpClient: httpClient, + helper: helper, + appMetrics: appMetrics, + } + + return &KeycloakManager{ + adminEndpoint: config.AdminEndpoint, + httpClient: httpClient, + credentials: credentials, + helper: helper, + appMetrics: appMetrics, + }, nil +} + +// jwtStillValid returns true if the token still valid and have enough time to be used and get a response from keycloak. +func (kc *KeycloakCredentials) jwtStillValid() bool { + return !kc.jwtToken.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(kc.jwtToken.expiresInTime) +} + +// requestJWTToken performs request to get jwt token. +func (kc *KeycloakCredentials) requestJWTToken() (*http.Response, error) { + data := url.Values{} + data.Set("client_id", kc.clientConfig.ClientID) + data.Set("client_secret", kc.clientConfig.ClientSecret) + data.Set("grant_type", kc.clientConfig.GrantType) + + payload := strings.NewReader(data.Encode()) + req, err := http.NewRequest(http.MethodPost, kc.clientConfig.TokenEndpoint, payload) + if err != nil { + return nil, err + } + req.Header.Add("content-type", "application/x-www-form-urlencoded") + + log.Debug("requesting new jwt token for keycloak idp manager") + + resp, err := kc.httpClient.Do(req) + if err != nil { + if kc.appMetrics != nil { + kc.appMetrics.IDPMetrics().CountRequestError() + } + + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unable to get keycloak token, statusCode %d", resp.StatusCode) + } + + return resp, nil +} + +// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds +func (kc *KeycloakCredentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) { + jwtToken := JWTToken{} + body, err := io.ReadAll(rawBody) + if err != nil { + return jwtToken, err + } + + err = kc.helper.Unmarshal(body, &jwtToken) + if err != nil { + return jwtToken, err + } + + if jwtToken.ExpiresIn == 0 && jwtToken.AccessToken == "" { + return jwtToken, fmt.Errorf("error while reading response body, expires_in: %d and access_token: %s", jwtToken.ExpiresIn, jwtToken.AccessToken) + } + + data, err := jwt.DecodeSegment(strings.Split(jwtToken.AccessToken, ".")[1]) + if err != nil { + return jwtToken, err + } + + // Exp maps into exp from jwt token + var IssuedAt struct{ Exp int64 } + err = kc.helper.Unmarshal(data, &IssuedAt) + if err != nil { + return jwtToken, err + } + jwtToken.expiresInTime = time.Unix(IssuedAt.Exp, 0) + + return jwtToken, nil +} + +// Authenticate retrieves access token to use the keycloak Management API. +func (kc *KeycloakCredentials) Authenticate() (JWTToken, error) { + kc.mux.Lock() + defer kc.mux.Unlock() + + if kc.appMetrics != nil { + kc.appMetrics.IDPMetrics().CountAuthenticate() + } + + // reuse the token without requesting a new one if it is not expired, + // and if expiry time is sufficient time available to make a request. + if kc.jwtStillValid() { + return kc.jwtToken, nil + } + + resp, err := kc.requestJWTToken() + if err != nil { + return kc.jwtToken, err + } + defer resp.Body.Close() + + jwtToken, err := kc.parseRequestJWTResponse(resp.Body) + if err != nil { + return kc.jwtToken, err + } + + kc.jwtToken = jwtToken + + return kc.jwtToken, nil +} + +// CreateUser creates a new user in keycloak Idp and sends an invite. +func (km *KeycloakManager) CreateUser(email string, name string, accountID string) (*UserData, error) { + jwtToken, err := km.credentials.Authenticate() + if err != nil { + return nil, err + } + + invite := true + appMetadata := AppMetadata{ + WTAccountID: accountID, + WTPendingInvite: &invite, + } + + payloadString, err := buildKeycloakCreateUserRequestPayload(email, name, appMetadata) + if err != nil { + return nil, err + } + + reqURL := fmt.Sprintf("%s/users", km.adminEndpoint) + payload := strings.NewReader(payloadString) + + req, err := http.NewRequest(http.MethodPost, reqURL, payload) + if err != nil { + return nil, err + } + req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) + req.Header.Add("content-type", "application/json") + + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountCreateUser() + } + + resp, err := km.httpClient.Do(req) + if err != nil { + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountRequestError() + } + + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountRequestStatusError() + } + + return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode) + } + + locationHeader := resp.Header.Get("location") + userID, err := extractUserIDFromLocationHeader(locationHeader) + if err != nil { + return nil, err + } + + return km.GetUserDataByID(userID, appMetadata) +} + +// GetUserByEmail searches users with a given email. +// If no users have been found, this function returns an empty list. +func (km *KeycloakManager) GetUserByEmail(email string) ([]*UserData, error) { + q := url.Values{} + q.Add("email", email) + q.Add("exact", "true") + + body, err := km.get("users", q) + if err != nil { + return nil, err + } + + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountGetUserByEmail() + } + + profiles := make([]keycloakProfile, 0) + err = km.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + users := make([]*UserData, 0) + for _, profile := range profiles { + users = append(users, profile.userData()) + } + + return users, nil +} + +// GetUserDataByID requests user data from keycloak via ID. +func (km *KeycloakManager) GetUserDataByID(userID string, appMetadata AppMetadata) (*UserData, error) { + body, err := km.get("users/"+userID, nil) + if err != nil { + return nil, err + } + + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountGetUserDataByID() + } + + var profile keycloakProfile + err = km.helper.Unmarshal(body, &profile) + if err != nil { + return nil, err + } + + return profile.userData(), nil +} + +// GetAccount returns all the users for a given profile. +func (km *KeycloakManager) GetAccount(accountID string) ([]*UserData, error) { + q := url.Values{} + q.Add("q", wtAccountID+":"+accountID) + + body, err := km.get("users", q) + if err != nil { + return nil, err + } + + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountGetAccount() + } + + profiles := make([]keycloakProfile, 0) + err = km.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + users := make([]*UserData, 0) + for _, profile := range profiles { + users = append(users, profile.userData()) + } + + return users, nil +} + +// GetAllAccounts gets all registered accounts with corresponding user data. +// It returns a list of users indexed by accountID. +func (km *KeycloakManager) GetAllAccounts() (map[string][]*UserData, error) { + totalUsers, err := km.totalUsersCount() + if err != nil { + return nil, err + } + + q := url.Values{} + q.Add("max", fmt.Sprint(*totalUsers)) + + body, err := km.get("users", q) + if err != nil { + return nil, err + } + + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountGetAllAccounts() + } + + profiles := make([]keycloakProfile, 0) + err = km.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + indexedUsers := make(map[string][]*UserData) + for _, profile := range profiles { + userData := profile.userData() + + accountID := userData.AppMetadata.WTAccountID + if accountID != "" { + if _, ok := indexedUsers[accountID]; !ok { + indexedUsers[accountID] = make([]*UserData, 0) + } + indexedUsers[accountID] = append(indexedUsers[accountID], userData) + } + } + + return indexedUsers, nil +} + +// UpdateUserAppMetadata updates user app metadata based on userID and metadata map. +func (km *KeycloakManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error { + jwtToken, err := km.credentials.Authenticate() + if err != nil { + return err + } + + attrs := keycloakUserAttributes{} + attrs.Set(wtAccountID, appMetadata.WTAccountID) + if appMetadata.WTPendingInvite != nil { + attrs.Set(wtPendingInvite, strconv.FormatBool(*appMetadata.WTPendingInvite)) + } else { + attrs.Set(wtPendingInvite, "false") + } + + reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, userID) + data, err := km.helper.Marshal(map[string]any{ + "attributes": attrs, + }) + if err != nil { + return err + } + payload := strings.NewReader(string(data)) + + req, err := http.NewRequest(http.MethodPut, reqURL, payload) + if err != nil { + return err + } + req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) + req.Header.Add("content-type", "application/json") + + log.Debugf("updating IdP metadata for user %s", userID) + + resp, err := km.httpClient.Do(req) + if err != nil { + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountRequestError() + } + return err + } + defer resp.Body.Close() + + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountUpdateUserAppMetadata() + } + + if resp.StatusCode != http.StatusNoContent { + return fmt.Errorf("unable to update the appMetadata, statusCode %d", resp.StatusCode) + } + + return nil +} + +func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) { + attrs := keycloakUserAttributes{} + attrs.Set(wtAccountID, appMetadata.WTAccountID) + attrs.Set(wtPendingInvite, strconv.FormatBool(*appMetadata.WTPendingInvite)) + + req := &keycloakCreateUserRequest{ + Email: email, + Username: name, + Enabled: true, + EmailVerified: true, + Credentials: []keycloakUserCredential{ + { + Type: "password", + Value: GeneratePassword(8, 1, 1, 1), + Temporary: false, + }, + }, + Attributes: attrs, + } + + str, err := json.Marshal(req) + if err != nil { + return "", err + } + + return string(str), nil +} + +// get perform Get requests. +func (km *KeycloakManager) get(resource string, q url.Values) ([]byte, error) { + jwtToken, err := km.credentials.Authenticate() + if err != nil { + return nil, err + } + + reqURL := fmt.Sprintf("%s/%s?%s", km.adminEndpoint, resource, q.Encode()) + req, err := http.NewRequest(http.MethodGet, reqURL, nil) + if err != nil { + return nil, err + } + req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) + req.Header.Add("content-type", "application/json") + + resp, err := km.httpClient.Do(req) + if err != nil { + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountRequestError() + } + + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + if km.appMetrics != nil { + km.appMetrics.IDPMetrics().CountRequestStatusError() + } + + return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode) + } + + return io.ReadAll(resp.Body) +} + +// totalUsersCount returns the total count of all user created. +// Used when fetching all registered accounts with pagination. +func (km *KeycloakManager) totalUsersCount() (*int, error) { + body, err := km.get("users/count", nil) + if err != nil { + return nil, err + } + + count, err := strconv.Atoi(string(body)) + if err != nil { + return nil, err + } + + return &count, nil +} + +// extractUserIDFromLocationHeader extracts the user ID from the location, +// header once the user is created successfully +func extractUserIDFromLocationHeader(locationHeader string) (string, error) { + userURL, err := url.Parse(locationHeader) + if err != nil { + return "", err + } + + return path.Base(userURL.Path), nil +} + +// userData construct user data from keycloak profile. +func (kp keycloakProfile) userData() *UserData { + accountID := kp.Attributes.Get(wtAccountID) + pendingInvite, err := strconv.ParseBool(kp.Attributes.Get(wtPendingInvite)) + if err != nil { + pendingInvite = false + } + + return &UserData{ + Email: kp.Email, + Name: kp.Username, + ID: kp.ID, + AppMetadata: AppMetadata{ + WTAccountID: accountID, + WTPendingInvite: &pendingInvite, + }, + } +} + +// Set sets the key to value. It replaces any existing +// values. +func (ka keycloakUserAttributes) Set(key, value string) { + ka[key] = []string{value} +} + +// Get returns the first value associated with the given key. +// If there are no values associated with the key, Get returns +// the empty string. +func (ka keycloakUserAttributes) Get(key string) string { + if ka == nil { + return "" + } + + values := ka[key] + if len(values) == 0 { + return "" + } + return values[0] +} diff --git a/management/server/idp/keycloak_test.go b/management/server/idp/keycloak_test.go new file mode 100644 index 000000000..00acf81bd --- /dev/null +++ b/management/server/idp/keycloak_test.go @@ -0,0 +1,401 @@ +package idp + +import ( + "fmt" + "io" + "strings" + "testing" + "time" + + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewKeycloakManager(t *testing.T) { + type test struct { + name string + inputConfig KeycloakClientConfig + assertErrFunc require.ErrorAssertionFunc + assertErrFuncMessage string + } + + defaultTestConfig := KeycloakClientConfig{ + ClientID: "client_id", + ClientSecret: "client_secret", + AdminEndpoint: "https://localhost:8080/auth/admin/realms/test123", + TokenEndpoint: "https://localhost:8080/auth/realms/test123/protocol/openid-connect/token", + GrantType: "client_credentials", + } + + testCase1 := test{ + name: "Good Configuration", + inputConfig: defaultTestConfig, + assertErrFunc: require.NoError, + assertErrFuncMessage: "shouldn't return error", + } + + testCase2Config := defaultTestConfig + testCase2Config.ClientID = "" + + testCase2 := test{ + name: "Missing ClientID Configuration", + inputConfig: testCase2Config, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + } + + testCase5Config := defaultTestConfig + testCase5Config.GrantType = "authorization_code" + + testCase5 := test{ + name: "Wrong GrantType", + inputConfig: testCase5Config, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when wrong grant type", + } + + for _, testCase := range []test{testCase1, testCase2, testCase5} { + t.Run(testCase.name, func(t *testing.T) { + _, err := NewKeycloakManager(testCase.inputConfig, &telemetry.MockAppMetrics{}) + testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage) + }) + } +} + +type mockKeycloakCredentials struct { + jwtToken JWTToken + err error +} + +func (mc *mockKeycloakCredentials) Authenticate() (JWTToken, error) { + return mc.jwtToken, mc.err +} + +func TestKeycloakRequestJWTToken(t *testing.T) { + + type requestJWTTokenTest struct { + name string + inputCode int + inputRespBody string + helper ManagerHelper + expectedFuncExitErrDiff error + expectedToken string + } + exp := 5 + token := newTestJWT(t, exp) + + requestJWTTokenTesttCase1 := requestJWTTokenTest{ + name: "Good JWT Response", + inputCode: 200, + inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp), + helper: JsonParser{}, + expectedToken: token, + } + requestJWTTokenTestCase2 := requestJWTTokenTest{ + name: "Request Bad Status Code", + inputCode: 400, + inputRespBody: "{}", + helper: JsonParser{}, + expectedFuncExitErrDiff: fmt.Errorf("unable to get keycloak token, statusCode 400"), + expectedToken: "", + } + + for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} { + t.Run(testCase.name, func(t *testing.T) { + + jwtReqClient := mockHTTPClient{ + resBody: testCase.inputRespBody, + code: testCase.inputCode, + } + config := KeycloakClientConfig{} + + creds := KeycloakCredentials{ + clientConfig: config, + httpClient: &jwtReqClient, + helper: testCase.helper, + } + + resp, err := creds.requestJWTToken() + if err != nil { + if testCase.expectedFuncExitErrDiff != nil { + assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") + } else { + t.Fatal(err) + } + } else { + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err, "unable to read the response body") + + jwtToken := JWTToken{} + err = testCase.helper.Unmarshal(body, &jwtToken) + assert.NoError(t, err, "unable to parse the json input") + + assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same") + } + }) + } +} + +func TestKeycloakParseRequestJWTResponse(t *testing.T) { + type parseRequestJWTResponseTest struct { + name string + inputRespBody string + helper ManagerHelper + expectedToken string + expectedExpiresIn int + assertErrFunc assert.ErrorAssertionFunc + assertErrFuncMessage string + } + + exp := 100 + token := newTestJWT(t, exp) + + parseRequestJWTResponseTestCase1 := parseRequestJWTResponseTest{ + name: "Parse Good JWT Body", + inputRespBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp), + helper: JsonParser{}, + expectedToken: token, + expectedExpiresIn: exp, + assertErrFunc: assert.NoError, + assertErrFuncMessage: "no error was expected", + } + parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{ + name: "Parse Bad json JWT Body", + inputRespBody: "", + helper: JsonParser{}, + expectedToken: "", + expectedExpiresIn: 0, + assertErrFunc: assert.Error, + assertErrFuncMessage: "json error was expected", + } + + for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} { + t.Run(testCase.name, func(t *testing.T) { + rawBody := io.NopCloser(strings.NewReader(testCase.inputRespBody)) + config := KeycloakClientConfig{} + + creds := KeycloakCredentials{ + clientConfig: config, + helper: testCase.helper, + } + jwtToken, err := creds.parseRequestJWTResponse(rawBody) + testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage) + + assert.Equalf(t, testCase.expectedToken, jwtToken.AccessToken, "two tokens should be the same") + assert.Equalf(t, testCase.expectedExpiresIn, jwtToken.ExpiresIn, "the two expire times should be the same") + }) + } +} + +func TestKeycloakJwtStillValid(t *testing.T) { + type jwtStillValidTest struct { + name string + inputTime time.Time + expectedResult bool + message string + } + + jwtStillValidTestCase1 := jwtStillValidTest{ + name: "JWT still valid", + inputTime: time.Now().Add(10 * time.Second), + expectedResult: true, + message: "should be true", + } + jwtStillValidTestCase2 := jwtStillValidTest{ + name: "JWT is invalid", + inputTime: time.Now(), + expectedResult: false, + message: "should be false", + } + + for _, testCase := range []jwtStillValidTest{jwtStillValidTestCase1, jwtStillValidTestCase2} { + t.Run(testCase.name, func(t *testing.T) { + config := KeycloakClientConfig{} + + creds := KeycloakCredentials{ + clientConfig: config, + } + creds.jwtToken.expiresInTime = testCase.inputTime + + assert.Equalf(t, testCase.expectedResult, creds.jwtStillValid(), testCase.message) + }) + } +} + +func TestKeycloakAuthenticate(t *testing.T) { + type authenticateTest struct { + name string + inputCode int + inputResBody string + inputExpireToken time.Time + helper ManagerHelper + expectedFuncExitErrDiff error + expectedCode int + expectedToken string + } + exp := 5 + token := newTestJWT(t, exp) + + authenticateTestCase1 := authenticateTest{ + name: "Get Cached token", + inputExpireToken: time.Now().Add(30 * time.Second), + helper: JsonParser{}, + expectedFuncExitErrDiff: nil, + expectedCode: 200, + expectedToken: "", + } + + authenticateTestCase2 := authenticateTest{ + name: "Get Good JWT Response", + inputCode: 200, + inputResBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp), + helper: JsonParser{}, + expectedCode: 200, + expectedToken: token, + } + + authenticateTestCase3 := authenticateTest{ + name: "Get Bad Status Code", + inputCode: 400, + inputResBody: "{}", + helper: JsonParser{}, + expectedFuncExitErrDiff: fmt.Errorf("unable to get keycloak token, statusCode 400"), + expectedCode: 200, + expectedToken: "", + } + + for _, testCase := range []authenticateTest{authenticateTestCase1, authenticateTestCase2, authenticateTestCase3} { + t.Run(testCase.name, func(t *testing.T) { + + jwtReqClient := mockHTTPClient{ + resBody: testCase.inputResBody, + code: testCase.inputCode, + } + config := KeycloakClientConfig{} + + creds := KeycloakCredentials{ + clientConfig: config, + httpClient: &jwtReqClient, + helper: testCase.helper, + } + creds.jwtToken.expiresInTime = testCase.inputExpireToken + + _, err := creds.Authenticate() + if err != nil { + if testCase.expectedFuncExitErrDiff != nil { + assert.EqualError(t, err, testCase.expectedFuncExitErrDiff.Error(), "errors should be the same") + } else { + t.Fatal(err) + } + } + + assert.Equalf(t, testCase.expectedToken, creds.jwtToken.AccessToken, "two tokens should be the same") + }) + } +} + +func TestKeycloakUpdateUserAppMetadata(t *testing.T) { + type updateUserAppMetadataTest struct { + name string + inputReqBody string + expectedReqBody string + appMetadata AppMetadata + statusCode int + helper ManagerHelper + managerCreds ManagerCredentials + assertErrFunc assert.ErrorAssertionFunc + assertErrFuncMessage string + } + + appMetadata := AppMetadata{WTAccountID: "ok"} + + updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{ + name: "Bad Authentication", + expectedReqBody: "", + appMetadata: appMetadata, + statusCode: 400, + helper: JsonParser{}, + managerCreds: &mockKeycloakCredentials{ + jwtToken: JWTToken{}, + err: fmt.Errorf("error"), + }, + assertErrFunc: assert.Error, + assertErrFuncMessage: "should return error", + } + + updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{ + name: "Bad Status Code", + expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"false\"]}}", appMetadata.WTAccountID), + appMetadata: appMetadata, + statusCode: 400, + helper: JsonParser{}, + managerCreds: &mockKeycloakCredentials{ + jwtToken: JWTToken{}, + }, + assertErrFunc: assert.Error, + assertErrFuncMessage: "should return error", + } + + updateUserAppMetadataTestCase3 := updateUserAppMetadataTest{ + name: "Bad Response Parsing", + statusCode: 400, + helper: &mockJsonParser{marshalErrorString: "error"}, + managerCreds: &mockKeycloakCredentials{ + jwtToken: JWTToken{}, + }, + assertErrFunc: assert.Error, + assertErrFuncMessage: "should return error", + } + + updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{ + name: "Good request", + expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"false\"]}}", appMetadata.WTAccountID), + appMetadata: appMetadata, + statusCode: 204, + helper: JsonParser{}, + managerCreds: &mockKeycloakCredentials{ + jwtToken: JWTToken{}, + }, + assertErrFunc: assert.NoError, + assertErrFuncMessage: "shouldn't return error", + } + + invite := true + updateUserAppMetadataTestCase5 := updateUserAppMetadataTest{ + name: "Update Pending Invite", + expectedReqBody: fmt.Sprintf("{\"attributes\":{\"wt_account_id\":[\"%s\"],\"wt_pending_invite\":[\"true\"]}}", appMetadata.WTAccountID), + appMetadata: AppMetadata{ + WTAccountID: "ok", + WTPendingInvite: &invite, + }, + statusCode: 204, + helper: JsonParser{}, + managerCreds: &mockKeycloakCredentials{ + jwtToken: JWTToken{}, + }, + assertErrFunc: assert.NoError, + assertErrFuncMessage: "shouldn't return error", + } + + for _, testCase := range []updateUserAppMetadataTest{updateUserAppMetadataTestCase1, updateUserAppMetadataTestCase2, + updateUserAppMetadataTestCase3, updateUserAppMetadataTestCase4, updateUserAppMetadataTestCase5} { + t.Run(testCase.name, func(t *testing.T) { + reqClient := mockHTTPClient{ + resBody: testCase.inputReqBody, + code: testCase.statusCode, + } + + manager := &KeycloakManager{ + httpClient: &reqClient, + credentials: testCase.managerCreds, + helper: testCase.helper, + } + + err := manager.UpdateUserAppMetadata("1", testCase.appMetadata) + testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage) + + assert.Equal(t, testCase.expectedReqBody, reqClient.reqBody, "request body should match") + }) + } +}