mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
Remove client secret from gRPC auth flow. The secret was originally included to support providers like Google Workspace that don't offer a proper PKCE flow, but this is no longer necessary with the embedded IdP. Deployments using such providers should migrate to the embedded IdP instead.
417 lines
12 KiB
Go
417 lines
12 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"crypto/subtle"
|
|
"crypto/tls"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"html/template"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.org/x/oauth2"
|
|
|
|
"github.com/netbirdio/netbird/client/internal/templates"
|
|
"github.com/netbirdio/netbird/shared/management/client/common"
|
|
)
|
|
|
|
var _ OAuthFlow = &PKCEAuthorizationFlow{}
|
|
|
|
const (
|
|
queryState = "state"
|
|
queryCode = "code"
|
|
queryError = "error"
|
|
queryErrorDesc = "error_description"
|
|
defaultPKCETimeoutSeconds = 300
|
|
)
|
|
|
|
// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow
|
|
type PKCEAuthProviderConfig struct {
|
|
// ClientID An IDP application client id
|
|
ClientID string
|
|
// Audience An Audience for to authorization validation
|
|
Audience string
|
|
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
|
TokenEndpoint string
|
|
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
|
AuthorizationEndpoint string
|
|
// Scopes provides the scopes to be included in the token request
|
|
Scope string
|
|
// RedirectURL handles authorization code from IDP manager
|
|
RedirectURLs []string
|
|
// UseIDToken indicates if the id token should be used for authentication
|
|
UseIDToken bool
|
|
// ClientCertPair is used for mTLS authentication to the IDP
|
|
ClientCertPair *tls.Certificate
|
|
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
|
DisablePromptLogin bool
|
|
// LoginFlag is used to configure the PKCE flow login behavior
|
|
LoginFlag common.LoginFlag
|
|
// LoginHint is used to pre-fill the email/username field during authentication
|
|
LoginHint string
|
|
}
|
|
|
|
// validatePKCEConfig validates PKCE provider configuration
|
|
func validatePKCEConfig(config *PKCEAuthProviderConfig) error {
|
|
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
|
|
|
if config.ClientID == "" {
|
|
return fmt.Errorf(errorMsgFormat, "Client ID")
|
|
}
|
|
if config.TokenEndpoint == "" {
|
|
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
|
|
}
|
|
if config.AuthorizationEndpoint == "" {
|
|
return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint")
|
|
}
|
|
if config.Scope == "" {
|
|
return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes")
|
|
}
|
|
if config.RedirectURLs == nil {
|
|
return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// PKCEAuthorizationFlow implements the OAuthFlow interface for
|
|
// the Authorization Code Flow with PKCE.
|
|
type PKCEAuthorizationFlow struct {
|
|
providerConfig PKCEAuthProviderConfig
|
|
state string
|
|
codeVerifier string
|
|
oAuthConfig *oauth2.Config
|
|
}
|
|
|
|
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
|
|
func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
|
var availableRedirectURL string
|
|
|
|
excludedRanges := getSystemExcludedPortRanges()
|
|
|
|
for _, redirectURL := range config.RedirectURLs {
|
|
if !isRedirectURLPortUsed(redirectURL, excludedRanges) {
|
|
availableRedirectURL = redirectURL
|
|
break
|
|
}
|
|
}
|
|
|
|
if availableRedirectURL == "" {
|
|
return nil, fmt.Errorf("no available port found from configured redirect URLs: %q", config.RedirectURLs)
|
|
}
|
|
|
|
cfg := &oauth2.Config{
|
|
ClientID: config.ClientID,
|
|
Endpoint: oauth2.Endpoint{
|
|
AuthURL: config.AuthorizationEndpoint,
|
|
TokenURL: config.TokenEndpoint,
|
|
},
|
|
RedirectURL: availableRedirectURL,
|
|
Scopes: strings.Split(config.Scope, " "),
|
|
}
|
|
|
|
return &PKCEAuthorizationFlow{
|
|
providerConfig: config,
|
|
oAuthConfig: cfg,
|
|
}, nil
|
|
}
|
|
|
|
// GetClientID returns the provider client id
|
|
func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string {
|
|
return p.providerConfig.ClientID
|
|
}
|
|
|
|
// RequestAuthInfo requests a authorization code login flow information.
|
|
func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
|
state, err := randomBytesInHex(24)
|
|
if err != nil {
|
|
return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err)
|
|
}
|
|
p.state = state
|
|
|
|
codeVerifier, err := randomBytesInHex(64)
|
|
if err != nil {
|
|
return AuthFlowInfo{}, fmt.Errorf("could not create a code verifier: %v", err)
|
|
}
|
|
p.codeVerifier = codeVerifier
|
|
|
|
codeChallenge := createCodeChallenge(codeVerifier)
|
|
|
|
params := []oauth2.AuthCodeOption{
|
|
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
|
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
|
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
|
}
|
|
if !p.providerConfig.DisablePromptLogin {
|
|
switch p.providerConfig.LoginFlag {
|
|
case common.LoginFlagPromptLogin:
|
|
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
|
case common.LoginFlagMaxAge0:
|
|
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
|
}
|
|
}
|
|
if p.providerConfig.LoginHint != "" {
|
|
params = append(params, oauth2.SetAuthURLParam("login_hint", p.providerConfig.LoginHint))
|
|
}
|
|
|
|
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
|
|
|
return AuthFlowInfo{
|
|
VerificationURIComplete: authURL,
|
|
ExpiresIn: defaultPKCETimeoutSeconds,
|
|
}, nil
|
|
}
|
|
|
|
// SetLoginHint sets the login hint for the PKCE authorization flow
|
|
func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) {
|
|
p.providerConfig.LoginHint = hint
|
|
}
|
|
|
|
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
|
|
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
|
|
// Once the token is received, it is converted to TokenInfo and validated before returning.
|
|
// The method creates a timeout context internally based on info.ExpiresIn.
|
|
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
|
// Create timeout context based on flow expiration
|
|
timeout := time.Duration(info.ExpiresIn) * time.Second
|
|
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
|
defer cancel()
|
|
|
|
tokenChan := make(chan *oauth2.Token, 1)
|
|
errChan := make(chan error, 1)
|
|
|
|
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
|
|
if err != nil {
|
|
return TokenInfo{}, fmt.Errorf("failed to parse redirect URL: %v", err)
|
|
}
|
|
|
|
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
|
defer func() {
|
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
if err := server.Shutdown(shutdownCtx); err != nil {
|
|
log.Errorf("failed to close the server: %v", err)
|
|
}
|
|
}()
|
|
|
|
go p.startServer(server, tokenChan, errChan)
|
|
|
|
select {
|
|
case <-waitCtx.Done():
|
|
return TokenInfo{}, waitCtx.Err()
|
|
case token := <-tokenChan:
|
|
return p.parseOAuthToken(token)
|
|
case err := <-errChan:
|
|
return TokenInfo{}, err
|
|
}
|
|
}
|
|
|
|
func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
|
cert := p.providerConfig.ClientCertPair
|
|
if cert != nil {
|
|
tr := &http.Transport{
|
|
TLSClientConfig: &tls.Config{
|
|
Certificates: []tls.Certificate{*cert},
|
|
},
|
|
}
|
|
sslClient := &http.Client{Transport: tr}
|
|
ctx := context.WithValue(req.Context(), oauth2.HTTPClient, sslClient)
|
|
req = req.WithContext(ctx)
|
|
}
|
|
|
|
token, err := p.handleRequest(req)
|
|
if err != nil {
|
|
renderPKCEFlowTmpl(w, err)
|
|
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
|
|
return
|
|
}
|
|
|
|
renderPKCEFlowTmpl(w, nil)
|
|
tokenChan <- token
|
|
})
|
|
|
|
server.Handler = mux
|
|
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
errChan <- err
|
|
}
|
|
}
|
|
|
|
func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, error) {
|
|
query := req.URL.Query()
|
|
|
|
if authError := query.Get(queryError); authError != "" {
|
|
authErrorDesc := query.Get(queryErrorDesc)
|
|
if authErrorDesc != "" {
|
|
return nil, fmt.Errorf("authentication failed: %s", authErrorDesc)
|
|
}
|
|
return nil, fmt.Errorf("authentication failed: %s", authError)
|
|
}
|
|
|
|
// Prevent timing attacks on the state
|
|
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
|
return nil, fmt.Errorf("authentication failed: Invalid state")
|
|
}
|
|
|
|
code := query.Get(queryCode)
|
|
if code == "" {
|
|
return nil, fmt.Errorf("authentication failed: missing code")
|
|
}
|
|
|
|
return p.oAuthConfig.Exchange(
|
|
req.Context(),
|
|
code,
|
|
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
|
|
)
|
|
}
|
|
|
|
func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, error) {
|
|
tokenInfo := TokenInfo{
|
|
AccessToken: token.AccessToken,
|
|
RefreshToken: token.RefreshToken,
|
|
TokenType: token.TokenType,
|
|
ExpiresIn: token.Expiry.Second(),
|
|
UseIDToken: p.providerConfig.UseIDToken,
|
|
}
|
|
if idToken, ok := token.Extra("id_token").(string); ok {
|
|
tokenInfo.IDToken = idToken
|
|
}
|
|
|
|
// if a provider doesn't support an audience, use the Client ID for token verification
|
|
audience := p.providerConfig.Audience
|
|
if audience == "" {
|
|
audience = p.providerConfig.ClientID
|
|
}
|
|
|
|
if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil {
|
|
return TokenInfo{}, fmt.Errorf("authentication failed: invalid access token - %w", err)
|
|
}
|
|
|
|
email, err := parseEmailFromIDToken(tokenInfo.IDToken)
|
|
if err != nil {
|
|
log.Warnf("failed to parse email from ID token: %v", err)
|
|
} else {
|
|
tokenInfo.Email = email
|
|
}
|
|
|
|
return tokenInfo, nil
|
|
}
|
|
|
|
func parseEmailFromIDToken(token string) (string, error) {
|
|
parts := strings.Split(token, ".")
|
|
if len(parts) < 2 {
|
|
return "", fmt.Errorf("invalid token format")
|
|
}
|
|
|
|
data, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to decode payload: %w", err)
|
|
}
|
|
var claims map[string]interface{}
|
|
if err := json.Unmarshal(data, &claims); err != nil {
|
|
return "", fmt.Errorf("json unmarshal error: %w", err)
|
|
}
|
|
|
|
var email string
|
|
if emailValue, ok := claims["email"].(string); ok {
|
|
email = emailValue
|
|
} else {
|
|
val, ok := claims["name"].(string)
|
|
if ok {
|
|
email = val
|
|
} else {
|
|
return "", fmt.Errorf("email or name field not found in token payload")
|
|
}
|
|
}
|
|
|
|
return email, nil
|
|
}
|
|
|
|
func createCodeChallenge(codeVerifier string) string {
|
|
sha2 := sha256.Sum256([]byte(codeVerifier))
|
|
return base64.RawURLEncoding.EncodeToString(sha2[:])
|
|
}
|
|
|
|
// isRedirectURLPortUsed checks if the port used in the redirect URL is in use or excluded on Windows.
|
|
func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRange) bool {
|
|
parsedURL, err := url.Parse(redirectURL)
|
|
if err != nil {
|
|
log.Errorf("failed to parse redirect URL: %v", err)
|
|
return true
|
|
}
|
|
|
|
port := parsedURL.Port()
|
|
|
|
if isPortInExcludedRange(port, excludedRanges) {
|
|
log.Warnf("port %s is in Windows excluded port range, skipping", port)
|
|
return true
|
|
}
|
|
|
|
addr := fmt.Sprintf(":%s", port)
|
|
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
defer func() {
|
|
if err := conn.Close(); err != nil {
|
|
log.Errorf("error while closing the connection: %v", err)
|
|
}
|
|
}()
|
|
|
|
return true
|
|
}
|
|
|
|
// excludedPortRange represents a range of excluded ports.
|
|
type excludedPortRange struct {
|
|
start int
|
|
end int
|
|
}
|
|
|
|
// isPortInExcludedRange checks if the given port is in any of the excluded ranges.
|
|
func isPortInExcludedRange(port string, excludedRanges []excludedPortRange) bool {
|
|
if len(excludedRanges) == 0 {
|
|
return false
|
|
}
|
|
|
|
portNum, err := strconv.Atoi(port)
|
|
if err != nil {
|
|
log.Debugf("invalid port number %s: %v", port, err)
|
|
return false
|
|
}
|
|
|
|
for _, r := range excludedRanges {
|
|
if portNum >= r.start && portNum <= r.end {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func renderPKCEFlowTmpl(w http.ResponseWriter, authError error) {
|
|
tmpl, err := template.New("pkce-auth-flow").Parse(templates.PKCEAuthMsgTmpl)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
data := make(map[string]string)
|
|
if authError != nil {
|
|
data["Error"] = authError.Error()
|
|
}
|
|
|
|
if err := tmpl.Execute(w, data); err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
}
|
|
}
|