Rotate Access token with refresh token (#280)

Add method for rotating access token with refresh tokens
This will be useful for catching expired sessions and
offboarding users

Also added functions to handle secrets. They have to be revisited
as some tests didn't run on CI as they waited some user input, like password
This commit is contained in:
Maycon Santos
2022-03-22 13:12:11 +01:00
committed by GitHub
parent 76db9afa11
commit a2fc4ec221
9 changed files with 329 additions and 46 deletions

View File

@@ -12,7 +12,10 @@ import (
)
// auth0GrantType grant type for device flow on Auth0
const auth0GrantType = "urn:ietf:params:oauth:grant-type:device_code"
const (
auth0GrantType = "urn:ietf:params:oauth:grant-type:device_code"
auth0RefreshGrant = "refresh_token"
)
// Auth0 client
type Auth0 struct {
@@ -34,9 +37,10 @@ type RequestDeviceCodePayload struct {
// TokenRequestPayload used for requesting the auth0 token
type TokenRequestPayload struct {
GrantType string `json:"grant_type"`
DeviceCode string `json:"device_code"`
ClientID string `json:"client_id"`
GrantType string `json:"grant_type"`
DeviceCode string `json:"device_code,omitempty"`
ClientID string `json:"client_id"`
RefreshToken string `json:"refresh_token,omitempty"`
}
// TokenRequestResponse used for parsing Auth0 token's response
@@ -48,7 +52,7 @@ type TokenRequestResponse struct {
// Claims used when validating the access token
type Claims struct {
Audiance string `json:"aud"`
Audience string `json:"aud"`
}
// NewAuth0DeviceFlow returns an Auth0 OAuth client
@@ -127,30 +131,13 @@ func (a *Auth0) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo,
DeviceCode: info.DeviceCode,
ClientID: a.ClientID,
}
p, err := json.Marshal(tokenReqPayload)
body, statusCode, err := requestToken(a.HTTPClient, url, tokenReqPayload)
if err != nil {
return TokenInfo{}, fmt.Errorf("parsing token payload failed with error: %v", err)
}
payload := strings.NewReader(string(p))
req, err := http.NewRequest("POST", url, payload)
if err != nil {
return TokenInfo{}, fmt.Errorf("creating wait token request failed with error: %v", err)
return TokenInfo{}, fmt.Errorf("wait for token: %v", err)
}
req.Header.Add("content-type", "application/json")
res, err := a.HTTPClient.Do(req)
if err != nil {
return TokenInfo{}, fmt.Errorf("doing wiat request failed with error: %v", err)
}
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return TokenInfo{}, fmt.Errorf("reading toekn body failed with error: %v", err)
}
if res.StatusCode > 499 {
if statusCode > 499 {
return TokenInfo{}, fmt.Errorf("wait token code returned error: %s", string(body))
}
@@ -184,6 +171,71 @@ func (a *Auth0) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo,
}
}
// RotateAccessToken requests a new token using an existing refresh token
func (a *Auth0) RotateAccessToken(ctx context.Context, refreshToken string) (TokenInfo, error) {
url := "https://" + a.Domain + "/oauth/token"
tokenReqPayload := TokenRequestPayload{
GrantType: auth0RefreshGrant,
ClientID: a.ClientID,
RefreshToken: refreshToken,
}
body, statusCode, err := requestToken(a.HTTPClient, url, tokenReqPayload)
if err != nil {
return TokenInfo{}, fmt.Errorf("rotate access token: %v", err)
}
if statusCode != 200 {
return TokenInfo{}, fmt.Errorf("rotating token returned error: %s", string(body))
}
tokenResponse := TokenRequestResponse{}
err = json.Unmarshal(body, &tokenResponse)
if err != nil {
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
}
err = isValidAccessToken(tokenResponse.AccessToken, a.Audience)
if err != nil {
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
}
tokenInfo := TokenInfo{
AccessToken: tokenResponse.AccessToken,
TokenType: tokenResponse.TokenType,
RefreshToken: tokenResponse.RefreshToken,
IDToken: tokenResponse.IDToken,
ExpiresIn: tokenResponse.ExpiresIn,
}
return tokenInfo, err
}
func requestToken(client HTTPClient, url string, tokenReqPayload TokenRequestPayload) ([]byte, int, error) {
p, err := json.Marshal(tokenReqPayload)
if err != nil {
return nil, 0, fmt.Errorf("parsing token payload failed with error: %v", err)
}
payload := strings.NewReader(string(p))
req, err := http.NewRequest("POST", url, payload)
if err != nil {
return nil, 0, fmt.Errorf("creating token request failed with error: %v", err)
}
req.Header.Add("content-type", "application/json")
res, err := client.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("doing token request failed with error: %v", err)
}
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, 0, fmt.Errorf("reading token body failed with error: %v", err)
}
return body, res.StatusCode, nil
}
// isValidAccessToken is a simple validation of the access token
func isValidAccessToken(token string, audience string) error {
if token == "" {
@@ -202,7 +254,7 @@ func isValidAccessToken(token string, audience string) error {
return err
}
if claims.Audiance != audience {
if claims.Audience != audience {
return fmt.Errorf("invalid audience")
}