diff --git a/client/android/login.go b/client/android/login.go index 4d4c7a650..a7905ea12 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -3,15 +3,7 @@ package android import ( "context" "fmt" - "time" - "github.com/cenkalti/backoff/v4" - log "github.com/sirupsen/logrus" - "google.golang.org/grpc/codes" - gstatus "google.golang.org/grpc/status" - - "github.com/netbirdio/netbird/client/cmd" - "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/system" @@ -84,34 +76,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) { } func (a *Auth) saveConfigIfSSOSupported() (bool, error) { - supportsSSO := true - err := a.withBackOff(a.ctx, func() (err error) { - _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil) - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { - _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) - s, ok := gstatus.FromError(err) - if !ok { - return err - } - if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented { - supportsSSO = false - err = nil - } + authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) + if err != nil { + return false, fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() - return err - } - - return err - }) + supportsSSO, err := authClient.IsSSOSupported(a.ctx) + if err != nil { + return false, fmt.Errorf("failed to check SSO support: %v", err) + } if !supportsSSO { return false, nil } - if err != nil { - return false, fmt.Errorf("backoff cycle failed: %v", err) - } - err = profilemanager.WriteOutConfig(a.cfgPath, a.config) return true, err } @@ -131,17 +110,15 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error { //nolint ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName) - - err := a.withBackOff(a.ctx, func() error { - backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "") - if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) { - // we got an answer from management, exit backoff earlier - return backoff.Permanent(backoffErr) - } - return backoffErr - }) + authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + return fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() + + err = authClient.Login(ctxWithValues, setupKey, "") + if err != nil { + return fmt.Errorf("login failed: %v", err) } return profilemanager.WriteOutConfig(a.cfgPath, a.config) @@ -160,15 +137,16 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT } func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error { - var needsLogin bool + authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) + if err != nil { + return fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() // check if we need to generate JWT token - err := a.withBackOff(a.ctx, func() (err error) { - needsLogin, err = internal.IsLoginRequired(a.ctx, a.config) - return - }) + needsLogin, err := authClient.IsLoginRequired(a.ctx) if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + return fmt.Errorf("failed to check login requirement: %v", err) } jwtToken := "" @@ -180,22 +158,13 @@ func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error { jwtToken = tokenInfo.GetTokenToUse() } - err = a.withBackOff(a.ctx, func() error { - err := internal.Login(a.ctx, a.config, "", jwtToken) - - if err == nil { - go urlOpener.OnLoginSuccess() - } - - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { - return nil - } - return err - }) + err = authClient.Login(a.ctx, "", jwtToken) if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + return fmt.Errorf("login failed: %v", err) } + go urlOpener.OnLoginSuccess() + return nil } @@ -212,22 +181,10 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*a go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode) - waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second - waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout) - defer cancel() - tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo) + tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo) if err != nil { return nil, fmt.Errorf("waiting for browser login failed: %v", err) } return &tokenInfo, nil } - -func (a *Auth) withBackOff(ctx context.Context, bf func() error) error { - return backoff.RetryNotify( - bf, - backoff.WithContext(cmd.CLIBackOffSettings, ctx), - func(err error, duration time.Duration) { - log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err) - }) -} diff --git a/client/cmd/login.go b/client/cmd/login.go index 2ddcccc8a..8f38bf293 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -2,13 +2,13 @@ package cmd import ( "context" + "errors" "fmt" "os" "os/exec" "os/user" "runtime" "strings" - "time" log "github.com/sirupsen/logrus" "github.com/skratchdot/open-golang/open" @@ -21,6 +21,7 @@ import ( "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/system" + mgm "github.com/netbirdio/netbird/shared/management/client" "github.com/netbirdio/netbird/util" ) @@ -277,18 +278,19 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo } func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error { + authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config) + if err != nil { + return fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() + needsLogin := false - err := WithBackOff(func() error { - err := internal.Login(ctx, config, "", "") - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { - needsLogin = true - return nil - } - return err - }) - if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + err = authClient.Login(ctx, "", "") + if errors.Is(err, mgm.ErrPermissionDenied) || errors.Is(err, mgm.ErrInvalidArgument) || errors.Is(err, mgm.ErrUnauthenticated) { + needsLogin = true + } else if err != nil { + return fmt.Errorf("login check failed: %v", err) } jwtToken := "" @@ -300,23 +302,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman jwtToken = tokenInfo.GetTokenToUse() } - var lastError error - - err = WithBackOff(func() error { - err := internal.Login(ctx, config, setupKey, jwtToken) - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { - lastError = err - return nil - } - return err - }) - - if lastError != nil { - return fmt.Errorf("login failed: %v", lastError) - } - + err = authClient.Login(ctx, setupKey, jwtToken) if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + return fmt.Errorf("login failed: %v", err) } return nil @@ -344,11 +332,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser) - waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second - waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout) - defer c() - - tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo) + tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo) if err != nil { return nil, fmt.Errorf("waiting for browser login failed: %v", err) } diff --git a/client/embed/embed.go b/client/embed/embed.go index 3090ca6a2..9a762977a 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" sshcommon "github.com/netbirdio/netbird/client/ssh" @@ -168,7 +169,13 @@ func (c *Client) Start(startCtx context.Context) error { ctx := internal.CtxInitState(context.Background()) // nolint:staticcheck ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName) - if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil { + authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config) + if err != nil { + return fmt.Errorf("create auth client: %w", err) + } + defer authClient.Close() + + if err := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil { return fmt.Errorf("login: %w", err) } diff --git a/client/internal/auth/auth.go b/client/internal/auth/auth.go new file mode 100644 index 000000000..9f9e2b571 --- /dev/null +++ b/client/internal/auth/auth.go @@ -0,0 +1,287 @@ +package auth + +import ( + "context" + "errors" + "fmt" + "net/url" + + "github.com/google/uuid" + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/client/system" + mgm "github.com/netbirdio/netbird/shared/management/client" + "github.com/netbirdio/netbird/shared/management/client/common" +) + +// Auth manages authentication operations with the management server +// The underlying management client handles connection retry and reconnection automatically +type Auth struct { + client *mgm.GrpcClient + config *profilemanager.Config +} + +// NewAuth creates a new Auth instance that manages authentication flows +// It establishes a connection to the management server that will be reused for all operations +// The management client handles connection retry and reconnection automatically +func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) { + // Validate WireGuard private key + myPrivateKey, err := wgtypes.ParseKey(privateKey) + if err != nil { + log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error()) + return nil, err + } + + // Determine TLS setting based on URL scheme + mgmTLSEnabled := mgmURL.Scheme == "https" + + log.Debugf("connecting to Management Service %s", mgmURL.String()) + mgmClient := mgm.NewClient(mgmURL.Host, myPrivateKey, mgmTLSEnabled) + if err := mgmClient.Connect(ctx); err != nil { + log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err) + return nil, err + } + + log.Debugf("connected to the Management service %s", mgmURL.String()) + + return &Auth{ + client: mgmClient, + config: config, + }, nil +} + +// Close closes the management client connection +func (a *Auth) Close() error { + if a.client == nil { + return nil + } + return a.client.Close() +} + +// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations. +// Returns true if either PKCE or Device authorization flow is supported, false otherwise. +func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) { + // Try PKCE flow first + _, err := a.getPKCEFlow(ctx) + if err == nil { + return true, nil + } + + // Check if PKCE is not supported + if errors.Is(err, mgm.ErrNotFound) || errors.Is(err, mgm.ErrUnimplemented) { + // PKCE not supported, try Device flow + _, err = a.getDeviceFlow(ctx) + if err == nil { + return true, nil + } + + // Check if Device flow is also not supported + if errors.Is(err, mgm.ErrNotFound) || errors.Is(err, mgm.ErrUnimplemented) { + // Neither PKCE nor Device flow is supported + return false, nil + } + + // Device flow check returned an error other than NotFound/Unimplemented + return false, err + } + + // PKCE flow check returned an error other than NotFound/Unimplemented + return false, err +} + +// IsLoginRequired checks if login is required by attempting to authenticate with the server +func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) { + pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey)) + if err != nil { + return false, err + } + + err = a.doMgmLogin(ctx, pubSSHKey) + if isLoginNeeded(err) { + return true, nil + } + + return false, err +} + +// Login attempts to log in or register the client with the management server +// Returns custom errors from mgm package: ErrPermissionDenied, ErrInvalidArgument, ErrUnauthenticated +func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) error { + pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey)) + if err != nil { + return fmt.Errorf("generate SSH public key: %w", err) + } + + err = a.doMgmLogin(ctx, pubSSHKey) + if isRegistrationNeeded(err) { + log.Debugf("peer registration required") + return a.registerPeer(ctx, setupKey, jwtToken, pubSSHKey) + } + return err +} + +// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance +func (a *Auth) getPKCEFlow(ctx context.Context) (*PKCEAuthorizationFlow, error) { + protoFlow, err := a.client.GetPKCEAuthorizationFlow(ctx) + if err != nil { + if errors.Is(err, mgm.ErrNotFound) { + log.Warnf("server couldn't find pkce flow, contact admin: %v", err) + return nil, err + } + log.Errorf("failed to retrieve pkce flow: %v", err) + return nil, err + } + + protoConfig := protoFlow.GetProviderConfig() + config := &PKCEAuthProviderConfig{ + Audience: protoConfig.GetAudience(), + ClientID: protoConfig.GetClientID(), + ClientSecret: protoConfig.GetClientSecret(), + TokenEndpoint: protoConfig.GetTokenEndpoint(), + AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(), + Scope: protoConfig.GetScope(), + RedirectURLs: protoConfig.GetRedirectURLs(), + UseIDToken: protoConfig.GetUseIDToken(), + ClientCertPair: a.config.ClientCertKeyPair, + DisablePromptLogin: protoConfig.GetDisablePromptLogin(), + LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()), + } + + if err := validatePKCEConfig(config); err != nil { + return nil, err + } + + flow, err := NewPKCEAuthorizationFlow(*config) + if err != nil { + return nil, err + } + + return flow, nil +} + +// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance +func (a *Auth) getDeviceFlow(ctx context.Context) (*DeviceAuthorizationFlow, error) { + protoFlow, err := a.client.GetDeviceAuthorizationFlow(ctx) + if err != nil { + if errors.Is(err, mgm.ErrNotFound) { + log.Warnf("server couldn't find device flow, contact admin: %v", err) + return nil, err + } + log.Errorf("failed to retrieve device flow: %v", err) + return nil, err + } + + protoConfig := protoFlow.GetProviderConfig() + config := &DeviceAuthProviderConfig{ + Audience: protoConfig.GetAudience(), + ClientID: protoConfig.GetClientID(), + ClientSecret: protoConfig.GetClientSecret(), + Domain: protoConfig.Domain, + TokenEndpoint: protoConfig.GetTokenEndpoint(), + DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(), + Scope: protoConfig.GetScope(), + UseIDToken: protoConfig.GetUseIDToken(), + } + + // Keep compatibility with older management versions + if config.Scope == "" { + config.Scope = "openid" + } + + if err := validateDeviceAuthConfig(config); err != nil { + return nil, err + } + + flow, err := NewDeviceAuthorizationFlow(*config) + if err != nil { + return nil, err + } + + return flow, nil +} + +// doMgmLogin performs the actual login operation with the management service +func (a *Auth) doMgmLogin(ctx context.Context, pubSSHKey []byte) error { + sysInfo := system.GetInfo(ctx) + sysInfo.SetFlags( + a.config.RosenpassEnabled, + a.config.RosenpassPermissive, + a.config.ServerSSHAllowed, + a.config.DisableClientRoutes, + a.config.DisableServerRoutes, + a.config.DisableDNS, + a.config.DisableFirewall, + a.config.BlockLANAccess, + a.config.BlockInbound, + a.config.LazyConnectionEnabled, + a.config.EnableSSHRoot, + a.config.EnableSSHSFTP, + a.config.EnableSSHLocalPortForwarding, + a.config.EnableSSHRemotePortForwarding, + a.config.DisableSSHAuth, + ) + _, err := a.client.Login(ctx, sysInfo, pubSSHKey, a.config.DNSLabels) + return err +} + +// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key. +// Otherwise tries to register with the provided setupKey via command line. +func (a *Auth) registerPeer(ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) error { + validSetupKey, err := uuid.Parse(setupKey) + if err != nil && jwtToken == "" { + return fmt.Errorf("%w: invalid setup-key or no SSO information provided: %v", mgm.ErrInvalidArgument, err) + } + + log.Debugf("sending peer registration request to Management Service") + info := system.GetInfo(ctx) + info.SetFlags( + a.config.RosenpassEnabled, + a.config.RosenpassPermissive, + a.config.ServerSSHAllowed, + a.config.DisableClientRoutes, + a.config.DisableServerRoutes, + a.config.DisableDNS, + a.config.DisableFirewall, + a.config.BlockLANAccess, + a.config.BlockInbound, + a.config.LazyConnectionEnabled, + a.config.EnableSSHRoot, + a.config.EnableSSHSFTP, + a.config.EnableSSHLocalPortForwarding, + a.config.EnableSSHRemotePortForwarding, + a.config.DisableSSHAuth, + ) + + // todo: fix error handling of validSetupKey + if err := a.client.Register(ctx, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels); err != nil { + log.Errorf("failed registering peer %v", err) + return err + } + + log.Infof("peer has been successfully registered on Management Service") + + return nil +} + +// isPermissionDenied checks if the error is a PermissionDenied error +func isPermissionDenied(err error) bool { + return errors.Is(err, mgm.ErrPermissionDenied) +} + +// isLoginNeeded checks if the error indicates login is required +func isLoginNeeded(err error) bool { + if err == nil { + return false + } + return errors.Is(err, mgm.ErrInvalidArgument) || + errors.Is(err, mgm.ErrPermissionDenied) || + errors.Is(err, mgm.ErrUnauthenticated) +} + +// isRegistrationNeeded checks if the error indicates peer registration is needed +func isRegistrationNeeded(err error) bool { + return isPermissionDenied(err) +} diff --git a/client/internal/auth/device_flow.go b/client/internal/auth/device_flow.go index 8ca760742..e33765300 100644 --- a/client/internal/auth/device_flow.go +++ b/client/internal/auth/device_flow.go @@ -15,7 +15,6 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/util/embeddedroots" ) @@ -26,12 +25,56 @@ const ( var _ OAuthFlow = &DeviceAuthorizationFlow{} +// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow +type DeviceAuthProviderConfig struct { + // ClientID An IDP application client id + ClientID string + // ClientSecret An IDP application client secret + ClientSecret string + // Domain An IDP API domain + // Deprecated. Use OIDCConfigEndpoint instead + Domain string + // Audience An Audience for to authorization validation + Audience string + // TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token + TokenEndpoint string + // DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code + DeviceAuthEndpoint string + // Scopes provides the scopes to be included in the token request + Scope string + // UseIDToken indicates if the id token should be used for authentication + UseIDToken bool + // LoginHint is used to pre-fill the email/username field during authentication + LoginHint string +} + +// validateDeviceAuthConfig validates device authorization provider configuration +func validateDeviceAuthConfig(config *DeviceAuthProviderConfig) error { + errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator" + + if config.Audience == "" { + return fmt.Errorf(errorMsgFormat, "Audience") + } + if config.ClientID == "" { + return fmt.Errorf(errorMsgFormat, "Client ID") + } + if config.TokenEndpoint == "" { + return fmt.Errorf(errorMsgFormat, "Token Endpoint") + } + if config.DeviceAuthEndpoint == "" { + return fmt.Errorf(errorMsgFormat, "Device Auth Endpoint") + } + if config.Scope == "" { + return fmt.Errorf(errorMsgFormat, "Device Auth Scopes") + } + return nil +} + // DeviceAuthorizationFlow implements the OAuthFlow interface, // for the Device Authorization Flow. type DeviceAuthorizationFlow struct { - providerConfig internal.DeviceAuthProviderConfig - - HTTPClient HTTPClient + providerConfig DeviceAuthProviderConfig + HTTPClient HTTPClient } // RequestDeviceCodePayload used for request device code payload for auth0 @@ -57,7 +100,7 @@ type TokenRequestResponse struct { } // NewDeviceAuthorizationFlow returns device authorization flow client -func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) { +func NewDeviceAuthorizationFlow(config DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) { httpTransport := http.DefaultTransport.(*http.Transport).Clone() httpTransport.MaxIdleConns = 5 @@ -89,6 +132,11 @@ func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string { return d.providerConfig.ClientID } +// SetLoginHint sets the login hint for the device authorization flow +func (d *DeviceAuthorizationFlow) SetLoginHint(hint string) { + d.providerConfig.LoginHint = hint +} + // RequestAuthInfo requests a device code login flow information from Hosted func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) { form := url.Values{} @@ -199,14 +247,22 @@ func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestR } // WaitToken waits user's login and authorize the app. Once the user's authorize -// it retrieves the access token from Hosted's endpoint and validates it before returning +// it retrieves the access token from Hosted's endpoint and validates it before returning. +// The method creates a timeout context internally based on info.ExpiresIn. func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) { + // Create timeout context based on flow expiration + timeout := time.Duration(info.ExpiresIn) * time.Second + waitCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + interval := time.Duration(info.Interval) * time.Second ticker := time.NewTicker(interval) + defer ticker.Stop() + for { select { - case <-ctx.Done(): - return TokenInfo{}, ctx.Err() + case <-waitCtx.Done(): + return TokenInfo{}, waitCtx.Err() case <-ticker.C: tokenResponse, err := d.requestToken(info) diff --git a/client/internal/auth/device_flow_test.go b/client/internal/auth/device_flow_test.go index 466645ee9..6a433cb61 100644 --- a/client/internal/auth/device_flow_test.go +++ b/client/internal/auth/device_flow_test.go @@ -12,8 +12,6 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/client/internal" ) type mockHTTPClient struct { @@ -115,18 +113,19 @@ func TestHosted_RequestDeviceCode(t *testing.T) { err: testCase.inputReqError, } - deviceFlow := &DeviceAuthorizationFlow{ - providerConfig: internal.DeviceAuthProviderConfig{ - Audience: expectedAudience, - ClientID: expectedClientID, - Scope: expectedScope, - TokenEndpoint: "test.hosted.com/token", - DeviceAuthEndpoint: "test.hosted.com/device/auth", - UseIDToken: false, - }, - HTTPClient: &httpClient, + config := DeviceAuthProviderConfig{ + Audience: expectedAudience, + ClientID: expectedClientID, + Scope: expectedScope, + TokenEndpoint: "test.hosted.com/token", + DeviceAuthEndpoint: "test.hosted.com/device/auth", + UseIDToken: false, } + deviceFlow, err := NewDeviceAuthorizationFlow(config) + require.NoError(t, err, "creating device flow should not fail") + deviceFlow.HTTPClient = &httpClient + authInfo, err := deviceFlow.RequestAuthInfo(context.TODO()) testCase.testingErrFunc(t, err, testCase.expectedErrorMSG) @@ -280,18 +279,19 @@ func TestHosted_WaitToken(t *testing.T) { countResBody: testCase.inputCountResBody, } - deviceFlow := DeviceAuthorizationFlow{ - providerConfig: internal.DeviceAuthProviderConfig{ - Audience: testCase.inputAudience, - ClientID: clientID, - TokenEndpoint: "test.hosted.com/token", - DeviceAuthEndpoint: "test.hosted.com/device/auth", - Scope: "openid", - UseIDToken: false, - }, - HTTPClient: &httpClient, + config := DeviceAuthProviderConfig{ + Audience: testCase.inputAudience, + ClientID: clientID, + TokenEndpoint: "test.hosted.com/token", + DeviceAuthEndpoint: "test.hosted.com/device/auth", + Scope: "openid", + UseIDToken: false, } + deviceFlow, err := NewDeviceAuthorizationFlow(config) + require.NoError(t, err, "creating device flow should not fail") + deviceFlow.HTTPClient = &httpClient + ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout) defer cancel() tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo) diff --git a/client/internal/auth/oauth.go b/client/internal/auth/oauth.go index 85a166005..c42d4375b 100644 --- a/client/internal/auth/oauth.go +++ b/client/internal/auth/oauth.go @@ -10,7 +10,6 @@ import ( "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" - "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/profilemanager" ) @@ -87,19 +86,33 @@ func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesk // authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) { - pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair) + authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config) + if err != nil { + return nil, fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() + + pkceFlowInfo, err := authClient.getPKCEFlow(ctx) if err != nil { return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) } - pkceFlowInfo.ProviderConfig.LoginHint = hint + if hint != "" { + pkceFlowInfo.SetLoginHint(hint) + } - return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) + return pkceFlowInfo, nil } // authenticateWithDeviceCodeFlow initializes the Device Code auth Flow func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) { - deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) + authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config) + if err != nil { + return nil, fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() + + deviceFlowInfo, err := authClient.getDeviceFlow(ctx) if err != nil { switch s, ok := gstatus.FromError(err); { case ok && s.Code() == codes.NotFound: @@ -114,7 +127,9 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager. } } - deviceFlowInfo.ProviderConfig.LoginHint = hint + if hint != "" { + deviceFlowInfo.SetLoginHint(hint) + } - return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig) + return deviceFlowInfo, nil } diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index 48873f640..9afecb401 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -19,8 +19,8 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/oauth2" - "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/templates" + "github.com/netbirdio/netbird/shared/management/client/common" ) var _ OAuthFlow = &PKCEAuthorizationFlow{} @@ -33,17 +33,67 @@ const ( defaultPKCETimeoutSeconds = 300 ) +// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow +type PKCEAuthProviderConfig struct { + // ClientID An IDP application client id + ClientID string + // ClientSecret An IDP application client secret + ClientSecret string + // Audience An Audience for to authorization validation + Audience string + // TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token + TokenEndpoint string + // AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code + AuthorizationEndpoint string + // Scopes provides the scopes to be included in the token request + Scope string + // RedirectURL handles authorization code from IDP manager + RedirectURLs []string + // UseIDToken indicates if the id token should be used for authentication + UseIDToken bool + // ClientCertPair is used for mTLS authentication to the IDP + ClientCertPair *tls.Certificate + // DisablePromptLogin makes the PKCE flow to not prompt the user for login + DisablePromptLogin bool + // LoginFlag is used to configure the PKCE flow login behavior + LoginFlag common.LoginFlag + // LoginHint is used to pre-fill the email/username field during authentication + LoginHint string +} + +// validatePKCEConfig validates PKCE provider configuration +func validatePKCEConfig(config *PKCEAuthProviderConfig) error { + errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator" + + if config.ClientID == "" { + return fmt.Errorf(errorMsgFormat, "Client ID") + } + if config.TokenEndpoint == "" { + return fmt.Errorf(errorMsgFormat, "Token Endpoint") + } + if config.AuthorizationEndpoint == "" { + return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint") + } + if config.Scope == "" { + return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes") + } + if config.RedirectURLs == nil { + return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs") + } + return nil +} + // PKCEAuthorizationFlow implements the OAuthFlow interface for // the Authorization Code Flow with PKCE. type PKCEAuthorizationFlow struct { - providerConfig internal.PKCEAuthProviderConfig + providerConfig PKCEAuthProviderConfig state string codeVerifier string oAuthConfig *oauth2.Config } // NewPKCEAuthorizationFlow returns new PKCE authorization code flow. -func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) { +func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) { var availableRedirectURL string // find the first available redirect URL @@ -121,10 +171,21 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn }, nil } +// SetLoginHint sets the login hint for the PKCE authorization flow +func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) { + p.providerConfig.LoginHint = hint +} + // WaitToken waits for the OAuth token in the PKCE Authorization Flow. // It starts an HTTP server to receive the OAuth token callback and waits for the token or an error. // Once the token is received, it is converted to TokenInfo and validated before returning. -func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) { +// The method creates a timeout context internally based on info.ExpiresIn. +func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) { + // Create timeout context based on flow expiration + timeout := time.Duration(info.ExpiresIn) * time.Second + waitCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + tokenChan := make(chan *oauth2.Token, 1) errChan := make(chan error, 1) @@ -135,7 +196,7 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) ( server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())} defer func() { - shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := server.Shutdown(shutdownCtx); err != nil { @@ -146,8 +207,8 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) ( go p.startServer(server, tokenChan, errChan) select { - case <-ctx.Done(): - return TokenInfo{}, ctx.Err() + case <-waitCtx.Done(): + return TokenInfo{}, waitCtx.Err() case token := <-tokenChan: return p.parseOAuthToken(token) case err := <-errChan: diff --git a/client/internal/auth/pkce_flow_test.go b/client/internal/auth/pkce_flow_test.go index b2347d12d..151fd4dad 100644 --- a/client/internal/auth/pkce_flow_test.go +++ b/client/internal/auth/pkce_flow_test.go @@ -6,7 +6,6 @@ import ( "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/client/internal" mgm "github.com/netbirdio/netbird/shared/management/client/common" ) @@ -41,7 +40,7 @@ func TestPromptLogin(t *testing.T) { for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { - config := internal.PKCEAuthProviderConfig{ + config := PKCEAuthProviderConfig{ ClientID: "test-client-id", Audience: "test-audience", TokenEndpoint: "https://test-token-endpoint.com/token", diff --git a/client/internal/connect.go b/client/internal/connect.go index 5a5f4f63c..3ff849160 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -162,6 +162,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan return err } + // Create management client once outside retry loop + mgmClient := mgm.NewClient(c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled) + mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder) + mgmClient.SetConnStateListener(mgmNotifier) + defer c.statusRecorder.ClientStop() operation := func() error { // if context cancelled we not start new backoff cycle @@ -180,12 +185,9 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan }() log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host) - mgmClient, err := mgm.NewClient(engineCtx, c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled) - if err != nil { - return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err)) + if err := mgmClient.ConnectWithoutRetry(engineCtx); err != nil { + return wrapErr(fmt.Errorf("failed connecting to Management Service: %w", err)) } - mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder) - mgmClient.SetConnStateListener(mgmNotifier) log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host) defer func() { @@ -198,7 +200,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config) if err != nil { log.Debug(err) - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) { + if errors.Is(err, mgm.ErrPermissionDenied) { state.Set(StatusNeedsLogin) _ = c.Stop() return backoff.Permanent(wrapErr(err)) // unrecoverable error @@ -320,7 +322,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan err = backoff.Retry(operation, backOff) if err != nil { log.Debugf("exiting client retry loop due to unrecoverable error: %s", err) - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) { + if errors.Is(err, mgm.ErrPermissionDenied) { state.Set(StatusNeedsLogin) _ = c.Stop() } @@ -504,12 +506,6 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP // loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc) func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) { - - serverPublicKey, err := client.GetServerPublicKey() - if err != nil { - return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err) - } - sysInfo := system.GetInfo(ctx) sysInfo.SetFlags( config.RosenpassEnabled, @@ -528,7 +524,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config.EnableSSHRemotePortForwarding, config.DisableSSHAuth, ) - loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels) + loginResp, err := client.Login(ctx, sysInfo, pubSSHKey, config.DNSLabels) if err != nil { return nil, err } diff --git a/client/internal/device_auth.go b/client/internal/device_auth.go deleted file mode 100644 index 7f7d06130..000000000 --- a/client/internal/device_auth.go +++ /dev/null @@ -1,136 +0,0 @@ -package internal - -import ( - "context" - "fmt" - "net/url" - - log "github.com/sirupsen/logrus" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - mgm "github.com/netbirdio/netbird/shared/management/client" -) - -// DeviceAuthorizationFlow represents Device Authorization Flow information -type DeviceAuthorizationFlow struct { - Provider string - ProviderConfig DeviceAuthProviderConfig -} - -// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow -type DeviceAuthProviderConfig struct { - // ClientID An IDP application client id - ClientID string - // ClientSecret An IDP application client secret - ClientSecret string - // Domain An IDP API domain - // Deprecated. Use OIDCConfigEndpoint instead - Domain string - // Audience An Audience for to authorization validation - Audience string - // TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token - TokenEndpoint string - // DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code - DeviceAuthEndpoint string - // Scopes provides the scopes to be included in the token request - Scope string - // UseIDToken indicates if the id token should be used for authentication - UseIDToken bool - // LoginHint is used to pre-fill the email/username field during authentication - LoginHint string -} - -// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it -func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) { - // validate our peer's Wireguard PRIVATE key - myPrivateKey, err := wgtypes.ParseKey(privateKey) - if err != nil { - log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error()) - return DeviceAuthorizationFlow{}, err - } - - var mgmTLSEnabled bool - if mgmURL.Scheme == "https" { - mgmTLSEnabled = true - } - - log.Debugf("connecting to Management Service %s", mgmURL.String()) - mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled) - if err != nil { - log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err) - return DeviceAuthorizationFlow{}, err - } - log.Debugf("connected to the Management service %s", mgmURL.String()) - - defer func() { - err = mgmClient.Close() - if err != nil { - log.Warnf("failed to close the Management service client %v", err) - } - }() - - serverKey, err := mgmClient.GetServerPublicKey() - if err != nil { - log.Errorf("failed while getting Management Service public key: %v", err) - return DeviceAuthorizationFlow{}, err - } - - protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey) - if err != nil { - if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { - log.Warnf("server couldn't find device flow, contact admin: %v", err) - return DeviceAuthorizationFlow{}, err - } - log.Errorf("failed to retrieve device flow: %v", err) - return DeviceAuthorizationFlow{}, err - } - - deviceAuthorizationFlow := DeviceAuthorizationFlow{ - Provider: protoDeviceAuthorizationFlow.Provider.String(), - - ProviderConfig: DeviceAuthProviderConfig{ - Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(), - ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(), - ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(), - Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain, - TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(), - DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(), - Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(), - UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(), - }, - } - - // keep compatibility with older management versions - if deviceAuthorizationFlow.ProviderConfig.Scope == "" { - deviceAuthorizationFlow.ProviderConfig.Scope = "openid" - } - - err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig) - if err != nil { - return DeviceAuthorizationFlow{}, err - } - - return deviceAuthorizationFlow, nil -} - -func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error { - errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator" - if config.Audience == "" { - return fmt.Errorf(errorMSGFormat, "Audience") - } - if config.ClientID == "" { - return fmt.Errorf(errorMSGFormat, "Client ID") - } - if config.TokenEndpoint == "" { - return fmt.Errorf(errorMSGFormat, "Token Endpoint") - } - if config.DeviceAuthEndpoint == "" { - return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint") - } - if config.Scope == "" { - return fmt.Errorf(errorMSGFormat, "Device Auth Scopes") - } - return nil -} diff --git a/client/internal/engine.go b/client/internal/engine.go index 0ff1006cd..fb386717d 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -891,7 +891,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error { e.config.DisableSSHAuth, ) - if err := e.mgmClient.SyncMeta(info); err != nil { + if err := e.mgmClient.SyncMeta(e.ctx, info); err != nil { log.Errorf("could not sync meta: error %s", err) return err } @@ -1517,7 +1517,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err e.config.DisableSSHAuth, ) - netMap, err := e.mgmClient.GetNetworkMap(info) + netMap, err := e.mgmClient.GetNetworkMap(e.ctx, info) if err != nil { return nil, nil, false, err } @@ -1666,7 +1666,7 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool { signalHealthy := e.signal.IsHealthy() log.Debugf("signal health check: healthy=%t", signalHealthy) - managementHealthy := e.mgmClient.IsHealthy() + managementHealthy := e.mgmClient.IsHealthy(e.ctx) log.Debugf("management health check: healthy=%t", managementHealthy) stuns := slices.Clone(e.STUNs) diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 9252ce13e..bbc6b286d 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1503,8 +1503,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin if err != nil { return nil, err } - mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false) - if err != nil { + mgmtClient := mgmt.NewClient(mgmtAddr, key, false) + if err := mgmtClient.Connect(ctx); err != nil { return nil, err } signalClient, err := signal.NewClient(ctx, signalAddr, key, false) @@ -1512,13 +1512,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin return nil, err } - publicKey, err := mgmtClient.GetServerPublicKey() - if err != nil { - return nil, err - } - info := system.GetInfo(ctx) - resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil) + err = mgmtClient.Register(ctx, setupKey, "", info, nil, nil) if err != nil { return nil, err } @@ -1531,9 +1526,10 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin } wgPort := 33100 + i + testAddr := fmt.Sprintf("100.64.0.%d/24", i+1) conf := &EngineConfig{ WgIfaceName: ifaceName, - WgAddr: resp.PeerConfig.Address, + WgAddr: testAddr, WgPrivateKey: key, WgPort: wgPort, MTU: iface.DefaultMTU, diff --git a/client/internal/login.go b/client/internal/login.go deleted file mode 100644 index f528783ef..000000000 --- a/client/internal/login.go +++ /dev/null @@ -1,201 +0,0 @@ -package internal - -import ( - "context" - "net/url" - - "github.com/google/uuid" - log "github.com/sirupsen/logrus" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - "github.com/netbirdio/netbird/client/internal/profilemanager" - "github.com/netbirdio/netbird/client/ssh" - "github.com/netbirdio/netbird/client/system" - mgm "github.com/netbirdio/netbird/shared/management/client" - mgmProto "github.com/netbirdio/netbird/shared/management/proto" -) - -// IsLoginRequired check that the server is support SSO or not -func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, error) { - mgmURL := config.ManagementURL - mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL) - if err != nil { - return false, err - } - defer func() { - err = mgmClient.Close() - if err != nil { - cStatus, ok := status.FromError(err) - if !ok || ok && cStatus.Code() != codes.Canceled { - log.Warnf("failed to close the Management service client, err: %v", err) - } - } - }() - log.Debugf("connected to the Management service %s", mgmURL.String()) - - pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey)) - if err != nil { - return false, err - } - - _, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config) - if isLoginNeeded(err) { - return true, nil - } - return false, err -} - -// Login or register the client -func Login(ctx context.Context, config *profilemanager.Config, setupKey string, jwtToken string) error { - mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL) - if err != nil { - return err - } - defer func() { - err = mgmClient.Close() - if err != nil { - cStatus, ok := status.FromError(err) - if !ok || ok && cStatus.Code() != codes.Canceled { - log.Warnf("failed to close the Management service client, err: %v", err) - } - } - }() - log.Debugf("connected to the Management service %s", config.ManagementURL.String()) - - pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey)) - if err != nil { - return err - } - - serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config) - if serverKey != nil && isRegistrationNeeded(err) { - log.Debugf("peer registration required") - _, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config) - if err != nil { - return err - } - } else if err != nil { - return err - } - - return nil -} - -func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) { - // validate our peer's Wireguard PRIVATE key - myPrivateKey, err := wgtypes.ParseKey(privateKey) - if err != nil { - log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error()) - return nil, err - } - - var mgmTlsEnabled bool - if mgmURL.Scheme == "https" { - mgmTlsEnabled = true - } - - log.Debugf("connecting to the Management service %s", mgmURL.String()) - mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled) - if err != nil { - log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err) - return nil, err - } - return mgmClient, err -} - -func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) { - serverKey, err := mgmClient.GetServerPublicKey() - if err != nil { - log.Errorf("failed while getting Management Service public key: %v", err) - return nil, nil, err - } - - sysInfo := system.GetInfo(ctx) - sysInfo.SetFlags( - config.RosenpassEnabled, - config.RosenpassPermissive, - config.ServerSSHAllowed, - config.DisableClientRoutes, - config.DisableServerRoutes, - config.DisableDNS, - config.DisableFirewall, - config.BlockLANAccess, - config.BlockInbound, - config.LazyConnectionEnabled, - config.EnableSSHRoot, - config.EnableSSHSFTP, - config.EnableSSHLocalPortForwarding, - config.EnableSSHRemotePortForwarding, - config.DisableSSHAuth, - ) - loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) - return serverKey, loginResp, err -} - -// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key. -// Otherwise tries to register with the provided setupKey via command line. -func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) { - validSetupKey, err := uuid.Parse(setupKey) - if err != nil && jwtToken == "" { - return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err) - } - - log.Debugf("sending peer registration request to Management Service") - info := system.GetInfo(ctx) - info.SetFlags( - config.RosenpassEnabled, - config.RosenpassPermissive, - config.ServerSSHAllowed, - config.DisableClientRoutes, - config.DisableServerRoutes, - config.DisableDNS, - config.DisableFirewall, - config.BlockLANAccess, - config.BlockInbound, - config.LazyConnectionEnabled, - config.EnableSSHRoot, - config.EnableSSHSFTP, - config.EnableSSHLocalPortForwarding, - config.EnableSSHRemotePortForwarding, - config.DisableSSHAuth, - ) - loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels) - if err != nil { - log.Errorf("failed registering peer %v", err) - return nil, err - } - - log.Infof("peer has been successfully registered on Management Service") - - return loginResp, nil -} - -func isLoginNeeded(err error) bool { - if err == nil { - return false - } - s, ok := status.FromError(err) - if !ok { - return false - } - if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied { - return true - } - return false -} - -func isRegistrationNeeded(err error) bool { - if err == nil { - return false - } - s, ok := status.FromError(err) - if !ok { - return false - } - if s.Code() == codes.PermissionDenied { - return true - } - return false -} diff --git a/client/internal/pkce_auth.go b/client/internal/pkce_auth.go deleted file mode 100644 index 23c92e8af..000000000 --- a/client/internal/pkce_auth.go +++ /dev/null @@ -1,138 +0,0 @@ -package internal - -import ( - "context" - "crypto/tls" - "fmt" - "net/url" - - log "github.com/sirupsen/logrus" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - - mgm "github.com/netbirdio/netbird/shared/management/client" - "github.com/netbirdio/netbird/shared/management/client/common" -) - -// PKCEAuthorizationFlow represents PKCE Authorization Flow information -type PKCEAuthorizationFlow struct { - ProviderConfig PKCEAuthProviderConfig -} - -// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow -type PKCEAuthProviderConfig struct { - // ClientID An IDP application client id - ClientID string - // ClientSecret An IDP application client secret - ClientSecret string - // Audience An Audience for to authorization validation - Audience string - // TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token - TokenEndpoint string - // AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code - AuthorizationEndpoint string - // Scopes provides the scopes to be included in the token request - Scope string - // RedirectURL handles authorization code from IDP manager - RedirectURLs []string - // UseIDToken indicates if the id token should be used for authentication - UseIDToken bool - // ClientCertPair is used for mTLS authentication to the IDP - ClientCertPair *tls.Certificate - // DisablePromptLogin makes the PKCE flow to not prompt the user for login - DisablePromptLogin bool - // LoginFlag is used to configure the PKCE flow login behavior - LoginFlag common.LoginFlag - // LoginHint is used to pre-fill the email/username field during authentication - LoginHint string -} - -// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it -func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) { - // validate our peer's Wireguard PRIVATE key - myPrivateKey, err := wgtypes.ParseKey(privateKey) - if err != nil { - log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error()) - return PKCEAuthorizationFlow{}, err - } - - var mgmTLSEnabled bool - if mgmURL.Scheme == "https" { - mgmTLSEnabled = true - } - - log.Debugf("connecting to Management Service %s", mgmURL.String()) - mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled) - if err != nil { - log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err) - return PKCEAuthorizationFlow{}, err - } - log.Debugf("connected to the Management service %s", mgmURL.String()) - - defer func() { - err = mgmClient.Close() - if err != nil { - log.Warnf("failed to close the Management service client %v", err) - } - }() - - serverKey, err := mgmClient.GetServerPublicKey() - if err != nil { - log.Errorf("failed while getting Management Service public key: %v", err) - return PKCEAuthorizationFlow{}, err - } - - protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey) - if err != nil { - if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { - log.Warnf("server couldn't find pkce flow, contact admin: %v", err) - return PKCEAuthorizationFlow{}, err - } - log.Errorf("failed to retrieve pkce flow: %v", err) - return PKCEAuthorizationFlow{}, err - } - - authFlow := PKCEAuthorizationFlow{ - ProviderConfig: PKCEAuthProviderConfig{ - Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(), - ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(), - ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(), - TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(), - AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(), - Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(), - RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(), - UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(), - ClientCertPair: clientCert, - DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(), - LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()), - }, - } - - err = isPKCEProviderConfigValid(authFlow.ProviderConfig) - if err != nil { - return PKCEAuthorizationFlow{}, err - } - - return authFlow, nil -} - -func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error { - errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator" - if config.ClientID == "" { - return fmt.Errorf(errorMSGFormat, "Client ID") - } - if config.TokenEndpoint == "" { - return fmt.Errorf(errorMSGFormat, "Token Endpoint") - } - if config.AuthorizationEndpoint == "" { - return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint") - } - if config.Scope == "" { - return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes") - } - if config.RedirectURLs == nil { - return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs") - } - return nil -} diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index 8f467a214..722f20ad6 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -730,8 +730,8 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri return config, err } - client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled) - if err != nil { + client := mgm.NewClient(newURL.Host, key, mgmTlsEnabled) + if err := client.Connect(ctx); err != nil { log.Infof("couldn't switch to the new Management %s", newURL.String()) return config, err } @@ -743,8 +743,7 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri }() // gRPC check - _, err = client.GetServerPublicKey() - if err != nil { + if err := client.HealthCheck(ctx); err != nil { log.Infof("couldn't switch to the new Management %s", newURL.String()) return nil, err } diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 6d969bb12..b3f1a9214 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -208,7 +208,13 @@ func (c *Client) IsLoginRequired() bool { ConfigPath: c.cfgFile, }) - needsLogin, _ := internal.IsLoginRequired(ctx, cfg) + authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg) + if err != nil { + return true // Assume login is required if we can't create auth client + } + defer authClient.Close() + + needsLogin, _ := authClient.IsLoginRequired(ctx) return needsLogin } @@ -240,15 +246,17 @@ func (c *Client) LoginForMobile() string { // This could cause a potential race condition with loading the extension which need to be handled on swift side go func() { - waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second - waitCTX, cancel := context.WithTimeout(ctx, waitTimeout) - defer cancel() - tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo) + tokenInfo, err := oAuthFlow.WaitToken(ctx, flowInfo) if err != nil { return } jwtToken := tokenInfo.GetTokenToUse() - _ = internal.Login(ctx, cfg, "", jwtToken) + authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg) + if err != nil { + return + } + defer authClient.Close() + _ = authClient.Login(ctx, "", jwtToken) c.loginComplete = true }() diff --git a/client/ios/NetBirdSDK/login.go b/client/ios/NetBirdSDK/login.go index 570c44f80..3a7940ccc 100644 --- a/client/ios/NetBirdSDK/login.go +++ b/client/ios/NetBirdSDK/login.go @@ -3,15 +3,8 @@ package NetBirdSDK import ( "context" "fmt" - "time" - "github.com/cenkalti/backoff/v4" - log "github.com/sirupsen/logrus" - "google.golang.org/grpc/codes" - gstatus "google.golang.org/grpc/status" - - "github.com/netbirdio/netbird/client/cmd" - "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/system" ) @@ -71,30 +64,21 @@ func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth // If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO // is not supported and returns false without saving the configuration. For other errors return false. func (a *Auth) SaveConfigIfSSOSupported() (bool, error) { - supportsSSO := true - err := a.withBackOff(a.ctx, func() (err error) { - _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { - _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil) - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { - supportsSSO = false - err = nil - } + authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) + if err != nil { + return false, fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() - return err - } - - return err - }) + supportsSSO, err := authClient.IsSSOSupported(a.ctx) + if err != nil { + return false, fmt.Errorf("failed to check SSO support: %v", err) + } if !supportsSSO { return false, nil } - if err != nil { - return false, fmt.Errorf("backoff cycle failed: %v", err) - } - err = profilemanager.WriteOutConfig(a.cfgPath, a.config) return true, err } @@ -103,32 +87,31 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) { func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error { //nolint ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName) - - err := a.withBackOff(a.ctx, func() error { - backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "") - if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) { - // we got an answer from management, exit backoff earlier - return backoff.Permanent(backoffErr) - } - return backoffErr - }) + authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + return fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() + + err = authClient.Login(ctxWithValues, setupKey, "") + if err != nil { + return fmt.Errorf("login failed: %v", err) } return profilemanager.WriteOutConfig(a.cfgPath, a.config) } func (a *Auth) Login() error { - var needsLogin bool + authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config) + if err != nil { + return fmt.Errorf("failed to create auth client: %v", err) + } + defer authClient.Close() // check if we need to generate JWT token - err := a.withBackOff(a.ctx, func() (err error) { - needsLogin, err = internal.IsLoginRequired(a.ctx, a.config) - return - }) + needsLogin, err := authClient.IsLoginRequired(a.ctx) if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + return fmt.Errorf("failed to check login requirement: %v", err) } jwtToken := "" @@ -136,25 +119,10 @@ func (a *Auth) Login() error { return fmt.Errorf("Not authenticated") } - err = a.withBackOff(a.ctx, func() error { - err := internal.Login(a.ctx, a.config, "", jwtToken) - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { - return nil - } - return err - }) + err = authClient.Login(a.ctx, "", jwtToken) if err != nil { - return fmt.Errorf("backoff cycle failed: %v", err) + return fmt.Errorf("login failed: %v", err) } return nil } - -func (a *Auth) withBackOff(ctx context.Context, bf func() error) error { - return backoff.RetryNotify( - bf, - backoff.WithContext(cmd.CLIBackOffSettings, ctx), - func(err error, duration time.Duration) { - log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err) - }) -} diff --git a/client/server/server.go b/client/server/server.go index 49000c092..212c7adfb 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -278,14 +278,22 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil // loginAttempt attempts to login using the provided information. it returns a status in case something fails func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) { - var status internal.StatusType - err := internal.Login(ctx, s.config, setupKey, jwtToken) + authClient, err := auth.NewAuth(ctx, s.config.PrivateKey, s.config.ManagementURL, s.config) if err != nil { - if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { - log.Warnf("failed login: %v", err) + log.Errorf("failed to create auth client: %v", err) + return internal.StatusLoginFailed, err + } + defer authClient.Close() + + var status internal.StatusType + err = authClient.Login(ctx, setupKey, jwtToken) + if err != nil { + // Check if it's an authentication error (permission denied, invalid credentials, unauthenticated) + if errors.Is(err, mgm.ErrPermissionDenied) || errors.Is(err, mgm.ErrInvalidArgument) || errors.Is(err, mgm.ErrUnauthenticated) { + log.Warnf("authentication failed: %v", err) status = internal.StatusNeedsLogin } else { - log.Errorf("failed login: %v", err) + log.Errorf("login failed: %v", err) status = internal.StatusLoginFailed } return status, err @@ -606,8 +614,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin s.oauthAuthFlow.waitCancel() } - waitTimeout := time.Until(s.oauthAuthFlow.expiresAt) - waitCTX, cancel := context.WithTimeout(ctx, waitTimeout) + waitCTX, cancel := context.WithCancel(ctx) defer cancel() s.mutex.Lock() @@ -1002,8 +1009,8 @@ func (s *Server) sendLogoutRequestWithConfig(ctx context.Context, config *profil } mgmTlsEnabled := config.ManagementURL.Scheme == "https" - mgmClient, err := mgm.NewClient(ctx, config.ManagementURL.Host, key, mgmTlsEnabled) - if err != nil { + mgmClient := mgm.NewClient(config.ManagementURL.Host, key, mgmTlsEnabled) + if err := mgmClient.Connect(ctx); err != nil { return fmt.Errorf("connect to management server: %w", err) } defer func() { @@ -1012,7 +1019,7 @@ func (s *Server) sendLogoutRequestWithConfig(ctx context.Context, config *profil } }() - return mgmClient.Logout() + return mgmClient.Logout(ctx) } // Status returns the daemon status diff --git a/shared/management/client/client.go b/shared/management/client/client.go index 3126bcd1f..5aa6629e7 100644 --- a/shared/management/client/client.go +++ b/shared/management/client/client.go @@ -4,8 +4,6 @@ import ( "context" "io" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" @@ -14,13 +12,13 @@ import ( type Client interface { io.Closer Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error - GetServerPublicKey() (*wgtypes.Key, error) - Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) - Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) - GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) - GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) - GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) - IsHealthy() bool - SyncMeta(sysInfo *system.Info) error - Logout() error + Register(ctx context.Context, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) error + Login(ctx context.Context, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) + GetDeviceAuthorizationFlow(ctx context.Context) (*proto.DeviceAuthorizationFlow, error) + GetPKCEAuthorizationFlow(ctx context.Context) (*proto.PKCEAuthorizationFlow, error) + GetNetworkMap(ctx context.Context, sysInfo *system.Info) (*proto.NetworkMap, error) + IsHealthy(ctx context.Context) bool + HealthCheck(ctx context.Context) error + SyncMeta(ctx context.Context, sysInfo *system.Info) error + Logout(ctx context.Context) error } diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index 9e08317f6..08993e0e7 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -193,12 +193,12 @@ func TestClient_GetServerPublicKey(t *testing.T) { s, listener := startManagement(t) defer closeManagementSilently(s, listener) - client, err := NewClient(ctx, listener.Addr().String(), testKey, false) - if err != nil { + client := NewClient(listener.Addr().String(), testKey, false) + if err := client.Connect(ctx); err != nil { t.Fatal(err) } - key, err := client.GetServerPublicKey() + key, err := client.getServerPublicKey(ctx) if err != nil { t.Error("couldn't retrieve management public key") } @@ -216,16 +216,12 @@ func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) { s, listener := startManagement(t) defer closeManagementSilently(s, listener) - client, err := NewClient(ctx, listener.Addr().String(), testKey, false) - if err != nil { - t.Fatal(err) - } - key, err := client.GetServerPublicKey() - if err != nil { + client := NewClient(listener.Addr().String(), testKey, false) + if err := client.Connect(ctx); err != nil { t.Fatal(err) } sysInfo := system.GetInfo(context.TODO()) - _, err = client.Login(*key, sysInfo, nil, nil) + _, err = client.Login(ctx, sysInfo, nil, nil) if err == nil { t.Error("expecting err on unregistered login, got nil") } @@ -243,24 +239,16 @@ func TestClient_LoginRegistered(t *testing.T) { s, listener := startManagement(t) defer closeManagementSilently(s, listener) - client, err := NewClient(ctx, listener.Addr().String(), testKey, false) - if err != nil { + client := NewClient(listener.Addr().String(), testKey, false) + if err := client.Connect(ctx); err != nil { t.Fatal(err) } - key, err := client.GetServerPublicKey() - if err != nil { - t.Error(err) - } info := system.GetInfo(context.TODO()) - resp, err := client.Register(*key, ValidKey, "", info, nil, nil) + err = client.Register(ctx, ValidKey, "", info, nil, nil) if err != nil { t.Error(err) } - - if resp == nil { - t.Error("expecting non nil response, got nil") - } } func TestClient_Sync(t *testing.T) { @@ -272,18 +260,13 @@ func TestClient_Sync(t *testing.T) { s, listener := startManagement(t) defer closeManagementSilently(s, listener) - client, err := NewClient(ctx, listener.Addr().String(), testKey, false) - if err != nil { + client := NewClient(listener.Addr().String(), testKey, false) + if err := client.Connect(ctx); err != nil { t.Fatal(err) } - serverKey, err := client.GetServerPublicKey() - if err != nil { - t.Error(err) - } - info := system.GetInfo(context.TODO()) - _, err = client.Register(*serverKey, ValidKey, "", info, nil, nil) + err = client.Register(ctx, ValidKey, "", info, nil, nil) if err != nil { t.Error(err) } @@ -293,13 +276,14 @@ func TestClient_Sync(t *testing.T) { if err != nil { t.Error(err) } - remoteClient, err := NewClient(context.TODO(), listener.Addr().String(), remoteKey, false) - if err != nil { + remoteClient := NewClient(listener.Addr().String(), remoteKey, false) + remoteCtx := context.TODO() + if err := remoteClient.Connect(remoteCtx); err != nil { t.Fatal(err) } info = system.GetInfo(context.TODO()) - _, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil, nil) + err = remoteClient.Register(remoteCtx, ValidKey, "", info, nil, nil) if err != nil { t.Fatal(err) } @@ -354,14 +338,9 @@ func Test_SystemMetaDataFromClient(t *testing.T) { serverAddr := lis.Addr().String() ctx := context.Background() - testClient, err := NewClient(ctx, serverAddr, testKey, false) - if err != nil { - t.Fatalf("error while creating testClient: %v", err) - } - - key, err := testClient.GetServerPublicKey() - if err != nil { - t.Fatalf("error while getting server public key from testclient, %v", err) + testClient := NewClient(serverAddr, testKey, false) + if err := testClient.Connect(ctx); err != nil { + t.Fatalf("error while connecting testClient: %v", err) } var actualMeta *mgmtProto.PeerSystemMeta @@ -400,7 +379,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) { } info := system.GetInfo(context.TODO()) - _, err = testClient.Register(*key, ValidKey, "", info, nil, nil) + err = testClient.Register(ctx, ValidKey, "", info, nil, nil) if err != nil { t.Errorf("error while trying to register client: %v", err) } @@ -489,9 +468,9 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) { serverAddr := lis.Addr().String() ctx := context.Background() - client, err := NewClient(ctx, serverAddr, testKey, false) - if err != nil { - t.Fatalf("error while creating testClient: %v", err) + client := NewClient(serverAddr, testKey, false) + if err := client.Connect(ctx); err != nil { + t.Fatalf("error while connecting testClient: %v", err) } expectedFlowInfo := &mgmtProto.DeviceAuthorizationFlow{ @@ -512,7 +491,7 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) { }, nil } - flowInfo, err := client.GetDeviceAuthorizationFlow(serverKey) + flowInfo, err := client.GetDeviceAuthorizationFlow(ctx) if err != nil { t.Error("error while retrieving device auth flow information") } @@ -533,9 +512,9 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) { serverAddr := lis.Addr().String() ctx := context.Background() - client, err := NewClient(ctx, serverAddr, testKey, false) - if err != nil { - t.Fatalf("error while creating testClient: %v", err) + client := NewClient(serverAddr, testKey, false) + if err := client.Connect(ctx); err != nil { + t.Fatalf("error while connecting testClient: %v", err) } expectedFlowInfo := &mgmtProto.PKCEAuthorizationFlow{ @@ -558,7 +537,7 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) { }, nil } - flowInfo, err := client.GetPKCEAuthorizationFlow(serverKey) + flowInfo, err := client.GetPKCEAuthorizationFlow(ctx) if err != nil { t.Error("error while retrieving pkce auth flow information") } diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 520a83e36..04558e333 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -25,6 +25,24 @@ import ( "github.com/netbirdio/netbird/util/wsproxy" ) +// Custom management client errors that abstract away gRPC error codes +var ( + // ErrPermissionDenied is returned when the server denies access to a resource + ErrPermissionDenied = errors.New("permission denied") + + // ErrInvalidArgument is returned when the request contains invalid arguments + ErrInvalidArgument = errors.New("invalid argument") + + // ErrUnauthenticated is returned when authentication is required + ErrUnauthenticated = errors.New("unauthenticated") + + // ErrNotFound is returned when the requested resource is not found + ErrNotFound = errors.New("not found") + + // ErrUnimplemented is returned when the operation is not implemented + ErrUnimplemented = errors.New("not implemented") +) + const ConnectTimeout = 10 * time.Second const ( @@ -41,40 +59,67 @@ type ConnStateNotifier interface { type GrpcClient struct { key wgtypes.Key realClient proto.ManagementServiceClient - ctx context.Context conn *grpc.ClientConn connStateCallback ConnStateNotifier connStateCallbackLock sync.RWMutex + addr string + tlsEnabled bool + reconnectMutex sync.Mutex } // NewClient creates a new client to Management service -func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) { +// The client is not connected after creation - call Connect to establish the connection +func NewClient(addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) *GrpcClient { + return &GrpcClient{ + key: ourPrivateKey, + addr: addr, + tlsEnabled: tlsEnabled, + connStateCallbackLock: sync.RWMutex{}, + reconnectMutex: sync.Mutex{}, + } +} + +// Connect establishes a connection to the Management Service with retry logic +// Retries connection attempts with exponential backoff on failure +func (c *GrpcClient) Connect(ctx context.Context) error { var conn *grpc.ClientConn operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent) + conn, err = nbgrpc.CreateConnection(ctx, c.addr, c.tlsEnabled, wsproxy.ManagementComponent) if err != nil { - return fmt.Errorf("create connection: %w", err) + log.Warnf("failed to connect to Management Service: %v", err) + return err } return nil } - err := backoff.Retry(operation, nbgrpc.Backoff(ctx)) - if err != nil { - log.Errorf("failed creating connection to Management Service: %v", err) - return nil, err + if err := backoff.Retry(operation, defaultBackoff(ctx)); err != nil { + log.Errorf("failed creating connection to Management Service after retries: %v", err) + return fmt.Errorf("create connection: %w", err) } - realClient := proto.NewManagementServiceClient(conn) + c.conn = conn + c.realClient = proto.NewManagementServiceClient(conn) - return &GrpcClient{ - key: ourPrivateKey, - realClient: realClient, - ctx: ctx, - conn: conn, - connStateCallbackLock: sync.RWMutex{}, - }, nil + log.Infof("connected to the Management Service at %s", c.addr) + return nil +} + +// ConnectWithoutRetry establishes a connection to the Management Service without retry logic +// Performs a single connection attempt - callers should implement their own retry logic if needed +func (c *GrpcClient) ConnectWithoutRetry(ctx context.Context) error { + conn, err := nbgrpc.CreateConnection(ctx, c.addr, c.tlsEnabled, wsproxy.ManagementComponent) + if err != nil { + log.Warnf("failed to connect to Management Service: %v", err) + return fmt.Errorf("create connection: %w", err) + } + + c.conn = conn + c.realClient = proto.NewManagementServiceClient(conn) + + log.Debugf("connected to the Management Service at %s", c.addr) + return nil } // Close closes connection to the Management Service @@ -89,19 +134,6 @@ func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) { c.connStateCallback = notifier } -// defaultBackoff is a basic backoff mechanism for general issues -func defaultBackoff(ctx context.Context) backoff.BackOff { - return backoff.WithContext(&backoff.ExponentialBackOff{ - InitialInterval: 800 * time.Millisecond, - RandomizationFactor: 1, - Multiplier: 1.7, - MaxInterval: 10 * time.Second, - MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months - Stop: backoff.Stop, - Clock: backoff.SystemClock, - }, ctx) -} - // ready indicates whether the client is okay and ready to be used // for now it just checks whether gRPC connection to the service is ready func (c *GrpcClient) ready() bool { @@ -122,7 +154,7 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler return fmt.Errorf("connection to management is not ready and in %s state", connState) } - serverPubKey, err := c.GetServerPublicKey() + serverPubKey, err := c.getServerPublicKey(ctx) if err != nil { log.Debugf(errMsgMgmtPublicKey, err) return err @@ -177,15 +209,13 @@ func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, } // GetNetworkMap return with the network map -func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) { - serverPubKey, err := c.GetServerPublicKey() +func (c *GrpcClient) GetNetworkMap(ctx context.Context, sysInfo *system.Info) (*proto.NetworkMap, error) { + serverPubKey, err := c.getServerPublicKey(ctx) if err != nil { log.Debugf("failed getting Management Service public key: %s", err) return nil, err } - ctx, cancelStream := context.WithCancel(c.ctx) - defer cancelStream() stream, err := c.connectToStream(ctx, *serverPubKey, sysInfo) if err != nil { log.Debugf("failed to open Management Service stream: %s", err) @@ -264,30 +294,32 @@ func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, se } } -// GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server) -func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) { +// getServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server) +// This is a simple operation without retry logic - callers should handle retries at the operation level +func (c *GrpcClient) getServerPublicKey(ctx context.Context) (*wgtypes.Key, error) { if !c.ready() { return nil, errors.New(errMsgNoMgmtConnection) } - mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) + mgmCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() + resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{}) if err != nil { log.Errorf("failed while getting Management Service public key: %v", err) return nil, fmt.Errorf("failed while getting Management Service public key") } - serverKey, err := wgtypes.ParseKey(resp.Key) + key, err := wgtypes.ParseKey(resp.Key) if err != nil { return nil, err } - return &serverKey, nil + return &key, nil } // IsHealthy probes the gRPC connection and returns false on errors -func (c *GrpcClient) IsHealthy() bool { +func (c *GrpcClient) IsHealthy(ctx context.Context) bool { switch c.conn.GetState() { case connectivity.TransientFailure: return false @@ -299,10 +331,10 @@ func (c *GrpcClient) IsHealthy() bool { case connectivity.Ready: } - ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second) + healthCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() - _, err := c.realClient.GetServerKey(ctx, &proto.Empty{}) + _, err := c.realClient.GetServerKey(healthCtx, &proto.Empty{}) if err != nil { c.notifyDisconnected(err) log.Warnf("health check returned: %s", err) @@ -312,12 +344,26 @@ func (c *GrpcClient) IsHealthy() bool { return true } -func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) { +// HealthCheck verifies connectivity to the management server +// Returns an error if the server is not reachable +// Internally uses getServerPublicKey to verify the connection +func (c *GrpcClient) HealthCheck(ctx context.Context) error { + _, err := c.getServerPublicKey(ctx) + return err +} + +func (c *GrpcClient) login(ctx context.Context, req *proto.LoginRequest) (*proto.LoginResponse, error) { if !c.ready() { return nil, errors.New(errMsgNoMgmtConnection) } - loginReq, err := encryption.EncryptMessage(serverKey, c.key, req) + serverKey, err := c.getServerPublicKey(ctx) + if err != nil { + log.Debugf(errMsgMgmtPublicKey, err) + return nil, err + } + + loginReq, err := encryption.EncryptMessage(*serverKey, c.key, req) if err != nil { log.Errorf("failed to encrypt message: %s", err) return nil, err @@ -325,7 +371,7 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro var resp *proto.EncryptedMessage operation := func() error { - mgmCtx, cancel := context.WithTimeout(context.Background(), ConnectTimeout) + mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout) defer cancel() var err error @@ -344,14 +390,14 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro return nil } - err = backoff.Retry(operation, nbgrpc.Backoff(c.ctx)) + err = backoff.Retry(operation, nbgrpc.Backoff(ctx)) if err != nil { log.Errorf("failed to login to Management Service: %v", err) return nil, err } loginResp := &proto.LoginResponse{} - err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp) + err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, loginResp) if err != nil { log.Errorf("failed to decrypt login response: %s", err) return nil, err @@ -363,99 +409,135 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro // Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key // Takes care of encrypting and decrypting messages. // This method will also collect system info and send it with the request (e.g. hostname, os, etc) -func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { +// Returns custom errors: ErrPermissionDenied, ErrInvalidArgument, ErrUnauthenticated +func (c *GrpcClient) Register(ctx context.Context, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) error { keys := &proto.PeerKeys{ SshPubKey: pubSSHKey, WgPubKey: []byte(c.key.PublicKey().String()), } - return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()}) + _, err := c.login(ctx, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()}) + return wrapGRPCError(err) } // Login attempts login to Management Server. Takes care of encrypting and decrypting messages. -func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { +// Returns custom errors: ErrPermissionDenied, ErrInvalidArgument, ErrUnauthenticated +func (c *GrpcClient) Login(ctx context.Context, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { keys := &proto.PeerKeys{ SshPubKey: pubSSHKey, WgPubKey: []byte(c.key.PublicKey().String()), } - return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()}) + resp, err := c.login(ctx, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()}) + return resp, wrapGRPCError(err) } // GetDeviceAuthorizationFlow returns a device authorization flow information. // It also takes care of encrypting and decrypting messages. -func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) { - if !c.ready() { - return nil, fmt.Errorf("no connection to management in order to get device authorization flow") - } - mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2) - defer cancel() +// It automatically retries with backoff and reconnection on connection errors. +// Returns custom errors: ErrNotFound, ErrUnimplemented +func (c *GrpcClient) GetDeviceAuthorizationFlow(ctx context.Context) (*proto.DeviceAuthorizationFlow, error) { + var flowInfoResp *proto.DeviceAuthorizationFlow - message := &proto.DeviceAuthorizationFlowRequest{} - encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message) - if err != nil { - return nil, err - } + err := c.withRetry(ctx, func() error { + if !c.ready() { + return fmt.Errorf("no connection to management in order to get device authorization flow") + } - resp, err := c.realClient.GetDeviceAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{ - WgPubKey: c.key.PublicKey().String(), - Body: encryptedMSG}, - ) - if err != nil { - return nil, err - } + serverKey, err := c.getServerPublicKey(ctx) + if err != nil { + log.Debugf(errMsgMgmtPublicKey, err) + return err + } - flowInfoResp := &proto.DeviceAuthorizationFlow{} - err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp) - if err != nil { - errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err) - log.Error(errWithMSG) - return nil, errWithMSG - } + mgmCtx, cancel := context.WithTimeout(ctx, time.Second*2) + defer cancel() - return flowInfoResp, nil + message := &proto.DeviceAuthorizationFlowRequest{} + encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message) + if err != nil { + return err + } + + resp, err := c.realClient.GetDeviceAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{ + WgPubKey: c.key.PublicKey().String(), + Body: encryptedMSG}, + ) + if err != nil { + return err + } + + flowInfo := &proto.DeviceAuthorizationFlow{} + err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfo) + if err != nil { + errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err) + log.Error(errWithMSG) + return errWithMSG + } + + flowInfoResp = flowInfo + return nil + }) + + return flowInfoResp, wrapGRPCError(err) } // GetPKCEAuthorizationFlow returns a pkce authorization flow information. // It also takes care of encrypting and decrypting messages. -func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) { - if !c.ready() { - return nil, fmt.Errorf("no connection to management in order to get pkce authorization flow") - } - mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2) - defer cancel() +// It automatically retries with backoff and reconnection on connection errors. +// Returns custom errors: ErrNotFound, ErrUnimplemented +func (c *GrpcClient) GetPKCEAuthorizationFlow(ctx context.Context) (*proto.PKCEAuthorizationFlow, error) { + var flowInfoResp *proto.PKCEAuthorizationFlow - message := &proto.PKCEAuthorizationFlowRequest{} - encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message) - if err != nil { - return nil, err - } + err := c.withRetry(ctx, func() error { + if !c.ready() { + return fmt.Errorf("no connection to management in order to get pkce authorization flow") + } - resp, err := c.realClient.GetPKCEAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{ - WgPubKey: c.key.PublicKey().String(), - Body: encryptedMSG, + serverKey, err := c.getServerPublicKey(ctx) + if err != nil { + log.Debugf(errMsgMgmtPublicKey, err) + return err + } + + mgmCtx, cancel := context.WithTimeout(ctx, time.Second*2) + defer cancel() + + message := &proto.PKCEAuthorizationFlowRequest{} + encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message) + if err != nil { + return err + } + + resp, err := c.realClient.GetPKCEAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{ + WgPubKey: c.key.PublicKey().String(), + Body: encryptedMSG, + }) + if err != nil { + return err + } + + flowInfo := &proto.PKCEAuthorizationFlow{} + err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfo) + if err != nil { + errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err) + log.Error(errWithMSG) + return errWithMSG + } + + flowInfoResp = flowInfo + return nil }) - if err != nil { - return nil, err - } - flowInfoResp := &proto.PKCEAuthorizationFlow{} - err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp) - if err != nil { - errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err) - log.Error(errWithMSG) - return nil, errWithMSG - } - - return flowInfoResp, nil + return flowInfoResp, wrapGRPCError(err) } // SyncMeta sends updated system metadata to the Management Service. // It should be used if there is changes on peer posture check after initial sync. -func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error { +func (c *GrpcClient) SyncMeta(ctx context.Context, sysInfo *system.Info) error { if !c.ready() { return errors.New(errMsgNoMgmtConnection) } - serverPubKey, err := c.GetServerPublicKey() + serverPubKey, err := c.getServerPublicKey(ctx) if err != nil { log.Debugf(errMsgMgmtPublicKey, err) return err @@ -467,7 +549,7 @@ func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error { return err } - mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout) + mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout) defer cancel() _, err = c.realClient.SyncMeta(mgmCtx, &proto.EncryptedMessage{ @@ -497,13 +579,13 @@ func (c *GrpcClient) notifyConnected() { c.connStateCallback.MarkManagementConnected() } -func (c *GrpcClient) Logout() error { - serverKey, err := c.GetServerPublicKey() +func (c *GrpcClient) Logout(ctx context.Context) error { + serverKey, err := c.getServerPublicKey(ctx) if err != nil { return fmt.Errorf("get server public key: %w", err) } - mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*15) + mgmCtx, cancel := context.WithTimeout(ctx, time.Second*15) defer cancel() message := &proto.Empty{} @@ -523,6 +605,156 @@ func (c *GrpcClient) Logout() error { return nil } +// reconnect closes the current connection and creates a new one +func (c *GrpcClient) reconnect(ctx context.Context) error { + c.reconnectMutex.Lock() + defer c.reconnectMutex.Unlock() + + // Close existing connection + if c.conn != nil { + if err := c.conn.Close(); err != nil { + log.Debugf("error closing old connection: %v", err) + } + } + + // Create new connection + log.Debugf("reconnecting to Management Service %s", c.addr) + conn, err := nbgrpc.CreateConnection(ctx, c.addr, c.tlsEnabled, wsproxy.ManagementComponent) + if err != nil { + log.Errorf("failed reconnecting to Management Service %s: %v", c.addr, err) + return fmt.Errorf("reconnect: create connection: %w", err) + } + + c.conn = conn + c.realClient = proto.NewManagementServiceClient(conn) + log.Debugf("successfully reconnected to Management service %s", c.addr) + return nil +} + +// withRetry wraps an operation with exponential backoff retry logic +// It automatically reconnects on connection errors +func (c *GrpcClient) withRetry(ctx context.Context, operation func() error) error { + backoffSettings := &backoff.ExponentialBackOff{ + InitialInterval: 500 * time.Millisecond, + RandomizationFactor: 0.5, + Multiplier: 1.5, + MaxInterval: 10 * time.Second, + MaxElapsedTime: 2 * time.Minute, + Stop: backoff.Stop, + Clock: backoff.SystemClock, + } + backoffSettings.Reset() + + return backoff.RetryNotify( + func() error { + err := operation() + if err == nil { + return nil + } + + // If it's a connection error, attempt reconnection + if isConnectionError(err) { + log.Warnf("connection error detected, attempting reconnection: %v", err) + if reconnectErr := c.reconnect(ctx); reconnectErr != nil { + log.Errorf("reconnection failed: %v", reconnectErr) + return reconnectErr + } + // Return the original error to trigger retry with the new connection + return err + } + + // For authentication errors (InvalidArgument, PermissionDenied), don't retry + if isAuthenticationError(err) { + return backoff.Permanent(err) + } + + return err + }, + backoff.WithContext(backoffSettings, ctx), + func(err error, duration time.Duration) { + log.Warnf("operation failed, retrying in %v: %v", duration, err) + }, + ) +} + +// defaultBackoff is a basic backoff mechanism for general issues +func defaultBackoff(ctx context.Context) backoff.BackOff { + return backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: 800 * time.Millisecond, + RandomizationFactor: 1, + Multiplier: 1.7, + MaxInterval: 10 * time.Second, + MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months + Stop: backoff.Stop, + Clock: backoff.SystemClock, + }, ctx) +} + +// isConnectionError checks if the error is a connection-related error that should trigger reconnection +func isConnectionError(err error) bool { + if err == nil { + return false + } + s, ok := gstatus.FromError(err) + if !ok { + return false + } + // These error codes indicate connection issues + return s.Code() == codes.Unavailable || + s.Code() == codes.DeadlineExceeded || + s.Code() == codes.Canceled || + s.Code() == codes.Internal +} + +// isAuthenticationError checks if the error is an authentication-related error that should not be retried +func isAuthenticationError(err error) bool { + if err == nil { + return false + } + s, ok := gstatus.FromError(err) + if !ok { + return false + } + return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied +} + +// wrapGRPCError converts gRPC errors to custom management client errors +func wrapGRPCError(err error) error { + if err == nil { + return nil + } + + // Check if it's already a custom error + if errors.Is(err, ErrPermissionDenied) || + errors.Is(err, ErrInvalidArgument) || + errors.Is(err, ErrUnauthenticated) || + errors.Is(err, ErrNotFound) || + errors.Is(err, ErrUnimplemented) { + return err + } + + // Convert gRPC status errors to custom errors + s, ok := gstatus.FromError(err) + if !ok { + return err + } + + switch s.Code() { + case codes.PermissionDenied: + return fmt.Errorf("%w: %s", ErrPermissionDenied, s.Message()) + case codes.InvalidArgument: + return fmt.Errorf("%w: %s", ErrInvalidArgument, s.Message()) + case codes.Unauthenticated: + return fmt.Errorf("%w: %s", ErrUnauthenticated, s.Message()) + case codes.NotFound: + return fmt.Errorf("%w: %s", ErrNotFound, s.Message()) + case codes.Unimplemented: + return fmt.Errorf("%w: %s", ErrUnimplemented, s.Message()) + default: + return err + } +} + func infoToMetaData(info *system.Info) *proto.PeerSystemMeta { if info == nil { return nil diff --git a/shared/management/client/mock.go b/shared/management/client/mock.go index 29006c9c3..e81bc4865 100644 --- a/shared/management/client/mock.go +++ b/shared/management/client/mock.go @@ -3,8 +3,6 @@ package client import ( "context" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" @@ -13,17 +11,21 @@ import ( type MockClient struct { CloseFunc func() error SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error - GetServerPublicKeyFunc func() (*wgtypes.Key, error) - RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) - LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) - GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) - GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) - SyncMetaFunc func(sysInfo *system.Info) error - LogoutFunc func() error + RegisterFunc func(ctx context.Context, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) error + LoginFunc func(ctx context.Context, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) + GetDeviceAuthorizationFlowFunc func(ctx context.Context) (*proto.DeviceAuthorizationFlow, error) + GetPKCEAuthorizationFlowFunc func(ctx context.Context) (*proto.PKCEAuthorizationFlow, error) + SyncMetaFunc func(ctx context.Context, sysInfo *system.Info) error + HealthCheckFunc func(ctx context.Context) error + LogoutFunc func(ctx context.Context) error + IsHealthyFunc func(ctx context.Context) bool } -func (m *MockClient) IsHealthy() bool { - return true +func (m *MockClient) IsHealthy(ctx context.Context) bool { + if m.IsHealthyFunc == nil { + return true + } + return m.IsHealthyFunc(ctx) } func (m *MockClient) Close() error { @@ -40,56 +42,56 @@ func (m *MockClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler return m.SyncFunc(ctx, sysInfo, msgHandler) } -func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) { - if m.GetServerPublicKeyFunc == nil { - return nil, nil - } - return m.GetServerPublicKeyFunc() -} - -func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { +func (m *MockClient) Register(ctx context.Context, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) error { if m.RegisterFunc == nil { - return nil, nil + return nil } - return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey, dnsLabels) + return m.RegisterFunc(ctx, setupKey, jwtToken, info, sshKey, dnsLabels) } -func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { +func (m *MockClient) Login(ctx context.Context, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { if m.LoginFunc == nil { return nil, nil } - return m.LoginFunc(serverKey, info, sshKey, dnsLabels) + return m.LoginFunc(ctx, info, sshKey, dnsLabels) } -func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) { +func (m *MockClient) GetDeviceAuthorizationFlow(ctx context.Context) (*proto.DeviceAuthorizationFlow, error) { if m.GetDeviceAuthorizationFlowFunc == nil { return nil, nil } - return m.GetDeviceAuthorizationFlowFunc(serverKey) + return m.GetDeviceAuthorizationFlowFunc(ctx) } -func (m *MockClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) { +func (m *MockClient) GetPKCEAuthorizationFlow(ctx context.Context) (*proto.PKCEAuthorizationFlow, error) { if m.GetPKCEAuthorizationFlowFunc == nil { return nil, nil } - return m.GetPKCEAuthorizationFlow(serverKey) + return m.GetPKCEAuthorizationFlowFunc(ctx) } // GetNetworkMap mock implementation of GetNetworkMap from mgm.Client interface -func (m *MockClient) GetNetworkMap(_ *system.Info) (*proto.NetworkMap, error) { +func (m *MockClient) GetNetworkMap(ctx context.Context, _ *system.Info) (*proto.NetworkMap, error) { return nil, nil } -func (m *MockClient) SyncMeta(sysInfo *system.Info) error { +func (m *MockClient) SyncMeta(ctx context.Context, sysInfo *system.Info) error { if m.SyncMetaFunc == nil { return nil } - return m.SyncMetaFunc(sysInfo) + return m.SyncMetaFunc(ctx, sysInfo) } -func (m *MockClient) Logout() error { +func (m *MockClient) HealthCheck(ctx context.Context) error { + if m.HealthCheckFunc == nil { + return nil + } + return m.HealthCheckFunc(ctx) +} + +func (m *MockClient) Logout(ctx context.Context) error { if m.LogoutFunc == nil { return nil } - return m.LogoutFunc() + return m.LogoutFunc(ctx) }