From 4bbedb5193f13f1e033acd4657a9c4f41d048bde Mon Sep 17 00:00:00 2001 From: Foosec <31885466+Foosec@users.noreply.github.com> Date: Tue, 13 Aug 2024 17:07:44 +0200 Subject: [PATCH] [client] Add mTLS support for SSO login (#2188) * Add mTLS support for SSO login * Refactor variable to follow Go naming conventions --------- Co-authored-by: bcmmbaga --- client/android/login.go | 2 +- client/internal/auth/oauth.go | 2 +- client/internal/auth/pkce_flow.go | 13 +++++++++++++ client/internal/config.go | 30 ++++++++++++++++++++++++++++++ client/internal/pkce_auth.go | 6 +++++- client/ios/NetBirdSDK/login.go | 2 +- 6 files changed, 51 insertions(+), 4 deletions(-) diff --git a/client/android/login.go b/client/android/login.go index 3840c75c1..bb61edfa8 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -84,7 +84,7 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) { func (a *Auth) saveConfigIfSSOSupported() (bool, error) { supportsSSO := true err := a.withBackOff(a.ctx, func() (err error) { - _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) + _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil) if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) s, ok := gstatus.FromError(err) diff --git a/client/internal/auth/oauth.go b/client/internal/auth/oauth.go index c9f10ca86..001609f26 100644 --- a/client/internal/auth/oauth.go +++ b/client/internal/auth/oauth.go @@ -86,7 +86,7 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopCl // authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { - pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) + pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair) if err != nil { return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) } diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index 32f5383d3..71ff6de41 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" "crypto/subtle" + "crypto/tls" "encoding/base64" "errors" "fmt" @@ -143,6 +144,18 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) ( func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) { mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + cert := p.providerConfig.ClientCertPair + if cert != nil { + tr := &http.Transport{ + TLSClientConfig: &tls.Config{ + Certificates: []tls.Certificate{*cert}, + }, + } + sslClient := &http.Client{Transport: tr} + ctx := context.WithValue(req.Context(), oauth2.HTTPClient, sslClient) + req = req.WithContext(ctx) + } + token, err := p.handleRequest(req) if err != nil { renderPKCEFlowTmpl(w, err) diff --git a/client/internal/config.go b/client/internal/config.go index 461dcdd96..725703c43 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -2,6 +2,7 @@ package internal import ( "context" + "crypto/tls" "fmt" "net/url" "os" @@ -57,6 +58,8 @@ type ConfigInput struct { DisableAutoConnect *bool ExtraIFaceBlackList []string DNSRouteInterval *time.Duration + ClientCertPath string + ClientCertKeyPath string } // Config Configuration type @@ -102,6 +105,13 @@ type Config struct { // DNSRouteInterval is the interval in which the DNS routes are updated DNSRouteInterval time.Duration + //Path to a certificate used for mTLS authentication + ClientCertPath string + + //Path to corresponding private key of ClientCertPath + ClientCertKeyPath string + + ClientCertKeyPair *tls.Certificate `json:"-"` } // ReadConfig read config file and return with Config. If it is not exists create a new with default values @@ -385,6 +395,26 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { } + if input.ClientCertKeyPath != "" { + config.ClientCertKeyPath = input.ClientCertKeyPath + updated = true + } + + if input.ClientCertPath != "" { + config.ClientCertPath = input.ClientCertPath + updated = true + } + + if config.ClientCertPath != "" && config.ClientCertKeyPath != "" { + cert, err := tls.LoadX509KeyPair(config.ClientCertPath, config.ClientCertKeyPath) + if err != nil { + log.Error("Failed to load mTLS cert/key pair: ", err) + } else { + config.ClientCertKeyPair = &cert + log.Info("Loaded client mTLS cert/key pair") + } + } + return updated, nil } diff --git a/client/internal/pkce_auth.go b/client/internal/pkce_auth.go index a35dacc77..6f714889f 100644 --- a/client/internal/pkce_auth.go +++ b/client/internal/pkce_auth.go @@ -2,6 +2,7 @@ package internal import ( "context" + "crypto/tls" "fmt" "net/url" @@ -36,10 +37,12 @@ type PKCEAuthProviderConfig struct { RedirectURLs []string // UseIDToken indicates if the id token should be used for authentication UseIDToken bool + //ClientCertPair is used for mTLS authentication to the IDP + ClientCertPair *tls.Certificate } // GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it -func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (PKCEAuthorizationFlow, error) { +func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) { // validate our peer's Wireguard PRIVATE key myPrivateKey, err := wgtypes.ParseKey(privateKey) if err != nil { @@ -93,6 +96,7 @@ func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(), RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(), UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(), + ClientCertPair: clientCert, }, } diff --git a/client/ios/NetBirdSDK/login.go b/client/ios/NetBirdSDK/login.go index 3572aa310..ff637edd4 100644 --- a/client/ios/NetBirdSDK/login.go +++ b/client/ios/NetBirdSDK/login.go @@ -74,7 +74,7 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) { err := a.withBackOff(a.ctx, func() (err error) { _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { - _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) + _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil) if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { supportsSSO = false err = nil