From bd61be24bebc93e3ee214fea515ea54561793aa2 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sun, 20 Mar 2022 08:29:18 +0100 Subject: [PATCH] Add OAuth Package and Auth0 Client (#273) Adding package for retrieving an access token with a device login flow For now, we got Auth0 as a client but the Interface Client is ready --- client/internal/oauth/auth0.go | 210 ++++++++++++++++++++ client/internal/oauth/auth0_test.go | 296 ++++++++++++++++++++++++++++ client/internal/oauth/oauth.go | 36 ++++ 3 files changed, 542 insertions(+) create mode 100644 client/internal/oauth/auth0.go create mode 100644 client/internal/oauth/auth0_test.go create mode 100644 client/internal/oauth/oauth.go diff --git a/client/internal/oauth/auth0.go b/client/internal/oauth/auth0.go new file mode 100644 index 000000000..3370994bd --- /dev/null +++ b/client/internal/oauth/auth0.go @@ -0,0 +1,210 @@ +package oauth + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "strings" + "time" +) + +// auth0GrantType grant type for device flow on Auth0 +const auth0GrantType = "urn:ietf:params:oauth:grant-type:device_code" + +// Auth0 client +type Auth0 struct { + // Auth0 API Audience for validation + Audience string + // Auth0 Native application client id + ClientID string + // Auth0 domain + Domain string + + HTTPClient HTTPClient +} + +// RequestDeviceCodePayload used for request device code payload for auth0 +type RequestDeviceCodePayload struct { + Audience string `json:"audience"` + ClientID string `json:"client_id"` +} + +// TokenRequestPayload used for requesting the auth0 token +type TokenRequestPayload struct { + GrantType string `json:"grant_type"` + DeviceCode string `json:"device_code"` + ClientID string `json:"client_id"` +} + +// TokenRequestResponse used for parsing Auth0 token's response +type TokenRequestResponse struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + TokenInfo +} + +// Claims used when validating the access token +type Claims struct { + Audiance string `json:"aud"` +} + +// NewAuth0DeviceFlow returns an Auth0 OAuth client +func NewAuth0DeviceFlow(audience string, clientID string, domain string) *Auth0 { + httpTransport := http.DefaultTransport.(*http.Transport).Clone() + httpTransport.MaxIdleConns = 5 + + httpClient := &http.Client{ + Timeout: 10 * time.Second, + Transport: httpTransport, + } + + return &Auth0{ + Audience: audience, + ClientID: clientID, + Domain: domain, + HTTPClient: httpClient, + } +} + +// RequestDeviceCode requests a device code login flow information from Auth0 +func (a *Auth0) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) { + url := "https://" + a.Domain + "/oauth/device/code" + codePayload := RequestDeviceCodePayload{ + Audience: a.Audience, + ClientID: a.ClientID, + } + p, err := json.Marshal(codePayload) + if err != nil { + return DeviceAuthInfo{}, fmt.Errorf("parsing payload failed with error: %v", err) + } + payload := strings.NewReader(string(p)) + req, err := http.NewRequest("POST", url, payload) + if err != nil { + return DeviceAuthInfo{}, fmt.Errorf("creating request failed with error: %v", err) + } + + req.Header.Add("content-type", "application/json") + + res, err := a.HTTPClient.Do(req) + if err != nil { + return DeviceAuthInfo{}, fmt.Errorf("doing request failed with error: %v", err) + } + + defer res.Body.Close() + body, err := ioutil.ReadAll(res.Body) + if err != nil { + return DeviceAuthInfo{}, fmt.Errorf("reading body failed with error: %v", err) + } + + if res.StatusCode != 200 { + return DeviceAuthInfo{}, fmt.Errorf("request device code returned status %d error: %s", res.StatusCode, string(body)) + } + + deviceCode := DeviceAuthInfo{} + err = json.Unmarshal(body, &deviceCode) + if err != nil { + return DeviceAuthInfo{}, fmt.Errorf("unmarshaling response failed with error: %v", err) + } + + return deviceCode, err +} + +// WaitToken waits user's login and authorize the app. Once the user's authorize +// it retrieves the access token from Auth0's endpoint and validates it before returning +func (a *Auth0) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error) { + ticker := time.NewTicker(time.Duration(info.Interval) * time.Second) + for { + select { + case <-ctx.Done(): + return TokenInfo{}, ctx.Err() + case <-ticker.C: + url := "https://" + a.Domain + "/oauth/token" + tokenReqPayload := TokenRequestPayload{ + GrantType: auth0GrantType, + DeviceCode: info.DeviceCode, + ClientID: a.ClientID, + } + p, err := json.Marshal(tokenReqPayload) + if err != nil { + return TokenInfo{}, fmt.Errorf("parsing token payload failed with error: %v", err) + } + payload := strings.NewReader(string(p)) + req, err := http.NewRequest("POST", url, payload) + if err != nil { + return TokenInfo{}, fmt.Errorf("creating wait token request failed with error: %v", err) + } + + req.Header.Add("content-type", "application/json") + + res, err := a.HTTPClient.Do(req) + if err != nil { + return TokenInfo{}, fmt.Errorf("doing wiat request failed with error: %v", err) + } + + defer res.Body.Close() + body, err := ioutil.ReadAll(res.Body) + if err != nil { + return TokenInfo{}, fmt.Errorf("reading toekn body failed with error: %v", err) + } + + if res.StatusCode > 499 { + return TokenInfo{}, fmt.Errorf("wait token code returned error: %s", string(body)) + } + + tokenResponse := TokenRequestResponse{} + err = json.Unmarshal(body, &tokenResponse) + if err != nil { + return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err) + } + + if tokenResponse.Error != "" { + if tokenResponse.Error == "authorization_pending" { + continue + } + return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription) + } + + err = isValidAccessToken(tokenResponse.AccessToken, a.Audience) + if err != nil { + return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err) + } + + tokenInfo := TokenInfo{ + AccessToken: tokenResponse.AccessToken, + TokenType: tokenResponse.TokenType, + RefreshToken: tokenResponse.RefreshToken, + IDToken: tokenResponse.IDToken, + ExpiresIn: tokenResponse.ExpiresIn, + } + return tokenInfo, err + } + } +} + +// isValidAccessToken is a simple validation of the access token +func isValidAccessToken(token string, audience string) error { + if token == "" { + return fmt.Errorf("token received is empty") + } + + encodedClaims := strings.Split(token, ".")[1] + claimsString, err := base64.RawURLEncoding.DecodeString(encodedClaims) + if err != nil { + return err + } + + claims := Claims{} + err = json.Unmarshal(claimsString, &claims) + if err != nil { + return err + } + + if claims.Audiance != audience { + return fmt.Errorf("invalid audience") + } + + return nil +} diff --git a/client/internal/oauth/auth0_test.go b/client/internal/oauth/auth0_test.go new file mode 100644 index 000000000..2246b166e --- /dev/null +++ b/client/internal/oauth/auth0_test.go @@ -0,0 +1,296 @@ +package oauth + +import ( + "context" + "encoding/json" + "fmt" + "github.com/golang-jwt/jwt" + "github.com/stretchr/testify/require" + "io/ioutil" + "net/http" + "strings" + "testing" + "time" +) + +type mockHTTPClient struct { + code int + resBody string + reqBody string + MaxReqs int + count int + countResBody string + err error +} + +func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) { + body, err := ioutil.ReadAll(req.Body) + if err == nil { + c.reqBody = string(body) + } + + if c.MaxReqs > c.count { + c.count++ + return &http.Response{ + StatusCode: c.code, + Body: ioutil.NopCloser(strings.NewReader(c.countResBody)), + }, c.err + } + + return &http.Response{ + StatusCode: c.code, + Body: ioutil.NopCloser(strings.NewReader(c.resBody)), + }, c.err +} + +func TestAuth0_RequestDeviceCode(t *testing.T) { + type test struct { + name string + inputResBody string + inputReqCode int + inputReqError error + testingErrFunc require.ErrorAssertionFunc + expectedErrorMSG string + testingFunc require.ComparisonAssertionFunc + expectedOut DeviceAuthInfo + expectedMSG string + expectPayload RequestDeviceCodePayload + } + + testCase1 := test{ + name: "Payload Is Valid", + expectPayload: RequestDeviceCodePayload{ + Audience: "ok", + ClientID: "bla", + }, + inputReqCode: 200, + testingErrFunc: require.Error, + testingFunc: require.EqualValues, + } + + testCase2 := test{ + name: "Exit On Network Error", + inputReqError: fmt.Errorf("error"), + testingErrFunc: require.Error, + expectedErrorMSG: "should return error", + testingFunc: require.EqualValues, + } + + testCase3 := test{ + name: "Exit On Exit Code", + inputReqCode: 400, + testingErrFunc: require.Error, + expectedErrorMSG: "should return error", + testingFunc: require.EqualValues, + } + testCase4Out := DeviceAuthInfo{ExpiresIn: 10} + testCase4 := test{ + name: "Got Device Code", + inputResBody: fmt.Sprintf("{\"expires_in\":%d}", testCase4Out.ExpiresIn), + expectPayload: RequestDeviceCodePayload{ + Audience: "ok", + ClientID: "bla", + }, + inputReqCode: 200, + testingErrFunc: require.NoError, + testingFunc: require.EqualValues, + expectedOut: testCase4Out, + expectedMSG: "out should match", + } + + for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} { + t.Run(testCase.name, func(t *testing.T) { + + httpClient := mockHTTPClient{ + resBody: testCase.inputResBody, + code: testCase.inputReqCode, + err: testCase.inputReqError, + } + + auth0 := Auth0{ + Audience: testCase.expectPayload.Audience, + ClientID: testCase.expectPayload.ClientID, + Domain: "test.auth0.com", + HTTPClient: &httpClient, + } + + authInfo, err := auth0.RequestDeviceCode(context.TODO()) + testCase.testingErrFunc(t, err, testCase.expectedErrorMSG) + + payload, _ := json.Marshal(testCase.expectPayload) + + require.EqualValues(t, string(payload), httpClient.reqBody, "payload should match") + + testCase.testingFunc(t, testCase.expectedOut, authInfo, testCase.expectedMSG) + + }) + } +} + +func TestAuth0_WaitToken(t *testing.T) { + type test struct { + name string + inputResBody string + inputReqCode int + inputReqError error + inputMaxReqs int + inputCountResBody string + inputTimeout time.Duration + inputInfo DeviceAuthInfo + inputAudience string + testingErrFunc require.ErrorAssertionFunc + expectedErrorMSG string + testingFunc require.ComparisonAssertionFunc + expectedOut TokenInfo + expectedMSG string + expectPayload TokenRequestPayload + } + + defaultInfo := DeviceAuthInfo{ + DeviceCode: "test", + ExpiresIn: 10, + Interval: 1, + } + + // payload,timeout,error 500,error conn,error response, invalid token,good req + tokenReqPayload := TokenRequestPayload{ + GrantType: auth0GrantType, + DeviceCode: defaultInfo.DeviceCode, + ClientID: "test", + } + + testCase1 := test{ + name: "Payload Is Valid", + inputInfo: defaultInfo, + inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second, + inputReqCode: 200, + testingErrFunc: require.Error, + testingFunc: require.EqualValues, + expectPayload: tokenReqPayload, + } + + testCase2 := test{ + name: "Exit On Network Error", + inputInfo: defaultInfo, + inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second, + expectPayload: tokenReqPayload, + inputReqError: fmt.Errorf("error"), + testingErrFunc: require.Error, + expectedErrorMSG: "should return error", + testingFunc: require.EqualValues, + } + + testCase3 := test{ + name: "Exit On 4XX When Not Pending", + inputInfo: defaultInfo, + inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second, + inputReqCode: 400, + expectPayload: tokenReqPayload, + testingErrFunc: require.Error, + expectedErrorMSG: "should return error", + testingFunc: require.EqualValues, + } + + testCase4 := test{ + name: "Exit On Exit Code 5XX", + inputInfo: defaultInfo, + inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second, + inputReqCode: 500, + expectPayload: tokenReqPayload, + testingErrFunc: require.Error, + expectedErrorMSG: "should return error", + testingFunc: require.EqualValues, + } + + testCase5 := test{ + name: "Exit On Content Timeout", + inputInfo: defaultInfo, + inputTimeout: 0 * time.Second, + testingErrFunc: require.Error, + expectedErrorMSG: "should return error", + testingFunc: require.EqualValues, + } + + audience := "test" + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"aud": audience}) + var hmacSampleSecret []byte + tokenString, _ := token.SignedString(hmacSampleSecret) + + testCase6 := test{ + name: "Exit On Invalid Audience", + inputInfo: defaultInfo, + inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString), + inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second, + inputReqCode: 200, + inputAudience: "super test", + testingErrFunc: require.Error, + testingFunc: require.EqualValues, + expectPayload: tokenReqPayload, + } + + testCase7 := test{ + name: "Received Token Info", + inputInfo: defaultInfo, + inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString), + inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second, + inputReqCode: 200, + inputAudience: audience, + testingErrFunc: require.NoError, + testingFunc: require.EqualValues, + expectPayload: tokenReqPayload, + expectedOut: TokenInfo{AccessToken: tokenString}, + } + + testCase8 := test{ + name: "Received Token Info after Multiple tries", + inputInfo: defaultInfo, + inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString), + inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second, + inputMaxReqs: 2, + inputCountResBody: "{\"error\":\"authorization_pending\"}", + inputReqCode: 200, + inputAudience: audience, + testingErrFunc: require.NoError, + testingFunc: require.EqualValues, + expectPayload: tokenReqPayload, + expectedOut: TokenInfo{AccessToken: tokenString}, + } + + for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7, testCase8} { + t.Run(testCase.name, func(t *testing.T) { + + httpClient := mockHTTPClient{ + resBody: testCase.inputResBody, + code: testCase.inputReqCode, + err: testCase.inputReqError, + MaxReqs: testCase.inputMaxReqs, + countResBody: testCase.inputCountResBody, + } + + auth0 := Auth0{ + Audience: testCase.inputAudience, + ClientID: testCase.expectPayload.ClientID, + Domain: "test.auth0.com", + HTTPClient: &httpClient, + } + + ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout) + defer cancel() + tokenInfo, err := auth0.WaitToken(ctx, testCase.inputInfo) + testCase.testingErrFunc(t, err, testCase.expectedErrorMSG) + + var payload []byte + var emptyPayload TokenRequestPayload + if testCase.expectPayload != emptyPayload { + payload, _ = json.Marshal(testCase.expectPayload) + } + require.EqualValues(t, string(payload), httpClient.reqBody, "payload should match") + + testCase.testingFunc(t, testCase.expectedOut, tokenInfo, testCase.expectedMSG) + + require.GreaterOrEqualf(t, testCase.inputMaxReqs, httpClient.count, "should run %d times", testCase.inputMaxReqs) + + }) + } +} diff --git a/client/internal/oauth/oauth.go b/client/internal/oauth/oauth.go new file mode 100644 index 000000000..07e6005f7 --- /dev/null +++ b/client/internal/oauth/oauth.go @@ -0,0 +1,36 @@ +package oauth + +import ( + "context" + "net/http" +) + +// HTTPClient http client interface for API calls +type HTTPClient interface { + Do(req *http.Request) (*http.Response, error) +} + +// DeviceAuthInfo holds information for the OAuth device login flow +type DeviceAuthInfo struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + VerificationURIComplete string `json:"verification_uri_complete"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + +// TokenInfo holds information of issued access token +type TokenInfo struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + IDToken string `json:"id_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` +} + +// Client is a OAuth client interface for various idp providers +type Client interface { + RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) + WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error) +}