Rotate Access token with refresh token (#280)

Add method for rotating access token with refresh tokens
This will be useful for catching expired sessions and
offboarding users

Also added functions to handle secrets. They have to be revisited
as some tests didn't run on CI as they waited some user input, like password
This commit is contained in:
Maycon Santos
2022-03-22 13:12:11 +01:00
committed by GitHub
parent 76db9afa11
commit a2fc4ec221
9 changed files with 329 additions and 46 deletions

View File

@@ -12,7 +12,10 @@ import (
)
// auth0GrantType grant type for device flow on Auth0
const auth0GrantType = "urn:ietf:params:oauth:grant-type:device_code"
const (
auth0GrantType = "urn:ietf:params:oauth:grant-type:device_code"
auth0RefreshGrant = "refresh_token"
)
// Auth0 client
type Auth0 struct {
@@ -34,9 +37,10 @@ type RequestDeviceCodePayload struct {
// TokenRequestPayload used for requesting the auth0 token
type TokenRequestPayload struct {
GrantType string `json:"grant_type"`
DeviceCode string `json:"device_code"`
ClientID string `json:"client_id"`
GrantType string `json:"grant_type"`
DeviceCode string `json:"device_code,omitempty"`
ClientID string `json:"client_id"`
RefreshToken string `json:"refresh_token,omitempty"`
}
// TokenRequestResponse used for parsing Auth0 token's response
@@ -48,7 +52,7 @@ type TokenRequestResponse struct {
// Claims used when validating the access token
type Claims struct {
Audiance string `json:"aud"`
Audience string `json:"aud"`
}
// NewAuth0DeviceFlow returns an Auth0 OAuth client
@@ -127,30 +131,13 @@ func (a *Auth0) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo,
DeviceCode: info.DeviceCode,
ClientID: a.ClientID,
}
p, err := json.Marshal(tokenReqPayload)
body, statusCode, err := requestToken(a.HTTPClient, url, tokenReqPayload)
if err != nil {
return TokenInfo{}, fmt.Errorf("parsing token payload failed with error: %v", err)
}
payload := strings.NewReader(string(p))
req, err := http.NewRequest("POST", url, payload)
if err != nil {
return TokenInfo{}, fmt.Errorf("creating wait token request failed with error: %v", err)
return TokenInfo{}, fmt.Errorf("wait for token: %v", err)
}
req.Header.Add("content-type", "application/json")
res, err := a.HTTPClient.Do(req)
if err != nil {
return TokenInfo{}, fmt.Errorf("doing wiat request failed with error: %v", err)
}
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return TokenInfo{}, fmt.Errorf("reading toekn body failed with error: %v", err)
}
if res.StatusCode > 499 {
if statusCode > 499 {
return TokenInfo{}, fmt.Errorf("wait token code returned error: %s", string(body))
}
@@ -184,6 +171,71 @@ func (a *Auth0) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo,
}
}
// RotateAccessToken requests a new token using an existing refresh token
func (a *Auth0) RotateAccessToken(ctx context.Context, refreshToken string) (TokenInfo, error) {
url := "https://" + a.Domain + "/oauth/token"
tokenReqPayload := TokenRequestPayload{
GrantType: auth0RefreshGrant,
ClientID: a.ClientID,
RefreshToken: refreshToken,
}
body, statusCode, err := requestToken(a.HTTPClient, url, tokenReqPayload)
if err != nil {
return TokenInfo{}, fmt.Errorf("rotate access token: %v", err)
}
if statusCode != 200 {
return TokenInfo{}, fmt.Errorf("rotating token returned error: %s", string(body))
}
tokenResponse := TokenRequestResponse{}
err = json.Unmarshal(body, &tokenResponse)
if err != nil {
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
}
err = isValidAccessToken(tokenResponse.AccessToken, a.Audience)
if err != nil {
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
}
tokenInfo := TokenInfo{
AccessToken: tokenResponse.AccessToken,
TokenType: tokenResponse.TokenType,
RefreshToken: tokenResponse.RefreshToken,
IDToken: tokenResponse.IDToken,
ExpiresIn: tokenResponse.ExpiresIn,
}
return tokenInfo, err
}
func requestToken(client HTTPClient, url string, tokenReqPayload TokenRequestPayload) ([]byte, int, error) {
p, err := json.Marshal(tokenReqPayload)
if err != nil {
return nil, 0, fmt.Errorf("parsing token payload failed with error: %v", err)
}
payload := strings.NewReader(string(p))
req, err := http.NewRequest("POST", url, payload)
if err != nil {
return nil, 0, fmt.Errorf("creating token request failed with error: %v", err)
}
req.Header.Add("content-type", "application/json")
res, err := client.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("doing token request failed with error: %v", err)
}
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, 0, fmt.Errorf("reading token body failed with error: %v", err)
}
return body, res.StatusCode, nil
}
// isValidAccessToken is a simple validation of the access token
func isValidAccessToken(token string, audience string) error {
if token == "" {
@@ -202,7 +254,7 @@ func isValidAccessToken(token string, audience string) error {
return err
}
if claims.Audiance != audience {
if claims.Audience != audience {
return fmt.Errorf("invalid audience")
}

View File

@@ -152,7 +152,6 @@ func TestAuth0_WaitToken(t *testing.T) {
Interval: 1,
}
// payload,timeout,error 500,error conn,error response, invalid token,good req
tokenReqPayload := TokenRequestPayload{
GrantType: auth0GrantType,
DeviceCode: defaultInfo.DeviceCode,
@@ -294,3 +293,123 @@ func TestAuth0_WaitToken(t *testing.T) {
})
}
}
func TestAuth0_RotateAccessToken(t *testing.T) {
type test struct {
name string
inputResBody string
inputReqCode int
inputReqError error
inputMaxReqs int
inputInfo DeviceAuthInfo
inputAudience string
testingErrFunc require.ErrorAssertionFunc
expectedErrorMSG string
testingFunc require.ComparisonAssertionFunc
expectedOut TokenInfo
expectedMSG string
expectPayload TokenRequestPayload
}
defaultInfo := DeviceAuthInfo{
DeviceCode: "test",
ExpiresIn: 10,
Interval: 1,
}
tokenReqPayload := TokenRequestPayload{
GrantType: auth0RefreshGrant,
ClientID: "test",
RefreshToken: "refresh_test",
}
testCase1 := test{
name: "Payload Is Valid",
inputInfo: defaultInfo,
inputReqCode: 200,
testingErrFunc: require.Error,
testingFunc: require.EqualValues,
expectPayload: tokenReqPayload,
}
testCase2 := test{
name: "Exit On Network Error",
inputInfo: defaultInfo,
expectPayload: tokenReqPayload,
inputReqError: fmt.Errorf("error"),
testingErrFunc: require.Error,
expectedErrorMSG: "should return error",
testingFunc: require.EqualValues,
}
testCase3 := test{
name: "Exit On Non 200 Status Code",
inputInfo: defaultInfo,
inputReqCode: 401,
expectPayload: tokenReqPayload,
testingErrFunc: require.Error,
expectedErrorMSG: "should return error",
testingFunc: require.EqualValues,
}
audience := "test"
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"aud": audience})
var hmacSampleSecret []byte
tokenString, _ := token.SignedString(hmacSampleSecret)
testCase4 := test{
name: "Exit On Invalid Audience",
inputInfo: defaultInfo,
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString),
inputReqCode: 200,
inputAudience: "super test",
testingErrFunc: require.Error,
testingFunc: require.EqualValues,
expectPayload: tokenReqPayload,
}
testCase5 := test{
name: "Received Token Info",
inputInfo: defaultInfo,
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString),
inputReqCode: 200,
inputAudience: audience,
testingErrFunc: require.NoError,
testingFunc: require.EqualValues,
expectPayload: tokenReqPayload,
expectedOut: TokenInfo{AccessToken: tokenString},
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
t.Run(testCase.name, func(t *testing.T) {
httpClient := mockHTTPClient{
resBody: testCase.inputResBody,
code: testCase.inputReqCode,
err: testCase.inputReqError,
MaxReqs: testCase.inputMaxReqs,
}
auth0 := Auth0{
Audience: testCase.inputAudience,
ClientID: testCase.expectPayload.ClientID,
Domain: "test.auth0.com",
HTTPClient: &httpClient,
}
tokenInfo, err := auth0.RotateAccessToken(context.TODO(), testCase.expectPayload.RefreshToken)
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
var payload []byte
var emptyPayload TokenRequestPayload
if testCase.expectPayload != emptyPayload {
payload, _ = json.Marshal(testCase.expectPayload)
}
require.EqualValues(t, string(payload), httpClient.reqBody, "payload should match")
testCase.testingFunc(t, testCase.expectedOut, tokenInfo, testCase.expectedMSG)
})
}
}

View File

@@ -32,5 +32,6 @@ type TokenInfo struct {
// Client is a OAuth client interface for various idp providers
type Client interface {
RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
RotateAccessToken(ctx context.Context, refreshToken string) (TokenInfo, error)
WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error)
}

View File

@@ -0,0 +1,66 @@
package oauth
import (
"fmt"
"github.com/99designs/keyring"
)
// ServiceName default service name for saving the secret
const ServiceName = "Wiretrustee"
func newSecretAPI() (keyring.Keyring, error) {
return keyring.Open(keyring.Config{
ServiceName: ServiceName,
//KeychainName: ServiceName,
})
}
// SetSecret stores the secret in the system's available backend
func SetSecret(key string, value string) error {
storeAPI, err := newSecretAPI()
if err != nil {
return fmt.Errorf("failed to create secret API for setting a secret, error: %v", err)
}
item := keyring.Item{
Key: key,
Data: []byte(value),
}
err = storeAPI.Set(item)
if err != nil {
return fmt.Errorf("failed to set the secret, error: %v", err)
}
return nil
}
// GetSecret retrieves a secret from the system's available backend
func GetSecret(key string) (string, error) {
storeAPI, err := newSecretAPI()
if err != nil {
return "", fmt.Errorf("failed to create secret API for getting a secret, error: %v", err)
}
item, err := storeAPI.Get(key)
if err != nil {
return "", fmt.Errorf("failed to get secret, error: %v", err)
}
return string(item.Data), nil
}
// DeleteSecret deletes a secret from the system's available backend
func DeleteSecret(key string) error {
storeAPI, err := newSecretAPI()
if err != nil {
return fmt.Errorf("failed to create secret API for deleting a secret, error: %v", err)
}
err = storeAPI.Remove(key)
if err != nil {
return fmt.Errorf("failed to delete secret, error: %v", err)
}
return nil
}

View File

@@ -0,0 +1,27 @@
package oauth
import (
"github.com/stretchr/testify/require"
"os"
"testing"
)
func TestSecret(t *testing.T) {
// this test is not ready to run as part of our ci/cd
// todo fix testing
if os.Getenv("CI") == "true" {
t.Skip("skipping testing in github actions")
}
key := "testing"
value := "1234"
err := SetSecret(key, value)
require.NoError(t, err, "should set secret")
v, err := GetSecret(key)
require.NoError(t, err, "should retrieve secret")
require.Equal(t, value, v, "values should match")
err = DeleteSecret(key)
require.NoError(t, err, "should delete secret")
}