mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[client] Consolidate authentication logic (#5010)
* Consolidate authentication logic - Moving auth functions from client/internal to client/internal/auth package - Creating unified auth.Auth client with NewAuth() constructor - Replacing direct auth function calls with auth client methods - Refactoring device flow and PKCE flow implementations - Updating iOS/Android/server code to use new auth client API * Refactor PKCE auth and login methods - Remove unnecessary internal package reference in PKCE flow test - Adjust context assignment placement in iOS and Android login methods
This commit is contained in:
@@ -3,15 +3,7 @@ package android
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
gstatus "google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/cmd"
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/auth"
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
@@ -84,34 +76,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||||
supportsSSO := true
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
if err != nil {
|
||||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
return false, fmt.Errorf("failed to create auth client: %v", err)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
}
|
||||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
defer authClient.Close()
|
||||||
s, ok := gstatus.FromError(err)
|
|
||||||
if !ok {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
|
|
||||||
supportsSSO = false
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
|
||||||
}
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to check SSO support: %v", err)
|
||||||
return err
|
}
|
||||||
})
|
|
||||||
|
|
||||||
if !supportsSSO {
|
if !supportsSSO {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
@@ -129,19 +108,17 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||||
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
//nolint
|
//nolint
|
||||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||||
|
err, _ = authClient.Login(ctxWithValues, setupKey, "")
|
||||||
err := a.withBackOff(a.ctx, func() error {
|
|
||||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
|
||||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
|
||||||
// we got an answer from management, exit backoff earlier
|
|
||||||
return backoff.Permanent(backoffErr)
|
|
||||||
}
|
|
||||||
return backoffErr
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||||
@@ -160,49 +137,41 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
||||||
var needsLogin bool
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
// check if we need to generate JWT token
|
// check if we need to generate JWT token
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
needsLogin, err := authClient.IsLoginRequired(a.ctx)
|
||||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
|
||||||
return
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
if needsLogin {
|
if needsLogin {
|
||||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
|
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, isAndroidTV)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||||
}
|
}
|
||||||
jwtToken = tokenInfo.GetTokenToUse()
|
jwtToken = tokenInfo.GetTokenToUse()
|
||||||
}
|
}
|
||||||
|
|
||||||
err = a.withBackOff(a.ctx, func() error {
|
err, _ = authClient.Login(a.ctx, "", jwtToken)
|
||||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
go urlOpener.OnLoginSuccess()
|
|
||||||
}
|
|
||||||
|
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go urlOpener.OnLoginSuccess()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
|
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, isAndroidTV)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
||||||
@@ -212,22 +181,10 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*a
|
|||||||
|
|
||||||
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||||
|
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo)
|
||||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
|
||||||
defer cancel()
|
|
||||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &tokenInfo, nil
|
return &tokenInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
|
||||||
return backoff.RetryNotify(
|
|
||||||
bf,
|
|
||||||
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
|
|
||||||
func(err error, duration time.Duration) {
|
|
||||||
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import (
|
|||||||
"os/user"
|
"os/user"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
@@ -277,18 +276,19 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
|
|||||||
}
|
}
|
||||||
|
|
||||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
||||||
|
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
needsLogin := false
|
needsLogin := false
|
||||||
|
|
||||||
err := WithBackOff(func() error {
|
err, isAuthError := authClient.Login(ctx, "", "")
|
||||||
err := internal.Login(ctx, config, "", "")
|
if isAuthError {
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
needsLogin = true
|
||||||
needsLogin = true
|
} else if err != nil {
|
||||||
return nil
|
return fmt.Errorf("login check failed: %v", err)
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
@@ -300,23 +300,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
|||||||
jwtToken = tokenInfo.GetTokenToUse()
|
jwtToken = tokenInfo.GetTokenToUse()
|
||||||
}
|
}
|
||||||
|
|
||||||
var lastError error
|
err, _ = authClient.Login(ctx, setupKey, jwtToken)
|
||||||
|
|
||||||
err = WithBackOff(func() error {
|
|
||||||
err := internal.Login(ctx, config, setupKey, jwtToken)
|
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
|
||||||
lastError = err
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
|
|
||||||
if lastError != nil {
|
|
||||||
return fmt.Errorf("login failed: %v", lastError)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -344,11 +330,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
|||||||
|
|
||||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||||
|
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
|
||||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
|
||||||
defer c()
|
|
||||||
|
|
||||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||||
@@ -176,7 +177,13 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
|
|
||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||||
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create auth client: %w", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
||||||
return fmt.Errorf("login: %w", err)
|
return fmt.Errorf("login: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
499
client/internal/auth/auth.go
Normal file
499
client/internal/auth/auth.go
Normal file
@@ -0,0 +1,499 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Auth manages authentication operations with the management server
|
||||||
|
// It maintains a long-lived connection and automatically handles reconnection with backoff
|
||||||
|
type Auth struct {
|
||||||
|
mutex sync.RWMutex
|
||||||
|
client *mgm.GrpcClient
|
||||||
|
config *profilemanager.Config
|
||||||
|
privateKey wgtypes.Key
|
||||||
|
mgmURL *url.URL
|
||||||
|
mgmTLSEnabled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuth creates a new Auth instance that manages authentication flows
|
||||||
|
// It establishes a connection to the management server that will be reused for all operations
|
||||||
|
// The connection is automatically recreated with backoff if it becomes disconnected
|
||||||
|
func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) {
|
||||||
|
// Validate WireGuard private key
|
||||||
|
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine TLS setting based on URL scheme
|
||||||
|
mgmTLSEnabled := mgmURL.Scheme == "https"
|
||||||
|
|
||||||
|
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||||
|
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||||
|
|
||||||
|
return &Auth{
|
||||||
|
client: mgmClient,
|
||||||
|
config: config,
|
||||||
|
privateKey: myPrivateKey,
|
||||||
|
mgmURL: mgmURL,
|
||||||
|
mgmTLSEnabled: mgmTLSEnabled,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the management client connection
|
||||||
|
func (a *Auth) Close() error {
|
||||||
|
a.mutex.Lock()
|
||||||
|
defer a.mutex.Unlock()
|
||||||
|
|
||||||
|
if a.client == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return a.client.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations.
|
||||||
|
// Returns true if either PKCE or Device authorization flow is supported, false otherwise.
|
||||||
|
// This function encapsulates the SSO detection logic to avoid exposing gRPC error codes to upper layers.
|
||||||
|
// Automatically retries with backoff and reconnection on connection errors.
|
||||||
|
func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) {
|
||||||
|
var supportsSSO bool
|
||||||
|
|
||||||
|
err := a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||||
|
// Try PKCE flow first
|
||||||
|
_, err := a.getPKCEFlow(client)
|
||||||
|
if err == nil {
|
||||||
|
supportsSSO = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if PKCE is not supported
|
||||||
|
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||||
|
// PKCE not supported, try Device flow
|
||||||
|
_, err = a.getDeviceFlow(client)
|
||||||
|
if err == nil {
|
||||||
|
supportsSSO = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if Device flow is also not supported
|
||||||
|
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||||
|
// Neither PKCE nor Device flow is supported
|
||||||
|
supportsSSO = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Device flow check returned an error other than NotFound/Unimplemented
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// PKCE flow check returned an error other than NotFound/Unimplemented
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
return supportsSSO, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOAuthFlow returns an OAuth flow (PKCE or Device) using the existing management connection
|
||||||
|
// This avoids creating a new connection to the management server
|
||||||
|
func (a *Auth) GetOAuthFlow(ctx context.Context, forceDeviceAuth bool) (OAuthFlow, error) {
|
||||||
|
var flow OAuthFlow
|
||||||
|
var err error
|
||||||
|
|
||||||
|
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||||
|
if forceDeviceAuth {
|
||||||
|
flow, err = a.getDeviceFlow(client)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try PKCE flow first
|
||||||
|
flow, err = a.getPKCEFlow(client)
|
||||||
|
if err != nil {
|
||||||
|
// If PKCE not supported, try Device flow
|
||||||
|
if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
||||||
|
flow, err = a.getDeviceFlow(client)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return flow, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLoginRequired checks if login is required by attempting to authenticate with the server
|
||||||
|
// Automatically retries with backoff and reconnection on connection errors.
|
||||||
|
func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
|
||||||
|
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var needsLogin bool
|
||||||
|
|
||||||
|
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||||
|
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||||
|
if isLoginNeeded(err) {
|
||||||
|
needsLogin = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
needsLogin = false
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
return needsLogin, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login attempts to log in or register the client with the management server
|
||||||
|
// Returns error and a boolean indicating if it's an authentication error (permission denied) that should stop retries.
|
||||||
|
// Automatically retries with backoff and reconnection on connection errors.
|
||||||
|
func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (error, bool) {
|
||||||
|
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
||||||
|
if err != nil {
|
||||||
|
return err, false
|
||||||
|
}
|
||||||
|
|
||||||
|
var isAuthError bool
|
||||||
|
|
||||||
|
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||||
|
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||||
|
if serverKey != nil && isRegistrationNeeded(err) {
|
||||||
|
log.Debugf("peer registration required")
|
||||||
|
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
||||||
|
if err != nil {
|
||||||
|
isAuthError = isPermissionDenied(err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
isAuthError = isPermissionDenied(err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
isAuthError = false
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return err, isAuthError
|
||||||
|
}
|
||||||
|
|
||||||
|
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
||||||
|
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
||||||
|
serverKey, err := client.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
|
||||||
|
if err != nil {
|
||||||
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
|
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.Errorf("failed to retrieve pkce flow: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
protoConfig := protoFlow.GetProviderConfig()
|
||||||
|
config := &PKCEAuthProviderConfig{
|
||||||
|
Audience: protoConfig.GetAudience(),
|
||||||
|
ClientID: protoConfig.GetClientID(),
|
||||||
|
ClientSecret: protoConfig.GetClientSecret(),
|
||||||
|
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||||
|
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
||||||
|
Scope: protoConfig.GetScope(),
|
||||||
|
RedirectURLs: protoConfig.GetRedirectURLs(),
|
||||||
|
UseIDToken: protoConfig.GetUseIDToken(),
|
||||||
|
ClientCertPair: a.config.ClientCertKeyPair,
|
||||||
|
DisablePromptLogin: protoConfig.GetDisablePromptLogin(),
|
||||||
|
LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validatePKCEConfig(config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
flow, err := NewPKCEAuthorizationFlow(*config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return flow, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
||||||
|
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
||||||
|
serverKey, err := client.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
|
||||||
|
if err != nil {
|
||||||
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
|
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.Errorf("failed to retrieve device flow: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
protoConfig := protoFlow.GetProviderConfig()
|
||||||
|
config := &DeviceAuthProviderConfig{
|
||||||
|
Audience: protoConfig.GetAudience(),
|
||||||
|
ClientID: protoConfig.GetClientID(),
|
||||||
|
ClientSecret: protoConfig.GetClientSecret(),
|
||||||
|
Domain: protoConfig.Domain,
|
||||||
|
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||||
|
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
||||||
|
Scope: protoConfig.GetScope(),
|
||||||
|
UseIDToken: protoConfig.GetUseIDToken(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep compatibility with older management versions
|
||||||
|
if config.Scope == "" {
|
||||||
|
config.Scope = "openid"
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validateDeviceAuthConfig(config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
flow, err := NewDeviceAuthorizationFlow(*config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return flow, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// doMgmLogin performs the actual login operation with the management service
|
||||||
|
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||||
|
serverKey, err := client.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sysInfo := system.GetInfo(ctx)
|
||||||
|
a.setSystemInfoFlags(sysInfo)
|
||||||
|
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||||
|
return serverKey, loginResp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||||
|
// Otherwise tries to register with the provided setupKey via command line.
|
||||||
|
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||||
|
serverPublicKey, err := client.GetServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
validSetupKey, err := uuid.Parse(setupKey)
|
||||||
|
if err != nil && jwtToken == "" {
|
||||||
|
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("sending peer registration request to Management Service")
|
||||||
|
info := system.GetInfo(ctx)
|
||||||
|
a.setSystemInfoFlags(info)
|
||||||
|
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed registering peer %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("peer has been successfully registered on Management Service")
|
||||||
|
|
||||||
|
return loginResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSystemInfoFlags sets all configuration flags on the provided system info
|
||||||
|
func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||||
|
info.SetFlags(
|
||||||
|
a.config.RosenpassEnabled,
|
||||||
|
a.config.RosenpassPermissive,
|
||||||
|
a.config.ServerSSHAllowed,
|
||||||
|
a.config.DisableClientRoutes,
|
||||||
|
a.config.DisableServerRoutes,
|
||||||
|
a.config.DisableDNS,
|
||||||
|
a.config.DisableFirewall,
|
||||||
|
a.config.BlockLANAccess,
|
||||||
|
a.config.BlockInbound,
|
||||||
|
a.config.LazyConnectionEnabled,
|
||||||
|
a.config.EnableSSHRoot,
|
||||||
|
a.config.EnableSSHSFTP,
|
||||||
|
a.config.EnableSSHLocalPortForwarding,
|
||||||
|
a.config.EnableSSHRemotePortForwarding,
|
||||||
|
a.config.DisableSSHAuth,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// reconnect closes the current connection and creates a new one
|
||||||
|
// It checks if the brokenClient is still the current client before reconnecting
|
||||||
|
// to avoid multiple threads reconnecting unnecessarily
|
||||||
|
func (a *Auth) reconnect(ctx context.Context, brokenClient *mgm.GrpcClient) error {
|
||||||
|
a.mutex.Lock()
|
||||||
|
defer a.mutex.Unlock()
|
||||||
|
|
||||||
|
// Double-check: if client has already been replaced by another thread, skip reconnection
|
||||||
|
if a.client != brokenClient {
|
||||||
|
log.Debugf("client already reconnected by another thread, skipping")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new connection FIRST, before closing the old one
|
||||||
|
// This ensures a.client is never nil, preventing panics in other threads
|
||||||
|
log.Debugf("reconnecting to Management Service %s", a.mgmURL.String())
|
||||||
|
mgmClient, err := mgm.NewClient(ctx, a.mgmURL.Host, a.privateKey, a.mgmTLSEnabled)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed reconnecting to Management Service %s: %v", a.mgmURL.String(), err)
|
||||||
|
// Keep the old client if reconnection fails
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close old connection AFTER new one is successfully created
|
||||||
|
oldClient := a.client
|
||||||
|
a.client = mgmClient
|
||||||
|
|
||||||
|
if oldClient != nil {
|
||||||
|
if err := oldClient.Close(); err != nil {
|
||||||
|
log.Debugf("error closing old connection: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("successfully reconnected to Management service %s", a.mgmURL.String())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isConnectionError checks if the error is a connection-related error that should trigger reconnection
|
||||||
|
func isConnectionError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// These error codes indicate connection issues
|
||||||
|
return s.Code() == codes.Unavailable ||
|
||||||
|
s.Code() == codes.DeadlineExceeded ||
|
||||||
|
s.Code() == codes.Canceled ||
|
||||||
|
s.Code() == codes.Internal
|
||||||
|
}
|
||||||
|
|
||||||
|
// withRetry wraps an operation with exponential backoff retry logic
|
||||||
|
// It automatically reconnects on connection errors
|
||||||
|
func (a *Auth) withRetry(ctx context.Context, operation func(client *mgm.GrpcClient) error) error {
|
||||||
|
backoffSettings := &backoff.ExponentialBackOff{
|
||||||
|
InitialInterval: 500 * time.Millisecond,
|
||||||
|
RandomizationFactor: 0.5,
|
||||||
|
Multiplier: 1.5,
|
||||||
|
MaxInterval: 10 * time.Second,
|
||||||
|
MaxElapsedTime: 2 * time.Minute,
|
||||||
|
Stop: backoff.Stop,
|
||||||
|
Clock: backoff.SystemClock,
|
||||||
|
}
|
||||||
|
backoffSettings.Reset()
|
||||||
|
|
||||||
|
return backoff.RetryNotify(
|
||||||
|
func() error {
|
||||||
|
// Capture the client BEFORE the operation to ensure we track the correct client
|
||||||
|
a.mutex.RLock()
|
||||||
|
currentClient := a.client
|
||||||
|
a.mutex.RUnlock()
|
||||||
|
|
||||||
|
if currentClient == nil {
|
||||||
|
return status.Errorf(codes.Unavailable, "client is not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute operation with the captured client
|
||||||
|
err := operation(currentClient)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it's a connection error, attempt reconnection using the client that was actually used
|
||||||
|
if isConnectionError(err) {
|
||||||
|
log.Warnf("connection error detected, attempting reconnection: %v", err)
|
||||||
|
|
||||||
|
if reconnectErr := a.reconnect(ctx, currentClient); reconnectErr != nil {
|
||||||
|
log.Errorf("reconnection failed: %v", reconnectErr)
|
||||||
|
return reconnectErr
|
||||||
|
}
|
||||||
|
// Return the original error to trigger retry with the new connection
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// For authentication errors, don't retry
|
||||||
|
if isAuthenticationError(err) {
|
||||||
|
return backoff.Permanent(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
backoff.WithContext(backoffSettings, ctx),
|
||||||
|
func(err error, duration time.Duration) {
|
||||||
|
log.Warnf("operation failed, retrying in %v: %v", duration, err)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAuthenticationError checks if the error is an authentication-related error that should not be retried.
|
||||||
|
// Returns true if the error is InvalidArgument or PermissionDenied, indicating that retrying won't help.
|
||||||
|
func isAuthenticationError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPermissionDenied checks if the error is a PermissionDenied error.
|
||||||
|
// This is used to determine if early exit from backoff is needed (e.g., when the server responded but denied access).
|
||||||
|
func isPermissionDenied(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s, ok := status.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.Code() == codes.PermissionDenied
|
||||||
|
}
|
||||||
|
|
||||||
|
func isLoginNeeded(err error) bool {
|
||||||
|
return isAuthenticationError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isRegistrationNeeded(err error) bool {
|
||||||
|
return isPermissionDenied(err)
|
||||||
|
}
|
||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -26,12 +25,56 @@ const (
|
|||||||
|
|
||||||
var _ OAuthFlow = &DeviceAuthorizationFlow{}
|
var _ OAuthFlow = &DeviceAuthorizationFlow{}
|
||||||
|
|
||||||
|
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
||||||
|
type DeviceAuthProviderConfig struct {
|
||||||
|
// ClientID An IDP application client id
|
||||||
|
ClientID string
|
||||||
|
// ClientSecret An IDP application client secret
|
||||||
|
ClientSecret string
|
||||||
|
// Domain An IDP API domain
|
||||||
|
// Deprecated. Use OIDCConfigEndpoint instead
|
||||||
|
Domain string
|
||||||
|
// Audience An Audience for to authorization validation
|
||||||
|
Audience string
|
||||||
|
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||||
|
TokenEndpoint string
|
||||||
|
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||||
|
DeviceAuthEndpoint string
|
||||||
|
// Scopes provides the scopes to be included in the token request
|
||||||
|
Scope string
|
||||||
|
// UseIDToken indicates if the id token should be used for authentication
|
||||||
|
UseIDToken bool
|
||||||
|
// LoginHint is used to pre-fill the email/username field during authentication
|
||||||
|
LoginHint string
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateDeviceAuthConfig validates device authorization provider configuration
|
||||||
|
func validateDeviceAuthConfig(config *DeviceAuthProviderConfig) error {
|
||||||
|
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||||
|
|
||||||
|
if config.Audience == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Audience")
|
||||||
|
}
|
||||||
|
if config.ClientID == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Client ID")
|
||||||
|
}
|
||||||
|
if config.TokenEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
|
||||||
|
}
|
||||||
|
if config.DeviceAuthEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Device Auth Endpoint")
|
||||||
|
}
|
||||||
|
if config.Scope == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Device Auth Scopes")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// DeviceAuthorizationFlow implements the OAuthFlow interface,
|
// DeviceAuthorizationFlow implements the OAuthFlow interface,
|
||||||
// for the Device Authorization Flow.
|
// for the Device Authorization Flow.
|
||||||
type DeviceAuthorizationFlow struct {
|
type DeviceAuthorizationFlow struct {
|
||||||
providerConfig internal.DeviceAuthProviderConfig
|
providerConfig DeviceAuthProviderConfig
|
||||||
|
HTTPClient HTTPClient
|
||||||
HTTPClient HTTPClient
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequestDeviceCodePayload used for request device code payload for auth0
|
// RequestDeviceCodePayload used for request device code payload for auth0
|
||||||
@@ -57,7 +100,7 @@ type TokenRequestResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewDeviceAuthorizationFlow returns device authorization flow client
|
// NewDeviceAuthorizationFlow returns device authorization flow client
|
||||||
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
func NewDeviceAuthorizationFlow(config DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
@@ -89,6 +132,11 @@ func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
|
|||||||
return d.providerConfig.ClientID
|
return d.providerConfig.ClientID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLoginHint sets the login hint for the device authorization flow
|
||||||
|
func (d *DeviceAuthorizationFlow) SetLoginHint(hint string) {
|
||||||
|
d.providerConfig.LoginHint = hint
|
||||||
|
}
|
||||||
|
|
||||||
// RequestAuthInfo requests a device code login flow information from Hosted
|
// RequestAuthInfo requests a device code login flow information from Hosted
|
||||||
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
||||||
form := url.Values{}
|
form := url.Values{}
|
||||||
@@ -199,14 +247,22 @@ func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestR
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
||||||
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
// it retrieves the access token from Hosted's endpoint and validates it before returning.
|
||||||
|
// The method creates a timeout context internally based on info.ExpiresIn.
|
||||||
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||||
|
// Create timeout context based on flow expiration
|
||||||
|
timeout := time.Duration(info.ExpiresIn) * time.Second
|
||||||
|
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
interval := time.Duration(info.Interval) * time.Second
|
interval := time.Duration(info.Interval) * time.Second
|
||||||
ticker := time.NewTicker(interval)
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-waitCtx.Done():
|
||||||
return TokenInfo{}, ctx.Err()
|
return TokenInfo{}, waitCtx.Err()
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
|
|
||||||
tokenResponse, err := d.requestToken(info)
|
tokenResponse, err := d.requestToken(info)
|
||||||
|
|||||||
@@ -12,8 +12,6 @@ import (
|
|||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockHTTPClient struct {
|
type mockHTTPClient struct {
|
||||||
@@ -115,18 +113,19 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
|||||||
err: testCase.inputReqError,
|
err: testCase.inputReqError,
|
||||||
}
|
}
|
||||||
|
|
||||||
deviceFlow := &DeviceAuthorizationFlow{
|
config := DeviceAuthProviderConfig{
|
||||||
providerConfig: internal.DeviceAuthProviderConfig{
|
Audience: expectedAudience,
|
||||||
Audience: expectedAudience,
|
ClientID: expectedClientID,
|
||||||
ClientID: expectedClientID,
|
Scope: expectedScope,
|
||||||
Scope: expectedScope,
|
TokenEndpoint: "test.hosted.com/token",
|
||||||
TokenEndpoint: "test.hosted.com/token",
|
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
UseIDToken: false,
|
||||||
UseIDToken: false,
|
|
||||||
},
|
|
||||||
HTTPClient: &httpClient,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
deviceFlow, err := NewDeviceAuthorizationFlow(config)
|
||||||
|
require.NoError(t, err, "creating device flow should not fail")
|
||||||
|
deviceFlow.HTTPClient = &httpClient
|
||||||
|
|
||||||
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
|
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
|
||||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||||
|
|
||||||
@@ -280,18 +279,19 @@ func TestHosted_WaitToken(t *testing.T) {
|
|||||||
countResBody: testCase.inputCountResBody,
|
countResBody: testCase.inputCountResBody,
|
||||||
}
|
}
|
||||||
|
|
||||||
deviceFlow := DeviceAuthorizationFlow{
|
config := DeviceAuthProviderConfig{
|
||||||
providerConfig: internal.DeviceAuthProviderConfig{
|
Audience: testCase.inputAudience,
|
||||||
Audience: testCase.inputAudience,
|
ClientID: clientID,
|
||||||
ClientID: clientID,
|
TokenEndpoint: "test.hosted.com/token",
|
||||||
TokenEndpoint: "test.hosted.com/token",
|
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
Scope: "openid",
|
||||||
Scope: "openid",
|
UseIDToken: false,
|
||||||
UseIDToken: false,
|
|
||||||
},
|
|
||||||
HTTPClient: &httpClient,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
deviceFlow, err := NewDeviceAuthorizationFlow(config)
|
||||||
|
require.NoError(t, err, "creating device flow should not fail")
|
||||||
|
deviceFlow.HTTPClient = &httpClient
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)
|
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -87,19 +86,33 @@ func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesk
|
|||||||
|
|
||||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
pkceFlowInfo, err := authClient.getPKCEFlow(authClient.client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pkceFlowInfo.ProviderConfig.LoginHint = hint
|
if hint != "" {
|
||||||
|
pkceFlowInfo.SetLoginHint(hint)
|
||||||
|
}
|
||||||
|
|
||||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
return pkceFlowInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
deviceFlowInfo, err := authClient.getDeviceFlow(authClient.client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch s, ok := gstatus.FromError(err); {
|
switch s, ok := gstatus.FromError(err); {
|
||||||
case ok && s.Code() == codes.NotFound:
|
case ok && s.Code() == codes.NotFound:
|
||||||
@@ -114,7 +127,9 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
deviceFlowInfo.ProviderConfig.LoginHint = hint
|
if hint != "" {
|
||||||
|
deviceFlowInfo.SetLoginHint(hint)
|
||||||
|
}
|
||||||
|
|
||||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
return deviceFlowInfo, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/templates"
|
"github.com/netbirdio/netbird/client/internal/templates"
|
||||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||||
)
|
)
|
||||||
@@ -35,17 +34,67 @@ const (
|
|||||||
defaultPKCETimeoutSeconds = 300
|
defaultPKCETimeoutSeconds = 300
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow
|
||||||
|
type PKCEAuthProviderConfig struct {
|
||||||
|
// ClientID An IDP application client id
|
||||||
|
ClientID string
|
||||||
|
// ClientSecret An IDP application client secret
|
||||||
|
ClientSecret string
|
||||||
|
// Audience An Audience for to authorization validation
|
||||||
|
Audience string
|
||||||
|
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||||
|
TokenEndpoint string
|
||||||
|
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
||||||
|
AuthorizationEndpoint string
|
||||||
|
// Scopes provides the scopes to be included in the token request
|
||||||
|
Scope string
|
||||||
|
// RedirectURL handles authorization code from IDP manager
|
||||||
|
RedirectURLs []string
|
||||||
|
// UseIDToken indicates if the id token should be used for authentication
|
||||||
|
UseIDToken bool
|
||||||
|
// ClientCertPair is used for mTLS authentication to the IDP
|
||||||
|
ClientCertPair *tls.Certificate
|
||||||
|
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
||||||
|
DisablePromptLogin bool
|
||||||
|
// LoginFlag is used to configure the PKCE flow login behavior
|
||||||
|
LoginFlag common.LoginFlag
|
||||||
|
// LoginHint is used to pre-fill the email/username field during authentication
|
||||||
|
LoginHint string
|
||||||
|
}
|
||||||
|
|
||||||
|
// validatePKCEConfig validates PKCE provider configuration
|
||||||
|
func validatePKCEConfig(config *PKCEAuthProviderConfig) error {
|
||||||
|
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||||
|
|
||||||
|
if config.ClientID == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Client ID")
|
||||||
|
}
|
||||||
|
if config.TokenEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
|
||||||
|
}
|
||||||
|
if config.AuthorizationEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint")
|
||||||
|
}
|
||||||
|
if config.Scope == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes")
|
||||||
|
}
|
||||||
|
if config.RedirectURLs == nil {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// PKCEAuthorizationFlow implements the OAuthFlow interface for
|
// PKCEAuthorizationFlow implements the OAuthFlow interface for
|
||||||
// the Authorization Code Flow with PKCE.
|
// the Authorization Code Flow with PKCE.
|
||||||
type PKCEAuthorizationFlow struct {
|
type PKCEAuthorizationFlow struct {
|
||||||
providerConfig internal.PKCEAuthProviderConfig
|
providerConfig PKCEAuthProviderConfig
|
||||||
state string
|
state string
|
||||||
codeVerifier string
|
codeVerifier string
|
||||||
oAuthConfig *oauth2.Config
|
oAuthConfig *oauth2.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
|
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
|
||||||
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||||
var availableRedirectURL string
|
var availableRedirectURL string
|
||||||
|
|
||||||
excludedRanges := getSystemExcludedPortRanges()
|
excludedRanges := getSystemExcludedPortRanges()
|
||||||
@@ -124,10 +173,21 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLoginHint sets the login hint for the PKCE authorization flow
|
||||||
|
func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) {
|
||||||
|
p.providerConfig.LoginHint = hint
|
||||||
|
}
|
||||||
|
|
||||||
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
|
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
|
||||||
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
|
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
|
||||||
// Once the token is received, it is converted to TokenInfo and validated before returning.
|
// Once the token is received, it is converted to TokenInfo and validated before returning.
|
||||||
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
|
// The method creates a timeout context internally based on info.ExpiresIn.
|
||||||
|
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||||
|
// Create timeout context based on flow expiration
|
||||||
|
timeout := time.Duration(info.ExpiresIn) * time.Second
|
||||||
|
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
tokenChan := make(chan *oauth2.Token, 1)
|
tokenChan := make(chan *oauth2.Token, 1)
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
|
|
||||||
@@ -138,7 +198,7 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
|||||||
|
|
||||||
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
||||||
defer func() {
|
defer func() {
|
||||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||||
@@ -149,8 +209,8 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
|||||||
go p.startServer(server, tokenChan, errChan)
|
go p.startServer(server, tokenChan, errChan)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-waitCtx.Done():
|
||||||
return TokenInfo{}, ctx.Err()
|
return TokenInfo{}, waitCtx.Err()
|
||||||
case token := <-tokenChan:
|
case token := <-tokenChan:
|
||||||
return p.parseOAuthToken(token)
|
return p.parseOAuthToken(token)
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client/common"
|
mgm "github.com/netbirdio/netbird/shared/management/client/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -50,7 +49,7 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
config := internal.PKCEAuthProviderConfig{
|
config := PKCEAuthProviderConfig{
|
||||||
ClientID: "test-client-id",
|
ClientID: "test-client-id",
|
||||||
Audience: "test-audience",
|
Audience: "test-audience",
|
||||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseExcludedPortRanges(t *testing.T) {
|
func TestParseExcludedPortRanges(t *testing.T) {
|
||||||
@@ -95,7 +93,7 @@ func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) {
|
|||||||
|
|
||||||
availablePort := 65432
|
availablePort := 65432
|
||||||
|
|
||||||
config := internal.PKCEAuthProviderConfig{
|
config := PKCEAuthProviderConfig{
|
||||||
ClientID: "test-client-id",
|
ClientID: "test-client-id",
|
||||||
Audience: "test-audience",
|
Audience: "test-audience",
|
||||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||||
|
|||||||
@@ -1,136 +0,0 @@
|
|||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
|
||||||
type DeviceAuthorizationFlow struct {
|
|
||||||
Provider string
|
|
||||||
ProviderConfig DeviceAuthProviderConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
|
||||||
type DeviceAuthProviderConfig struct {
|
|
||||||
// ClientID An IDP application client id
|
|
||||||
ClientID string
|
|
||||||
// ClientSecret An IDP application client secret
|
|
||||||
ClientSecret string
|
|
||||||
// Domain An IDP API domain
|
|
||||||
// Deprecated. Use OIDCConfigEndpoint instead
|
|
||||||
Domain string
|
|
||||||
// Audience An Audience for to authorization validation
|
|
||||||
Audience string
|
|
||||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
|
||||||
TokenEndpoint string
|
|
||||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
|
||||||
DeviceAuthEndpoint string
|
|
||||||
// Scopes provides the scopes to be included in the token request
|
|
||||||
Scope string
|
|
||||||
// UseIDToken indicates if the id token should be used for authentication
|
|
||||||
UseIDToken bool
|
|
||||||
// LoginHint is used to pre-fill the email/username field during authentication
|
|
||||||
LoginHint string
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
|
||||||
func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) {
|
|
||||||
// validate our peer's Wireguard PRIVATE key
|
|
||||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var mgmTLSEnabled bool
|
|
||||||
if mgmURL.Scheme == "https" {
|
|
||||||
mgmTLSEnabled = true
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
|
||||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err = mgmClient.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to close the Management service client %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
serverKey, err := mgmClient.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
|
|
||||||
if err != nil {
|
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
|
||||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
log.Errorf("failed to retrieve device flow: %v", err)
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
|
||||||
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
|
||||||
|
|
||||||
ProviderConfig: DeviceAuthProviderConfig{
|
|
||||||
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
|
||||||
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
|
||||||
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
|
||||||
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
|
|
||||||
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
|
||||||
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
|
|
||||||
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
|
|
||||||
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// keep compatibility with older management versions
|
|
||||||
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
|
|
||||||
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
|
|
||||||
}
|
|
||||||
|
|
||||||
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
|
||||||
if err != nil {
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return deviceAuthorizationFlow, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
|
|
||||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
|
||||||
if config.Audience == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Audience")
|
|
||||||
}
|
|
||||||
if config.ClientID == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
|
||||||
}
|
|
||||||
if config.TokenEndpoint == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
|
||||||
}
|
|
||||||
if config.DeviceAuthEndpoint == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
|
|
||||||
}
|
|
||||||
if config.Scope == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,201 +0,0 @@
|
|||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// IsLoginRequired check that the server is support SSO or not
|
|
||||||
func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, error) {
|
|
||||||
mgmURL := config.ManagementURL
|
|
||||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = mgmClient.Close()
|
|
||||||
if err != nil {
|
|
||||||
cStatus, ok := status.FromError(err)
|
|
||||||
if !ok || ok && cStatus.Code() != codes.Canceled {
|
|
||||||
log.Warnf("failed to close the Management service client, err: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
|
||||||
|
|
||||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
|
||||||
if isLoginNeeded(err) {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Login or register the client
|
|
||||||
func Login(ctx context.Context, config *profilemanager.Config, setupKey string, jwtToken string) error {
|
|
||||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = mgmClient.Close()
|
|
||||||
if err != nil {
|
|
||||||
cStatus, ok := status.FromError(err)
|
|
||||||
if !ok || ok && cStatus.Code() != codes.Canceled {
|
|
||||||
log.Warnf("failed to close the Management service client, err: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
|
|
||||||
|
|
||||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
|
||||||
if serverKey != nil && isRegistrationNeeded(err) {
|
|
||||||
log.Debugf("peer registration required")
|
|
||||||
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
|
|
||||||
// validate our peer's Wireguard PRIVATE key
|
|
||||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var mgmTlsEnabled bool
|
|
||||||
if mgmURL.Scheme == "https" {
|
|
||||||
mgmTlsEnabled = true
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("connecting to the Management service %s", mgmURL.String())
|
|
||||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return mgmClient, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
|
||||||
serverKey, err := mgmClient.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sysInfo := system.GetInfo(ctx)
|
|
||||||
sysInfo.SetFlags(
|
|
||||||
config.RosenpassEnabled,
|
|
||||||
config.RosenpassPermissive,
|
|
||||||
config.ServerSSHAllowed,
|
|
||||||
config.DisableClientRoutes,
|
|
||||||
config.DisableServerRoutes,
|
|
||||||
config.DisableDNS,
|
|
||||||
config.DisableFirewall,
|
|
||||||
config.BlockLANAccess,
|
|
||||||
config.BlockInbound,
|
|
||||||
config.LazyConnectionEnabled,
|
|
||||||
config.EnableSSHRoot,
|
|
||||||
config.EnableSSHSFTP,
|
|
||||||
config.EnableSSHLocalPortForwarding,
|
|
||||||
config.EnableSSHRemotePortForwarding,
|
|
||||||
config.DisableSSHAuth,
|
|
||||||
)
|
|
||||||
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
|
|
||||||
return serverKey, loginResp, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
|
||||||
// Otherwise tries to register with the provided setupKey via command line.
|
|
||||||
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
|
||||||
validSetupKey, err := uuid.Parse(setupKey)
|
|
||||||
if err != nil && jwtToken == "" {
|
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("sending peer registration request to Management Service")
|
|
||||||
info := system.GetInfo(ctx)
|
|
||||||
info.SetFlags(
|
|
||||||
config.RosenpassEnabled,
|
|
||||||
config.RosenpassPermissive,
|
|
||||||
config.ServerSSHAllowed,
|
|
||||||
config.DisableClientRoutes,
|
|
||||||
config.DisableServerRoutes,
|
|
||||||
config.DisableDNS,
|
|
||||||
config.DisableFirewall,
|
|
||||||
config.BlockLANAccess,
|
|
||||||
config.BlockInbound,
|
|
||||||
config.LazyConnectionEnabled,
|
|
||||||
config.EnableSSHRoot,
|
|
||||||
config.EnableSSHSFTP,
|
|
||||||
config.EnableSSHLocalPortForwarding,
|
|
||||||
config.EnableSSHRemotePortForwarding,
|
|
||||||
config.DisableSSHAuth,
|
|
||||||
)
|
|
||||||
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed registering peer %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("peer has been successfully registered on Management Service")
|
|
||||||
|
|
||||||
return loginResp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isLoginNeeded(err error) bool {
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
s, ok := status.FromError(err)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func isRegistrationNeeded(err error) bool {
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
s, ok := status.FromError(err)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if s.Code() == codes.PermissionDenied {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
@@ -1,138 +0,0 @@
|
|||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
|
|
||||||
type PKCEAuthorizationFlow struct {
|
|
||||||
ProviderConfig PKCEAuthProviderConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow
|
|
||||||
type PKCEAuthProviderConfig struct {
|
|
||||||
// ClientID An IDP application client id
|
|
||||||
ClientID string
|
|
||||||
// ClientSecret An IDP application client secret
|
|
||||||
ClientSecret string
|
|
||||||
// Audience An Audience for to authorization validation
|
|
||||||
Audience string
|
|
||||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
|
||||||
TokenEndpoint string
|
|
||||||
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
|
||||||
AuthorizationEndpoint string
|
|
||||||
// Scopes provides the scopes to be included in the token request
|
|
||||||
Scope string
|
|
||||||
// RedirectURL handles authorization code from IDP manager
|
|
||||||
RedirectURLs []string
|
|
||||||
// UseIDToken indicates if the id token should be used for authentication
|
|
||||||
UseIDToken bool
|
|
||||||
// ClientCertPair is used for mTLS authentication to the IDP
|
|
||||||
ClientCertPair *tls.Certificate
|
|
||||||
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
|
||||||
DisablePromptLogin bool
|
|
||||||
// LoginFlag is used to configure the PKCE flow login behavior
|
|
||||||
LoginFlag common.LoginFlag
|
|
||||||
// LoginHint is used to pre-fill the email/username field during authentication
|
|
||||||
LoginHint string
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
|
||||||
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
|
|
||||||
// validate our peer's Wireguard PRIVATE key
|
|
||||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var mgmTLSEnabled bool
|
|
||||||
if mgmURL.Scheme == "https" {
|
|
||||||
mgmTLSEnabled = true
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
|
||||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err = mgmClient.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to close the Management service client %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
serverKey, err := mgmClient.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey)
|
|
||||||
if err != nil {
|
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
|
||||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
log.Errorf("failed to retrieve pkce flow: %v", err)
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
authFlow := PKCEAuthorizationFlow{
|
|
||||||
ProviderConfig: PKCEAuthProviderConfig{
|
|
||||||
Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(),
|
|
||||||
ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(),
|
|
||||||
ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
|
||||||
TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
|
||||||
AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(),
|
|
||||||
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
|
|
||||||
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
|
|
||||||
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
|
||||||
ClientCertPair: clientCert,
|
|
||||||
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
|
|
||||||
LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
err = isPKCEProviderConfigValid(authFlow.ProviderConfig)
|
|
||||||
if err != nil {
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return authFlow, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
|
|
||||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
|
||||||
if config.ClientID == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
|
||||||
}
|
|
||||||
if config.TokenEndpoint == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
|
||||||
}
|
|
||||||
if config.AuthorizationEndpoint == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint")
|
|
||||||
}
|
|
||||||
if config.Scope == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes")
|
|
||||||
}
|
|
||||||
if config.RedirectURLs == nil {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -263,7 +263,14 @@ func (c *Client) IsLoginRequired() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
needsLogin, err := internal.IsLoginRequired(ctx, cfg)
|
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("IsLoginRequired: failed to create auth client: %v", err)
|
||||||
|
return true // Assume login is required if we can't create auth client
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
needsLogin, err := authClient.IsLoginRequired(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("IsLoginRequired: check failed: %v", err)
|
log.Errorf("IsLoginRequired: check failed: %v", err)
|
||||||
// If the check fails, assume login is required to be safe
|
// If the check fails, assume login is required to be safe
|
||||||
@@ -314,16 +321,19 @@ func (c *Client) LoginForMobile() string {
|
|||||||
|
|
||||||
// This could cause a potential race condition with loading the extension which need to be handled on swift side
|
// This could cause a potential race condition with loading the extension which need to be handled on swift side
|
||||||
go func() {
|
go func() {
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
tokenInfo, err := oAuthFlow.WaitToken(ctx, flowInfo)
|
||||||
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
|
|
||||||
defer cancel()
|
|
||||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("LoginForMobile: WaitToken failed: %v", err)
|
log.Errorf("LoginForMobile: WaitToken failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
jwtToken := tokenInfo.GetTokenToUse()
|
jwtToken := tokenInfo.GetTokenToUse()
|
||||||
if err := internal.Login(ctx, cfg, "", jwtToken); err != nil {
|
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("LoginForMobile: failed to create auth client: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
if err, _ := authClient.Login(ctx, "", jwtToken); err != nil {
|
||||||
log.Errorf("LoginForMobile: Login failed: %v", err)
|
log.Errorf("LoginForMobile: Login failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,13 +7,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
gstatus "google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/cmd"
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/auth"
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
@@ -90,34 +85,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||||
supportsSSO := true
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
if err != nil {
|
||||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
return false, fmt.Errorf("failed to create auth client: %v", err)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
}
|
||||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
defer authClient.Close()
|
||||||
s, ok := gstatus.FromError(err)
|
|
||||||
if !ok {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
|
|
||||||
supportsSSO = false
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
|
||||||
}
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to check SSO support: %v", err)
|
||||||
return err
|
}
|
||||||
})
|
|
||||||
|
|
||||||
if !supportsSSO {
|
if !supportsSSO {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
||||||
// which are blocked by the tvOS sandbox in App Group containers
|
// which are blocked by the tvOS sandbox in App Group containers
|
||||||
err = profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
|
err = profilemanager.DirectWriteOutConfig(a.cfgPath, a.config)
|
||||||
@@ -141,19 +123,17 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||||
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
//nolint
|
//nolint
|
||||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||||
|
err, _ = authClient.Login(ctxWithValues, setupKey, "")
|
||||||
err := a.withBackOff(a.ctx, func() error {
|
|
||||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
|
||||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
|
||||||
// we got an answer from management, exit backoff earlier
|
|
||||||
return backoff.Permanent(backoffErr)
|
|
||||||
}
|
|
||||||
return backoffErr
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
// Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename)
|
||||||
@@ -164,15 +144,16 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string
|
|||||||
// LoginSync performs a synchronous login check without UI interaction
|
// LoginSync performs a synchronous login check without UI interaction
|
||||||
// Used for background VPN connection where user should already be authenticated
|
// Used for background VPN connection where user should already be authenticated
|
||||||
func (a *Auth) LoginSync() error {
|
func (a *Auth) LoginSync() error {
|
||||||
var needsLogin bool
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
// check if we need to generate JWT token
|
// check if we need to generate JWT token
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
needsLogin, err := authClient.IsLoginRequired(a.ctx)
|
||||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
|
||||||
return
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
@@ -180,15 +161,12 @@ func (a *Auth) LoginSync() error {
|
|||||||
return fmt.Errorf("not authenticated")
|
return fmt.Errorf("not authenticated")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = a.withBackOff(a.ctx, func() error {
|
err, isAuthError := authClient.Login(a.ctx, "", jwtToken)
|
||||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
|
||||||
// PermissionDenied means registration is required or peer is blocked
|
|
||||||
return backoff.Permanent(err)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if isAuthError {
|
||||||
|
// PermissionDenied means registration is required or peer is blocked
|
||||||
|
return fmt.Errorf("authentication error: %v", err)
|
||||||
|
}
|
||||||
return fmt.Errorf("login failed: %v", err)
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -225,8 +203,6 @@ func (a *Auth) LoginWithDeviceName(resultListener ErrListener, urlOpener URLOpen
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName string) error {
|
func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName string) error {
|
||||||
var needsLogin bool
|
|
||||||
|
|
||||||
// Create context with device name if provided
|
// Create context with device name if provided
|
||||||
ctx := a.ctx
|
ctx := a.ctx
|
||||||
if deviceName != "" {
|
if deviceName != "" {
|
||||||
@@ -234,33 +210,33 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin
|
|||||||
ctx = context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
ctx = context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if we need to generate JWT token
|
authClient, err := auth.NewAuth(ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
err := a.withBackOff(ctx, func() (err error) {
|
|
||||||
needsLogin, err = internal.IsLoginRequired(ctx, a.config)
|
|
||||||
return
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
// check if we need to generate JWT token
|
||||||
|
needsLogin, err := authClient.IsLoginRequired(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
if needsLogin {
|
if needsLogin {
|
||||||
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, forceDeviceAuth)
|
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, forceDeviceAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||||
}
|
}
|
||||||
jwtToken = tokenInfo.GetTokenToUse()
|
jwtToken = tokenInfo.GetTokenToUse()
|
||||||
}
|
}
|
||||||
|
|
||||||
err = a.withBackOff(ctx, func() error {
|
err, isAuthError := authClient.Login(ctx, "", jwtToken)
|
||||||
err := internal.Login(ctx, a.config, "", jwtToken)
|
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
|
||||||
// PermissionDenied means registration is required or peer is blocked
|
|
||||||
return backoff.Permanent(err)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if isAuthError {
|
||||||
|
// PermissionDenied means registration is required or peer is blocked
|
||||||
|
return fmt.Errorf("authentication error: %v", err)
|
||||||
|
}
|
||||||
return fmt.Errorf("login failed: %v", err)
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,10 +261,10 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin
|
|||||||
|
|
||||||
const authInfoRequestTimeout = 30 * time.Second
|
const authInfoRequestTimeout = 30 * time.Second
|
||||||
|
|
||||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
|
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, forceDeviceAuth, "")
|
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, forceDeviceAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use a bounded timeout for the auth info request to prevent indefinite hangs
|
// Use a bounded timeout for the auth info request to prevent indefinite hangs
|
||||||
@@ -313,15 +289,6 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool)
|
|||||||
return &tokenInfo, nil
|
return &tokenInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
|
||||||
return backoff.RetryNotify(
|
|
||||||
bf,
|
|
||||||
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
|
|
||||||
func(err error, duration time.Duration) {
|
|
||||||
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetConfigJSON returns the current config as a JSON string.
|
// GetConfigJSON returns the current config as a JSON string.
|
||||||
// This can be used by the caller to persist the config via alternative storage
|
// This can be used by the caller to persist the config via alternative storage
|
||||||
// mechanisms (e.g., UserDefaults on tvOS where file writes are blocked).
|
// mechanisms (e.g., UserDefaults on tvOS where file writes are blocked).
|
||||||
|
|||||||
@@ -253,10 +253,17 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
|
|||||||
|
|
||||||
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
|
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
|
||||||
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
|
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
|
||||||
var status internal.StatusType
|
authClient, err := auth.NewAuth(ctx, s.config.PrivateKey, s.config.ManagementURL, s.config)
|
||||||
err := internal.Login(ctx, s.config, setupKey, jwtToken)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
log.Errorf("failed to create auth client: %v", err)
|
||||||
|
return internal.StatusLoginFailed, err
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
var status internal.StatusType
|
||||||
|
err, isAuthError := authClient.Login(ctx, setupKey, jwtToken)
|
||||||
|
if err != nil {
|
||||||
|
if isAuthError {
|
||||||
log.Warnf("failed login: %v", err)
|
log.Warnf("failed login: %v", err)
|
||||||
status = internal.StatusNeedsLogin
|
status = internal.StatusNeedsLogin
|
||||||
} else {
|
} else {
|
||||||
@@ -581,8 +588,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
|||||||
s.oauthAuthFlow.waitCancel()
|
s.oauthAuthFlow.waitCancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
waitTimeout := time.Until(s.oauthAuthFlow.expiresAt)
|
waitCTX, cancel := context.WithCancel(ctx)
|
||||||
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
|
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
|
|||||||
Reference in New Issue
Block a user