diff --git a/client/android/login.go b/client/android/login.go index 4d4c7a650..a9422cdbf 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -3,15 +3,7 @@ package android import ( "context" "fmt" - "time" - "github.com/cenkalti/backoff/v4" - log "github.com/sirupsen/logrus" - "google.golang.org/grpc/codes" - gstatus "google.golang.org/grpc/status" - - "github.com/netbirdio/netbird/client/cmd" - "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/system" @@ -84,34 +76,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) { } func (a *Auth) saveConfigIfSSOSupported() (bool, error) { - supportsSSO := true - err := a.withBackOff(a.ctx, func() (err error) { - _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil) - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { - _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) - s, ok := gstatus.FromError(err) - if !ok { - return err - } - if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented { - supportsSSO = false - err = nil - } + authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) + if err != nil { + return false, fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() - return err - } - - return err - }) + supportsSSO, err := authClient.IsSSOSupported(a.ctx) + if err != nil { + return false, fmt.Errorf("failed to check SSO support: %v", err) + } if !supportsSSO { return false, nil } - if err != nil { - return false, fmt.Errorf("backoff cycle failed: %v", err) - } - err = profilemanager.WriteOutConfig(a.cfgPath, a.config) return true, err } @@ -129,19 +108,17 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK } func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error { + authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) + if err != nil { + return fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() + //nolint ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName) - - err := a.withBackOff(a.ctx, func() error { - backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "") - if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) { - // we got an answer from management, exit backoff earlier - return backoff.Permanent(backoffErr) - } - return backoffErr - }) + err, _ = authClient.Login(ctxWithValues, setupKey, "") if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + return fmt.Errorf("login failed: %v", err) } return profilemanager.WriteOutConfig(a.cfgPath, a.config) @@ -160,49 +137,41 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT } func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error { - var needsLogin bool + authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) + if err != nil { + return fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() // check if we need to generate JWT token - err := a.withBackOff(a.ctx, func() (err error) { - needsLogin, err = internal.IsLoginRequired(a.ctx, a.config) - return - }) + needsLogin, err := authClient.IsLoginRequired(a.ctx) if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + return fmt.Errorf("failed to check login requirement: %v", err) } jwtToken := "" if needsLogin { - tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV) + tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, isAndroidTV) if err != nil { return fmt.Errorf("interactive sso login failed: %v", err) } jwtToken = tokenInfo.GetTokenToUse() } - err = a.withBackOff(a.ctx, func() error { - err := internal.Login(a.ctx, a.config, "", jwtToken) - - if err == nil { - go urlOpener.OnLoginSuccess() - } - - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { - return nil - } - return err - }) + err, _ = authClient.Login(a.ctx, "", jwtToken) if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + return fmt.Errorf("login failed: %v", err) } + go urlOpener.OnLoginSuccess() + return nil } -func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) { - oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "") +func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) { + oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, isAndroidTV) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get OAuth flow: %v", err) } flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO()) @@ -212,22 +181,10 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*a go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode) - waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second - waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout) - defer cancel() - tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo) + tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo) if err != nil { return nil, fmt.Errorf("waiting for browser login failed: %v", err) } return &tokenInfo, nil } - -func (a *Auth) withBackOff(ctx context.Context, bf func() error) error { - return backoff.RetryNotify( - bf, - backoff.WithContext(cmd.CLIBackOffSettings, ctx), - func(err error, duration time.Duration) { - log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err) - }) -} diff --git a/client/cmd/login.go b/client/cmd/login.go index 57c010571..64b45e557 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -7,7 +7,6 @@ import ( "os/user" "runtime" "strings" - "time" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" @@ -277,18 +276,19 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo } func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error { + authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config) + if err != nil { + return fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() + needsLogin := false - err := WithBackOff(func() error { - err := internal.Login(ctx, config, "", "") - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { - needsLogin = true - return nil - } - return err - }) - if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + err, isAuthError := authClient.Login(ctx, "", "") + if isAuthError { + needsLogin = true + } else if err != nil { + return fmt.Errorf("login check failed: %v", err) } jwtToken := "" @@ -300,23 +300,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman jwtToken = tokenInfo.GetTokenToUse() } - var lastError error - - err = WithBackOff(func() error { - err := internal.Login(ctx, config, setupKey, jwtToken) - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { - lastError = err - return nil - } - return err - }) - - if lastError != nil { - return fmt.Errorf("login failed: %v", lastError) - } - + err, _ = authClient.Login(ctx, setupKey, jwtToken) if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + return fmt.Errorf("login failed: %v", err) } return nil @@ -344,11 +330,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser) - waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second - waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout) - defer c() - - tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo) + tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo) if err != nil { return nil, fmt.Errorf("waiting for browser login failed: %v", err) } diff --git a/client/embed/embed.go b/client/embed/embed.go index 8bbbef0f2..e266aae28 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" sshcommon "github.com/netbirdio/netbird/client/ssh" @@ -176,7 +177,13 @@ func (c *Client) Start(startCtx context.Context) error { // nolint:staticcheck ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName) - if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil { + authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config) + if err != nil { + return fmt.Errorf("create auth client: %w", err) + } + defer authClient.Close() + + if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil { return fmt.Errorf("login: %w", err) } diff --git a/client/internal/auth/auth.go b/client/internal/auth/auth.go new file mode 100644 index 000000000..44e98bede --- /dev/null +++ b/client/internal/auth/auth.go @@ -0,0 +1,499 @@ +package auth + +import ( + "context" + "net/url" + "sync" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/google/uuid" + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/client/system" + mgm "github.com/netbirdio/netbird/shared/management/client" + "github.com/netbirdio/netbird/shared/management/client/common" + mgmProto "github.com/netbirdio/netbird/shared/management/proto" +) + +// Auth manages authentication operations with the management server +// It maintains a long-lived connection and automatically handles reconnection with backoff +type Auth struct { + mutex sync.RWMutex + client *mgm.GrpcClient + config *profilemanager.Config + privateKey wgtypes.Key + mgmURL *url.URL + mgmTLSEnabled bool +} + +// NewAuth creates a new Auth instance that manages authentication flows +// It establishes a connection to the management server that will be reused for all operations +// The connection is automatically recreated with backoff if it becomes disconnected +func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) { + // Validate WireGuard private key + myPrivateKey, err := wgtypes.ParseKey(privateKey) + if err != nil { + return nil, err + } + + // Determine TLS setting based on URL scheme + mgmTLSEnabled := mgmURL.Scheme == "https" + + log.Debugf("connecting to Management Service %s", mgmURL.String()) + mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled) + if err != nil { + log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err) + return nil, err + } + + log.Debugf("connected to the Management service %s", mgmURL.String()) + + return &Auth{ + client: mgmClient, + config: config, + privateKey: myPrivateKey, + mgmURL: mgmURL, + mgmTLSEnabled: mgmTLSEnabled, + }, nil +} + +// Close closes the management client connection +func (a *Auth) Close() error { + a.mutex.Lock() + defer a.mutex.Unlock() + + if a.client == nil { + return nil + } + return a.client.Close() +} + +// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations. +// Returns true if either PKCE or Device authorization flow is supported, false otherwise. +// This function encapsulates the SSO detection logic to avoid exposing gRPC error codes to upper layers. +// Automatically retries with backoff and reconnection on connection errors. +func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) { + var supportsSSO bool + + err := a.withRetry(ctx, func(client *mgm.GrpcClient) error { + // Try PKCE flow first + _, err := a.getPKCEFlow(client) + if err == nil { + supportsSSO = true + return nil + } + + // Check if PKCE is not supported + if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { + // PKCE not supported, try Device flow + _, err = a.getDeviceFlow(client) + if err == nil { + supportsSSO = true + return nil + } + + // Check if Device flow is also not supported + if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { + // Neither PKCE nor Device flow is supported + supportsSSO = false + return nil + } + + // Device flow check returned an error other than NotFound/Unimplemented + return err + } + + // PKCE flow check returned an error other than NotFound/Unimplemented + return err + }) + + return supportsSSO, err +} + +// GetOAuthFlow returns an OAuth flow (PKCE or Device) using the existing management connection +// This avoids creating a new connection to the management server +func (a *Auth) GetOAuthFlow(ctx context.Context, forceDeviceAuth bool) (OAuthFlow, error) { + var flow OAuthFlow + var err error + + err = a.withRetry(ctx, func(client *mgm.GrpcClient) error { + if forceDeviceAuth { + flow, err = a.getDeviceFlow(client) + return err + } + + // Try PKCE flow first + flow, err = a.getPKCEFlow(client) + if err != nil { + // If PKCE not supported, try Device flow + if s, ok := status.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { + flow, err = a.getDeviceFlow(client) + return err + } + return err + } + return nil + }) + + return flow, err +} + +// IsLoginRequired checks if login is required by attempting to authenticate with the server +// Automatically retries with backoff and reconnection on connection errors. +func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) { + pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey)) + if err != nil { + return false, err + } + + var needsLogin bool + + err = a.withRetry(ctx, func(client *mgm.GrpcClient) error { + _, _, err := a.doMgmLogin(client, ctx, pubSSHKey) + if isLoginNeeded(err) { + needsLogin = true + return nil + } + needsLogin = false + return err + }) + + return needsLogin, err +} + +// Login attempts to log in or register the client with the management server +// Returns error and a boolean indicating if it's an authentication error (permission denied) that should stop retries. +// Automatically retries with backoff and reconnection on connection errors. +func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (error, bool) { + pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey)) + if err != nil { + return err, false + } + + var isAuthError bool + + err = a.withRetry(ctx, func(client *mgm.GrpcClient) error { + serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey) + if serverKey != nil && isRegistrationNeeded(err) { + log.Debugf("peer registration required") + _, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey) + if err != nil { + isAuthError = isPermissionDenied(err) + return err + } + } else if err != nil { + isAuthError = isPermissionDenied(err) + return err + } + + isAuthError = false + return nil + }) + + return err, isAuthError +} + +// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance +func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) { + serverKey, err := client.GetServerPublicKey() + if err != nil { + log.Errorf("failed while getting Management Service public key: %v", err) + return nil, err + } + + protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey) + if err != nil { + if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { + log.Warnf("server couldn't find pkce flow, contact admin: %v", err) + return nil, err + } + log.Errorf("failed to retrieve pkce flow: %v", err) + return nil, err + } + + protoConfig := protoFlow.GetProviderConfig() + config := &PKCEAuthProviderConfig{ + Audience: protoConfig.GetAudience(), + ClientID: protoConfig.GetClientID(), + ClientSecret: protoConfig.GetClientSecret(), + TokenEndpoint: protoConfig.GetTokenEndpoint(), + AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(), + Scope: protoConfig.GetScope(), + RedirectURLs: protoConfig.GetRedirectURLs(), + UseIDToken: protoConfig.GetUseIDToken(), + ClientCertPair: a.config.ClientCertKeyPair, + DisablePromptLogin: protoConfig.GetDisablePromptLogin(), + LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()), + } + + if err := validatePKCEConfig(config); err != nil { + return nil, err + } + + flow, err := NewPKCEAuthorizationFlow(*config) + if err != nil { + return nil, err + } + + return flow, nil +} + +// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance +func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) { + serverKey, err := client.GetServerPublicKey() + if err != nil { + log.Errorf("failed while getting Management Service public key: %v", err) + return nil, err + } + + protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey) + if err != nil { + if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { + log.Warnf("server couldn't find device flow, contact admin: %v", err) + return nil, err + } + log.Errorf("failed to retrieve device flow: %v", err) + return nil, err + } + + protoConfig := protoFlow.GetProviderConfig() + config := &DeviceAuthProviderConfig{ + Audience: protoConfig.GetAudience(), + ClientID: protoConfig.GetClientID(), + ClientSecret: protoConfig.GetClientSecret(), + Domain: protoConfig.Domain, + TokenEndpoint: protoConfig.GetTokenEndpoint(), + DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(), + Scope: protoConfig.GetScope(), + UseIDToken: protoConfig.GetUseIDToken(), + } + + // Keep compatibility with older management versions + if config.Scope == "" { + config.Scope = "openid" + } + + if err := validateDeviceAuthConfig(config); err != nil { + return nil, err + } + + flow, err := NewDeviceAuthorizationFlow(*config) + if err != nil { + return nil, err + } + + return flow, nil +} + +// doMgmLogin performs the actual login operation with the management service +func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) { + serverKey, err := client.GetServerPublicKey() + if err != nil { + log.Errorf("failed while getting Management Service public key: %v", err) + return nil, nil, err + } + + sysInfo := system.GetInfo(ctx) + a.setSystemInfoFlags(sysInfo) + loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels) + return serverKey, loginResp, err +} + +// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key. +// Otherwise tries to register with the provided setupKey via command line. +func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) { + serverPublicKey, err := client.GetServerPublicKey() + if err != nil { + log.Errorf("failed while getting Management Service public key: %v", err) + return nil, err + } + + validSetupKey, err := uuid.Parse(setupKey) + if err != nil && jwtToken == "" { + return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err) + } + + log.Debugf("sending peer registration request to Management Service") + info := system.GetInfo(ctx) + a.setSystemInfoFlags(info) + loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels) + if err != nil { + log.Errorf("failed registering peer %v", err) + return nil, err + } + + log.Infof("peer has been successfully registered on Management Service") + + return loginResp, nil +} + +// setSystemInfoFlags sets all configuration flags on the provided system info +func (a *Auth) setSystemInfoFlags(info *system.Info) { + info.SetFlags( + a.config.RosenpassEnabled, + a.config.RosenpassPermissive, + a.config.ServerSSHAllowed, + a.config.DisableClientRoutes, + a.config.DisableServerRoutes, + a.config.DisableDNS, + a.config.DisableFirewall, + a.config.BlockLANAccess, + a.config.BlockInbound, + a.config.LazyConnectionEnabled, + a.config.EnableSSHRoot, + a.config.EnableSSHSFTP, + a.config.EnableSSHLocalPortForwarding, + a.config.EnableSSHRemotePortForwarding, + a.config.DisableSSHAuth, + ) +} + +// reconnect closes the current connection and creates a new one +// It checks if the brokenClient is still the current client before reconnecting +// to avoid multiple threads reconnecting unnecessarily +func (a *Auth) reconnect(ctx context.Context, brokenClient *mgm.GrpcClient) error { + a.mutex.Lock() + defer a.mutex.Unlock() + + // Double-check: if client has already been replaced by another thread, skip reconnection + if a.client != brokenClient { + log.Debugf("client already reconnected by another thread, skipping") + return nil + } + + // Create new connection FIRST, before closing the old one + // This ensures a.client is never nil, preventing panics in other threads + log.Debugf("reconnecting to Management Service %s", a.mgmURL.String()) + mgmClient, err := mgm.NewClient(ctx, a.mgmURL.Host, a.privateKey, a.mgmTLSEnabled) + if err != nil { + log.Errorf("failed reconnecting to Management Service %s: %v", a.mgmURL.String(), err) + // Keep the old client if reconnection fails + return err + } + + // Close old connection AFTER new one is successfully created + oldClient := a.client + a.client = mgmClient + + if oldClient != nil { + if err := oldClient.Close(); err != nil { + log.Debugf("error closing old connection: %v", err) + } + } + + log.Debugf("successfully reconnected to Management service %s", a.mgmURL.String()) + return nil +} + +// isConnectionError checks if the error is a connection-related error that should trigger reconnection +func isConnectionError(err error) bool { + if err == nil { + return false + } + s, ok := status.FromError(err) + if !ok { + return false + } + // These error codes indicate connection issues + return s.Code() == codes.Unavailable || + s.Code() == codes.DeadlineExceeded || + s.Code() == codes.Canceled || + s.Code() == codes.Internal +} + +// withRetry wraps an operation with exponential backoff retry logic +// It automatically reconnects on connection errors +func (a *Auth) withRetry(ctx context.Context, operation func(client *mgm.GrpcClient) error) error { + backoffSettings := &backoff.ExponentialBackOff{ + InitialInterval: 500 * time.Millisecond, + RandomizationFactor: 0.5, + Multiplier: 1.5, + MaxInterval: 10 * time.Second, + MaxElapsedTime: 2 * time.Minute, + Stop: backoff.Stop, + Clock: backoff.SystemClock, + } + backoffSettings.Reset() + + return backoff.RetryNotify( + func() error { + // Capture the client BEFORE the operation to ensure we track the correct client + a.mutex.RLock() + currentClient := a.client + a.mutex.RUnlock() + + if currentClient == nil { + return status.Errorf(codes.Unavailable, "client is not initialized") + } + + // Execute operation with the captured client + err := operation(currentClient) + if err == nil { + return nil + } + + // If it's a connection error, attempt reconnection using the client that was actually used + if isConnectionError(err) { + log.Warnf("connection error detected, attempting reconnection: %v", err) + + if reconnectErr := a.reconnect(ctx, currentClient); reconnectErr != nil { + log.Errorf("reconnection failed: %v", reconnectErr) + return reconnectErr + } + // Return the original error to trigger retry with the new connection + return err + } + + // For authentication errors, don't retry + if isAuthenticationError(err) { + return backoff.Permanent(err) + } + + return err + }, + backoff.WithContext(backoffSettings, ctx), + func(err error, duration time.Duration) { + log.Warnf("operation failed, retrying in %v: %v", duration, err) + }, + ) +} + +// isAuthenticationError checks if the error is an authentication-related error that should not be retried. +// Returns true if the error is InvalidArgument or PermissionDenied, indicating that retrying won't help. +func isAuthenticationError(err error) bool { + if err == nil { + return false + } + s, ok := status.FromError(err) + if !ok { + return false + } + return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied +} + +// isPermissionDenied checks if the error is a PermissionDenied error. +// This is used to determine if early exit from backoff is needed (e.g., when the server responded but denied access). +func isPermissionDenied(err error) bool { + if err == nil { + return false + } + s, ok := status.FromError(err) + if !ok { + return false + } + return s.Code() == codes.PermissionDenied +} + +func isLoginNeeded(err error) bool { + return isAuthenticationError(err) +} + +func isRegistrationNeeded(err error) bool { + return isPermissionDenied(err) +} diff --git a/client/internal/auth/device_flow.go b/client/internal/auth/device_flow.go index 8ca760742..e33765300 100644 --- a/client/internal/auth/device_flow.go +++ b/client/internal/auth/device_flow.go @@ -15,7 +15,6 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/util/embeddedroots" ) @@ -26,12 +25,56 @@ const ( var _ OAuthFlow = &DeviceAuthorizationFlow{} +// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow +type DeviceAuthProviderConfig struct { + // ClientID An IDP application client id + ClientID string + // ClientSecret An IDP application client secret + ClientSecret string + // Domain An IDP API domain + // Deprecated. Use OIDCConfigEndpoint instead + Domain string + // Audience An Audience for to authorization validation + Audience string + // TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token + TokenEndpoint string + // DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code + DeviceAuthEndpoint string + // Scopes provides the scopes to be included in the token request + Scope string + // UseIDToken indicates if the id token should be used for authentication + UseIDToken bool + // LoginHint is used to pre-fill the email/username field during authentication + LoginHint string +} + +// validateDeviceAuthConfig validates device authorization provider configuration +func validateDeviceAuthConfig(config *DeviceAuthProviderConfig) error { + errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator" + + if config.Audience == "" { + return fmt.Errorf(errorMsgFormat, "Audience") + } + if config.ClientID == "" { + return fmt.Errorf(errorMsgFormat, "Client ID") + } + if config.TokenEndpoint == "" { + return fmt.Errorf(errorMsgFormat, "Token Endpoint") + } + if config.DeviceAuthEndpoint == "" { + return fmt.Errorf(errorMsgFormat, "Device Auth Endpoint") + } + if config.Scope == "" { + return fmt.Errorf(errorMsgFormat, "Device Auth Scopes") + } + return nil +} + // DeviceAuthorizationFlow implements the OAuthFlow interface, // for the Device Authorization Flow. type DeviceAuthorizationFlow struct { - providerConfig internal.DeviceAuthProviderConfig - - HTTPClient HTTPClient + providerConfig DeviceAuthProviderConfig + HTTPClient HTTPClient } // RequestDeviceCodePayload used for request device code payload for auth0 @@ -57,7 +100,7 @@ type TokenRequestResponse struct { } // NewDeviceAuthorizationFlow returns device authorization flow client -func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) { +func NewDeviceAuthorizationFlow(config DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) { httpTransport := http.DefaultTransport.(*http.Transport).Clone() httpTransport.MaxIdleConns = 5 @@ -89,6 +132,11 @@ func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string { return d.providerConfig.ClientID } +// SetLoginHint sets the login hint for the device authorization flow +func (d *DeviceAuthorizationFlow) SetLoginHint(hint string) { + d.providerConfig.LoginHint = hint +} + // RequestAuthInfo requests a device code login flow information from Hosted func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) { form := url.Values{} @@ -199,14 +247,22 @@ func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestR } // WaitToken waits user's login and authorize the app. Once the user's authorize -// it retrieves the access token from Hosted's endpoint and validates it before returning +// it retrieves the access token from Hosted's endpoint and validates it before returning. +// The method creates a timeout context internally based on info.ExpiresIn. func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) { + // Create timeout context based on flow expiration + timeout := time.Duration(info.ExpiresIn) * time.Second + waitCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + interval := time.Duration(info.Interval) * time.Second ticker := time.NewTicker(interval) + defer ticker.Stop() + for { select { - case <-ctx.Done(): - return TokenInfo{}, ctx.Err() + case <-waitCtx.Done(): + return TokenInfo{}, waitCtx.Err() case <-ticker.C: tokenResponse, err := d.requestToken(info) diff --git a/client/internal/auth/device_flow_test.go b/client/internal/auth/device_flow_test.go index 466645ee9..6a433cb61 100644 --- a/client/internal/auth/device_flow_test.go +++ b/client/internal/auth/device_flow_test.go @@ -12,8 +12,6 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/client/internal" ) type mockHTTPClient struct { @@ -115,18 +113,19 @@ func TestHosted_RequestDeviceCode(t *testing.T) { err: testCase.inputReqError, } - deviceFlow := &DeviceAuthorizationFlow{ - providerConfig: internal.DeviceAuthProviderConfig{ - Audience: expectedAudience, - ClientID: expectedClientID, - Scope: expectedScope, - TokenEndpoint: "test.hosted.com/token", - DeviceAuthEndpoint: "test.hosted.com/device/auth", - UseIDToken: false, - }, - HTTPClient: &httpClient, + config := DeviceAuthProviderConfig{ + Audience: expectedAudience, + ClientID: expectedClientID, + Scope: expectedScope, + TokenEndpoint: "test.hosted.com/token", + DeviceAuthEndpoint: "test.hosted.com/device/auth", + UseIDToken: false, } + deviceFlow, err := NewDeviceAuthorizationFlow(config) + require.NoError(t, err, "creating device flow should not fail") + deviceFlow.HTTPClient = &httpClient + authInfo, err := deviceFlow.RequestAuthInfo(context.TODO()) testCase.testingErrFunc(t, err, testCase.expectedErrorMSG) @@ -280,18 +279,19 @@ func TestHosted_WaitToken(t *testing.T) { countResBody: testCase.inputCountResBody, } - deviceFlow := DeviceAuthorizationFlow{ - providerConfig: internal.DeviceAuthProviderConfig{ - Audience: testCase.inputAudience, - ClientID: clientID, - TokenEndpoint: "test.hosted.com/token", - DeviceAuthEndpoint: "test.hosted.com/device/auth", - Scope: "openid", - UseIDToken: false, - }, - HTTPClient: &httpClient, + config := DeviceAuthProviderConfig{ + Audience: testCase.inputAudience, + ClientID: clientID, + TokenEndpoint: "test.hosted.com/token", + DeviceAuthEndpoint: "test.hosted.com/device/auth", + Scope: "openid", + UseIDToken: false, } + deviceFlow, err := NewDeviceAuthorizationFlow(config) + require.NoError(t, err, "creating device flow should not fail") + deviceFlow.HTTPClient = &httpClient + ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout) defer cancel() tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo) diff --git a/client/internal/auth/oauth.go b/client/internal/auth/oauth.go index 85a166005..a50a2ce6f 100644 --- a/client/internal/auth/oauth.go +++ b/client/internal/auth/oauth.go @@ -10,7 +10,6 @@ import ( "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" - "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/profilemanager" ) @@ -87,19 +86,33 @@ func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesk // authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) { - pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair) + authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config) + if err != nil { + return nil, fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() + + pkceFlowInfo, err := authClient.getPKCEFlow(authClient.client) if err != nil { return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) } - pkceFlowInfo.ProviderConfig.LoginHint = hint + if hint != "" { + pkceFlowInfo.SetLoginHint(hint) + } - return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) + return pkceFlowInfo, nil } // authenticateWithDeviceCodeFlow initializes the Device Code auth Flow func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) { - deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) + authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config) + if err != nil { + return nil, fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() + + deviceFlowInfo, err := authClient.getDeviceFlow(authClient.client) if err != nil { switch s, ok := gstatus.FromError(err); { case ok && s.Code() == codes.NotFound: @@ -114,7 +127,9 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager. } } - deviceFlowInfo.ProviderConfig.LoginHint = hint + if hint != "" { + deviceFlowInfo.SetLoginHint(hint) + } - return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig) + return deviceFlowInfo, nil } diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index cc43c8648..2e16836d8 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -20,7 +20,6 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/oauth2" - "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/templates" "github.com/netbirdio/netbird/shared/management/client/common" ) @@ -35,17 +34,67 @@ const ( defaultPKCETimeoutSeconds = 300 ) +// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow +type PKCEAuthProviderConfig struct { + // ClientID An IDP application client id + ClientID string + // ClientSecret An IDP application client secret + ClientSecret string + // Audience An Audience for to authorization validation + Audience string + // TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token + TokenEndpoint string + // AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code + AuthorizationEndpoint string + // Scopes provides the scopes to be included in the token request + Scope string + // RedirectURL handles authorization code from IDP manager + RedirectURLs []string + // UseIDToken indicates if the id token should be used for authentication + UseIDToken bool + // ClientCertPair is used for mTLS authentication to the IDP + ClientCertPair *tls.Certificate + // DisablePromptLogin makes the PKCE flow to not prompt the user for login + DisablePromptLogin bool + // LoginFlag is used to configure the PKCE flow login behavior + LoginFlag common.LoginFlag + // LoginHint is used to pre-fill the email/username field during authentication + LoginHint string +} + +// validatePKCEConfig validates PKCE provider configuration +func validatePKCEConfig(config *PKCEAuthProviderConfig) error { + errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator" + + if config.ClientID == "" { + return fmt.Errorf(errorMsgFormat, "Client ID") + } + if config.TokenEndpoint == "" { + return fmt.Errorf(errorMsgFormat, "Token Endpoint") + } + if config.AuthorizationEndpoint == "" { + return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint") + } + if config.Scope == "" { + return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes") + } + if config.RedirectURLs == nil { + return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs") + } + return nil +} + // PKCEAuthorizationFlow implements the OAuthFlow interface for // the Authorization Code Flow with PKCE. type PKCEAuthorizationFlow struct { - providerConfig internal.PKCEAuthProviderConfig + providerConfig PKCEAuthProviderConfig state string codeVerifier string oAuthConfig *oauth2.Config } // NewPKCEAuthorizationFlow returns new PKCE authorization code flow. -func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) { +func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) { var availableRedirectURL string excludedRanges := getSystemExcludedPortRanges() @@ -124,10 +173,21 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn }, nil } +// SetLoginHint sets the login hint for the PKCE authorization flow +func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) { + p.providerConfig.LoginHint = hint +} + // WaitToken waits for the OAuth token in the PKCE Authorization Flow. // It starts an HTTP server to receive the OAuth token callback and waits for the token or an error. // Once the token is received, it is converted to TokenInfo and validated before returning. -func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) { +// The method creates a timeout context internally based on info.ExpiresIn. +func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) { + // Create timeout context based on flow expiration + timeout := time.Duration(info.ExpiresIn) * time.Second + waitCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + tokenChan := make(chan *oauth2.Token, 1) errChan := make(chan error, 1) @@ -138,7 +198,7 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) ( server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())} defer func() { - shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := server.Shutdown(shutdownCtx); err != nil { @@ -149,8 +209,8 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) ( go p.startServer(server, tokenChan, errChan) select { - case <-ctx.Done(): - return TokenInfo{}, ctx.Err() + case <-waitCtx.Done(): + return TokenInfo{}, waitCtx.Err() case token := <-tokenChan: return p.parseOAuthToken(token) case err := <-errChan: diff --git a/client/internal/auth/pkce_flow_test.go b/client/internal/auth/pkce_flow_test.go index b77a17eaa..c487c13df 100644 --- a/client/internal/auth/pkce_flow_test.go +++ b/client/internal/auth/pkce_flow_test.go @@ -9,7 +9,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/client/internal" mgm "github.com/netbirdio/netbird/shared/management/client/common" ) @@ -50,7 +49,7 @@ func TestPromptLogin(t *testing.T) { for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { - config := internal.PKCEAuthProviderConfig{ + config := PKCEAuthProviderConfig{ ClientID: "test-client-id", Audience: "test-audience", TokenEndpoint: "https://test-token-endpoint.com/token", diff --git a/client/internal/auth/pkce_flow_windows_test.go b/client/internal/auth/pkce_flow_windows_test.go index dd455b2fe..125eb270a 100644 --- a/client/internal/auth/pkce_flow_windows_test.go +++ b/client/internal/auth/pkce_flow_windows_test.go @@ -9,8 +9,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/client/internal" ) func TestParseExcludedPortRanges(t *testing.T) { @@ -95,7 +93,7 @@ func TestNewPKCEAuthorizationFlow_WithActualExcludedPorts(t *testing.T) { availablePort := 65432 - config := internal.PKCEAuthProviderConfig{ + config := PKCEAuthProviderConfig{ ClientID: "test-client-id", Audience: "test-audience", TokenEndpoint: "https://test-token-endpoint.com/token", diff --git a/client/internal/device_auth.go b/client/internal/device_auth.go deleted file mode 100644 index 7f7d06130..000000000 --- a/client/internal/device_auth.go +++ /dev/null @@ -1,136 +0,0 @@ -package internal - -import ( - "context" - "fmt" - "net/url" - - log "github.com/sirupsen/logrus" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - mgm "github.com/netbirdio/netbird/shared/management/client" -) - -// DeviceAuthorizationFlow represents Device Authorization Flow information -type DeviceAuthorizationFlow struct { - Provider string - ProviderConfig DeviceAuthProviderConfig -} - -// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow -type DeviceAuthProviderConfig struct { - // ClientID An IDP application client id - ClientID string - // ClientSecret An IDP application client secret - ClientSecret string - // Domain An IDP API domain - // Deprecated. Use OIDCConfigEndpoint instead - Domain string - // Audience An Audience for to authorization validation - Audience string - // TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token - TokenEndpoint string - // DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code - DeviceAuthEndpoint string - // Scopes provides the scopes to be included in the token request - Scope string - // UseIDToken indicates if the id token should be used for authentication - UseIDToken bool - // LoginHint is used to pre-fill the email/username field during authentication - LoginHint string -} - -// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it -func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) { - // validate our peer's Wireguard PRIVATE key - myPrivateKey, err := wgtypes.ParseKey(privateKey) - if err != nil { - log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error()) - return DeviceAuthorizationFlow{}, err - } - - var mgmTLSEnabled bool - if mgmURL.Scheme == "https" { - mgmTLSEnabled = true - } - - log.Debugf("connecting to Management Service %s", mgmURL.String()) - mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled) - if err != nil { - log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err) - return DeviceAuthorizationFlow{}, err - } - log.Debugf("connected to the Management service %s", mgmURL.String()) - - defer func() { - err = mgmClient.Close() - if err != nil { - log.Warnf("failed to close the Management service client %v", err) - } - }() - - serverKey, err := mgmClient.GetServerPublicKey() - if err != nil { - log.Errorf("failed while getting Management Service public key: %v", err) - return DeviceAuthorizationFlow{}, err - } - - protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey) - if err != nil { - if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { - log.Warnf("server couldn't find device flow, contact admin: %v", err) - return DeviceAuthorizationFlow{}, err - } - log.Errorf("failed to retrieve device flow: %v", err) - return DeviceAuthorizationFlow{}, err - } - - deviceAuthorizationFlow := DeviceAuthorizationFlow{ - Provider: protoDeviceAuthorizationFlow.Provider.String(), - - ProviderConfig: DeviceAuthProviderConfig{ - Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(), - ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(), - ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(), - Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain, - TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(), - DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(), - Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(), - UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(), - }, - } - - // keep compatibility with older management versions - if deviceAuthorizationFlow.ProviderConfig.Scope == "" { - deviceAuthorizationFlow.ProviderConfig.Scope = "openid" - } - - err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig) - if err != nil { - return DeviceAuthorizationFlow{}, err - } - - return deviceAuthorizationFlow, nil -} - -func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error { - errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator" - if config.Audience == "" { - return fmt.Errorf(errorMSGFormat, "Audience") - } - if config.ClientID == "" { - return fmt.Errorf(errorMSGFormat, "Client ID") - } - if config.TokenEndpoint == "" { - return fmt.Errorf(errorMSGFormat, "Token Endpoint") - } - if config.DeviceAuthEndpoint == "" { - return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint") - } - if config.Scope == "" { - return fmt.Errorf(errorMSGFormat, "Device Auth Scopes") - } - return nil -} diff --git a/client/internal/login.go b/client/internal/login.go deleted file mode 100644 index f528783ef..000000000 --- a/client/internal/login.go +++ /dev/null @@ -1,201 +0,0 @@ -package internal - -import ( - "context" - "net/url" - - "github.com/google/uuid" - log "github.com/sirupsen/logrus" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - "github.com/netbirdio/netbird/client/internal/profilemanager" - "github.com/netbirdio/netbird/client/ssh" - "github.com/netbirdio/netbird/client/system" - mgm "github.com/netbirdio/netbird/shared/management/client" - mgmProto "github.com/netbirdio/netbird/shared/management/proto" -) - -// IsLoginRequired check that the server is support SSO or not -func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, error) { - mgmURL := config.ManagementURL - mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL) - if err != nil { - return false, err - } - defer func() { - err = mgmClient.Close() - if err != nil { - cStatus, ok := status.FromError(err) - if !ok || ok && cStatus.Code() != codes.Canceled { - log.Warnf("failed to close the Management service client, err: %v", err) - } - } - }() - log.Debugf("connected to the Management service %s", mgmURL.String()) - - pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey)) - if err != nil { - return false, err - } - - _, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config) - if isLoginNeeded(err) { - return true, nil - } - return false, err -} - -// Login or register the client -func Login(ctx context.Context, config *profilemanager.Config, setupKey string, jwtToken string) error { - mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL) - if err != nil { - return err - } - defer func() { - err = mgmClient.Close() - if err != nil { - cStatus, ok := status.FromError(err) - if !ok || ok && cStatus.Code() != codes.Canceled { - log.Warnf("failed to close the Management service client, err: %v", err) - } - } - }() - log.Debugf("connected to the Management service %s", config.ManagementURL.String()) - - pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey)) - if err != nil { - return err - } - - serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config) - if serverKey != nil && isRegistrationNeeded(err) { - log.Debugf("peer registration required") - _, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config) - if err != nil { - return err - } - } else if err != nil { - return err - } - - return nil -} - -func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) { - // validate our peer's Wireguard PRIVATE key - myPrivateKey, err := wgtypes.ParseKey(privateKey) - if err != nil { - log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error()) - return nil, err - } - - var mgmTlsEnabled bool - if mgmURL.Scheme == "https" { - mgmTlsEnabled = true - } - - log.Debugf("connecting to the Management service %s", mgmURL.String()) - mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled) - if err != nil { - log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err) - return nil, err - } - return mgmClient, err -} - -func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) { - serverKey, err := mgmClient.GetServerPublicKey() - if err != nil { - log.Errorf("failed while getting Management Service public key: %v", err) - return nil, nil, err - } - - sysInfo := system.GetInfo(ctx) - sysInfo.SetFlags( - config.RosenpassEnabled, - config.RosenpassPermissive, - config.ServerSSHAllowed, - config.DisableClientRoutes, - config.DisableServerRoutes, - config.DisableDNS, - config.DisableFirewall, - config.BlockLANAccess, - config.BlockInbound, - config.LazyConnectionEnabled, - config.EnableSSHRoot, - config.EnableSSHSFTP, - config.EnableSSHLocalPortForwarding, - config.EnableSSHRemotePortForwarding, - config.DisableSSHAuth, - ) - loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) - return serverKey, loginResp, err -} - -// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key. -// Otherwise tries to register with the provided setupKey via command line. -func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) { - validSetupKey, err := uuid.Parse(setupKey) - if err != nil && jwtToken == "" { - return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err) - } - - log.Debugf("sending peer registration request to Management Service") - info := system.GetInfo(ctx) - info.SetFlags( - config.RosenpassEnabled, - config.RosenpassPermissive, - config.ServerSSHAllowed, - config.DisableClientRoutes, - config.DisableServerRoutes, - config.DisableDNS, - config.DisableFirewall, - config.BlockLANAccess, - config.BlockInbound, - config.LazyConnectionEnabled, - config.EnableSSHRoot, - config.EnableSSHSFTP, - config.EnableSSHLocalPortForwarding, - config.EnableSSHRemotePortForwarding, - config.DisableSSHAuth, - ) - loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels) - if err != nil { - log.Errorf("failed registering peer %v", err) - return nil, err - } - - log.Infof("peer has been successfully registered on Management Service") - - return loginResp, nil -} - -func isLoginNeeded(err error) bool { - if err == nil { - return false - } - s, ok := status.FromError(err) - if !ok { - return false - } - if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied { - return true - } - return false -} - -func isRegistrationNeeded(err error) bool { - if err == nil { - return false - } - s, ok := status.FromError(err) - if !ok { - return false - } - if s.Code() == codes.PermissionDenied { - return true - } - return false -} diff --git a/client/internal/pkce_auth.go b/client/internal/pkce_auth.go deleted file mode 100644 index 23c92e8af..000000000 --- a/client/internal/pkce_auth.go +++ /dev/null @@ -1,138 +0,0 @@ -package internal - -import ( - "context" - "crypto/tls" - "fmt" - "net/url" - - log "github.com/sirupsen/logrus" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - mgm "github.com/netbirdio/netbird/shared/management/client" - "github.com/netbirdio/netbird/shared/management/client/common" -) - -// PKCEAuthorizationFlow represents PKCE Authorization Flow information -type PKCEAuthorizationFlow struct { - ProviderConfig PKCEAuthProviderConfig -} - -// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow -type PKCEAuthProviderConfig struct { - // ClientID An IDP application client id - ClientID string - // ClientSecret An IDP application client secret - ClientSecret string - // Audience An Audience for to authorization validation - Audience string - // TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token - TokenEndpoint string - // AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code - AuthorizationEndpoint string - // Scopes provides the scopes to be included in the token request - Scope string - // RedirectURL handles authorization code from IDP manager - RedirectURLs []string - // UseIDToken indicates if the id token should be used for authentication - UseIDToken bool - // ClientCertPair is used for mTLS authentication to the IDP - ClientCertPair *tls.Certificate - // DisablePromptLogin makes the PKCE flow to not prompt the user for login - DisablePromptLogin bool - // LoginFlag is used to configure the PKCE flow login behavior - LoginFlag common.LoginFlag - // LoginHint is used to pre-fill the email/username field during authentication - LoginHint string -} - -// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it -func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) { - // validate our peer's Wireguard PRIVATE key - myPrivateKey, err := wgtypes.ParseKey(privateKey) - if err != nil { - log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error()) - return PKCEAuthorizationFlow{}, err - } - - var mgmTLSEnabled bool - if mgmURL.Scheme == "https" { - mgmTLSEnabled = true - } - - log.Debugf("connecting to Management Service %s", mgmURL.String()) - mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled) - if err != nil { - log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err) - return PKCEAuthorizationFlow{}, err - } - log.Debugf("connected to the Management service %s", mgmURL.String()) - - defer func() { - err = mgmClient.Close() - if err != nil { - log.Warnf("failed to close the Management service client %v", err) - } - }() - - serverKey, err := mgmClient.GetServerPublicKey() - if err != nil { - log.Errorf("failed while getting Management Service public key: %v", err) - return PKCEAuthorizationFlow{}, err - } - - protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey) - if err != nil { - if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { - log.Warnf("server couldn't find pkce flow, contact admin: %v", err) - return PKCEAuthorizationFlow{}, err - } - log.Errorf("failed to retrieve pkce flow: %v", err) - return PKCEAuthorizationFlow{}, err - } - - authFlow := PKCEAuthorizationFlow{ - ProviderConfig: PKCEAuthProviderConfig{ - Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(), - ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(), - ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(), - TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(), - AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(), - Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(), - RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(), - UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(), - ClientCertPair: clientCert, - DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(), - LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()), - }, - } - - err = isPKCEProviderConfigValid(authFlow.ProviderConfig) - if err != nil { - return PKCEAuthorizationFlow{}, err - } - - return authFlow, nil -} - -func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error { - errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator" - if config.ClientID == "" { - return fmt.Errorf(errorMSGFormat, "Client ID") - } - if config.TokenEndpoint == "" { - return fmt.Errorf(errorMSGFormat, "Token Endpoint") - } - if config.AuthorizationEndpoint == "" { - return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint") - } - if config.Scope == "" { - return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes") - } - if config.RedirectURLs == nil { - return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs") - } - return nil -} diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 935910fc9..aafef41d3 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -263,7 +263,14 @@ func (c *Client) IsLoginRequired() bool { return true } - needsLogin, err := internal.IsLoginRequired(ctx, cfg) + authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg) + if err != nil { + log.Errorf("IsLoginRequired: failed to create auth client: %v", err) + return true // Assume login is required if we can't create auth client + } + defer authClient.Close() + + needsLogin, err := authClient.IsLoginRequired(ctx) if err != nil { log.Errorf("IsLoginRequired: check failed: %v", err) // If the check fails, assume login is required to be safe @@ -314,16 +321,19 @@ func (c *Client) LoginForMobile() string { // This could cause a potential race condition with loading the extension which need to be handled on swift side go func() { - waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second - waitCTX, cancel := context.WithTimeout(ctx, waitTimeout) - defer cancel() - tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo) + tokenInfo, err := oAuthFlow.WaitToken(ctx, flowInfo) if err != nil { log.Errorf("LoginForMobile: WaitToken failed: %v", err) return } jwtToken := tokenInfo.GetTokenToUse() - if err := internal.Login(ctx, cfg, "", jwtToken); err != nil { + authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg) + if err != nil { + log.Errorf("LoginForMobile: failed to create auth client: %v", err) + return + } + defer authClient.Close() + if err, _ := authClient.Login(ctx, "", jwtToken); err != nil { log.Errorf("LoginForMobile: Login failed: %v", err) return } diff --git a/client/ios/NetBirdSDK/login.go b/client/ios/NetBirdSDK/login.go index 27fdcf5ef..9d447ef3f 100644 --- a/client/ios/NetBirdSDK/login.go +++ b/client/ios/NetBirdSDK/login.go @@ -7,13 +7,8 @@ import ( "fmt" "time" - "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" - "google.golang.org/grpc/codes" - gstatus "google.golang.org/grpc/status" - "github.com/netbirdio/netbird/client/cmd" - "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/system" @@ -90,34 +85,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) { } func (a *Auth) saveConfigIfSSOSupported() (bool, error) { - supportsSSO := true - err := a.withBackOff(a.ctx, func() (err error) { - _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil) - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { - _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) - s, ok := gstatus.FromError(err) - if !ok { - return err - } - if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented { - supportsSSO = false - err = nil - } + authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) + if err != nil { + return false, fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() - return err - } - - return err - }) + supportsSSO, err := authClient.IsSSOSupported(a.ctx) + if err != nil { + return false, fmt.Errorf("failed to check SSO support: %v", err) + } if !supportsSSO { return false, nil } - if err != nil { - return false, fmt.Errorf("backoff cycle failed: %v", err) - } - // Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename) // which are blocked by the tvOS sandbox in App Group containers err = profilemanager.DirectWriteOutConfig(a.cfgPath, a.config) @@ -141,19 +123,17 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK } func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error { + authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) + if err != nil { + return fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() + //nolint ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName) - - err := a.withBackOff(a.ctx, func() error { - backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "") - if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) { - // we got an answer from management, exit backoff earlier - return backoff.Permanent(backoffErr) - } - return backoffErr - }) + err, _ = authClient.Login(ctxWithValues, setupKey, "") if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + return fmt.Errorf("login failed: %v", err) } // Use DirectWriteOutConfig to avoid atomic file operations (temp file + rename) @@ -164,15 +144,16 @@ func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string // LoginSync performs a synchronous login check without UI interaction // Used for background VPN connection where user should already be authenticated func (a *Auth) LoginSync() error { - var needsLogin bool + authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) + if err != nil { + return fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() // check if we need to generate JWT token - err := a.withBackOff(a.ctx, func() (err error) { - needsLogin, err = internal.IsLoginRequired(a.ctx, a.config) - return - }) + needsLogin, err := authClient.IsLoginRequired(a.ctx) if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + return fmt.Errorf("failed to check login requirement: %v", err) } jwtToken := "" @@ -180,15 +161,12 @@ func (a *Auth) LoginSync() error { return fmt.Errorf("not authenticated") } - err = a.withBackOff(a.ctx, func() error { - err := internal.Login(a.ctx, a.config, "", jwtToken) - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) { - // PermissionDenied means registration is required or peer is blocked - return backoff.Permanent(err) - } - return err - }) + err, isAuthError := authClient.Login(a.ctx, "", jwtToken) if err != nil { + if isAuthError { + // PermissionDenied means registration is required or peer is blocked + return fmt.Errorf("authentication error: %v", err) + } return fmt.Errorf("login failed: %v", err) } @@ -225,8 +203,6 @@ func (a *Auth) LoginWithDeviceName(resultListener ErrListener, urlOpener URLOpen } func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName string) error { - var needsLogin bool - // Create context with device name if provided ctx := a.ctx if deviceName != "" { @@ -234,33 +210,33 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin ctx = context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName) } - // check if we need to generate JWT token - err := a.withBackOff(ctx, func() (err error) { - needsLogin, err = internal.IsLoginRequired(ctx, a.config) - return - }) + authClient, err := auth.NewAuth(ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + return fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() + + // check if we need to generate JWT token + needsLogin, err := authClient.IsLoginRequired(ctx) + if err != nil { + return fmt.Errorf("failed to check login requirement: %v", err) } jwtToken := "" if needsLogin { - tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, forceDeviceAuth) + tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, forceDeviceAuth) if err != nil { return fmt.Errorf("interactive sso login failed: %v", err) } jwtToken = tokenInfo.GetTokenToUse() } - err = a.withBackOff(ctx, func() error { - err := internal.Login(ctx, a.config, "", jwtToken) - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) { - // PermissionDenied means registration is required or peer is blocked - return backoff.Permanent(err) - } - return err - }) + err, isAuthError := authClient.Login(ctx, "", jwtToken) if err != nil { + if isAuthError { + // PermissionDenied means registration is required or peer is blocked + return fmt.Errorf("authentication error: %v", err) + } return fmt.Errorf("login failed: %v", err) } @@ -285,10 +261,10 @@ func (a *Auth) login(urlOpener URLOpener, forceDeviceAuth bool, deviceName strin const authInfoRequestTimeout = 30 * time.Second -func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) { - oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, forceDeviceAuth, "") +func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, forceDeviceAuth bool) (*auth.TokenInfo, error) { + oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, forceDeviceAuth) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get OAuth flow: %v", err) } // Use a bounded timeout for the auth info request to prevent indefinite hangs @@ -313,15 +289,6 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, forceDeviceAuth bool) return &tokenInfo, nil } -func (a *Auth) withBackOff(ctx context.Context, bf func() error) error { - return backoff.RetryNotify( - bf, - backoff.WithContext(cmd.CLIBackOffSettings, ctx), - func(err error, duration time.Duration) { - log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err) - }) -} - // GetConfigJSON returns the current config as a JSON string. // This can be used by the caller to persist the config via alternative storage // mechanisms (e.g., UserDefaults on tvOS where file writes are blocked). diff --git a/client/server/server.go b/client/server/server.go index b291d7f71..108eab9fe 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -253,10 +253,17 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil // loginAttempt attempts to login using the provided information. it returns a status in case something fails func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) { - var status internal.StatusType - err := internal.Login(ctx, s.config, setupKey, jwtToken) + authClient, err := auth.NewAuth(ctx, s.config.PrivateKey, s.config.ManagementURL, s.config) if err != nil { - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { + log.Errorf("failed to create auth client: %v", err) + return internal.StatusLoginFailed, err + } + defer authClient.Close() + + var status internal.StatusType + err, isAuthError := authClient.Login(ctx, setupKey, jwtToken) + if err != nil { + if isAuthError { log.Warnf("failed login: %v", err) status = internal.StatusNeedsLogin } else { @@ -581,8 +588,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin s.oauthAuthFlow.waitCancel() } - waitTimeout := time.Until(s.oauthAuthFlow.expiresAt) - waitCTX, cancel := context.WithTimeout(ctx, waitTimeout) + waitCTX, cancel := context.WithCancel(ctx) defer cancel() s.mutex.Lock()