mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-29 21:56:40 +00:00
Rotate Access token with refresh token (#280)
Add method for rotating access token with refresh tokens This will be useful for catching expired sessions and offboarding users Also added functions to handle secrets. They have to be revisited as some tests didn't run on CI as they waited some user input, like password
This commit is contained in:
@@ -12,7 +12,10 @@ import (
|
||||
)
|
||||
|
||||
// auth0GrantType grant type for device flow on Auth0
|
||||
const auth0GrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
const (
|
||||
auth0GrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
auth0RefreshGrant = "refresh_token"
|
||||
)
|
||||
|
||||
// Auth0 client
|
||||
type Auth0 struct {
|
||||
@@ -34,9 +37,10 @@ type RequestDeviceCodePayload struct {
|
||||
|
||||
// TokenRequestPayload used for requesting the auth0 token
|
||||
type TokenRequestPayload struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
DeviceCode string `json:"device_code"`
|
||||
ClientID string `json:"client_id"`
|
||||
GrantType string `json:"grant_type"`
|
||||
DeviceCode string `json:"device_code,omitempty"`
|
||||
ClientID string `json:"client_id"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
}
|
||||
|
||||
// TokenRequestResponse used for parsing Auth0 token's response
|
||||
@@ -48,7 +52,7 @@ type TokenRequestResponse struct {
|
||||
|
||||
// Claims used when validating the access token
|
||||
type Claims struct {
|
||||
Audiance string `json:"aud"`
|
||||
Audience string `json:"aud"`
|
||||
}
|
||||
|
||||
// NewAuth0DeviceFlow returns an Auth0 OAuth client
|
||||
@@ -127,30 +131,13 @@ func (a *Auth0) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo,
|
||||
DeviceCode: info.DeviceCode,
|
||||
ClientID: a.ClientID,
|
||||
}
|
||||
p, err := json.Marshal(tokenReqPayload)
|
||||
|
||||
body, statusCode, err := requestToken(a.HTTPClient, url, tokenReqPayload)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("parsing token payload failed with error: %v", err)
|
||||
}
|
||||
payload := strings.NewReader(string(p))
|
||||
req, err := http.NewRequest("POST", url, payload)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("creating wait token request failed with error: %v", err)
|
||||
return TokenInfo{}, fmt.Errorf("wait for token: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
res, err := a.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("doing wiat request failed with error: %v", err)
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("reading toekn body failed with error: %v", err)
|
||||
}
|
||||
|
||||
if res.StatusCode > 499 {
|
||||
if statusCode > 499 {
|
||||
return TokenInfo{}, fmt.Errorf("wait token code returned error: %s", string(body))
|
||||
}
|
||||
|
||||
@@ -184,6 +171,71 @@ func (a *Auth0) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo,
|
||||
}
|
||||
}
|
||||
|
||||
// RotateAccessToken requests a new token using an existing refresh token
|
||||
func (a *Auth0) RotateAccessToken(ctx context.Context, refreshToken string) (TokenInfo, error) {
|
||||
url := "https://" + a.Domain + "/oauth/token"
|
||||
tokenReqPayload := TokenRequestPayload{
|
||||
GrantType: auth0RefreshGrant,
|
||||
ClientID: a.ClientID,
|
||||
RefreshToken: refreshToken,
|
||||
}
|
||||
|
||||
body, statusCode, err := requestToken(a.HTTPClient, url, tokenReqPayload)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("rotate access token: %v", err)
|
||||
}
|
||||
|
||||
if statusCode != 200 {
|
||||
return TokenInfo{}, fmt.Errorf("rotating token returned error: %s", string(body))
|
||||
}
|
||||
|
||||
tokenResponse := TokenRequestResponse{}
|
||||
err = json.Unmarshal(body, &tokenResponse)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
||||
}
|
||||
|
||||
err = isValidAccessToken(tokenResponse.AccessToken, a.Audience)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
||||
}
|
||||
|
||||
tokenInfo := TokenInfo{
|
||||
AccessToken: tokenResponse.AccessToken,
|
||||
TokenType: tokenResponse.TokenType,
|
||||
RefreshToken: tokenResponse.RefreshToken,
|
||||
IDToken: tokenResponse.IDToken,
|
||||
ExpiresIn: tokenResponse.ExpiresIn,
|
||||
}
|
||||
return tokenInfo, err
|
||||
}
|
||||
|
||||
func requestToken(client HTTPClient, url string, tokenReqPayload TokenRequestPayload) ([]byte, int, error) {
|
||||
p, err := json.Marshal(tokenReqPayload)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("parsing token payload failed with error: %v", err)
|
||||
}
|
||||
payload := strings.NewReader(string(p))
|
||||
req, err := http.NewRequest("POST", url, payload)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("creating token request failed with error: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("doing token request failed with error: %v", err)
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("reading token body failed with error: %v", err)
|
||||
}
|
||||
return body, res.StatusCode, nil
|
||||
}
|
||||
|
||||
// isValidAccessToken is a simple validation of the access token
|
||||
func isValidAccessToken(token string, audience string) error {
|
||||
if token == "" {
|
||||
@@ -202,7 +254,7 @@ func isValidAccessToken(token string, audience string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if claims.Audiance != audience {
|
||||
if claims.Audience != audience {
|
||||
return fmt.Errorf("invalid audience")
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user