mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 00:06:38 +00:00
Rotate Access token with refresh token (#280)
Add method for rotating access token with refresh tokens This will be useful for catching expired sessions and offboarding users Also added functions to handle secrets. They have to be revisited as some tests didn't run on CI as they waited some user input, like password
This commit is contained in:
@@ -12,7 +12,10 @@ import (
|
||||
)
|
||||
|
||||
// auth0GrantType grant type for device flow on Auth0
|
||||
const auth0GrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
const (
|
||||
auth0GrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
auth0RefreshGrant = "refresh_token"
|
||||
)
|
||||
|
||||
// Auth0 client
|
||||
type Auth0 struct {
|
||||
@@ -34,9 +37,10 @@ type RequestDeviceCodePayload struct {
|
||||
|
||||
// TokenRequestPayload used for requesting the auth0 token
|
||||
type TokenRequestPayload struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
DeviceCode string `json:"device_code"`
|
||||
ClientID string `json:"client_id"`
|
||||
GrantType string `json:"grant_type"`
|
||||
DeviceCode string `json:"device_code,omitempty"`
|
||||
ClientID string `json:"client_id"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
}
|
||||
|
||||
// TokenRequestResponse used for parsing Auth0 token's response
|
||||
@@ -48,7 +52,7 @@ type TokenRequestResponse struct {
|
||||
|
||||
// Claims used when validating the access token
|
||||
type Claims struct {
|
||||
Audiance string `json:"aud"`
|
||||
Audience string `json:"aud"`
|
||||
}
|
||||
|
||||
// NewAuth0DeviceFlow returns an Auth0 OAuth client
|
||||
@@ -127,30 +131,13 @@ func (a *Auth0) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo,
|
||||
DeviceCode: info.DeviceCode,
|
||||
ClientID: a.ClientID,
|
||||
}
|
||||
p, err := json.Marshal(tokenReqPayload)
|
||||
|
||||
body, statusCode, err := requestToken(a.HTTPClient, url, tokenReqPayload)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("parsing token payload failed with error: %v", err)
|
||||
}
|
||||
payload := strings.NewReader(string(p))
|
||||
req, err := http.NewRequest("POST", url, payload)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("creating wait token request failed with error: %v", err)
|
||||
return TokenInfo{}, fmt.Errorf("wait for token: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
res, err := a.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("doing wiat request failed with error: %v", err)
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("reading toekn body failed with error: %v", err)
|
||||
}
|
||||
|
||||
if res.StatusCode > 499 {
|
||||
if statusCode > 499 {
|
||||
return TokenInfo{}, fmt.Errorf("wait token code returned error: %s", string(body))
|
||||
}
|
||||
|
||||
@@ -184,6 +171,71 @@ func (a *Auth0) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo,
|
||||
}
|
||||
}
|
||||
|
||||
// RotateAccessToken requests a new token using an existing refresh token
|
||||
func (a *Auth0) RotateAccessToken(ctx context.Context, refreshToken string) (TokenInfo, error) {
|
||||
url := "https://" + a.Domain + "/oauth/token"
|
||||
tokenReqPayload := TokenRequestPayload{
|
||||
GrantType: auth0RefreshGrant,
|
||||
ClientID: a.ClientID,
|
||||
RefreshToken: refreshToken,
|
||||
}
|
||||
|
||||
body, statusCode, err := requestToken(a.HTTPClient, url, tokenReqPayload)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("rotate access token: %v", err)
|
||||
}
|
||||
|
||||
if statusCode != 200 {
|
||||
return TokenInfo{}, fmt.Errorf("rotating token returned error: %s", string(body))
|
||||
}
|
||||
|
||||
tokenResponse := TokenRequestResponse{}
|
||||
err = json.Unmarshal(body, &tokenResponse)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
||||
}
|
||||
|
||||
err = isValidAccessToken(tokenResponse.AccessToken, a.Audience)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
||||
}
|
||||
|
||||
tokenInfo := TokenInfo{
|
||||
AccessToken: tokenResponse.AccessToken,
|
||||
TokenType: tokenResponse.TokenType,
|
||||
RefreshToken: tokenResponse.RefreshToken,
|
||||
IDToken: tokenResponse.IDToken,
|
||||
ExpiresIn: tokenResponse.ExpiresIn,
|
||||
}
|
||||
return tokenInfo, err
|
||||
}
|
||||
|
||||
func requestToken(client HTTPClient, url string, tokenReqPayload TokenRequestPayload) ([]byte, int, error) {
|
||||
p, err := json.Marshal(tokenReqPayload)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("parsing token payload failed with error: %v", err)
|
||||
}
|
||||
payload := strings.NewReader(string(p))
|
||||
req, err := http.NewRequest("POST", url, payload)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("creating token request failed with error: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("doing token request failed with error: %v", err)
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("reading token body failed with error: %v", err)
|
||||
}
|
||||
return body, res.StatusCode, nil
|
||||
}
|
||||
|
||||
// isValidAccessToken is a simple validation of the access token
|
||||
func isValidAccessToken(token string, audience string) error {
|
||||
if token == "" {
|
||||
@@ -202,7 +254,7 @@ func isValidAccessToken(token string, audience string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if claims.Audiance != audience {
|
||||
if claims.Audience != audience {
|
||||
return fmt.Errorf("invalid audience")
|
||||
}
|
||||
|
||||
|
||||
@@ -152,7 +152,6 @@ func TestAuth0_WaitToken(t *testing.T) {
|
||||
Interval: 1,
|
||||
}
|
||||
|
||||
// payload,timeout,error 500,error conn,error response, invalid token,good req
|
||||
tokenReqPayload := TokenRequestPayload{
|
||||
GrantType: auth0GrantType,
|
||||
DeviceCode: defaultInfo.DeviceCode,
|
||||
@@ -294,3 +293,123 @@ func TestAuth0_WaitToken(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuth0_RotateAccessToken(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputResBody string
|
||||
inputReqCode int
|
||||
inputReqError error
|
||||
inputMaxReqs int
|
||||
inputInfo DeviceAuthInfo
|
||||
inputAudience string
|
||||
testingErrFunc require.ErrorAssertionFunc
|
||||
expectedErrorMSG string
|
||||
testingFunc require.ComparisonAssertionFunc
|
||||
expectedOut TokenInfo
|
||||
expectedMSG string
|
||||
expectPayload TokenRequestPayload
|
||||
}
|
||||
|
||||
defaultInfo := DeviceAuthInfo{
|
||||
DeviceCode: "test",
|
||||
ExpiresIn: 10,
|
||||
Interval: 1,
|
||||
}
|
||||
|
||||
tokenReqPayload := TokenRequestPayload{
|
||||
GrantType: auth0RefreshGrant,
|
||||
ClientID: "test",
|
||||
RefreshToken: "refresh_test",
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "Payload Is Valid",
|
||||
inputInfo: defaultInfo,
|
||||
inputReqCode: 200,
|
||||
testingErrFunc: require.Error,
|
||||
testingFunc: require.EqualValues,
|
||||
expectPayload: tokenReqPayload,
|
||||
}
|
||||
|
||||
testCase2 := test{
|
||||
name: "Exit On Network Error",
|
||||
inputInfo: defaultInfo,
|
||||
expectPayload: tokenReqPayload,
|
||||
inputReqError: fmt.Errorf("error"),
|
||||
testingErrFunc: require.Error,
|
||||
expectedErrorMSG: "should return error",
|
||||
testingFunc: require.EqualValues,
|
||||
}
|
||||
|
||||
testCase3 := test{
|
||||
name: "Exit On Non 200 Status Code",
|
||||
inputInfo: defaultInfo,
|
||||
inputReqCode: 401,
|
||||
expectPayload: tokenReqPayload,
|
||||
testingErrFunc: require.Error,
|
||||
expectedErrorMSG: "should return error",
|
||||
testingFunc: require.EqualValues,
|
||||
}
|
||||
|
||||
audience := "test"
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"aud": audience})
|
||||
var hmacSampleSecret []byte
|
||||
tokenString, _ := token.SignedString(hmacSampleSecret)
|
||||
|
||||
testCase4 := test{
|
||||
name: "Exit On Invalid Audience",
|
||||
inputInfo: defaultInfo,
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString),
|
||||
inputReqCode: 200,
|
||||
inputAudience: "super test",
|
||||
testingErrFunc: require.Error,
|
||||
testingFunc: require.EqualValues,
|
||||
expectPayload: tokenReqPayload,
|
||||
}
|
||||
|
||||
testCase5 := test{
|
||||
name: "Received Token Info",
|
||||
inputInfo: defaultInfo,
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString),
|
||||
inputReqCode: 200,
|
||||
inputAudience: audience,
|
||||
testingErrFunc: require.NoError,
|
||||
testingFunc: require.EqualValues,
|
||||
expectPayload: tokenReqPayload,
|
||||
expectedOut: TokenInfo{AccessToken: tokenString},
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
httpClient := mockHTTPClient{
|
||||
resBody: testCase.inputResBody,
|
||||
code: testCase.inputReqCode,
|
||||
err: testCase.inputReqError,
|
||||
MaxReqs: testCase.inputMaxReqs,
|
||||
}
|
||||
|
||||
auth0 := Auth0{
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: testCase.expectPayload.ClientID,
|
||||
Domain: "test.auth0.com",
|
||||
HTTPClient: &httpClient,
|
||||
}
|
||||
|
||||
tokenInfo, err := auth0.RotateAccessToken(context.TODO(), testCase.expectPayload.RefreshToken)
|
||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||
|
||||
var payload []byte
|
||||
var emptyPayload TokenRequestPayload
|
||||
if testCase.expectPayload != emptyPayload {
|
||||
payload, _ = json.Marshal(testCase.expectPayload)
|
||||
}
|
||||
require.EqualValues(t, string(payload), httpClient.reqBody, "payload should match")
|
||||
|
||||
testCase.testingFunc(t, testCase.expectedOut, tokenInfo, testCase.expectedMSG)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,5 +32,6 @@ type TokenInfo struct {
|
||||
// Client is a OAuth client interface for various idp providers
|
||||
type Client interface {
|
||||
RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
|
||||
RotateAccessToken(ctx context.Context, refreshToken string) (TokenInfo, error)
|
||||
WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error)
|
||||
}
|
||||
|
||||
66
client/internal/oauth/secret.go
Normal file
66
client/internal/oauth/secret.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/99designs/keyring"
|
||||
)
|
||||
|
||||
// ServiceName default service name for saving the secret
|
||||
const ServiceName = "Wiretrustee"
|
||||
|
||||
func newSecretAPI() (keyring.Keyring, error) {
|
||||
return keyring.Open(keyring.Config{
|
||||
ServiceName: ServiceName,
|
||||
//KeychainName: ServiceName,
|
||||
})
|
||||
}
|
||||
|
||||
// SetSecret stores the secret in the system's available backend
|
||||
func SetSecret(key string, value string) error {
|
||||
storeAPI, err := newSecretAPI()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create secret API for setting a secret, error: %v", err)
|
||||
}
|
||||
|
||||
item := keyring.Item{
|
||||
Key: key,
|
||||
Data: []byte(value),
|
||||
}
|
||||
|
||||
err = storeAPI.Set(item)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set the secret, error: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSecret retrieves a secret from the system's available backend
|
||||
func GetSecret(key string) (string, error) {
|
||||
storeAPI, err := newSecretAPI()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create secret API for getting a secret, error: %v", err)
|
||||
}
|
||||
|
||||
item, err := storeAPI.Get(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get secret, error: %v", err)
|
||||
}
|
||||
|
||||
return string(item.Data), nil
|
||||
}
|
||||
|
||||
// DeleteSecret deletes a secret from the system's available backend
|
||||
func DeleteSecret(key string) error {
|
||||
storeAPI, err := newSecretAPI()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create secret API for deleting a secret, error: %v", err)
|
||||
}
|
||||
|
||||
err = storeAPI.Remove(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete secret, error: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
27
client/internal/oauth/secret_test.go
Normal file
27
client/internal/oauth/secret_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSecret(t *testing.T) {
|
||||
// this test is not ready to run as part of our ci/cd
|
||||
// todo fix testing
|
||||
if os.Getenv("CI") == "true" {
|
||||
t.Skip("skipping testing in github actions")
|
||||
}
|
||||
|
||||
key := "testing"
|
||||
value := "1234"
|
||||
err := SetSecret(key, value)
|
||||
require.NoError(t, err, "should set secret")
|
||||
|
||||
v, err := GetSecret(key)
|
||||
require.NoError(t, err, "should retrieve secret")
|
||||
require.Equal(t, value, v, "values should match")
|
||||
|
||||
err = DeleteSecret(key)
|
||||
require.NoError(t, err, "should delete secret")
|
||||
}
|
||||
Reference in New Issue
Block a user