diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index 9451ce055..482b137b8 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -6,6 +6,7 @@ import ( "crypto/subtle" "encoding/base64" "fmt" + "html/template" "net" "net/http" "net/url" @@ -16,6 +17,7 @@ import ( "golang.org/x/oauth2" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/templates" ) var _ OAuthFlow = &PKCEAuthorizationFlow{} @@ -136,33 +138,35 @@ func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errC }() http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { - query := req.URL.Query() + tokenValidatorFunc := func() (*oauth2.Token, error) { + query := req.URL.Query() - state := query.Get(queryState) - // Prevent timing attacks on state - if subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 { - errChan <- fmt.Errorf("invalid state") - return + state := query.Get(queryState) + // Prevent timing attacks on state + if subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 { + return nil, fmt.Errorf("invalid state") + } + + code := query.Get(queryCode) + if code == "" { + return nil, fmt.Errorf("missing code") + } + + return p.oAuthConfig.Exchange( + req.Context(), + code, + oauth2.SetAuthURLParam("code_verifier", p.codeVerifier), + ) } - code := query.Get(queryCode) - if code == "" { - errChan <- fmt.Errorf("missing code") - return - } - - // Exchange the authorization code for the OAuth token - token, err := p.oAuthConfig.Exchange( - req.Context(), - code, - oauth2.SetAuthURLParam("code_verifier", p.codeVerifier), - ) + token, err := tokenValidatorFunc() if err != nil { - errChan <- fmt.Errorf("OAuth token exchange failed: %v", err) - return + errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err) + renderPKCEFlowTmpl(w, err) } tokenChan <- token + renderPKCEFlowTmpl(w, nil) }) if err := server.ListenAndServe(); err != nil { @@ -215,3 +219,20 @@ func isRedirectURLPortUsed(redirectURL string) bool { return true } + +func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) { + tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + data := make(map[string]string) + if authError != nil { + data["Error"] = authError.Error() + } + + if err := tmpl.Execute(w, data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} diff --git a/client/internal/templates/embed.go b/client/internal/templates/embed.go new file mode 100644 index 000000000..5c0ead176 --- /dev/null +++ b/client/internal/templates/embed.go @@ -0,0 +1,8 @@ +package templates + +import ( + _ "embed" +) + +//go:embed pkce-auth-msg.html +var PKCEAuthMsgTmpl string diff --git a/client/internal/templates/pkce-auth-msg.html b/client/internal/templates/pkce-auth-msg.html new file mode 100644 index 000000000..efd1e06a3 --- /dev/null +++ b/client/internal/templates/pkce-auth-msg.html @@ -0,0 +1,87 @@ + + +
+ + + + +
+