diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index 39da2a79f..c03376f5b 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -22,6 +22,7 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/templates" + "github.com/netbirdio/netbird/shared/management/client/common" ) var _ OAuthFlow = &PKCEAuthorizationFlow{} @@ -104,11 +105,12 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn oauth2.SetAuthURLParam("audience", p.providerConfig.Audience), } if !p.providerConfig.DisablePromptLogin { - if p.providerConfig.LoginFlag.IsPromptLogin() { - params = append(params, oauth2.SetAuthURLParam("prompt", "login")) - } - if p.providerConfig.LoginFlag.IsMaxAge0Login() { + switch p.providerConfig.LoginFlag { + case common.LoginFlagPromptLogin: + params = append(params, oauth2.SetAuthURLParam("prompt", "login select_account")) + case common.LoginFlagMaxAge0: params = append(params, oauth2.SetAuthURLParam("max_age", "0")) + params = append(params, oauth2.SetAuthURLParam("prompt", "select_account")) } } if p.providerConfig.LoginHint != "" { diff --git a/client/internal/auth/pkce_flow_test.go b/client/internal/auth/pkce_flow_test.go index 380a360e5..b5843f104 100644 --- a/client/internal/auth/pkce_flow_test.go +++ b/client/internal/auth/pkce_flow_test.go @@ -15,30 +15,37 @@ import ( func TestPromptLogin(t *testing.T) { const ( - promptLogin = "prompt=login" - maxAge0 = "max_age=0" + promptSelectAccountLogin = "prompt=login+select_account" + promptSelectAccount = "prompt=select_account" + maxAge0 = "max_age=0" ) tt := []struct { name string loginFlag mgm.LoginFlag disablePromptLogin bool - expect string + expectContains []string }{ { - name: "Prompt login", - loginFlag: mgm.LoginFlagPrompt, - expect: promptLogin, + name: "Prompt login with select account", + loginFlag: mgm.LoginFlagPromptLogin, + expectContains: []string{promptSelectAccountLogin}, }, { - name: "Max age 0 login", - loginFlag: mgm.LoginFlagMaxAge0, - expect: maxAge0, + name: "Max age 0 with select account", + loginFlag: mgm.LoginFlagMaxAge0, + expectContains: []string{maxAge0, promptSelectAccount}, }, { name: "Disable prompt login", - loginFlag: mgm.LoginFlagPrompt, + loginFlag: mgm.LoginFlagPromptLogin, disablePromptLogin: true, + expectContains: []string{}, + }, + { + name: "None flag should not add parameters", + loginFlag: mgm.LoginFlagNone, + expectContains: []string{}, }, } @@ -53,6 +60,7 @@ func TestPromptLogin(t *testing.T) { RedirectURLs: []string{"http://127.0.0.1:33992/"}, UseIDToken: true, LoginFlag: tc.loginFlag, + DisablePromptLogin: tc.disablePromptLogin, } pkce, err := NewPKCEAuthorizationFlow(config) if err != nil { @@ -63,11 +71,8 @@ func TestPromptLogin(t *testing.T) { t.Fatalf("Failed to request auth info: %v", err) } - if !tc.disablePromptLogin { - require.Contains(t, authInfo.VerificationURIComplete, tc.expect) - } else { - require.Contains(t, authInfo.VerificationURIComplete, promptLogin) - require.NotContains(t, authInfo.VerificationURIComplete, maxAge0) + for _, expected := range tc.expectContains { + require.Contains(t, authInfo.VerificationURIComplete, expected) } }) } diff --git a/shared/management/client/common/types.go b/shared/management/client/common/types.go index 699617574..550bcde30 100644 --- a/shared/management/client/common/types.go +++ b/shared/management/client/common/types.go @@ -1,19 +1,20 @@ package common -// LoginFlag introduces additional login flags to the PKCE authorization request +// LoginFlag introduces additional login flags to the PKCE authorization request. +// +// # Config Values +// +// | Value | Flag | OAuth Parameters | +// |-------|----------------------|-----------------------------------------| +// | 0 | LoginFlagPromptLogin | prompt=select_account login | +// | 1 | LoginFlagMaxAge0 | max_age=0 & prompt=select_account | type LoginFlag uint8 const ( - // LoginFlagPrompt adds prompt=login to the authorization request - LoginFlagPrompt LoginFlag = iota - // LoginFlagMaxAge0 adds max_age=0 to the authorization request + // LoginFlagPromptLogin adds prompt=select_account login to the authorization request + LoginFlagPromptLogin LoginFlag = iota + // LoginFlagMaxAge0 adds max_age=0 and prompt=select_account to the authorization request LoginFlagMaxAge0 + // LoginFlagNone disables all login flags + LoginFlagNone ) - -func (l LoginFlag) IsPromptLogin() bool { - return l == LoginFlagPrompt -} - -func (l LoginFlag) IsMaxAge0Login() bool { - return l == LoginFlagMaxAge0 -}