mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
Support Generic OAuth 2.0 Device Authorization Grant (#433)
Support Generic OAuth 2.0 Device Authorization Grant as per RFC specification https://www.rfc-editor.org/rfc/rfc8628. The previous version supported only Auth0 as an IDP backend. This implementation enables the Interactive SSO Login feature for any IDP compatible with the specification, e.g., Keycloak.
This commit is contained in:
@@ -169,7 +169,8 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
||||
hostedClient := internal.NewHostedDeviceFlow(
|
||||
providerConfig.ProviderConfig.Audience,
|
||||
providerConfig.ProviderConfig.ClientID,
|
||||
providerConfig.ProviderConfig.Domain,
|
||||
providerConfig.ProviderConfig.TokenEndpoint,
|
||||
providerConfig.ProviderConfig.DeviceAuthEndpoint,
|
||||
)
|
||||
|
||||
flowInfo, err := hostedClient.RequestDeviceCode(context.TODO())
|
||||
|
||||
@@ -197,9 +197,14 @@ type ProviderConfig struct {
|
||||
// ClientSecret An IDP application client secret
|
||||
ClientSecret string
|
||||
// Domain An IDP API domain
|
||||
// Deprecated. Use OIDCConfigEndpoint instead
|
||||
Domain string
|
||||
// Audience An Audience for to authorization validation
|
||||
Audience string
|
||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||
TokenEndpoint string
|
||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||
DeviceAuthEndpoint string
|
||||
}
|
||||
|
||||
func GetDeviceAuthorizationFlowInfo(ctx context.Context, config *Config) (DeviceAuthorizationFlow, error) {
|
||||
@@ -250,10 +255,12 @@ func GetDeviceAuthorizationFlowInfo(ctx context.Context, config *Config) (Device
|
||||
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
||||
|
||||
ProviderConfig: ProviderConfig{
|
||||
Audience: protoDeviceAuthorizationFlow.ProviderConfig.Audience,
|
||||
ClientID: protoDeviceAuthorizationFlow.ProviderConfig.ClientID,
|
||||
ClientSecret: protoDeviceAuthorizationFlow.ProviderConfig.ClientSecret,
|
||||
Domain: protoDeviceAuthorizationFlow.ProviderConfig.Domain,
|
||||
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
||||
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
||||
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
||||
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
|
||||
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
||||
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -5,8 +5,10 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -14,7 +16,6 @@ import (
|
||||
// OAuthClient is a OAuth client interface for various idp providers
|
||||
type OAuthClient interface {
|
||||
RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
|
||||
RotateAccessToken(ctx context.Context, refreshToken string) (TokenInfo, error)
|
||||
WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error)
|
||||
GetClientID(ctx context.Context) string
|
||||
}
|
||||
@@ -55,8 +56,10 @@ type Hosted struct {
|
||||
Audience string
|
||||
// Hosted Native application client id
|
||||
ClientID string
|
||||
// Hosted domain
|
||||
Domain string
|
||||
// TokenEndpoint to request access token
|
||||
TokenEndpoint string
|
||||
// DeviceAuthEndpoint to request device authorization code
|
||||
DeviceAuthEndpoint string
|
||||
|
||||
HTTPClient HTTPClient
|
||||
}
|
||||
@@ -84,11 +87,11 @@ type TokenRequestResponse struct {
|
||||
|
||||
// Claims used when validating the access token
|
||||
type Claims struct {
|
||||
Audience string `json:"aud"`
|
||||
Audience interface{} `json:"aud"`
|
||||
}
|
||||
|
||||
// NewHostedDeviceFlow returns an Hosted OAuth client
|
||||
func NewHostedDeviceFlow(audience string, clientID string, domain string) *Hosted {
|
||||
func NewHostedDeviceFlow(audience string, clientID string, tokenEndpoint string, deviceAuthEndpoint string) *Hosted {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
@@ -98,10 +101,11 @@ func NewHostedDeviceFlow(audience string, clientID string, domain string) *Hoste
|
||||
}
|
||||
|
||||
return &Hosted{
|
||||
Audience: audience,
|
||||
ClientID: clientID,
|
||||
Domain: domain,
|
||||
HTTPClient: httpClient,
|
||||
Audience: audience,
|
||||
ClientID: clientID,
|
||||
TokenEndpoint: tokenEndpoint,
|
||||
HTTPClient: httpClient,
|
||||
DeviceAuthEndpoint: deviceAuthEndpoint,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,22 +116,15 @@ func (h *Hosted) GetClientID(ctx context.Context) string {
|
||||
|
||||
// RequestDeviceCode requests a device code login flow information from Hosted
|
||||
func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error) {
|
||||
url := "https://" + h.Domain + "/oauth/device/code"
|
||||
codePayload := RequestDeviceCodePayload{
|
||||
Audience: h.Audience,
|
||||
ClientID: h.ClientID,
|
||||
}
|
||||
p, err := json.Marshal(codePayload)
|
||||
if err != nil {
|
||||
return DeviceAuthInfo{}, fmt.Errorf("parsing payload failed with error: %v", err)
|
||||
}
|
||||
payload := strings.NewReader(string(p))
|
||||
req, err := http.NewRequest("POST", url, payload)
|
||||
form := url.Values{}
|
||||
form.Add("client_id", h.ClientID)
|
||||
form.Add("audience", h.Audience)
|
||||
req, err := http.NewRequest("POST", h.DeviceAuthEndpoint,
|
||||
strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return DeviceAuthInfo{}, fmt.Errorf("creating request failed with error: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Add("content-type", "application/json")
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
res, err := h.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -135,7 +132,7 @@ func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return DeviceAuthInfo{}, fmt.Errorf("reading body failed with error: %v", err)
|
||||
}
|
||||
@@ -153,6 +150,48 @@ func (h *Hosted) RequestDeviceCode(ctx context.Context) (DeviceAuthInfo, error)
|
||||
return deviceCode, err
|
||||
}
|
||||
|
||||
func (h *Hosted) requestToken(info DeviceAuthInfo) (TokenRequestResponse, error) {
|
||||
form := url.Values{}
|
||||
form.Add("client_id", h.ClientID)
|
||||
form.Add("grant_type", HostedGrantType)
|
||||
form.Add("device_code", info.DeviceCode)
|
||||
req, err := http.NewRequest("POST", h.TokenEndpoint, strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return TokenRequestResponse{}, fmt.Errorf("failed to create request access token: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
res, err := h.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return TokenRequestResponse{}, fmt.Errorf("failed to request access token with error: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := res.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return TokenRequestResponse{}, fmt.Errorf("failed reading access token response body with error: %v", err)
|
||||
}
|
||||
|
||||
if res.StatusCode > 499 {
|
||||
return TokenRequestResponse{}, fmt.Errorf("access token response returned code: %s", string(body))
|
||||
}
|
||||
|
||||
tokenResponse := TokenRequestResponse{}
|
||||
err = json.Unmarshal(body, &tokenResponse)
|
||||
if err != nil {
|
||||
return TokenRequestResponse{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
||||
}
|
||||
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
||||
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
||||
func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo, error) {
|
||||
@@ -163,24 +202,8 @@ func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo,
|
||||
case <-ctx.Done():
|
||||
return TokenInfo{}, ctx.Err()
|
||||
case <-ticker.C:
|
||||
url := "https://" + h.Domain + "/oauth/token"
|
||||
tokenReqPayload := TokenRequestPayload{
|
||||
GrantType: HostedGrantType,
|
||||
DeviceCode: info.DeviceCode,
|
||||
ClientID: h.ClientID,
|
||||
}
|
||||
|
||||
body, statusCode, err := requestToken(h.HTTPClient, url, tokenReqPayload)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("wait for token: %v", err)
|
||||
}
|
||||
|
||||
if statusCode > 499 {
|
||||
return TokenInfo{}, fmt.Errorf("wait token code returned error: %s", string(body))
|
||||
}
|
||||
|
||||
tokenResponse := TokenRequestResponse{}
|
||||
err = json.Unmarshal(body, &tokenResponse)
|
||||
tokenResponse, err := h.requestToken(info)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
||||
}
|
||||
@@ -214,71 +237,6 @@ func (h *Hosted) WaitToken(ctx context.Context, info DeviceAuthInfo) (TokenInfo,
|
||||
}
|
||||
}
|
||||
|
||||
// RotateAccessToken requests a new token using an existing refresh token
|
||||
func (h *Hosted) RotateAccessToken(ctx context.Context, refreshToken string) (TokenInfo, error) {
|
||||
url := "https://" + h.Domain + "/oauth/token"
|
||||
tokenReqPayload := TokenRequestPayload{
|
||||
GrantType: HostedRefreshGrant,
|
||||
ClientID: h.ClientID,
|
||||
RefreshToken: refreshToken,
|
||||
}
|
||||
|
||||
body, statusCode, err := requestToken(h.HTTPClient, url, tokenReqPayload)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("rotate access token: %v", err)
|
||||
}
|
||||
|
||||
if statusCode != 200 {
|
||||
return TokenInfo{}, fmt.Errorf("rotating token returned error: %s", string(body))
|
||||
}
|
||||
|
||||
tokenResponse := TokenRequestResponse{}
|
||||
err = json.Unmarshal(body, &tokenResponse)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("parsing token response failed with error: %v", err)
|
||||
}
|
||||
|
||||
err = isValidAccessToken(tokenResponse.AccessToken, h.Audience)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err)
|
||||
}
|
||||
|
||||
tokenInfo := TokenInfo{
|
||||
AccessToken: tokenResponse.AccessToken,
|
||||
TokenType: tokenResponse.TokenType,
|
||||
RefreshToken: tokenResponse.RefreshToken,
|
||||
IDToken: tokenResponse.IDToken,
|
||||
ExpiresIn: tokenResponse.ExpiresIn,
|
||||
}
|
||||
return tokenInfo, err
|
||||
}
|
||||
|
||||
func requestToken(client HTTPClient, url string, tokenReqPayload TokenRequestPayload) ([]byte, int, error) {
|
||||
p, err := json.Marshal(tokenReqPayload)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("parsing token payload failed with error: %v", err)
|
||||
}
|
||||
payload := strings.NewReader(string(p))
|
||||
req, err := http.NewRequest("POST", url, payload)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("creating token request failed with error: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("doing token request failed with error: %v", err)
|
||||
}
|
||||
|
||||
defer res.Body.Close()
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("reading token body failed with error: %v", err)
|
||||
}
|
||||
return body, res.StatusCode, nil
|
||||
}
|
||||
|
||||
// isValidAccessToken is a simple validation of the access token
|
||||
func isValidAccessToken(token string, audience string) error {
|
||||
if token == "" {
|
||||
@@ -297,9 +255,24 @@ func isValidAccessToken(token string, audience string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if claims.Audience != audience {
|
||||
return fmt.Errorf("invalid audience")
|
||||
if claims.Audience == nil {
|
||||
return fmt.Errorf("required token field audience is absent")
|
||||
}
|
||||
|
||||
return nil
|
||||
// Audience claim of JWT can be a string or an array of strings
|
||||
typ := reflect.TypeOf(claims.Audience)
|
||||
switch typ.Kind() {
|
||||
case reflect.String:
|
||||
if claims.Audience == audience {
|
||||
return nil
|
||||
}
|
||||
case reflect.Slice:
|
||||
for _, aud := range claims.Audience.([]interface{}) {
|
||||
if audience == aud {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid JWT token audience field")
|
||||
}
|
||||
|
||||
@@ -2,12 +2,12 @@ package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -24,7 +24,7 @@ type mockHTTPClient struct {
|
||||
}
|
||||
|
||||
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||
body, err := ioutil.ReadAll(req.Body)
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err == nil {
|
||||
c.reqBody = string(body)
|
||||
}
|
||||
@@ -33,13 +33,13 @@ func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||
c.count++
|
||||
return &http.Response{
|
||||
StatusCode: c.code,
|
||||
Body: ioutil.NopCloser(strings.NewReader(c.countResBody)),
|
||||
Body: io.NopCloser(strings.NewReader(c.countResBody)),
|
||||
}, c.err
|
||||
}
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: c.code,
|
||||
Body: ioutil.NopCloser(strings.NewReader(c.resBody)),
|
||||
Body: io.NopCloser(strings.NewReader(c.resBody)),
|
||||
}, c.err
|
||||
}
|
||||
|
||||
@@ -54,15 +54,19 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
||||
testingFunc require.ComparisonAssertionFunc
|
||||
expectedOut DeviceAuthInfo
|
||||
expectedMSG string
|
||||
expectPayload RequestDeviceCodePayload
|
||||
expectPayload string
|
||||
}
|
||||
|
||||
expectedAudience := "ok"
|
||||
expectedClientID := "bla"
|
||||
form := url.Values{}
|
||||
form.Add("audience", expectedAudience)
|
||||
form.Add("client_id", expectedClientID)
|
||||
expectPayload := form.Encode()
|
||||
|
||||
testCase1 := test{
|
||||
name: "Payload Is Valid",
|
||||
expectPayload: RequestDeviceCodePayload{
|
||||
Audience: "ok",
|
||||
ClientID: "bla",
|
||||
},
|
||||
name: "Payload Is Valid",
|
||||
expectPayload: expectPayload,
|
||||
inputReqCode: 200,
|
||||
testingErrFunc: require.Error,
|
||||
testingFunc: require.EqualValues,
|
||||
@@ -74,6 +78,7 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
||||
testingErrFunc: require.Error,
|
||||
expectedErrorMSG: "should return error",
|
||||
testingFunc: require.EqualValues,
|
||||
expectPayload: expectPayload,
|
||||
}
|
||||
|
||||
testCase3 := test{
|
||||
@@ -82,15 +87,13 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
||||
testingErrFunc: require.Error,
|
||||
expectedErrorMSG: "should return error",
|
||||
testingFunc: require.EqualValues,
|
||||
expectPayload: expectPayload,
|
||||
}
|
||||
testCase4Out := DeviceAuthInfo{ExpiresIn: 10}
|
||||
testCase4 := test{
|
||||
name: "Got Device Code",
|
||||
inputResBody: fmt.Sprintf("{\"expires_in\":%d}", testCase4Out.ExpiresIn),
|
||||
expectPayload: RequestDeviceCodePayload{
|
||||
Audience: "ok",
|
||||
ClientID: "bla",
|
||||
},
|
||||
name: "Got Device Code",
|
||||
inputResBody: fmt.Sprintf("{\"expires_in\":%d}", testCase4Out.ExpiresIn),
|
||||
expectPayload: expectPayload,
|
||||
inputReqCode: 200,
|
||||
testingErrFunc: require.NoError,
|
||||
testingFunc: require.EqualValues,
|
||||
@@ -108,18 +111,17 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
||||
}
|
||||
|
||||
hosted := Hosted{
|
||||
Audience: testCase.expectPayload.Audience,
|
||||
ClientID: testCase.expectPayload.ClientID,
|
||||
Domain: "test.hosted.com",
|
||||
HTTPClient: &httpClient,
|
||||
Audience: expectedAudience,
|
||||
ClientID: expectedClientID,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
HTTPClient: &httpClient,
|
||||
}
|
||||
|
||||
authInfo, err := hosted.RequestDeviceCode(context.TODO())
|
||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||
|
||||
payload, _ := json.Marshal(testCase.expectPayload)
|
||||
|
||||
require.EqualValues(t, string(payload), httpClient.reqBody, "payload should match")
|
||||
require.EqualValues(t, expectPayload, httpClient.reqBody, "payload should match")
|
||||
|
||||
testCase.testingFunc(t, testCase.expectedOut, authInfo, testCase.expectedMSG)
|
||||
|
||||
@@ -143,7 +145,7 @@ func TestHosted_WaitToken(t *testing.T) {
|
||||
testingFunc require.ComparisonAssertionFunc
|
||||
expectedOut TokenInfo
|
||||
expectedMSG string
|
||||
expectPayload TokenRequestPayload
|
||||
expectPayload string
|
||||
}
|
||||
|
||||
defaultInfo := DeviceAuthInfo{
|
||||
@@ -152,11 +154,13 @@ func TestHosted_WaitToken(t *testing.T) {
|
||||
Interval: 1,
|
||||
}
|
||||
|
||||
tokenReqPayload := TokenRequestPayload{
|
||||
GrantType: HostedGrantType,
|
||||
DeviceCode: defaultInfo.DeviceCode,
|
||||
ClientID: "test",
|
||||
}
|
||||
clientID := "test"
|
||||
|
||||
form := url.Values{}
|
||||
form.Add("grant_type", HostedGrantType)
|
||||
form.Add("device_code", defaultInfo.DeviceCode)
|
||||
form.Add("client_id", clientID)
|
||||
tokenReqPayload := form.Encode()
|
||||
|
||||
testCase1 := test{
|
||||
name: "Payload Is Valid",
|
||||
@@ -268,10 +272,11 @@ func TestHosted_WaitToken(t *testing.T) {
|
||||
}
|
||||
|
||||
hosted := Hosted{
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: testCase.expectPayload.ClientID,
|
||||
Domain: "test.hosted.com",
|
||||
HTTPClient: &httpClient,
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: clientID,
|
||||
TokenEndpoint: "test.hosted.com/token",
|
||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||
HTTPClient: &httpClient,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
||||
@@ -279,12 +284,7 @@ func TestHosted_WaitToken(t *testing.T) {
|
||||
tokenInfo, err := hosted.WaitToken(ctx, testCase.inputInfo)
|
||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||
|
||||
var payload []byte
|
||||
var emptyPayload TokenRequestPayload
|
||||
if testCase.expectPayload != emptyPayload {
|
||||
payload, _ = json.Marshal(testCase.expectPayload)
|
||||
}
|
||||
require.EqualValues(t, string(payload), httpClient.reqBody, "payload should match")
|
||||
require.EqualValues(t, testCase.expectPayload, httpClient.reqBody, "payload should match")
|
||||
|
||||
testCase.testingFunc(t, testCase.expectedOut, tokenInfo, testCase.expectedMSG)
|
||||
|
||||
@@ -293,123 +293,3 @@ func TestHosted_WaitToken(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHosted_RotateAccessToken(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputResBody string
|
||||
inputReqCode int
|
||||
inputReqError error
|
||||
inputMaxReqs int
|
||||
inputInfo DeviceAuthInfo
|
||||
inputAudience string
|
||||
testingErrFunc require.ErrorAssertionFunc
|
||||
expectedErrorMSG string
|
||||
testingFunc require.ComparisonAssertionFunc
|
||||
expectedOut TokenInfo
|
||||
expectedMSG string
|
||||
expectPayload TokenRequestPayload
|
||||
}
|
||||
|
||||
defaultInfo := DeviceAuthInfo{
|
||||
DeviceCode: "test",
|
||||
ExpiresIn: 10,
|
||||
Interval: 1,
|
||||
}
|
||||
|
||||
tokenReqPayload := TokenRequestPayload{
|
||||
GrantType: HostedRefreshGrant,
|
||||
ClientID: "test",
|
||||
RefreshToken: "refresh_test",
|
||||
}
|
||||
|
||||
testCase1 := test{
|
||||
name: "Payload Is Valid",
|
||||
inputInfo: defaultInfo,
|
||||
inputReqCode: 200,
|
||||
testingErrFunc: require.Error,
|
||||
testingFunc: require.EqualValues,
|
||||
expectPayload: tokenReqPayload,
|
||||
}
|
||||
|
||||
testCase2 := test{
|
||||
name: "Exit On Network Error",
|
||||
inputInfo: defaultInfo,
|
||||
expectPayload: tokenReqPayload,
|
||||
inputReqError: fmt.Errorf("error"),
|
||||
testingErrFunc: require.Error,
|
||||
expectedErrorMSG: "should return error",
|
||||
testingFunc: require.EqualValues,
|
||||
}
|
||||
|
||||
testCase3 := test{
|
||||
name: "Exit On Non 200 Status Code",
|
||||
inputInfo: defaultInfo,
|
||||
inputReqCode: 401,
|
||||
expectPayload: tokenReqPayload,
|
||||
testingErrFunc: require.Error,
|
||||
expectedErrorMSG: "should return error",
|
||||
testingFunc: require.EqualValues,
|
||||
}
|
||||
|
||||
audience := "test"
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"aud": audience})
|
||||
var hmacSampleSecret []byte
|
||||
tokenString, _ := token.SignedString(hmacSampleSecret)
|
||||
|
||||
testCase4 := test{
|
||||
name: "Exit On Invalid Audience",
|
||||
inputInfo: defaultInfo,
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString),
|
||||
inputReqCode: 200,
|
||||
inputAudience: "super test",
|
||||
testingErrFunc: require.Error,
|
||||
testingFunc: require.EqualValues,
|
||||
expectPayload: tokenReqPayload,
|
||||
}
|
||||
|
||||
testCase5 := test{
|
||||
name: "Received Token Info",
|
||||
inputInfo: defaultInfo,
|
||||
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString),
|
||||
inputReqCode: 200,
|
||||
inputAudience: audience,
|
||||
testingErrFunc: require.NoError,
|
||||
testingFunc: require.EqualValues,
|
||||
expectPayload: tokenReqPayload,
|
||||
expectedOut: TokenInfo{AccessToken: tokenString},
|
||||
}
|
||||
|
||||
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
|
||||
httpClient := mockHTTPClient{
|
||||
resBody: testCase.inputResBody,
|
||||
code: testCase.inputReqCode,
|
||||
err: testCase.inputReqError,
|
||||
MaxReqs: testCase.inputMaxReqs,
|
||||
}
|
||||
|
||||
hosted := Hosted{
|
||||
Audience: testCase.inputAudience,
|
||||
ClientID: testCase.expectPayload.ClientID,
|
||||
Domain: "test.hosted.com",
|
||||
HTTPClient: &httpClient,
|
||||
}
|
||||
|
||||
tokenInfo, err := hosted.RotateAccessToken(context.TODO(), testCase.expectPayload.RefreshToken)
|
||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||
|
||||
var payload []byte
|
||||
var emptyPayload TokenRequestPayload
|
||||
if testCase.expectPayload != emptyPayload {
|
||||
payload, _ = json.Marshal(testCase.expectPayload)
|
||||
}
|
||||
require.EqualValues(t, string(payload), httpClient.reqBody, "payload should match")
|
||||
|
||||
testCase.testingFunc(t, testCase.expectedOut, tokenInfo, testCase.expectedMSG)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/netbirdio/netbird/client/cmd"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
@@ -208,7 +208,8 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
hostedClient := internal.NewHostedDeviceFlow(
|
||||
providerConfig.ProviderConfig.Audience,
|
||||
providerConfig.ProviderConfig.ClientID,
|
||||
providerConfig.ProviderConfig.Domain,
|
||||
providerConfig.ProviderConfig.TokenEndpoint,
|
||||
providerConfig.ProviderConfig.DeviceAuthEndpoint,
|
||||
)
|
||||
|
||||
if s.oauthAuthFlow.client != nil && s.oauthAuthFlow.client.GetClientID(ctx) == hostedClient.GetClientID(context.TODO()) {
|
||||
|
||||
Reference in New Issue
Block a user