Support Generic OAuth 2.0 Device Authorization Grant (#433)

Support Generic OAuth 2.0 Device Authorization Grant
as per RFC specification https://www.rfc-editor.org/rfc/rfc8628.
The previous version supported only Auth0 as an IDP backend.
This implementation enables the Interactive SSO Login feature 
for any IDP compatible with the specification, e.g., Keycloak.
This commit is contained in:
Misha Bragin
2022-08-23 15:46:12 +02:00
committed by GitHub
parent 47add9a9c3
commit 3def84b111
11 changed files with 309 additions and 322 deletions

View File

@@ -169,7 +169,8 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
hostedClient := internal.NewHostedDeviceFlow(
providerConfig.ProviderConfig.Audience,
providerConfig.ProviderConfig.ClientID,
providerConfig.ProviderConfig.Domain,
providerConfig.ProviderConfig.TokenEndpoint,
providerConfig.ProviderConfig.DeviceAuthEndpoint,
)
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())

View File

@@ -197,9 +197,14 @@ type ProviderConfig struct {
// ClientSecret An IDP application client secret
ClientSecret string
// Domain An IDP API domain
// Deprecated. Use OIDCConfigEndpoint instead
Domain string
// Audience An Audience for to authorization validation
Audience string
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
TokenEndpoint string
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
DeviceAuthEndpoint string
}
func GetDeviceAuthorizationFlowInfo(ctx context.Context, config *Config) (DeviceAuthorizationFlow, error) {
@@ -250,10 +255,12 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, config *Config) (Device
Provider: protoDeviceAuthorizationFlow.Provider.String(),
ProviderConfig: ProviderConfig{
Audience: protoDeviceAuthorizationFlow.ProviderConfig.Audience,
ClientID: protoDeviceAuthorizationFlow.ProviderConfig.ClientID,
ClientSecret: protoDeviceAuthorizationFlow.ProviderConfig.ClientSecret,
Domain: protoDeviceAuthorizationFlow.ProviderConfig.Domain,
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
},
}, nil
}

View File

@@ -5,8 +5,10 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"io"
"net/http"
"net/url"
"reflect"
"strings"
"time"
)
@@ -14,7 +16,6 @@ import (
// OAuthClient is a OAuth client interface for various idp providers
type OAuthClient interface {
RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
RotateAccessToken(ctx context.Context, refreshToken string) (TokenInfo, error)
WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error)
GetClientID(ctx context.Context) string
}
@@ -55,8 +56,10 @@ type Hosted struct {
Audience string
// Hosted Native application client id
ClientID string
// Hosted domain
Domain string
// TokenEndpoint to request access token
TokenEndpoint string
// DeviceAuthEndpoint to request device authorization code
DeviceAuthEndpoint string
HTTPClient HTTPClient
}
@@ -84,11 +87,11 @@ type TokenRequestResponse struct {
// Claims used when validating the access token
type Claims struct {
Audience string `json:"aud"`
Audience interface{} `json:"aud"`
}
// NewHostedDeviceFlow returns an Hosted OAuth client
func NewHostedDeviceFlow(audience string, clientID string, domain string) *Hosted {
func NewHostedDeviceFlow(audience string, clientID string, tokenEndpoint string, deviceAuthEndpoint string) *Hosted {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
@@ -98,10 +101,11 @@ func NewHostedDeviceFlow(audience string, clientID string, domain string) *Hoste
}
return &Hosted{
Audience: audience,
ClientID: clientID,
Domain: domain,
HTTPClient: httpClient,
Audience: audience,
ClientID: clientID,
TokenEndpoint: tokenEndpoint,
HTTPClient: httpClient,
DeviceAuthEndpoint: deviceAuthEndpoint,
}
}
@@ -112,22 +116,15 @@ func (h *Hosted) GetClientID(ctx context.Context) string {
// RequestDeviceCode requests a device code login flow information from Hosted
func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) {
url := "https://" + h.Domain + "/oauth/device/code"
codePayload := RequestDeviceCodePayload{
Audience: h.Audience,
ClientID: h.ClientID,
}
p, err := json.Marshal(codePayload)
if err != nil {
return DeviceAuthInfo{}, fmt.Errorf("parsing payload failed with error: %v", err)
}
payload := strings.NewReader(string(p))
req, err := http.NewRequest("POST", url, payload)
form := url.Values{}
form.Add("client_id", h.ClientID)
form.Add("audience", h.Audience)
req, err := http.NewRequest("POST", h.DeviceAuthEndpoint,
strings.NewReader(form.Encode()))
if err != nil {
return DeviceAuthInfo{}, fmt.Errorf("creating request failed with error: %v", err)
}
req.Header.Add("content-type", "application/json")
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
res, err := h.HTTPClient.Do(req)
if err != nil {
@@ -135,7 +132,7 @@ func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
}
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
body, err := io.ReadAll(res.Body)
if err != nil {
return DeviceAuthInfo{}, fmt.Errorf("reading body failed with error: %v", err)
}
@@ -153,6 +150,48 @@ func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
return deviceCode, err
}
func (h *Hosted) requestToken(info DeviceAuthInfo) (TokenRequestResponse, error) {
form := url.Values{}
form.Add("client_id", h.ClientID)
form.Add("grant_type", HostedGrantType)
form.Add("device_code", info.DeviceCode)
req, err := http.NewRequest("POST", h.TokenEndpoint, strings.NewReader(form.Encode()))
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err)
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
res, err := h.HTTPClient.Do(req)
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("failed to request access token with error: %v", err)
}
defer func() {
err := res.Body.Close()
if err != nil {
return
}
}()
body, err := io.ReadAll(res.Body)
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("failed reading access token response body with error: %v", err)
}
if res.StatusCode > 499 {
return TokenRequestResponse{}, fmt.Errorf("access token response returned code: %s", string(body))
}
tokenResponse := TokenRequestResponse{}
err = json.Unmarshal(body, &tokenResponse)
if err != nil {
return TokenRequestResponse{}, fmt.Errorf("parsing token response failed with error: %v", err)
}
return tokenResponse, nil
}
// WaitToken waits user's login and authorize the app. Once the user's authorize
// it retrieves the access token from Hosted's endpoint and validates it before returning
func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error) {
@@ -163,24 +202,8 @@ func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo,
case <-ctx.Done():
return TokenInfo{}, ctx.Err()
case <-ticker.C:
url := "https://" + h.Domain + "/oauth/token"
tokenReqPayload := TokenRequestPayload{
GrantType: HostedGrantType,
DeviceCode: info.DeviceCode,
ClientID: h.ClientID,
}
body, statusCode, err := requestToken(h.HTTPClient, url, tokenReqPayload)
if err != nil {
return TokenInfo{}, fmt.Errorf("wait for token: %v", err)
}
if statusCode > 499 {
return TokenInfo{}, fmt.Errorf("wait token code returned error: %s", string(body))
}
tokenResponse := TokenRequestResponse{}
err = json.Unmarshal(body, &tokenResponse)
tokenResponse, err := h.requestToken(info)
if err != nil {
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
}
@@ -214,71 +237,6 @@ func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo,
}
}
// RotateAccessToken requests a new token using an existing refresh token
func (h *Hosted) RotateAccessToken(ctx context.Context, refreshToken string) (TokenInfo, error) {
url := "https://" + h.Domain + "/oauth/token"
tokenReqPayload := TokenRequestPayload{
GrantType: HostedRefreshGrant,
ClientID: h.ClientID,
RefreshToken: refreshToken,
}
body, statusCode, err := requestToken(h.HTTPClient, url, tokenReqPayload)
if err != nil {
return TokenInfo{}, fmt.Errorf("rotate access token: %v", err)
}
if statusCode != 200 {
return TokenInfo{}, fmt.Errorf("rotating token returned error: %s", string(body))
}
tokenResponse := TokenRequestResponse{}
err = json.Unmarshal(body, &tokenResponse)
if err != nil {
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
}
err = isValidAccessToken(tokenResponse.AccessToken, h.Audience)
if err != nil {
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
}
tokenInfo := TokenInfo{
AccessToken: tokenResponse.AccessToken,
TokenType: tokenResponse.TokenType,
RefreshToken: tokenResponse.RefreshToken,
IDToken: tokenResponse.IDToken,
ExpiresIn: tokenResponse.ExpiresIn,
}
return tokenInfo, err
}
func requestToken(client HTTPClient, url string, tokenReqPayload TokenRequestPayload) ([]byte, int, error) {
p, err := json.Marshal(tokenReqPayload)
if err != nil {
return nil, 0, fmt.Errorf("parsing token payload failed with error: %v", err)
}
payload := strings.NewReader(string(p))
req, err := http.NewRequest("POST", url, payload)
if err != nil {
return nil, 0, fmt.Errorf("creating token request failed with error: %v", err)
}
req.Header.Add("content-type", "application/json")
res, err := client.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("doing token request failed with error: %v", err)
}
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, 0, fmt.Errorf("reading token body failed with error: %v", err)
}
return body, res.StatusCode, nil
}
// isValidAccessToken is a simple validation of the access token
func isValidAccessToken(token string, audience string) error {
if token == "" {
@@ -297,9 +255,24 @@ func isValidAccessToken(token string, audience string) error {
return err
}
if claims.Audience != audience {
return fmt.Errorf("invalid audience")
if claims.Audience == nil {
return fmt.Errorf("required token field audience is absent")
}
return nil
// Audience claim of JWT can be a string or an array of strings
typ := reflect.TypeOf(claims.Audience)
switch typ.Kind() {
case reflect.String:
if claims.Audience == audience {
return nil
}
case reflect.Slice:
for _, aud := range claims.Audience.([]interface{}) {
if audience == aud {
return nil
}
}
}
return fmt.Errorf("invalid JWT token audience field")
}

View File

@@ -2,12 +2,12 @@ package internal
import (
"context"
"encoding/json"
"fmt"
"github.com/golang-jwt/jwt"
"github.com/stretchr/testify/require"
"io/ioutil"
"io"
"net/http"
"net/url"
"strings"
"testing"
"time"
@@ -24,7 +24,7 @@ type mockHTTPClient struct {
}
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
body, err := ioutil.ReadAll(req.Body)
body, err := io.ReadAll(req.Body)
if err == nil {
c.reqBody = string(body)
}
@@ -33,13 +33,13 @@ func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
c.count++
return &http.Response{
StatusCode: c.code,
Body: ioutil.NopCloser(strings.NewReader(c.countResBody)),
Body: io.NopCloser(strings.NewReader(c.countResBody)),
}, c.err
}
return &http.Response{
StatusCode: c.code,
Body: ioutil.NopCloser(strings.NewReader(c.resBody)),
Body: io.NopCloser(strings.NewReader(c.resBody)),
}, c.err
}
@@ -54,15 +54,19 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
testingFunc require.ComparisonAssertionFunc
expectedOut DeviceAuthInfo
expectedMSG string
expectPayload RequestDeviceCodePayload
expectPayload string
}
expectedAudience := "ok"
expectedClientID := "bla"
form := url.Values{}
form.Add("audience", expectedAudience)
form.Add("client_id", expectedClientID)
expectPayload := form.Encode()
testCase1 := test{
name: "Payload Is Valid",
expectPayload: RequestDeviceCodePayload{
Audience: "ok",
ClientID: "bla",
},
name: "Payload Is Valid",
expectPayload: expectPayload,
inputReqCode: 200,
testingErrFunc: require.Error,
testingFunc: require.EqualValues,
@@ -74,6 +78,7 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
testingErrFunc: require.Error,
expectedErrorMSG: "should return error",
testingFunc: require.EqualValues,
expectPayload: expectPayload,
}
testCase3 := test{
@@ -82,15 +87,13 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
testingErrFunc: require.Error,
expectedErrorMSG: "should return error",
testingFunc: require.EqualValues,
expectPayload: expectPayload,
}
testCase4Out := DeviceAuthInfo{ExpiresIn: 10}
testCase4 := test{
name: "Got Device Code",
inputResBody: fmt.Sprintf("{\"expires_in\":%d}", testCase4Out.ExpiresIn),
expectPayload: RequestDeviceCodePayload{
Audience: "ok",
ClientID: "bla",
},
name: "Got Device Code",
inputResBody: fmt.Sprintf("{\"expires_in\":%d}", testCase4Out.ExpiresIn),
expectPayload: expectPayload,
inputReqCode: 200,
testingErrFunc: require.NoError,
testingFunc: require.EqualValues,
@@ -108,18 +111,17 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
}
hosted := Hosted{
Audience: testCase.expectPayload.Audience,
ClientID: testCase.expectPayload.ClientID,
Domain: "test.hosted.com",
HTTPClient: &httpClient,
Audience: expectedAudience,
ClientID: expectedClientID,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
HTTPClient: &httpClient,
}
authInfo, err := hosted.RequestDeviceCode(context.TODO())
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
payload, _ := json.Marshal(testCase.expectPayload)
require.EqualValues(t, string(payload), httpClient.reqBody, "payload should match")
require.EqualValues(t, expectPayload, httpClient.reqBody, "payload should match")
testCase.testingFunc(t, testCase.expectedOut, authInfo, testCase.expectedMSG)
@@ -143,7 +145,7 @@ func TestHosted_WaitToken(t *testing.T) {
testingFunc require.ComparisonAssertionFunc
expectedOut TokenInfo
expectedMSG string
expectPayload TokenRequestPayload
expectPayload string
}
defaultInfo := DeviceAuthInfo{
@@ -152,11 +154,13 @@ func TestHosted_WaitToken(t *testing.T) {
Interval: 1,
}
tokenReqPayload := TokenRequestPayload{
GrantType: HostedGrantType,
DeviceCode: defaultInfo.DeviceCode,
ClientID: "test",
}
clientID := "test"
form := url.Values{}
form.Add("grant_type", HostedGrantType)
form.Add("device_code", defaultInfo.DeviceCode)
form.Add("client_id", clientID)
tokenReqPayload := form.Encode()
testCase1 := test{
name: "Payload Is Valid",
@@ -268,10 +272,11 @@ func TestHosted_WaitToken(t *testing.T) {
}
hosted := Hosted{
Audience: testCase.inputAudience,
ClientID: testCase.expectPayload.ClientID,
Domain: "test.hosted.com",
HTTPClient: &httpClient,
Audience: testCase.inputAudience,
ClientID: clientID,
TokenEndpoint: "test.hosted.com/token",
DeviceAuthEndpoint: "test.hosted.com/device/auth",
HTTPClient: &httpClient,
}
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
@@ -279,12 +284,7 @@ func TestHosted_WaitToken(t *testing.T) {
tokenInfo, err := hosted.WaitToken(ctx, testCase.inputInfo)
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
var payload []byte
var emptyPayload TokenRequestPayload
if testCase.expectPayload != emptyPayload {
payload, _ = json.Marshal(testCase.expectPayload)
}
require.EqualValues(t, string(payload), httpClient.reqBody, "payload should match")
require.EqualValues(t, testCase.expectPayload, httpClient.reqBody, "payload should match")
testCase.testingFunc(t, testCase.expectedOut, tokenInfo, testCase.expectedMSG)
@@ -293,123 +293,3 @@ func TestHosted_WaitToken(t *testing.T) {
})
}
}
func TestHosted_RotateAccessToken(t *testing.T) {
type test struct {
name string
inputResBody string
inputReqCode int
inputReqError error
inputMaxReqs int
inputInfo DeviceAuthInfo
inputAudience string
testingErrFunc require.ErrorAssertionFunc
expectedErrorMSG string
testingFunc require.ComparisonAssertionFunc
expectedOut TokenInfo
expectedMSG string
expectPayload TokenRequestPayload
}
defaultInfo := DeviceAuthInfo{
DeviceCode: "test",
ExpiresIn: 10,
Interval: 1,
}
tokenReqPayload := TokenRequestPayload{
GrantType: HostedRefreshGrant,
ClientID: "test",
RefreshToken: "refresh_test",
}
testCase1 := test{
name: "Payload Is Valid",
inputInfo: defaultInfo,
inputReqCode: 200,
testingErrFunc: require.Error,
testingFunc: require.EqualValues,
expectPayload: tokenReqPayload,
}
testCase2 := test{
name: "Exit On Network Error",
inputInfo: defaultInfo,
expectPayload: tokenReqPayload,
inputReqError: fmt.Errorf("error"),
testingErrFunc: require.Error,
expectedErrorMSG: "should return error",
testingFunc: require.EqualValues,
}
testCase3 := test{
name: "Exit On Non 200 Status Code",
inputInfo: defaultInfo,
inputReqCode: 401,
expectPayload: tokenReqPayload,
testingErrFunc: require.Error,
expectedErrorMSG: "should return error",
testingFunc: require.EqualValues,
}
audience := "test"
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"aud": audience})
var hmacSampleSecret []byte
tokenString, _ := token.SignedString(hmacSampleSecret)
testCase4 := test{
name: "Exit On Invalid Audience",
inputInfo: defaultInfo,
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString),
inputReqCode: 200,
inputAudience: "super test",
testingErrFunc: require.Error,
testingFunc: require.EqualValues,
expectPayload: tokenReqPayload,
}
testCase5 := test{
name: "Received Token Info",
inputInfo: defaultInfo,
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString),
inputReqCode: 200,
inputAudience: audience,
testingErrFunc: require.NoError,
testingFunc: require.EqualValues,
expectPayload: tokenReqPayload,
expectedOut: TokenInfo{AccessToken: tokenString},
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
t.Run(testCase.name, func(t *testing.T) {
httpClient := mockHTTPClient{
resBody: testCase.inputResBody,
code: testCase.inputReqCode,
err: testCase.inputReqError,
MaxReqs: testCase.inputMaxReqs,
}
hosted := Hosted{
Audience: testCase.inputAudience,
ClientID: testCase.expectPayload.ClientID,
Domain: "test.hosted.com",
HTTPClient: &httpClient,
}
tokenInfo, err := hosted.RotateAccessToken(context.TODO(), testCase.expectPayload.RefreshToken)
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
var payload []byte
var emptyPayload TokenRequestPayload
if testCase.expectPayload != emptyPayload {
payload, _ = json.Marshal(testCase.expectPayload)
}
require.EqualValues(t, string(payload), httpClient.reqBody, "payload should match")
testCase.testingFunc(t, testCase.expectedOut, tokenInfo, testCase.expectedMSG)
})
}
}

View File

@@ -1,9 +1,8 @@
package main
import (
"os"
"github.com/netbirdio/netbird/client/cmd"
"os"
)
func main() {

View File

@@ -208,7 +208,8 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
hostedClient := internal.NewHostedDeviceFlow(
providerConfig.ProviderConfig.Audience,
providerConfig.ProviderConfig.ClientID,
providerConfig.ProviderConfig.Domain,
providerConfig.ProviderConfig.TokenEndpoint,
providerConfig.ProviderConfig.DeviceAuthEndpoint,
)
if s.oauthAuthFlow.client != nil && s.oauthAuthFlow.client.GetClientID(ctx) == hostedClient.GetClientID(context.TODO()) {