diff --git a/client/cmd/login.go b/client/cmd/login.go index a5cc3215c..5433db522 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -3,8 +3,6 @@ package cmd import ( "context" "fmt" - "os" - "runtime" "strings" "time" @@ -195,51 +193,12 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) { codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode) } - browserAuthMsg := "Please do the SSO login in your browser. \n" + + cmd.Println("Please do the SSO login in your browser. \n" + "If your browser didn't open automatically, use this URL to log in:\n\n" + - verificationURIComplete + " " + codeMsg - - setupKeyAuthMsg := "\nAlternatively, you may want to use a setup key, see:\n\n" + - "https://docs.netbird.io/how-to/register-machines-using-setup-keys" - - authenticateUsingBrowser := func() { - cmd.Println(browserAuthMsg) - cmd.Println("") - if err := open.Run(verificationURIComplete); err != nil { - cmd.Println(setupKeyAuthMsg) - } - } - - switch runtime.GOOS { - case "windows", "darwin": - authenticateUsingBrowser() - case "linux": - if isLinuxRunningDesktop() { - authenticateUsingBrowser() - } else { - // If current flow is PKCE, it implies the server is anticipating the redirect to localhost. - // Devices lacking browser support are incompatible with this flow.Therefore, - // these devices will need to resort to setup keys instead. - if isPKCEFlow(verificationURIComplete) { - cmd.Println("Please proceed with setting up this device using setup keys, see:\n\n" + - "https://docs.netbird.io/how-to/register-machines-using-setup-keys") - } else { - cmd.Println(browserAuthMsg) - } - } + verificationURIComplete + " " + codeMsg) + cmd.Println("") + if err := open.Run(verificationURIComplete); err != nil { + cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" + + "https://docs.netbird.io/how-to/register-machines-using-setup-keys") } } - -// isLinuxRunningDesktop checks if a Linux OS is running desktop environment. -func isLinuxRunningDesktop() bool { - return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != "" -} - -// isPKCEFlow determines if the PKCE flow is active or not, -// by checking the existence of redirect_uri inside the verification URL. -func isPKCEFlow(verificationURL string) bool { - if verificationURL == "" { - return false - } - return strings.Contains(verificationURL, "redirect_uri") -} diff --git a/client/internal/auth/oauth.go b/client/internal/auth/oauth.go index 794fe0958..8731e4f0b 100644 --- a/client/internal/auth/oauth.go +++ b/client/internal/auth/oauth.go @@ -4,8 +4,8 @@ import ( "context" "fmt" "net/http" + "runtime" - log "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" @@ -57,25 +57,43 @@ func (t TokenInfo) GetTokenToUse() string { return t.AccessToken } -// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration. +// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration +// +// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow, +// and if that also fails, the authentication process is deemed unsuccessful +// +// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { - log.Debug("loading pkce authorization flow info") - - pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) - if err == nil { - return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) + if runtime.GOOS == "linux" && !isLinuxRunningDesktop() { + return authenticateWithDeviceCodeFlow(ctx, config) } - log.Debugf("loading pkce authorization flow info failed with error: %v", err) - log.Debugf("falling back to device authorization flow info") + pkceFlow, err := authenticateWithPKCEFlow(ctx, config) + if err != nil { + // fallback to device code flow + return authenticateWithDeviceCodeFlow(ctx, config) + } + return pkceFlow, nil +} +// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow +func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { + pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) + if err != nil { + return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) + } + return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) +} + +// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow +func authenticateWithDeviceCodeFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) if err != nil { s, ok := gstatus.FromError(err) if ok && s.Code() == codes.NotFound { return nil, fmt.Errorf("no SSO provider returned from management. " + - "If you are using hosting Netbird see documentation at " + - "https://github.com/netbirdio/netbird/tree/main/management for details") + "Please proceed with setting up this device using setup keys " + + "https://docs.netbird.io/how-to/register-machines-using-setup-keys") } else if ok && s.Code() == codes.Unimplemented { return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+ "please update your server or use Setup Keys to login", config.ManagementURL) diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index d15d49373..a3d0c1309 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -12,7 +12,6 @@ import ( "net/http" "net/url" "strings" - "sync" "time" log "github.com/sirupsen/logrus" @@ -80,7 +79,7 @@ func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string { } // RequestAuthInfo requests a authorization code login flow information. -func (p *PKCEAuthorizationFlow) RequestAuthInfo(_ context.Context) (AuthFlowInfo, error) { +func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) { state, err := randomBytesInHex(24) if err != nil { return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err) @@ -114,64 +113,37 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) ( tokenChan := make(chan *oauth2.Token, 1) errChan := make(chan error, 1) - go p.startServer(tokenChan, errChan) + parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL) + if err != nil { + return TokenInfo{}, fmt.Errorf("failed to parse redirect URL: %v", err) + } + + server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())} + defer func() { + shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + if err := server.Shutdown(shutdownCtx); err != nil { + log.Errorf("failed to close the server: %v", err) + } + }() + + go p.startServer(server, tokenChan, errChan) select { case <-ctx.Done(): return TokenInfo{}, ctx.Err() case token := <-tokenChan: - return p.handleOAuthToken(token) + return p.parseOAuthToken(token) case err := <-errChan: return TokenInfo{}, err } } -func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errChan chan<- error) { - var wg sync.WaitGroup - - parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL) - if err != nil { - errChan <- fmt.Errorf("failed to parse redirect URL: %v", err) - return - } - - server := http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())} - go func() { - if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - errChan <- err - } - }() - - wg.Add(1) - http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { - defer wg.Done() - - tokenValidatorFunc := func() (*oauth2.Token, error) { - query := req.URL.Query() - - if authError := query.Get(queryError); authError != "" { - authErrorDesc := query.Get(queryErrorDesc) - return nil, fmt.Errorf("%s.%s", authError, authErrorDesc) - } - - // Prevent timing attacks on state - if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 { - return nil, fmt.Errorf("invalid state") - } - - code := query.Get(queryCode) - if code == "" { - return nil, fmt.Errorf("missing code") - } - - return p.oAuthConfig.Exchange( - req.Context(), - code, - oauth2.SetAuthURLParam("code_verifier", p.codeVerifier), - ) - } - - token, err := tokenValidatorFunc() +func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + token, err := p.handleRequest(req) if err != nil { renderPKCEFlowTmpl(w, err) errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err) @@ -182,13 +154,38 @@ func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errC tokenChan <- token }) - wg.Wait() - if err := server.Shutdown(context.Background()); err != nil { - log.Errorf("error while shutting down pkce flow server: %v", err) + server.Handler = mux + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + errChan <- err } } -func (p *PKCEAuthorizationFlow) handleOAuthToken(token *oauth2.Token) (TokenInfo, error) { +func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, error) { + query := req.URL.Query() + + if authError := query.Get(queryError); authError != "" { + authErrorDesc := query.Get(queryErrorDesc) + return nil, fmt.Errorf("%s.%s", authError, authErrorDesc) + } + + // Prevent timing attacks on the state + if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 { + return nil, fmt.Errorf("invalid state") + } + + code := query.Get(queryCode) + if code == "" { + return nil, fmt.Errorf("missing code") + } + + return p.oAuthConfig.Exchange( + req.Context(), + code, + oauth2.SetAuthURLParam("code_verifier", p.codeVerifier), + ) +} + +func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, error) { tokenInfo := TokenInfo{ AccessToken: token.AccessToken, RefreshToken: token.RefreshToken, diff --git a/client/internal/auth/util.go b/client/internal/auth/util.go index 33a0e6e35..e61e0f175 100644 --- a/client/internal/auth/util.go +++ b/client/internal/auth/util.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "io" + "os" "reflect" "strings" ) @@ -60,3 +61,8 @@ func isValidAccessToken(token string, audience string) error { return fmt.Errorf("invalid JWT token audience field") } + +// isLinuxRunningDesktop checks if a Linux OS is running desktop environment +func isLinuxRunningDesktop() bool { + return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != "" +}