mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 01:36:46 +00:00
Compare commits
1 Commits
feature/up
...
retries-re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4b3e1f1b52 |
@@ -3,15 +3,7 @@ package android
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
gstatus "google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/cmd"
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/auth"
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
@@ -84,34 +76,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
|
||||||
supportsSSO := true
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
if err != nil {
|
||||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
return false, fmt.Errorf("failed to create auth client: %v", err)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
}
|
||||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
defer authClient.Close()
|
||||||
s, ok := gstatus.FromError(err)
|
|
||||||
if !ok {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
|
|
||||||
supportsSSO = false
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
|
||||||
}
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to check SSO support: %v", err)
|
||||||
return err
|
}
|
||||||
})
|
|
||||||
|
|
||||||
if !supportsSSO {
|
if !supportsSSO {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
@@ -131,17 +110,15 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
|
|||||||
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||||
//nolint
|
//nolint
|
||||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||||
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
err := a.withBackOff(a.ctx, func() error {
|
|
||||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
|
||||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
|
||||||
// we got an answer from management, exit backoff earlier
|
|
||||||
return backoff.Permanent(backoffErr)
|
|
||||||
}
|
|
||||||
return backoffErr
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
err = authClient.Login(ctxWithValues, setupKey, "")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||||
@@ -160,15 +137,16 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
||||||
var needsLogin bool
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
// check if we need to generate JWT token
|
// check if we need to generate JWT token
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
needsLogin, err := authClient.IsLoginRequired(a.ctx)
|
||||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
|
||||||
return
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
@@ -180,22 +158,13 @@ func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
|
|||||||
jwtToken = tokenInfo.GetTokenToUse()
|
jwtToken = tokenInfo.GetTokenToUse()
|
||||||
}
|
}
|
||||||
|
|
||||||
err = a.withBackOff(a.ctx, func() error {
|
err = authClient.Login(a.ctx, "", jwtToken)
|
||||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
go urlOpener.OnLoginSuccess()
|
|
||||||
}
|
|
||||||
|
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go urlOpener.OnLoginSuccess()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -212,22 +181,10 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*a
|
|||||||
|
|
||||||
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||||
|
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo)
|
||||||
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
|
|
||||||
defer cancel()
|
|
||||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &tokenInfo, nil
|
return &tokenInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
|
||||||
return backoff.RetryNotify(
|
|
||||||
bf,
|
|
||||||
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
|
|
||||||
func(err error, duration time.Duration) {
|
|
||||||
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,13 +2,13 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"os/user"
|
"os/user"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/skratchdot/open-golang/open"
|
"github.com/skratchdot/open-golang/open"
|
||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -277,18 +278,19 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
|
|||||||
}
|
}
|
||||||
|
|
||||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
||||||
|
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
needsLogin := false
|
needsLogin := false
|
||||||
|
|
||||||
err := WithBackOff(func() error {
|
err = authClient.Login(ctx, "", "")
|
||||||
err := internal.Login(ctx, config, "", "")
|
if errors.Is(err, mgm.ErrPermissionDenied) || errors.Is(err, mgm.ErrInvalidArgument) || errors.Is(err, mgm.ErrUnauthenticated) {
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
needsLogin = true
|
||||||
needsLogin = true
|
} else if err != nil {
|
||||||
return nil
|
return fmt.Errorf("login check failed: %v", err)
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
@@ -300,23 +302,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
|||||||
jwtToken = tokenInfo.GetTokenToUse()
|
jwtToken = tokenInfo.GetTokenToUse()
|
||||||
}
|
}
|
||||||
|
|
||||||
var lastError error
|
err = authClient.Login(ctx, setupKey, jwtToken)
|
||||||
|
|
||||||
err = WithBackOff(func() error {
|
|
||||||
err := internal.Login(ctx, config, setupKey, jwtToken)
|
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
|
||||||
lastError = err
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
|
|
||||||
if lastError != nil {
|
|
||||||
return fmt.Errorf("login failed: %v", lastError)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -344,11 +332,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
|||||||
|
|
||||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||||
|
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
|
||||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
|
||||||
defer c()
|
|
||||||
|
|
||||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||||
@@ -168,7 +169,13 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
ctx := internal.CtxInitState(context.Background())
|
ctx := internal.CtxInitState(context.Background())
|
||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||||
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create auth client: %w", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
if err := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
||||||
return fmt.Errorf("login: %w", err)
|
return fmt.Errorf("login: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
287
client/internal/auth/auth.go
Normal file
287
client/internal/auth/auth.go
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Auth manages authentication operations with the management server
|
||||||
|
// The underlying management client handles connection retry and reconnection automatically
|
||||||
|
type Auth struct {
|
||||||
|
client *mgm.GrpcClient
|
||||||
|
config *profilemanager.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuth creates a new Auth instance that manages authentication flows
|
||||||
|
// It establishes a connection to the management server that will be reused for all operations
|
||||||
|
// The management client handles connection retry and reconnection automatically
|
||||||
|
func NewAuth(ctx context.Context, privateKey string, mgmURL *url.URL, config *profilemanager.Config) (*Auth, error) {
|
||||||
|
// Validate WireGuard private key
|
||||||
|
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine TLS setting based on URL scheme
|
||||||
|
mgmTLSEnabled := mgmURL.Scheme == "https"
|
||||||
|
|
||||||
|
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
||||||
|
mgmClient := mgm.NewClient(mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
||||||
|
if err := mgmClient.Connect(ctx); err != nil {
|
||||||
|
log.Errorf("failed connecting to Management Service %s: %v", mgmURL.String(), err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("connected to the Management service %s", mgmURL.String())
|
||||||
|
|
||||||
|
return &Auth{
|
||||||
|
client: mgmClient,
|
||||||
|
config: config,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the management client connection
|
||||||
|
func (a *Auth) Close() error {
|
||||||
|
if a.client == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return a.client.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSSOSupported checks if the management server supports SSO by attempting to retrieve auth flow configurations.
|
||||||
|
// Returns true if either PKCE or Device authorization flow is supported, false otherwise.
|
||||||
|
func (a *Auth) IsSSOSupported(ctx context.Context) (bool, error) {
|
||||||
|
// Try PKCE flow first
|
||||||
|
_, err := a.getPKCEFlow(ctx)
|
||||||
|
if err == nil {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if PKCE is not supported
|
||||||
|
if errors.Is(err, mgm.ErrNotFound) || errors.Is(err, mgm.ErrUnimplemented) {
|
||||||
|
// PKCE not supported, try Device flow
|
||||||
|
_, err = a.getDeviceFlow(ctx)
|
||||||
|
if err == nil {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if Device flow is also not supported
|
||||||
|
if errors.Is(err, mgm.ErrNotFound) || errors.Is(err, mgm.ErrUnimplemented) {
|
||||||
|
// Neither PKCE nor Device flow is supported
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Device flow check returned an error other than NotFound/Unimplemented
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// PKCE flow check returned an error other than NotFound/Unimplemented
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLoginRequired checks if login is required by attempting to authenticate with the server
|
||||||
|
func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
|
||||||
|
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = a.doMgmLogin(ctx, pubSSHKey)
|
||||||
|
if isLoginNeeded(err) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login attempts to log in or register the client with the management server
|
||||||
|
// Returns custom errors from mgm package: ErrPermissionDenied, ErrInvalidArgument, ErrUnauthenticated
|
||||||
|
func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) error {
|
||||||
|
pubSSHKey, err := ssh.GeneratePublicKey([]byte(a.config.SSHKey))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate SSH public key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = a.doMgmLogin(ctx, pubSSHKey)
|
||||||
|
if isRegistrationNeeded(err) {
|
||||||
|
log.Debugf("peer registration required")
|
||||||
|
return a.registerPeer(ctx, setupKey, jwtToken, pubSSHKey)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
||||||
|
func (a *Auth) getPKCEFlow(ctx context.Context) (*PKCEAuthorizationFlow, error) {
|
||||||
|
protoFlow, err := a.client.GetPKCEAuthorizationFlow(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, mgm.ErrNotFound) {
|
||||||
|
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.Errorf("failed to retrieve pkce flow: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
protoConfig := protoFlow.GetProviderConfig()
|
||||||
|
config := &PKCEAuthProviderConfig{
|
||||||
|
Audience: protoConfig.GetAudience(),
|
||||||
|
ClientID: protoConfig.GetClientID(),
|
||||||
|
ClientSecret: protoConfig.GetClientSecret(),
|
||||||
|
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||||
|
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
||||||
|
Scope: protoConfig.GetScope(),
|
||||||
|
RedirectURLs: protoConfig.GetRedirectURLs(),
|
||||||
|
UseIDToken: protoConfig.GetUseIDToken(),
|
||||||
|
ClientCertPair: a.config.ClientCertKeyPair,
|
||||||
|
DisablePromptLogin: protoConfig.GetDisablePromptLogin(),
|
||||||
|
LoginFlag: common.LoginFlag(protoConfig.GetLoginFlag()),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validatePKCEConfig(config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
flow, err := NewPKCEAuthorizationFlow(*config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return flow, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
||||||
|
func (a *Auth) getDeviceFlow(ctx context.Context) (*DeviceAuthorizationFlow, error) {
|
||||||
|
protoFlow, err := a.client.GetDeviceAuthorizationFlow(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, mgm.ErrNotFound) {
|
||||||
|
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.Errorf("failed to retrieve device flow: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
protoConfig := protoFlow.GetProviderConfig()
|
||||||
|
config := &DeviceAuthProviderConfig{
|
||||||
|
Audience: protoConfig.GetAudience(),
|
||||||
|
ClientID: protoConfig.GetClientID(),
|
||||||
|
ClientSecret: protoConfig.GetClientSecret(),
|
||||||
|
Domain: protoConfig.Domain,
|
||||||
|
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||||
|
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
||||||
|
Scope: protoConfig.GetScope(),
|
||||||
|
UseIDToken: protoConfig.GetUseIDToken(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep compatibility with older management versions
|
||||||
|
if config.Scope == "" {
|
||||||
|
config.Scope = "openid"
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validateDeviceAuthConfig(config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
flow, err := NewDeviceAuthorizationFlow(*config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return flow, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// doMgmLogin performs the actual login operation with the management service
|
||||||
|
func (a *Auth) doMgmLogin(ctx context.Context, pubSSHKey []byte) error {
|
||||||
|
sysInfo := system.GetInfo(ctx)
|
||||||
|
sysInfo.SetFlags(
|
||||||
|
a.config.RosenpassEnabled,
|
||||||
|
a.config.RosenpassPermissive,
|
||||||
|
a.config.ServerSSHAllowed,
|
||||||
|
a.config.DisableClientRoutes,
|
||||||
|
a.config.DisableServerRoutes,
|
||||||
|
a.config.DisableDNS,
|
||||||
|
a.config.DisableFirewall,
|
||||||
|
a.config.BlockLANAccess,
|
||||||
|
a.config.BlockInbound,
|
||||||
|
a.config.LazyConnectionEnabled,
|
||||||
|
a.config.EnableSSHRoot,
|
||||||
|
a.config.EnableSSHSFTP,
|
||||||
|
a.config.EnableSSHLocalPortForwarding,
|
||||||
|
a.config.EnableSSHRemotePortForwarding,
|
||||||
|
a.config.DisableSSHAuth,
|
||||||
|
)
|
||||||
|
_, err := a.client.Login(ctx, sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||||
|
// Otherwise tries to register with the provided setupKey via command line.
|
||||||
|
func (a *Auth) registerPeer(ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) error {
|
||||||
|
validSetupKey, err := uuid.Parse(setupKey)
|
||||||
|
if err != nil && jwtToken == "" {
|
||||||
|
return fmt.Errorf("%w: invalid setup-key or no SSO information provided: %v", mgm.ErrInvalidArgument, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("sending peer registration request to Management Service")
|
||||||
|
info := system.GetInfo(ctx)
|
||||||
|
info.SetFlags(
|
||||||
|
a.config.RosenpassEnabled,
|
||||||
|
a.config.RosenpassPermissive,
|
||||||
|
a.config.ServerSSHAllowed,
|
||||||
|
a.config.DisableClientRoutes,
|
||||||
|
a.config.DisableServerRoutes,
|
||||||
|
a.config.DisableDNS,
|
||||||
|
a.config.DisableFirewall,
|
||||||
|
a.config.BlockLANAccess,
|
||||||
|
a.config.BlockInbound,
|
||||||
|
a.config.LazyConnectionEnabled,
|
||||||
|
a.config.EnableSSHRoot,
|
||||||
|
a.config.EnableSSHSFTP,
|
||||||
|
a.config.EnableSSHLocalPortForwarding,
|
||||||
|
a.config.EnableSSHRemotePortForwarding,
|
||||||
|
a.config.DisableSSHAuth,
|
||||||
|
)
|
||||||
|
|
||||||
|
// todo: fix error handling of validSetupKey
|
||||||
|
if err := a.client.Register(ctx, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels); err != nil {
|
||||||
|
log.Errorf("failed registering peer %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("peer has been successfully registered on Management Service")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPermissionDenied checks if the error is a PermissionDenied error
|
||||||
|
func isPermissionDenied(err error) bool {
|
||||||
|
return errors.Is(err, mgm.ErrPermissionDenied)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isLoginNeeded checks if the error indicates login is required
|
||||||
|
func isLoginNeeded(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return errors.Is(err, mgm.ErrInvalidArgument) ||
|
||||||
|
errors.Is(err, mgm.ErrPermissionDenied) ||
|
||||||
|
errors.Is(err, mgm.ErrUnauthenticated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isRegistrationNeeded checks if the error indicates peer registration is needed
|
||||||
|
func isRegistrationNeeded(err error) bool {
|
||||||
|
return isPermissionDenied(err)
|
||||||
|
}
|
||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -26,12 +25,56 @@ const (
|
|||||||
|
|
||||||
var _ OAuthFlow = &DeviceAuthorizationFlow{}
|
var _ OAuthFlow = &DeviceAuthorizationFlow{}
|
||||||
|
|
||||||
|
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
||||||
|
type DeviceAuthProviderConfig struct {
|
||||||
|
// ClientID An IDP application client id
|
||||||
|
ClientID string
|
||||||
|
// ClientSecret An IDP application client secret
|
||||||
|
ClientSecret string
|
||||||
|
// Domain An IDP API domain
|
||||||
|
// Deprecated. Use OIDCConfigEndpoint instead
|
||||||
|
Domain string
|
||||||
|
// Audience An Audience for to authorization validation
|
||||||
|
Audience string
|
||||||
|
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||||
|
TokenEndpoint string
|
||||||
|
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
||||||
|
DeviceAuthEndpoint string
|
||||||
|
// Scopes provides the scopes to be included in the token request
|
||||||
|
Scope string
|
||||||
|
// UseIDToken indicates if the id token should be used for authentication
|
||||||
|
UseIDToken bool
|
||||||
|
// LoginHint is used to pre-fill the email/username field during authentication
|
||||||
|
LoginHint string
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateDeviceAuthConfig validates device authorization provider configuration
|
||||||
|
func validateDeviceAuthConfig(config *DeviceAuthProviderConfig) error {
|
||||||
|
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||||
|
|
||||||
|
if config.Audience == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Audience")
|
||||||
|
}
|
||||||
|
if config.ClientID == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Client ID")
|
||||||
|
}
|
||||||
|
if config.TokenEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
|
||||||
|
}
|
||||||
|
if config.DeviceAuthEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Device Auth Endpoint")
|
||||||
|
}
|
||||||
|
if config.Scope == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Device Auth Scopes")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// DeviceAuthorizationFlow implements the OAuthFlow interface,
|
// DeviceAuthorizationFlow implements the OAuthFlow interface,
|
||||||
// for the Device Authorization Flow.
|
// for the Device Authorization Flow.
|
||||||
type DeviceAuthorizationFlow struct {
|
type DeviceAuthorizationFlow struct {
|
||||||
providerConfig internal.DeviceAuthProviderConfig
|
providerConfig DeviceAuthProviderConfig
|
||||||
|
HTTPClient HTTPClient
|
||||||
HTTPClient HTTPClient
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequestDeviceCodePayload used for request device code payload for auth0
|
// RequestDeviceCodePayload used for request device code payload for auth0
|
||||||
@@ -57,7 +100,7 @@ type TokenRequestResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewDeviceAuthorizationFlow returns device authorization flow client
|
// NewDeviceAuthorizationFlow returns device authorization flow client
|
||||||
func NewDeviceAuthorizationFlow(config internal.DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
func NewDeviceAuthorizationFlow(config DeviceAuthProviderConfig) (*DeviceAuthorizationFlow, error) {
|
||||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
httpTransport.MaxIdleConns = 5
|
httpTransport.MaxIdleConns = 5
|
||||||
|
|
||||||
@@ -89,6 +132,11 @@ func (d *DeviceAuthorizationFlow) GetClientID(ctx context.Context) string {
|
|||||||
return d.providerConfig.ClientID
|
return d.providerConfig.ClientID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLoginHint sets the login hint for the device authorization flow
|
||||||
|
func (d *DeviceAuthorizationFlow) SetLoginHint(hint string) {
|
||||||
|
d.providerConfig.LoginHint = hint
|
||||||
|
}
|
||||||
|
|
||||||
// RequestAuthInfo requests a device code login flow information from Hosted
|
// RequestAuthInfo requests a device code login flow information from Hosted
|
||||||
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
||||||
form := url.Values{}
|
form := url.Values{}
|
||||||
@@ -199,14 +247,22 @@ func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestR
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
// WaitToken waits user's login and authorize the app. Once the user's authorize
|
||||||
// it retrieves the access token from Hosted's endpoint and validates it before returning
|
// it retrieves the access token from Hosted's endpoint and validates it before returning.
|
||||||
|
// The method creates a timeout context internally based on info.ExpiresIn.
|
||||||
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||||
|
// Create timeout context based on flow expiration
|
||||||
|
timeout := time.Duration(info.ExpiresIn) * time.Second
|
||||||
|
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
interval := time.Duration(info.Interval) * time.Second
|
interval := time.Duration(info.Interval) * time.Second
|
||||||
ticker := time.NewTicker(interval)
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-waitCtx.Done():
|
||||||
return TokenInfo{}, ctx.Err()
|
return TokenInfo{}, waitCtx.Err()
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
|
|
||||||
tokenResponse, err := d.requestToken(info)
|
tokenResponse, err := d.requestToken(info)
|
||||||
|
|||||||
@@ -12,8 +12,6 @@ import (
|
|||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockHTTPClient struct {
|
type mockHTTPClient struct {
|
||||||
@@ -115,18 +113,19 @@ func TestHosted_RequestDeviceCode(t *testing.T) {
|
|||||||
err: testCase.inputReqError,
|
err: testCase.inputReqError,
|
||||||
}
|
}
|
||||||
|
|
||||||
deviceFlow := &DeviceAuthorizationFlow{
|
config := DeviceAuthProviderConfig{
|
||||||
providerConfig: internal.DeviceAuthProviderConfig{
|
Audience: expectedAudience,
|
||||||
Audience: expectedAudience,
|
ClientID: expectedClientID,
|
||||||
ClientID: expectedClientID,
|
Scope: expectedScope,
|
||||||
Scope: expectedScope,
|
TokenEndpoint: "test.hosted.com/token",
|
||||||
TokenEndpoint: "test.hosted.com/token",
|
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
UseIDToken: false,
|
||||||
UseIDToken: false,
|
|
||||||
},
|
|
||||||
HTTPClient: &httpClient,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
deviceFlow, err := NewDeviceAuthorizationFlow(config)
|
||||||
|
require.NoError(t, err, "creating device flow should not fail")
|
||||||
|
deviceFlow.HTTPClient = &httpClient
|
||||||
|
|
||||||
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
|
authInfo, err := deviceFlow.RequestAuthInfo(context.TODO())
|
||||||
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
||||||
|
|
||||||
@@ -280,18 +279,19 @@ func TestHosted_WaitToken(t *testing.T) {
|
|||||||
countResBody: testCase.inputCountResBody,
|
countResBody: testCase.inputCountResBody,
|
||||||
}
|
}
|
||||||
|
|
||||||
deviceFlow := DeviceAuthorizationFlow{
|
config := DeviceAuthProviderConfig{
|
||||||
providerConfig: internal.DeviceAuthProviderConfig{
|
Audience: testCase.inputAudience,
|
||||||
Audience: testCase.inputAudience,
|
ClientID: clientID,
|
||||||
ClientID: clientID,
|
TokenEndpoint: "test.hosted.com/token",
|
||||||
TokenEndpoint: "test.hosted.com/token",
|
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
||||||
DeviceAuthEndpoint: "test.hosted.com/device/auth",
|
Scope: "openid",
|
||||||
Scope: "openid",
|
UseIDToken: false,
|
||||||
UseIDToken: false,
|
|
||||||
},
|
|
||||||
HTTPClient: &httpClient,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
deviceFlow, err := NewDeviceAuthorizationFlow(config)
|
||||||
|
require.NoError(t, err, "creating device flow should not fail")
|
||||||
|
deviceFlow.HTTPClient = &httpClient
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)
|
tokenInfo, err := deviceFlow.WaitToken(ctx, testCase.inputInfo)
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -87,19 +86,33 @@ func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesk
|
|||||||
|
|
||||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
pkceFlowInfo, err := authClient.getPKCEFlow(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pkceFlowInfo.ProviderConfig.LoginHint = hint
|
if hint != "" {
|
||||||
|
pkceFlowInfo.SetLoginHint(hint)
|
||||||
|
}
|
||||||
|
|
||||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
return pkceFlowInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
authClient, err := NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
deviceFlowInfo, err := authClient.getDeviceFlow(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch s, ok := gstatus.FromError(err); {
|
switch s, ok := gstatus.FromError(err); {
|
||||||
case ok && s.Code() == codes.NotFound:
|
case ok && s.Code() == codes.NotFound:
|
||||||
@@ -114,7 +127,9 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
deviceFlowInfo.ProviderConfig.LoginHint = hint
|
if hint != "" {
|
||||||
|
deviceFlowInfo.SetLoginHint(hint)
|
||||||
|
}
|
||||||
|
|
||||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
return deviceFlowInfo, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/templates"
|
"github.com/netbirdio/netbird/client/internal/templates"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/client/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ OAuthFlow = &PKCEAuthorizationFlow{}
|
var _ OAuthFlow = &PKCEAuthorizationFlow{}
|
||||||
@@ -33,17 +33,67 @@ const (
|
|||||||
defaultPKCETimeoutSeconds = 300
|
defaultPKCETimeoutSeconds = 300
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PKCEAuthProviderConfig has all attributes needed to initiate PKCE authorization flow
|
||||||
|
type PKCEAuthProviderConfig struct {
|
||||||
|
// ClientID An IDP application client id
|
||||||
|
ClientID string
|
||||||
|
// ClientSecret An IDP application client secret
|
||||||
|
ClientSecret string
|
||||||
|
// Audience An Audience for to authorization validation
|
||||||
|
Audience string
|
||||||
|
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
||||||
|
TokenEndpoint string
|
||||||
|
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
||||||
|
AuthorizationEndpoint string
|
||||||
|
// Scopes provides the scopes to be included in the token request
|
||||||
|
Scope string
|
||||||
|
// RedirectURL handles authorization code from IDP manager
|
||||||
|
RedirectURLs []string
|
||||||
|
// UseIDToken indicates if the id token should be used for authentication
|
||||||
|
UseIDToken bool
|
||||||
|
// ClientCertPair is used for mTLS authentication to the IDP
|
||||||
|
ClientCertPair *tls.Certificate
|
||||||
|
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
||||||
|
DisablePromptLogin bool
|
||||||
|
// LoginFlag is used to configure the PKCE flow login behavior
|
||||||
|
LoginFlag common.LoginFlag
|
||||||
|
// LoginHint is used to pre-fill the email/username field during authentication
|
||||||
|
LoginHint string
|
||||||
|
}
|
||||||
|
|
||||||
|
// validatePKCEConfig validates PKCE provider configuration
|
||||||
|
func validatePKCEConfig(config *PKCEAuthProviderConfig) error {
|
||||||
|
errorMsgFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
||||||
|
|
||||||
|
if config.ClientID == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Client ID")
|
||||||
|
}
|
||||||
|
if config.TokenEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Token Endpoint")
|
||||||
|
}
|
||||||
|
if config.AuthorizationEndpoint == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "Authorization Auth Endpoint")
|
||||||
|
}
|
||||||
|
if config.Scope == "" {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "PKCE Auth Scopes")
|
||||||
|
}
|
||||||
|
if config.RedirectURLs == nil {
|
||||||
|
return fmt.Errorf(errorMsgFormat, "PKCE Redirect URLs")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// PKCEAuthorizationFlow implements the OAuthFlow interface for
|
// PKCEAuthorizationFlow implements the OAuthFlow interface for
|
||||||
// the Authorization Code Flow with PKCE.
|
// the Authorization Code Flow with PKCE.
|
||||||
type PKCEAuthorizationFlow struct {
|
type PKCEAuthorizationFlow struct {
|
||||||
providerConfig internal.PKCEAuthProviderConfig
|
providerConfig PKCEAuthProviderConfig
|
||||||
state string
|
state string
|
||||||
codeVerifier string
|
codeVerifier string
|
||||||
oAuthConfig *oauth2.Config
|
oAuthConfig *oauth2.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
|
// NewPKCEAuthorizationFlow returns new PKCE authorization code flow.
|
||||||
func NewPKCEAuthorizationFlow(config internal.PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
func NewPKCEAuthorizationFlow(config PKCEAuthProviderConfig) (*PKCEAuthorizationFlow, error) {
|
||||||
var availableRedirectURL string
|
var availableRedirectURL string
|
||||||
|
|
||||||
// find the first available redirect URL
|
// find the first available redirect URL
|
||||||
@@ -121,10 +171,21 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLoginHint sets the login hint for the PKCE authorization flow
|
||||||
|
func (p *PKCEAuthorizationFlow) SetLoginHint(hint string) {
|
||||||
|
p.providerConfig.LoginHint = hint
|
||||||
|
}
|
||||||
|
|
||||||
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
|
// WaitToken waits for the OAuth token in the PKCE Authorization Flow.
|
||||||
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
|
// It starts an HTTP server to receive the OAuth token callback and waits for the token or an error.
|
||||||
// Once the token is received, it is converted to TokenInfo and validated before returning.
|
// Once the token is received, it is converted to TokenInfo and validated before returning.
|
||||||
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (TokenInfo, error) {
|
// The method creates a timeout context internally based on info.ExpiresIn.
|
||||||
|
func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowInfo) (TokenInfo, error) {
|
||||||
|
// Create timeout context based on flow expiration
|
||||||
|
timeout := time.Duration(info.ExpiresIn) * time.Second
|
||||||
|
waitCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
tokenChan := make(chan *oauth2.Token, 1)
|
tokenChan := make(chan *oauth2.Token, 1)
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
|
|
||||||
@@ -135,7 +196,7 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
|||||||
|
|
||||||
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
||||||
defer func() {
|
defer func() {
|
||||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||||
@@ -146,8 +207,8 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
|||||||
go p.startServer(server, tokenChan, errChan)
|
go p.startServer(server, tokenChan, errChan)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-waitCtx.Done():
|
||||||
return TokenInfo{}, ctx.Err()
|
return TokenInfo{}, waitCtx.Err()
|
||||||
case token := <-tokenChan:
|
case token := <-tokenChan:
|
||||||
return p.parseOAuthToken(token)
|
return p.parseOAuthToken(token)
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client/common"
|
mgm "github.com/netbirdio/netbird/shared/management/client/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -41,7 +40,7 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
config := internal.PKCEAuthProviderConfig{
|
config := PKCEAuthProviderConfig{
|
||||||
ClientID: "test-client-id",
|
ClientID: "test-client-id",
|
||||||
Audience: "test-audience",
|
Audience: "test-audience",
|
||||||
TokenEndpoint: "https://test-token-endpoint.com/token",
|
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||||
|
|||||||
@@ -162,6 +162,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create management client once outside retry loop
|
||||||
|
mgmClient := mgm.NewClient(c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
||||||
|
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
|
||||||
|
mgmClient.SetConnStateListener(mgmNotifier)
|
||||||
|
|
||||||
defer c.statusRecorder.ClientStop()
|
defer c.statusRecorder.ClientStop()
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
// if context cancelled we not start new backoff cycle
|
// if context cancelled we not start new backoff cycle
|
||||||
@@ -180,12 +185,9 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host)
|
log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host)
|
||||||
mgmClient, err := mgm.NewClient(engineCtx, c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
|
if err := mgmClient.ConnectWithoutRetry(engineCtx); err != nil {
|
||||||
if err != nil {
|
return wrapErr(fmt.Errorf("failed connecting to Management Service: %w", err))
|
||||||
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
|
|
||||||
}
|
}
|
||||||
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
|
|
||||||
mgmClient.SetConnStateListener(mgmNotifier)
|
|
||||||
|
|
||||||
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
|
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -198,7 +200,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
|
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug(err)
|
log.Debug(err)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
if errors.Is(err, mgm.ErrPermissionDenied) {
|
||||||
state.Set(StatusNeedsLogin)
|
state.Set(StatusNeedsLogin)
|
||||||
_ = c.Stop()
|
_ = c.Stop()
|
||||||
return backoff.Permanent(wrapErr(err)) // unrecoverable error
|
return backoff.Permanent(wrapErr(err)) // unrecoverable error
|
||||||
@@ -320,7 +322,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
err = backoff.Retry(operation, backOff)
|
err = backoff.Retry(operation, backOff)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
if errors.Is(err, mgm.ErrPermissionDenied) {
|
||||||
state.Set(StatusNeedsLogin)
|
state.Set(StatusNeedsLogin)
|
||||||
_ = c.Stop()
|
_ = c.Stop()
|
||||||
}
|
}
|
||||||
@@ -504,12 +506,6 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
|
|||||||
|
|
||||||
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
||||||
|
|
||||||
serverPublicKey, err := client.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sysInfo := system.GetInfo(ctx)
|
sysInfo := system.GetInfo(ctx)
|
||||||
sysInfo.SetFlags(
|
sysInfo.SetFlags(
|
||||||
config.RosenpassEnabled,
|
config.RosenpassEnabled,
|
||||||
@@ -528,7 +524,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
|||||||
config.EnableSSHRemotePortForwarding,
|
config.EnableSSHRemotePortForwarding,
|
||||||
config.DisableSSHAuth,
|
config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
loginResp, err := client.Login(ctx, sysInfo, pubSSHKey, config.DNSLabels)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,136 +0,0 @@
|
|||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
|
||||||
type DeviceAuthorizationFlow struct {
|
|
||||||
Provider string
|
|
||||||
ProviderConfig DeviceAuthProviderConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeviceAuthProviderConfig has all attributes needed to initiate a device authorization flow
|
|
||||||
type DeviceAuthProviderConfig struct {
|
|
||||||
// ClientID An IDP application client id
|
|
||||||
ClientID string
|
|
||||||
// ClientSecret An IDP application client secret
|
|
||||||
ClientSecret string
|
|
||||||
// Domain An IDP API domain
|
|
||||||
// Deprecated. Use OIDCConfigEndpoint instead
|
|
||||||
Domain string
|
|
||||||
// Audience An Audience for to authorization validation
|
|
||||||
Audience string
|
|
||||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
|
||||||
TokenEndpoint string
|
|
||||||
// DeviceAuthEndpoint is the endpoint of an IDP manager where clients can obtain device authorization code
|
|
||||||
DeviceAuthEndpoint string
|
|
||||||
// Scopes provides the scopes to be included in the token request
|
|
||||||
Scope string
|
|
||||||
// UseIDToken indicates if the id token should be used for authentication
|
|
||||||
UseIDToken bool
|
|
||||||
// LoginHint is used to pre-fill the email/username field during authentication
|
|
||||||
LoginHint string
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
|
||||||
func GetDeviceAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (DeviceAuthorizationFlow, error) {
|
|
||||||
// validate our peer's Wireguard PRIVATE key
|
|
||||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var mgmTLSEnabled bool
|
|
||||||
if mgmURL.Scheme == "https" {
|
|
||||||
mgmTLSEnabled = true
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
|
||||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err = mgmClient.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to close the Management service client %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
serverKey, err := mgmClient.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
protoDeviceAuthorizationFlow, err := mgmClient.GetDeviceAuthorizationFlow(*serverKey)
|
|
||||||
if err != nil {
|
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
|
||||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
log.Errorf("failed to retrieve device flow: %v", err)
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
deviceAuthorizationFlow := DeviceAuthorizationFlow{
|
|
||||||
Provider: protoDeviceAuthorizationFlow.Provider.String(),
|
|
||||||
|
|
||||||
ProviderConfig: DeviceAuthProviderConfig{
|
|
||||||
Audience: protoDeviceAuthorizationFlow.GetProviderConfig().GetAudience(),
|
|
||||||
ClientID: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientID(),
|
|
||||||
ClientSecret: protoDeviceAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
|
||||||
Domain: protoDeviceAuthorizationFlow.GetProviderConfig().Domain,
|
|
||||||
TokenEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
|
||||||
DeviceAuthEndpoint: protoDeviceAuthorizationFlow.GetProviderConfig().GetDeviceAuthEndpoint(),
|
|
||||||
Scope: protoDeviceAuthorizationFlow.GetProviderConfig().GetScope(),
|
|
||||||
UseIDToken: protoDeviceAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// keep compatibility with older management versions
|
|
||||||
if deviceAuthorizationFlow.ProviderConfig.Scope == "" {
|
|
||||||
deviceAuthorizationFlow.ProviderConfig.Scope = "openid"
|
|
||||||
}
|
|
||||||
|
|
||||||
err = isDeviceAuthProviderConfigValid(deviceAuthorizationFlow.ProviderConfig)
|
|
||||||
if err != nil {
|
|
||||||
return DeviceAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return deviceAuthorizationFlow, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isDeviceAuthProviderConfigValid(config DeviceAuthProviderConfig) error {
|
|
||||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
|
||||||
if config.Audience == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Audience")
|
|
||||||
}
|
|
||||||
if config.ClientID == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
|
||||||
}
|
|
||||||
if config.TokenEndpoint == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
|
||||||
}
|
|
||||||
if config.DeviceAuthEndpoint == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Device Auth Endpoint")
|
|
||||||
}
|
|
||||||
if config.Scope == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Device Auth Scopes")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -891,7 +891,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
|||||||
e.config.DisableSSHAuth,
|
e.config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
if err := e.mgmClient.SyncMeta(e.ctx, info); err != nil {
|
||||||
log.Errorf("could not sync meta: error %s", err)
|
log.Errorf("could not sync meta: error %s", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1517,7 +1517,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
|||||||
e.config.DisableSSHAuth,
|
e.config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
netMap, err := e.mgmClient.GetNetworkMap(e.ctx, info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, false, err
|
return nil, nil, false, err
|
||||||
}
|
}
|
||||||
@@ -1666,7 +1666,7 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
|||||||
signalHealthy := e.signal.IsHealthy()
|
signalHealthy := e.signal.IsHealthy()
|
||||||
log.Debugf("signal health check: healthy=%t", signalHealthy)
|
log.Debugf("signal health check: healthy=%t", signalHealthy)
|
||||||
|
|
||||||
managementHealthy := e.mgmClient.IsHealthy()
|
managementHealthy := e.mgmClient.IsHealthy(e.ctx)
|
||||||
log.Debugf("management health check: healthy=%t", managementHealthy)
|
log.Debugf("management health check: healthy=%t", managementHealthy)
|
||||||
|
|
||||||
stuns := slices.Clone(e.STUNs)
|
stuns := slices.Clone(e.STUNs)
|
||||||
|
|||||||
@@ -1503,8 +1503,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
|
mgmtClient := mgmt.NewClient(mgmtAddr, key, false)
|
||||||
if err != nil {
|
if err := mgmtClient.Connect(ctx); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
|
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
|
||||||
@@ -1512,13 +1512,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
publicKey, err := mgmtClient.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
info := system.GetInfo(ctx)
|
info := system.GetInfo(ctx)
|
||||||
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil)
|
err = mgmtClient.Register(ctx, setupKey, "", info, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1531,9 +1526,10 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
wgPort := 33100 + i
|
wgPort := 33100 + i
|
||||||
|
testAddr := fmt.Sprintf("100.64.0.%d/24", i+1)
|
||||||
conf := &EngineConfig{
|
conf := &EngineConfig{
|
||||||
WgIfaceName: ifaceName,
|
WgIfaceName: ifaceName,
|
||||||
WgAddr: resp.PeerConfig.Address,
|
WgAddr: testAddr,
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: wgPort,
|
WgPort: wgPort,
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
|
|||||||
@@ -1,201 +0,0 @@
|
|||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// IsLoginRequired check that the server is support SSO or not
|
|
||||||
func IsLoginRequired(ctx context.Context, config *profilemanager.Config) (bool, error) {
|
|
||||||
mgmURL := config.ManagementURL
|
|
||||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = mgmClient.Close()
|
|
||||||
if err != nil {
|
|
||||||
cStatus, ok := status.FromError(err)
|
|
||||||
if !ok || ok && cStatus.Code() != codes.Canceled {
|
|
||||||
log.Warnf("failed to close the Management service client, err: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
|
||||||
|
|
||||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
|
||||||
if isLoginNeeded(err) {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Login or register the client
|
|
||||||
func Login(ctx context.Context, config *profilemanager.Config, setupKey string, jwtToken string) error {
|
|
||||||
mgmClient, err := getMgmClient(ctx, config.PrivateKey, config.ManagementURL)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = mgmClient.Close()
|
|
||||||
if err != nil {
|
|
||||||
cStatus, ok := status.FromError(err)
|
|
||||||
if !ok || ok && cStatus.Code() != codes.Canceled {
|
|
||||||
log.Warnf("failed to close the Management service client, err: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
log.Debugf("connected to the Management service %s", config.ManagementURL.String())
|
|
||||||
|
|
||||||
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
serverKey, _, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
|
|
||||||
if serverKey != nil && isRegistrationNeeded(err) {
|
|
||||||
log.Debugf("peer registration required")
|
|
||||||
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm.GrpcClient, error) {
|
|
||||||
// validate our peer's Wireguard PRIVATE key
|
|
||||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var mgmTlsEnabled bool
|
|
||||||
if mgmURL.Scheme == "https" {
|
|
||||||
mgmTlsEnabled = true
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("connecting to the Management service %s", mgmURL.String())
|
|
||||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTlsEnabled)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed connecting to the Management service %s %v", mgmURL.String(), err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return mgmClient, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *profilemanager.Config) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
|
||||||
serverKey, err := mgmClient.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sysInfo := system.GetInfo(ctx)
|
|
||||||
sysInfo.SetFlags(
|
|
||||||
config.RosenpassEnabled,
|
|
||||||
config.RosenpassPermissive,
|
|
||||||
config.ServerSSHAllowed,
|
|
||||||
config.DisableClientRoutes,
|
|
||||||
config.DisableServerRoutes,
|
|
||||||
config.DisableDNS,
|
|
||||||
config.DisableFirewall,
|
|
||||||
config.BlockLANAccess,
|
|
||||||
config.BlockInbound,
|
|
||||||
config.LazyConnectionEnabled,
|
|
||||||
config.EnableSSHRoot,
|
|
||||||
config.EnableSSHSFTP,
|
|
||||||
config.EnableSSHLocalPortForwarding,
|
|
||||||
config.EnableSSHRemotePortForwarding,
|
|
||||||
config.DisableSSHAuth,
|
|
||||||
)
|
|
||||||
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
|
|
||||||
return serverKey, loginResp, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
|
||||||
// Otherwise tries to register with the provided setupKey via command line.
|
|
||||||
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
|
||||||
validSetupKey, err := uuid.Parse(setupKey)
|
|
||||||
if err != nil && jwtToken == "" {
|
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("sending peer registration request to Management Service")
|
|
||||||
info := system.GetInfo(ctx)
|
|
||||||
info.SetFlags(
|
|
||||||
config.RosenpassEnabled,
|
|
||||||
config.RosenpassPermissive,
|
|
||||||
config.ServerSSHAllowed,
|
|
||||||
config.DisableClientRoutes,
|
|
||||||
config.DisableServerRoutes,
|
|
||||||
config.DisableDNS,
|
|
||||||
config.DisableFirewall,
|
|
||||||
config.BlockLANAccess,
|
|
||||||
config.BlockInbound,
|
|
||||||
config.LazyConnectionEnabled,
|
|
||||||
config.EnableSSHRoot,
|
|
||||||
config.EnableSSHSFTP,
|
|
||||||
config.EnableSSHLocalPortForwarding,
|
|
||||||
config.EnableSSHRemotePortForwarding,
|
|
||||||
config.DisableSSHAuth,
|
|
||||||
)
|
|
||||||
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed registering peer %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("peer has been successfully registered on Management Service")
|
|
||||||
|
|
||||||
return loginResp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isLoginNeeded(err error) bool {
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
s, ok := status.FromError(err)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func isRegistrationNeeded(err error) bool {
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
s, ok := status.FromError(err)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if s.Code() == codes.PermissionDenied {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
@@ -1,138 +0,0 @@
|
|||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/client/common"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PKCEAuthorizationFlow represents PKCE Authorization Flow information
|
|
||||||
type PKCEAuthorizationFlow struct {
|
|
||||||
ProviderConfig PKCEAuthProviderConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// PKCEAuthProviderConfig has all attributes needed to initiate pkce authorization flow
|
|
||||||
type PKCEAuthProviderConfig struct {
|
|
||||||
// ClientID An IDP application client id
|
|
||||||
ClientID string
|
|
||||||
// ClientSecret An IDP application client secret
|
|
||||||
ClientSecret string
|
|
||||||
// Audience An Audience for to authorization validation
|
|
||||||
Audience string
|
|
||||||
// TokenEndpoint is the endpoint of an IDP manager where clients can obtain access token
|
|
||||||
TokenEndpoint string
|
|
||||||
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code
|
|
||||||
AuthorizationEndpoint string
|
|
||||||
// Scopes provides the scopes to be included in the token request
|
|
||||||
Scope string
|
|
||||||
// RedirectURL handles authorization code from IDP manager
|
|
||||||
RedirectURLs []string
|
|
||||||
// UseIDToken indicates if the id token should be used for authentication
|
|
||||||
UseIDToken bool
|
|
||||||
// ClientCertPair is used for mTLS authentication to the IDP
|
|
||||||
ClientCertPair *tls.Certificate
|
|
||||||
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
|
||||||
DisablePromptLogin bool
|
|
||||||
// LoginFlag is used to configure the PKCE flow login behavior
|
|
||||||
LoginFlag common.LoginFlag
|
|
||||||
// LoginHint is used to pre-fill the email/username field during authentication
|
|
||||||
LoginHint string
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
|
||||||
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
|
|
||||||
// validate our peer's Wireguard PRIVATE key
|
|
||||||
myPrivateKey, err := wgtypes.ParseKey(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed parsing Wireguard key %s: [%s]", privateKey, err.Error())
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var mgmTLSEnabled bool
|
|
||||||
if mgmURL.Scheme == "https" {
|
|
||||||
mgmTLSEnabled = true
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("connecting to Management Service %s", mgmURL.String())
|
|
||||||
mgmClient, err := mgm.NewClient(ctx, mgmURL.Host, myPrivateKey, mgmTLSEnabled)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed connecting to Management Service %s %v", mgmURL.String(), err)
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
log.Debugf("connected to the Management service %s", mgmURL.String())
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err = mgmClient.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to close the Management service client %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
serverKey, err := mgmClient.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
protoPKCEAuthorizationFlow, err := mgmClient.GetPKCEAuthorizationFlow(*serverKey)
|
|
||||||
if err != nil {
|
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
|
||||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
log.Errorf("failed to retrieve pkce flow: %v", err)
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
authFlow := PKCEAuthorizationFlow{
|
|
||||||
ProviderConfig: PKCEAuthProviderConfig{
|
|
||||||
Audience: protoPKCEAuthorizationFlow.GetProviderConfig().GetAudience(),
|
|
||||||
ClientID: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientID(),
|
|
||||||
ClientSecret: protoPKCEAuthorizationFlow.GetProviderConfig().GetClientSecret(),
|
|
||||||
TokenEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetTokenEndpoint(),
|
|
||||||
AuthorizationEndpoint: protoPKCEAuthorizationFlow.GetProviderConfig().GetAuthorizationEndpoint(),
|
|
||||||
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
|
|
||||||
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
|
|
||||||
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
|
|
||||||
ClientCertPair: clientCert,
|
|
||||||
DisablePromptLogin: protoPKCEAuthorizationFlow.GetProviderConfig().GetDisablePromptLogin(),
|
|
||||||
LoginFlag: common.LoginFlag(protoPKCEAuthorizationFlow.GetProviderConfig().GetLoginFlag()),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
err = isPKCEProviderConfigValid(authFlow.ProviderConfig)
|
|
||||||
if err != nil {
|
|
||||||
return PKCEAuthorizationFlow{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return authFlow, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isPKCEProviderConfigValid(config PKCEAuthProviderConfig) error {
|
|
||||||
errorMSGFormat := "invalid provider configuration received from management: %s value is empty. Contact your NetBird administrator"
|
|
||||||
if config.ClientID == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Client ID")
|
|
||||||
}
|
|
||||||
if config.TokenEndpoint == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Token Endpoint")
|
|
||||||
}
|
|
||||||
if config.AuthorizationEndpoint == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "Authorization Auth Endpoint")
|
|
||||||
}
|
|
||||||
if config.Scope == "" {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "PKCE Auth Scopes")
|
|
||||||
}
|
|
||||||
if config.RedirectURLs == nil {
|
|
||||||
return fmt.Errorf(errorMSGFormat, "PKCE Redirect URLs")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -730,8 +730,8 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri
|
|||||||
return config, err
|
return config, err
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := mgm.NewClient(ctx, newURL.Host, key, mgmTlsEnabled)
|
client := mgm.NewClient(newURL.Host, key, mgmTlsEnabled)
|
||||||
if err != nil {
|
if err := client.Connect(ctx); err != nil {
|
||||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||||
return config, err
|
return config, err
|
||||||
}
|
}
|
||||||
@@ -743,8 +743,7 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// gRPC check
|
// gRPC check
|
||||||
_, err = client.GetServerPublicKey()
|
if err := client.HealthCheck(ctx); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -208,7 +208,13 @@ func (c *Client) IsLoginRequired() bool {
|
|||||||
ConfigPath: c.cfgFile,
|
ConfigPath: c.cfgFile,
|
||||||
})
|
})
|
||||||
|
|
||||||
needsLogin, _ := internal.IsLoginRequired(ctx, cfg)
|
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
|
||||||
|
if err != nil {
|
||||||
|
return true // Assume login is required if we can't create auth client
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
needsLogin, _ := authClient.IsLoginRequired(ctx)
|
||||||
return needsLogin
|
return needsLogin
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -240,15 +246,17 @@ func (c *Client) LoginForMobile() string {
|
|||||||
|
|
||||||
// This could cause a potential race condition with loading the extension which need to be handled on swift side
|
// This could cause a potential race condition with loading the extension which need to be handled on swift side
|
||||||
go func() {
|
go func() {
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
tokenInfo, err := oAuthFlow.WaitToken(ctx, flowInfo)
|
||||||
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
|
|
||||||
defer cancel()
|
|
||||||
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
jwtToken := tokenInfo.GetTokenToUse()
|
jwtToken := tokenInfo.GetTokenToUse()
|
||||||
_ = internal.Login(ctx, cfg, "", jwtToken)
|
authClient, err := auth.NewAuth(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
_ = authClient.Login(ctx, "", jwtToken)
|
||||||
c.loginComplete = true
|
c.loginComplete = true
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|||||||
@@ -3,15 +3,8 @@ package NetBirdSDK
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
gstatus "google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/cmd"
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
)
|
)
|
||||||
@@ -71,30 +64,21 @@ func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth
|
|||||||
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
|
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
|
||||||
// is not supported and returns false without saving the configuration. For other errors return false.
|
// is not supported and returns false without saving the configuration. For other errors return false.
|
||||||
func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
|
func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
|
||||||
supportsSSO := true
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
if err != nil {
|
||||||
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
|
return false, fmt.Errorf("failed to create auth client: %v", err)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
}
|
||||||
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
|
defer authClient.Close()
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
|
|
||||||
supportsSSO = false
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
|
||||||
}
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to check SSO support: %v", err)
|
||||||
return err
|
}
|
||||||
})
|
|
||||||
|
|
||||||
if !supportsSSO {
|
if !supportsSSO {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("backoff cycle failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
@@ -103,32 +87,31 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
|
|||||||
func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
func (a *Auth) LoginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
|
||||||
//nolint
|
//nolint
|
||||||
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)
|
||||||
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
err := a.withBackOff(a.ctx, func() error {
|
|
||||||
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
|
|
||||||
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
|
|
||||||
// we got an answer from management, exit backoff earlier
|
|
||||||
return backoff.Permanent(backoffErr)
|
|
||||||
}
|
|
||||||
return backoffErr
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
err = authClient.Login(ctxWithValues, setupKey, "")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
return profilemanager.WriteOutConfig(a.cfgPath, a.config)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) Login() error {
|
func (a *Auth) Login() error {
|
||||||
var needsLogin bool
|
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create auth client: %v", err)
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
// check if we need to generate JWT token
|
// check if we need to generate JWT token
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
needsLogin, err := authClient.IsLoginRequired(a.ctx)
|
||||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
|
||||||
return
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("failed to check login requirement: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
@@ -136,25 +119,10 @@ func (a *Auth) Login() error {
|
|||||||
return fmt.Errorf("Not authenticated")
|
return fmt.Errorf("Not authenticated")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = a.withBackOff(a.ctx, func() error {
|
err = authClient.Login(a.ctx, "", jwtToken)
|
||||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("backoff cycle failed: %v", err)
|
return fmt.Errorf("login failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
|
|
||||||
return backoff.RetryNotify(
|
|
||||||
bf,
|
|
||||||
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
|
|
||||||
func(err error, duration time.Duration) {
|
|
||||||
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -278,14 +278,22 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
|
|||||||
|
|
||||||
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
|
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
|
||||||
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
|
func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) {
|
||||||
var status internal.StatusType
|
authClient, err := auth.NewAuth(ctx, s.config.PrivateKey, s.config.ManagementURL, s.config)
|
||||||
err := internal.Login(ctx, s.config, setupKey, jwtToken)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
log.Errorf("failed to create auth client: %v", err)
|
||||||
log.Warnf("failed login: %v", err)
|
return internal.StatusLoginFailed, err
|
||||||
|
}
|
||||||
|
defer authClient.Close()
|
||||||
|
|
||||||
|
var status internal.StatusType
|
||||||
|
err = authClient.Login(ctx, setupKey, jwtToken)
|
||||||
|
if err != nil {
|
||||||
|
// Check if it's an authentication error (permission denied, invalid credentials, unauthenticated)
|
||||||
|
if errors.Is(err, mgm.ErrPermissionDenied) || errors.Is(err, mgm.ErrInvalidArgument) || errors.Is(err, mgm.ErrUnauthenticated) {
|
||||||
|
log.Warnf("authentication failed: %v", err)
|
||||||
status = internal.StatusNeedsLogin
|
status = internal.StatusNeedsLogin
|
||||||
} else {
|
} else {
|
||||||
log.Errorf("failed login: %v", err)
|
log.Errorf("login failed: %v", err)
|
||||||
status = internal.StatusLoginFailed
|
status = internal.StatusLoginFailed
|
||||||
}
|
}
|
||||||
return status, err
|
return status, err
|
||||||
@@ -606,8 +614,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
|||||||
s.oauthAuthFlow.waitCancel()
|
s.oauthAuthFlow.waitCancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
waitTimeout := time.Until(s.oauthAuthFlow.expiresAt)
|
waitCTX, cancel := context.WithCancel(ctx)
|
||||||
waitCTX, cancel := context.WithTimeout(ctx, waitTimeout)
|
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
@@ -1002,8 +1009,8 @@ func (s *Server) sendLogoutRequestWithConfig(ctx context.Context, config *profil
|
|||||||
}
|
}
|
||||||
|
|
||||||
mgmTlsEnabled := config.ManagementURL.Scheme == "https"
|
mgmTlsEnabled := config.ManagementURL.Scheme == "https"
|
||||||
mgmClient, err := mgm.NewClient(ctx, config.ManagementURL.Host, key, mgmTlsEnabled)
|
mgmClient := mgm.NewClient(config.ManagementURL.Host, key, mgmTlsEnabled)
|
||||||
if err != nil {
|
if err := mgmClient.Connect(ctx); err != nil {
|
||||||
return fmt.Errorf("connect to management server: %w", err)
|
return fmt.Errorf("connect to management server: %w", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -1012,7 +1019,7 @@ func (s *Server) sendLogoutRequestWithConfig(ctx context.Context, config *profil
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return mgmClient.Logout()
|
return mgmClient.Logout(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Status returns the daemon status
|
// Status returns the daemon status
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
@@ -14,13 +12,13 @@ import (
|
|||||||
type Client interface {
|
type Client interface {
|
||||||
io.Closer
|
io.Closer
|
||||||
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
|
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
|
||||||
GetServerPublicKey() (*wgtypes.Key, error)
|
Register(ctx context.Context, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) error
|
||||||
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
Login(ctx context.Context, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||||
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
GetDeviceAuthorizationFlow(ctx context.Context) (*proto.DeviceAuthorizationFlow, error)
|
||||||
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
|
GetPKCEAuthorizationFlow(ctx context.Context) (*proto.PKCEAuthorizationFlow, error)
|
||||||
GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
|
GetNetworkMap(ctx context.Context, sysInfo *system.Info) (*proto.NetworkMap, error)
|
||||||
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)
|
IsHealthy(ctx context.Context) bool
|
||||||
IsHealthy() bool
|
HealthCheck(ctx context.Context) error
|
||||||
SyncMeta(sysInfo *system.Info) error
|
SyncMeta(ctx context.Context, sysInfo *system.Info) error
|
||||||
Logout() error
|
Logout(ctx context.Context) error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -193,12 +193,12 @@ func TestClient_GetServerPublicKey(t *testing.T) {
|
|||||||
s, listener := startManagement(t)
|
s, listener := startManagement(t)
|
||||||
defer closeManagementSilently(s, listener)
|
defer closeManagementSilently(s, listener)
|
||||||
|
|
||||||
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
|
client := NewClient(listener.Addr().String(), testKey, false)
|
||||||
if err != nil {
|
if err := client.Connect(ctx); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := client.GetServerPublicKey()
|
key, err := client.getServerPublicKey(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("couldn't retrieve management public key")
|
t.Error("couldn't retrieve management public key")
|
||||||
}
|
}
|
||||||
@@ -216,16 +216,12 @@ func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
|
|||||||
s, listener := startManagement(t)
|
s, listener := startManagement(t)
|
||||||
defer closeManagementSilently(s, listener)
|
defer closeManagementSilently(s, listener)
|
||||||
|
|
||||||
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
|
client := NewClient(listener.Addr().String(), testKey, false)
|
||||||
if err != nil {
|
if err := client.Connect(ctx); err != nil {
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
key, err := client.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
sysInfo := system.GetInfo(context.TODO())
|
sysInfo := system.GetInfo(context.TODO())
|
||||||
_, err = client.Login(*key, sysInfo, nil, nil)
|
_, err = client.Login(ctx, sysInfo, nil, nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expecting err on unregistered login, got nil")
|
t.Error("expecting err on unregistered login, got nil")
|
||||||
}
|
}
|
||||||
@@ -243,24 +239,16 @@ func TestClient_LoginRegistered(t *testing.T) {
|
|||||||
s, listener := startManagement(t)
|
s, listener := startManagement(t)
|
||||||
defer closeManagementSilently(s, listener)
|
defer closeManagementSilently(s, listener)
|
||||||
|
|
||||||
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
|
client := NewClient(listener.Addr().String(), testKey, false)
|
||||||
if err != nil {
|
if err := client.Connect(ctx); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := client.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
info := system.GetInfo(context.TODO())
|
info := system.GetInfo(context.TODO())
|
||||||
resp, err := client.Register(*key, ValidKey, "", info, nil, nil)
|
err = client.Register(ctx, ValidKey, "", info, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp == nil {
|
|
||||||
t.Error("expecting non nil response, got nil")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_Sync(t *testing.T) {
|
func TestClient_Sync(t *testing.T) {
|
||||||
@@ -272,18 +260,13 @@ func TestClient_Sync(t *testing.T) {
|
|||||||
s, listener := startManagement(t)
|
s, listener := startManagement(t)
|
||||||
defer closeManagementSilently(s, listener)
|
defer closeManagementSilently(s, listener)
|
||||||
|
|
||||||
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
|
client := NewClient(listener.Addr().String(), testKey, false)
|
||||||
if err != nil {
|
if err := client.Connect(ctx); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
serverKey, err := client.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
info := system.GetInfo(context.TODO())
|
info := system.GetInfo(context.TODO())
|
||||||
_, err = client.Register(*serverKey, ValidKey, "", info, nil, nil)
|
err = client.Register(ctx, ValidKey, "", info, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -293,13 +276,14 @@ func TestClient_Sync(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
remoteClient, err := NewClient(context.TODO(), listener.Addr().String(), remoteKey, false)
|
remoteClient := NewClient(listener.Addr().String(), remoteKey, false)
|
||||||
if err != nil {
|
remoteCtx := context.TODO()
|
||||||
|
if err := remoteClient.Connect(remoteCtx); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
info = system.GetInfo(context.TODO())
|
info = system.GetInfo(context.TODO())
|
||||||
_, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil, nil)
|
err = remoteClient.Register(remoteCtx, ValidKey, "", info, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -354,14 +338,9 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
|
|||||||
serverAddr := lis.Addr().String()
|
serverAddr := lis.Addr().String()
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
testClient, err := NewClient(ctx, serverAddr, testKey, false)
|
testClient := NewClient(serverAddr, testKey, false)
|
||||||
if err != nil {
|
if err := testClient.Connect(ctx); err != nil {
|
||||||
t.Fatalf("error while creating testClient: %v", err)
|
t.Fatalf("error while connecting testClient: %v", err)
|
||||||
}
|
|
||||||
|
|
||||||
key, err := testClient.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("error while getting server public key from testclient, %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var actualMeta *mgmtProto.PeerSystemMeta
|
var actualMeta *mgmtProto.PeerSystemMeta
|
||||||
@@ -400,7 +379,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
info := system.GetInfo(context.TODO())
|
info := system.GetInfo(context.TODO())
|
||||||
_, err = testClient.Register(*key, ValidKey, "", info, nil, nil)
|
err = testClient.Register(ctx, ValidKey, "", info, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("error while trying to register client: %v", err)
|
t.Errorf("error while trying to register client: %v", err)
|
||||||
}
|
}
|
||||||
@@ -489,9 +468,9 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) {
|
|||||||
serverAddr := lis.Addr().String()
|
serverAddr := lis.Addr().String()
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
client, err := NewClient(ctx, serverAddr, testKey, false)
|
client := NewClient(serverAddr, testKey, false)
|
||||||
if err != nil {
|
if err := client.Connect(ctx); err != nil {
|
||||||
t.Fatalf("error while creating testClient: %v", err)
|
t.Fatalf("error while connecting testClient: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedFlowInfo := &mgmtProto.DeviceAuthorizationFlow{
|
expectedFlowInfo := &mgmtProto.DeviceAuthorizationFlow{
|
||||||
@@ -512,7 +491,7 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
flowInfo, err := client.GetDeviceAuthorizationFlow(serverKey)
|
flowInfo, err := client.GetDeviceAuthorizationFlow(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("error while retrieving device auth flow information")
|
t.Error("error while retrieving device auth flow information")
|
||||||
}
|
}
|
||||||
@@ -533,9 +512,9 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
|
|||||||
serverAddr := lis.Addr().String()
|
serverAddr := lis.Addr().String()
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
client, err := NewClient(ctx, serverAddr, testKey, false)
|
client := NewClient(serverAddr, testKey, false)
|
||||||
if err != nil {
|
if err := client.Connect(ctx); err != nil {
|
||||||
t.Fatalf("error while creating testClient: %v", err)
|
t.Fatalf("error while connecting testClient: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedFlowInfo := &mgmtProto.PKCEAuthorizationFlow{
|
expectedFlowInfo := &mgmtProto.PKCEAuthorizationFlow{
|
||||||
@@ -558,7 +537,7 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
flowInfo, err := client.GetPKCEAuthorizationFlow(serverKey)
|
flowInfo, err := client.GetPKCEAuthorizationFlow(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("error while retrieving pkce auth flow information")
|
t.Error("error while retrieving pkce auth flow information")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,6 +25,24 @@ import (
|
|||||||
"github.com/netbirdio/netbird/util/wsproxy"
|
"github.com/netbirdio/netbird/util/wsproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Custom management client errors that abstract away gRPC error codes
|
||||||
|
var (
|
||||||
|
// ErrPermissionDenied is returned when the server denies access to a resource
|
||||||
|
ErrPermissionDenied = errors.New("permission denied")
|
||||||
|
|
||||||
|
// ErrInvalidArgument is returned when the request contains invalid arguments
|
||||||
|
ErrInvalidArgument = errors.New("invalid argument")
|
||||||
|
|
||||||
|
// ErrUnauthenticated is returned when authentication is required
|
||||||
|
ErrUnauthenticated = errors.New("unauthenticated")
|
||||||
|
|
||||||
|
// ErrNotFound is returned when the requested resource is not found
|
||||||
|
ErrNotFound = errors.New("not found")
|
||||||
|
|
||||||
|
// ErrUnimplemented is returned when the operation is not implemented
|
||||||
|
ErrUnimplemented = errors.New("not implemented")
|
||||||
|
)
|
||||||
|
|
||||||
const ConnectTimeout = 10 * time.Second
|
const ConnectTimeout = 10 * time.Second
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -41,40 +59,67 @@ type ConnStateNotifier interface {
|
|||||||
type GrpcClient struct {
|
type GrpcClient struct {
|
||||||
key wgtypes.Key
|
key wgtypes.Key
|
||||||
realClient proto.ManagementServiceClient
|
realClient proto.ManagementServiceClient
|
||||||
ctx context.Context
|
|
||||||
conn *grpc.ClientConn
|
conn *grpc.ClientConn
|
||||||
connStateCallback ConnStateNotifier
|
connStateCallback ConnStateNotifier
|
||||||
connStateCallbackLock sync.RWMutex
|
connStateCallbackLock sync.RWMutex
|
||||||
|
addr string
|
||||||
|
tlsEnabled bool
|
||||||
|
reconnectMutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates a new client to Management service
|
// NewClient creates a new client to Management service
|
||||||
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
|
// The client is not connected after creation - call Connect to establish the connection
|
||||||
|
func NewClient(addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) *GrpcClient {
|
||||||
|
return &GrpcClient{
|
||||||
|
key: ourPrivateKey,
|
||||||
|
addr: addr,
|
||||||
|
tlsEnabled: tlsEnabled,
|
||||||
|
connStateCallbackLock: sync.RWMutex{},
|
||||||
|
reconnectMutex: sync.Mutex{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect establishes a connection to the Management Service with retry logic
|
||||||
|
// Retries connection attempts with exponential backoff on failure
|
||||||
|
func (c *GrpcClient) Connect(ctx context.Context) error {
|
||||||
var conn *grpc.ClientConn
|
var conn *grpc.ClientConn
|
||||||
|
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
var err error
|
var err error
|
||||||
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent)
|
conn, err = nbgrpc.CreateConnection(ctx, c.addr, c.tlsEnabled, wsproxy.ManagementComponent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create connection: %w", err)
|
log.Warnf("failed to connect to Management Service: %v", err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
|
if err := backoff.Retry(operation, defaultBackoff(ctx)); err != nil {
|
||||||
if err != nil {
|
log.Errorf("failed creating connection to Management Service after retries: %v", err)
|
||||||
log.Errorf("failed creating connection to Management Service: %v", err)
|
return fmt.Errorf("create connection: %w", err)
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
realClient := proto.NewManagementServiceClient(conn)
|
c.conn = conn
|
||||||
|
c.realClient = proto.NewManagementServiceClient(conn)
|
||||||
|
|
||||||
return &GrpcClient{
|
log.Infof("connected to the Management Service at %s", c.addr)
|
||||||
key: ourPrivateKey,
|
return nil
|
||||||
realClient: realClient,
|
}
|
||||||
ctx: ctx,
|
|
||||||
conn: conn,
|
// ConnectWithoutRetry establishes a connection to the Management Service without retry logic
|
||||||
connStateCallbackLock: sync.RWMutex{},
|
// Performs a single connection attempt - callers should implement their own retry logic if needed
|
||||||
}, nil
|
func (c *GrpcClient) ConnectWithoutRetry(ctx context.Context) error {
|
||||||
|
conn, err := nbgrpc.CreateConnection(ctx, c.addr, c.tlsEnabled, wsproxy.ManagementComponent)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to connect to Management Service: %v", err)
|
||||||
|
return fmt.Errorf("create connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.conn = conn
|
||||||
|
c.realClient = proto.NewManagementServiceClient(conn)
|
||||||
|
|
||||||
|
log.Debugf("connected to the Management Service at %s", c.addr)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes connection to the Management Service
|
// Close closes connection to the Management Service
|
||||||
@@ -89,19 +134,6 @@ func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) {
|
|||||||
c.connStateCallback = notifier
|
c.connStateCallback = notifier
|
||||||
}
|
}
|
||||||
|
|
||||||
// defaultBackoff is a basic backoff mechanism for general issues
|
|
||||||
func defaultBackoff(ctx context.Context) backoff.BackOff {
|
|
||||||
return backoff.WithContext(&backoff.ExponentialBackOff{
|
|
||||||
InitialInterval: 800 * time.Millisecond,
|
|
||||||
RandomizationFactor: 1,
|
|
||||||
Multiplier: 1.7,
|
|
||||||
MaxInterval: 10 * time.Second,
|
|
||||||
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
|
|
||||||
Stop: backoff.Stop,
|
|
||||||
Clock: backoff.SystemClock,
|
|
||||||
}, ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ready indicates whether the client is okay and ready to be used
|
// ready indicates whether the client is okay and ready to be used
|
||||||
// for now it just checks whether gRPC connection to the service is ready
|
// for now it just checks whether gRPC connection to the service is ready
|
||||||
func (c *GrpcClient) ready() bool {
|
func (c *GrpcClient) ready() bool {
|
||||||
@@ -122,7 +154,7 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
|
|||||||
return fmt.Errorf("connection to management is not ready and in %s state", connState)
|
return fmt.Errorf("connection to management is not ready and in %s state", connState)
|
||||||
}
|
}
|
||||||
|
|
||||||
serverPubKey, err := c.GetServerPublicKey()
|
serverPubKey, err := c.getServerPublicKey(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf(errMsgMgmtPublicKey, err)
|
log.Debugf(errMsgMgmtPublicKey, err)
|
||||||
return err
|
return err
|
||||||
@@ -177,15 +209,13 @@ func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetNetworkMap return with the network map
|
// GetNetworkMap return with the network map
|
||||||
func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) {
|
func (c *GrpcClient) GetNetworkMap(ctx context.Context, sysInfo *system.Info) (*proto.NetworkMap, error) {
|
||||||
serverPubKey, err := c.GetServerPublicKey()
|
serverPubKey, err := c.getServerPublicKey(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed getting Management Service public key: %s", err)
|
log.Debugf("failed getting Management Service public key: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancelStream := context.WithCancel(c.ctx)
|
|
||||||
defer cancelStream()
|
|
||||||
stream, err := c.connectToStream(ctx, *serverPubKey, sysInfo)
|
stream, err := c.connectToStream(ctx, *serverPubKey, sysInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to open Management Service stream: %s", err)
|
log.Debugf("failed to open Management Service stream: %s", err)
|
||||||
@@ -264,30 +294,32 @@ func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, se
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server)
|
// getServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server)
|
||||||
func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
// This is a simple operation without retry logic - callers should handle retries at the operation level
|
||||||
|
func (c *GrpcClient) getServerPublicKey(ctx context.Context) (*wgtypes.Key, error) {
|
||||||
if !c.ready() {
|
if !c.ready() {
|
||||||
return nil, errors.New(errMsgNoMgmtConnection)
|
return nil, errors.New(errMsgNoMgmtConnection)
|
||||||
}
|
}
|
||||||
|
|
||||||
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
|
mgmCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
|
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||||
return nil, fmt.Errorf("failed while getting Management Service public key")
|
return nil, fmt.Errorf("failed while getting Management Service public key")
|
||||||
}
|
}
|
||||||
|
|
||||||
serverKey, err := wgtypes.ParseKey(resp.Key)
|
key, err := wgtypes.ParseKey(resp.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &serverKey, nil
|
return &key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsHealthy probes the gRPC connection and returns false on errors
|
// IsHealthy probes the gRPC connection and returns false on errors
|
||||||
func (c *GrpcClient) IsHealthy() bool {
|
func (c *GrpcClient) IsHealthy(ctx context.Context) bool {
|
||||||
switch c.conn.GetState() {
|
switch c.conn.GetState() {
|
||||||
case connectivity.TransientFailure:
|
case connectivity.TransientFailure:
|
||||||
return false
|
return false
|
||||||
@@ -299,10 +331,10 @@ func (c *GrpcClient) IsHealthy() bool {
|
|||||||
case connectivity.Ready:
|
case connectivity.Ready:
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second)
|
healthCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
_, err := c.realClient.GetServerKey(ctx, &proto.Empty{})
|
_, err := c.realClient.GetServerKey(healthCtx, &proto.Empty{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.notifyDisconnected(err)
|
c.notifyDisconnected(err)
|
||||||
log.Warnf("health check returned: %s", err)
|
log.Warnf("health check returned: %s", err)
|
||||||
@@ -312,12 +344,26 @@ func (c *GrpcClient) IsHealthy() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
|
// HealthCheck verifies connectivity to the management server
|
||||||
|
// Returns an error if the server is not reachable
|
||||||
|
// Internally uses getServerPublicKey to verify the connection
|
||||||
|
func (c *GrpcClient) HealthCheck(ctx context.Context) error {
|
||||||
|
_, err := c.getServerPublicKey(ctx)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *GrpcClient) login(ctx context.Context, req *proto.LoginRequest) (*proto.LoginResponse, error) {
|
||||||
if !c.ready() {
|
if !c.ready() {
|
||||||
return nil, errors.New(errMsgNoMgmtConnection)
|
return nil, errors.New(errMsgNoMgmtConnection)
|
||||||
}
|
}
|
||||||
|
|
||||||
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
|
serverKey, err := c.getServerPublicKey(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf(errMsgMgmtPublicKey, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
loginReq, err := encryption.EncryptMessage(*serverKey, c.key, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to encrypt message: %s", err)
|
log.Errorf("failed to encrypt message: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -325,7 +371,7 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
|
|||||||
|
|
||||||
var resp *proto.EncryptedMessage
|
var resp *proto.EncryptedMessage
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
mgmCtx, cancel := context.WithTimeout(context.Background(), ConnectTimeout)
|
mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
@@ -344,14 +390,14 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = backoff.Retry(operation, nbgrpc.Backoff(c.ctx))
|
err = backoff.Retry(operation, nbgrpc.Backoff(ctx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to login to Management Service: %v", err)
|
log.Errorf("failed to login to Management Service: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
loginResp := &proto.LoginResponse{}
|
loginResp := &proto.LoginResponse{}
|
||||||
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
|
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, loginResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to decrypt login response: %s", err)
|
log.Errorf("failed to decrypt login response: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -363,99 +409,135 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
|
|||||||
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
|
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
|
||||||
// Takes care of encrypting and decrypting messages.
|
// Takes care of encrypting and decrypting messages.
|
||||||
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
|
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
|
||||||
func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
// Returns custom errors: ErrPermissionDenied, ErrInvalidArgument, ErrUnauthenticated
|
||||||
|
func (c *GrpcClient) Register(ctx context.Context, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) error {
|
||||||
keys := &proto.PeerKeys{
|
keys := &proto.PeerKeys{
|
||||||
SshPubKey: pubSSHKey,
|
SshPubKey: pubSSHKey,
|
||||||
WgPubKey: []byte(c.key.PublicKey().String()),
|
WgPubKey: []byte(c.key.PublicKey().String()),
|
||||||
}
|
}
|
||||||
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
|
_, err := c.login(ctx, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
|
||||||
|
return wrapGRPCError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
|
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
|
||||||
func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
// Returns custom errors: ErrPermissionDenied, ErrInvalidArgument, ErrUnauthenticated
|
||||||
|
func (c *GrpcClient) Login(ctx context.Context, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||||
keys := &proto.PeerKeys{
|
keys := &proto.PeerKeys{
|
||||||
SshPubKey: pubSSHKey,
|
SshPubKey: pubSSHKey,
|
||||||
WgPubKey: []byte(c.key.PublicKey().String()),
|
WgPubKey: []byte(c.key.PublicKey().String()),
|
||||||
}
|
}
|
||||||
return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
|
resp, err := c.login(ctx, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
|
||||||
|
return resp, wrapGRPCError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDeviceAuthorizationFlow returns a device authorization flow information.
|
// GetDeviceAuthorizationFlow returns a device authorization flow information.
|
||||||
// It also takes care of encrypting and decrypting messages.
|
// It also takes care of encrypting and decrypting messages.
|
||||||
func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
|
// It automatically retries with backoff and reconnection on connection errors.
|
||||||
if !c.ready() {
|
// Returns custom errors: ErrNotFound, ErrUnimplemented
|
||||||
return nil, fmt.Errorf("no connection to management in order to get device authorization flow")
|
func (c *GrpcClient) GetDeviceAuthorizationFlow(ctx context.Context) (*proto.DeviceAuthorizationFlow, error) {
|
||||||
}
|
var flowInfoResp *proto.DeviceAuthorizationFlow
|
||||||
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
message := &proto.DeviceAuthorizationFlowRequest{}
|
err := c.withRetry(ctx, func() error {
|
||||||
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
|
if !c.ready() {
|
||||||
if err != nil {
|
return fmt.Errorf("no connection to management in order to get device authorization flow")
|
||||||
return nil, err
|
}
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := c.realClient.GetDeviceAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{
|
serverKey, err := c.getServerPublicKey(ctx)
|
||||||
WgPubKey: c.key.PublicKey().String(),
|
if err != nil {
|
||||||
Body: encryptedMSG},
|
log.Debugf(errMsgMgmtPublicKey, err)
|
||||||
)
|
return err
|
||||||
if err != nil {
|
}
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
flowInfoResp := &proto.DeviceAuthorizationFlow{}
|
mgmCtx, cancel := context.WithTimeout(ctx, time.Second*2)
|
||||||
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
|
defer cancel()
|
||||||
if err != nil {
|
|
||||||
errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err)
|
|
||||||
log.Error(errWithMSG)
|
|
||||||
return nil, errWithMSG
|
|
||||||
}
|
|
||||||
|
|
||||||
return flowInfoResp, nil
|
message := &proto.DeviceAuthorizationFlowRequest{}
|
||||||
|
encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.realClient.GetDeviceAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{
|
||||||
|
WgPubKey: c.key.PublicKey().String(),
|
||||||
|
Body: encryptedMSG},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
flowInfo := &proto.DeviceAuthorizationFlow{}
|
||||||
|
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfo)
|
||||||
|
if err != nil {
|
||||||
|
errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err)
|
||||||
|
log.Error(errWithMSG)
|
||||||
|
return errWithMSG
|
||||||
|
}
|
||||||
|
|
||||||
|
flowInfoResp = flowInfo
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return flowInfoResp, wrapGRPCError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPKCEAuthorizationFlow returns a pkce authorization flow information.
|
// GetPKCEAuthorizationFlow returns a pkce authorization flow information.
|
||||||
// It also takes care of encrypting and decrypting messages.
|
// It also takes care of encrypting and decrypting messages.
|
||||||
func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
|
// It automatically retries with backoff and reconnection on connection errors.
|
||||||
if !c.ready() {
|
// Returns custom errors: ErrNotFound, ErrUnimplemented
|
||||||
return nil, fmt.Errorf("no connection to management in order to get pkce authorization flow")
|
func (c *GrpcClient) GetPKCEAuthorizationFlow(ctx context.Context) (*proto.PKCEAuthorizationFlow, error) {
|
||||||
}
|
var flowInfoResp *proto.PKCEAuthorizationFlow
|
||||||
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
message := &proto.PKCEAuthorizationFlowRequest{}
|
err := c.withRetry(ctx, func() error {
|
||||||
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
|
if !c.ready() {
|
||||||
if err != nil {
|
return fmt.Errorf("no connection to management in order to get pkce authorization flow")
|
||||||
return nil, err
|
}
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := c.realClient.GetPKCEAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{
|
serverKey, err := c.getServerPublicKey(ctx)
|
||||||
WgPubKey: c.key.PublicKey().String(),
|
if err != nil {
|
||||||
Body: encryptedMSG,
|
log.Debugf(errMsgMgmtPublicKey, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
mgmCtx, cancel := context.WithTimeout(ctx, time.Second*2)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
message := &proto.PKCEAuthorizationFlowRequest{}
|
||||||
|
encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.realClient.GetPKCEAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{
|
||||||
|
WgPubKey: c.key.PublicKey().String(),
|
||||||
|
Body: encryptedMSG,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
flowInfo := &proto.PKCEAuthorizationFlow{}
|
||||||
|
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfo)
|
||||||
|
if err != nil {
|
||||||
|
errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err)
|
||||||
|
log.Error(errWithMSG)
|
||||||
|
return errWithMSG
|
||||||
|
}
|
||||||
|
|
||||||
|
flowInfoResp = flowInfo
|
||||||
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
flowInfoResp := &proto.PKCEAuthorizationFlow{}
|
return flowInfoResp, wrapGRPCError(err)
|
||||||
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
|
|
||||||
if err != nil {
|
|
||||||
errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err)
|
|
||||||
log.Error(errWithMSG)
|
|
||||||
return nil, errWithMSG
|
|
||||||
}
|
|
||||||
|
|
||||||
return flowInfoResp, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SyncMeta sends updated system metadata to the Management Service.
|
// SyncMeta sends updated system metadata to the Management Service.
|
||||||
// It should be used if there is changes on peer posture check after initial sync.
|
// It should be used if there is changes on peer posture check after initial sync.
|
||||||
func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
|
func (c *GrpcClient) SyncMeta(ctx context.Context, sysInfo *system.Info) error {
|
||||||
if !c.ready() {
|
if !c.ready() {
|
||||||
return errors.New(errMsgNoMgmtConnection)
|
return errors.New(errMsgNoMgmtConnection)
|
||||||
}
|
}
|
||||||
|
|
||||||
serverPubKey, err := c.GetServerPublicKey()
|
serverPubKey, err := c.getServerPublicKey(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf(errMsgMgmtPublicKey, err)
|
log.Debugf(errMsgMgmtPublicKey, err)
|
||||||
return err
|
return err
|
||||||
@@ -467,7 +549,7 @@ func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
|
mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
_, err = c.realClient.SyncMeta(mgmCtx, &proto.EncryptedMessage{
|
_, err = c.realClient.SyncMeta(mgmCtx, &proto.EncryptedMessage{
|
||||||
@@ -497,13 +579,13 @@ func (c *GrpcClient) notifyConnected() {
|
|||||||
c.connStateCallback.MarkManagementConnected()
|
c.connStateCallback.MarkManagementConnected()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *GrpcClient) Logout() error {
|
func (c *GrpcClient) Logout(ctx context.Context) error {
|
||||||
serverKey, err := c.GetServerPublicKey()
|
serverKey, err := c.getServerPublicKey(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get server public key: %w", err)
|
return fmt.Errorf("get server public key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*15)
|
mgmCtx, cancel := context.WithTimeout(ctx, time.Second*15)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
message := &proto.Empty{}
|
message := &proto.Empty{}
|
||||||
@@ -523,6 +605,156 @@ func (c *GrpcClient) Logout() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reconnect closes the current connection and creates a new one
|
||||||
|
func (c *GrpcClient) reconnect(ctx context.Context) error {
|
||||||
|
c.reconnectMutex.Lock()
|
||||||
|
defer c.reconnectMutex.Unlock()
|
||||||
|
|
||||||
|
// Close existing connection
|
||||||
|
if c.conn != nil {
|
||||||
|
if err := c.conn.Close(); err != nil {
|
||||||
|
log.Debugf("error closing old connection: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new connection
|
||||||
|
log.Debugf("reconnecting to Management Service %s", c.addr)
|
||||||
|
conn, err := nbgrpc.CreateConnection(ctx, c.addr, c.tlsEnabled, wsproxy.ManagementComponent)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed reconnecting to Management Service %s: %v", c.addr, err)
|
||||||
|
return fmt.Errorf("reconnect: create connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.conn = conn
|
||||||
|
c.realClient = proto.NewManagementServiceClient(conn)
|
||||||
|
log.Debugf("successfully reconnected to Management service %s", c.addr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// withRetry wraps an operation with exponential backoff retry logic
|
||||||
|
// It automatically reconnects on connection errors
|
||||||
|
func (c *GrpcClient) withRetry(ctx context.Context, operation func() error) error {
|
||||||
|
backoffSettings := &backoff.ExponentialBackOff{
|
||||||
|
InitialInterval: 500 * time.Millisecond,
|
||||||
|
RandomizationFactor: 0.5,
|
||||||
|
Multiplier: 1.5,
|
||||||
|
MaxInterval: 10 * time.Second,
|
||||||
|
MaxElapsedTime: 2 * time.Minute,
|
||||||
|
Stop: backoff.Stop,
|
||||||
|
Clock: backoff.SystemClock,
|
||||||
|
}
|
||||||
|
backoffSettings.Reset()
|
||||||
|
|
||||||
|
return backoff.RetryNotify(
|
||||||
|
func() error {
|
||||||
|
err := operation()
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it's a connection error, attempt reconnection
|
||||||
|
if isConnectionError(err) {
|
||||||
|
log.Warnf("connection error detected, attempting reconnection: %v", err)
|
||||||
|
if reconnectErr := c.reconnect(ctx); reconnectErr != nil {
|
||||||
|
log.Errorf("reconnection failed: %v", reconnectErr)
|
||||||
|
return reconnectErr
|
||||||
|
}
|
||||||
|
// Return the original error to trigger retry with the new connection
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// For authentication errors (InvalidArgument, PermissionDenied), don't retry
|
||||||
|
if isAuthenticationError(err) {
|
||||||
|
return backoff.Permanent(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
backoff.WithContext(backoffSettings, ctx),
|
||||||
|
func(err error, duration time.Duration) {
|
||||||
|
log.Warnf("operation failed, retrying in %v: %v", duration, err)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultBackoff is a basic backoff mechanism for general issues
|
||||||
|
func defaultBackoff(ctx context.Context) backoff.BackOff {
|
||||||
|
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||||
|
InitialInterval: 800 * time.Millisecond,
|
||||||
|
RandomizationFactor: 1,
|
||||||
|
Multiplier: 1.7,
|
||||||
|
MaxInterval: 10 * time.Second,
|
||||||
|
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
|
||||||
|
Stop: backoff.Stop,
|
||||||
|
Clock: backoff.SystemClock,
|
||||||
|
}, ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isConnectionError checks if the error is a connection-related error that should trigger reconnection
|
||||||
|
func isConnectionError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s, ok := gstatus.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// These error codes indicate connection issues
|
||||||
|
return s.Code() == codes.Unavailable ||
|
||||||
|
s.Code() == codes.DeadlineExceeded ||
|
||||||
|
s.Code() == codes.Canceled ||
|
||||||
|
s.Code() == codes.Internal
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAuthenticationError checks if the error is an authentication-related error that should not be retried
|
||||||
|
func isAuthenticationError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s, ok := gstatus.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapGRPCError converts gRPC errors to custom management client errors
|
||||||
|
func wrapGRPCError(err error) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if it's already a custom error
|
||||||
|
if errors.Is(err, ErrPermissionDenied) ||
|
||||||
|
errors.Is(err, ErrInvalidArgument) ||
|
||||||
|
errors.Is(err, ErrUnauthenticated) ||
|
||||||
|
errors.Is(err, ErrNotFound) ||
|
||||||
|
errors.Is(err, ErrUnimplemented) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert gRPC status errors to custom errors
|
||||||
|
s, ok := gstatus.FromError(err)
|
||||||
|
if !ok {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch s.Code() {
|
||||||
|
case codes.PermissionDenied:
|
||||||
|
return fmt.Errorf("%w: %s", ErrPermissionDenied, s.Message())
|
||||||
|
case codes.InvalidArgument:
|
||||||
|
return fmt.Errorf("%w: %s", ErrInvalidArgument, s.Message())
|
||||||
|
case codes.Unauthenticated:
|
||||||
|
return fmt.Errorf("%w: %s", ErrUnauthenticated, s.Message())
|
||||||
|
case codes.NotFound:
|
||||||
|
return fmt.Errorf("%w: %s", ErrNotFound, s.Message())
|
||||||
|
case codes.Unimplemented:
|
||||||
|
return fmt.Errorf("%w: %s", ErrUnimplemented, s.Message())
|
||||||
|
default:
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
|
func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
|
||||||
if info == nil {
|
if info == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ package client
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
@@ -13,17 +11,21 @@ import (
|
|||||||
type MockClient struct {
|
type MockClient struct {
|
||||||
CloseFunc func() error
|
CloseFunc func() error
|
||||||
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
|
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
|
||||||
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
|
RegisterFunc func(ctx context.Context, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) error
|
||||||
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
LoginFunc func(ctx context.Context, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||||
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
GetDeviceAuthorizationFlowFunc func(ctx context.Context) (*proto.DeviceAuthorizationFlow, error)
|
||||||
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
|
GetPKCEAuthorizationFlowFunc func(ctx context.Context) (*proto.PKCEAuthorizationFlow, error)
|
||||||
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
|
SyncMetaFunc func(ctx context.Context, sysInfo *system.Info) error
|
||||||
SyncMetaFunc func(sysInfo *system.Info) error
|
HealthCheckFunc func(ctx context.Context) error
|
||||||
LogoutFunc func() error
|
LogoutFunc func(ctx context.Context) error
|
||||||
|
IsHealthyFunc func(ctx context.Context) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockClient) IsHealthy() bool {
|
func (m *MockClient) IsHealthy(ctx context.Context) bool {
|
||||||
return true
|
if m.IsHealthyFunc == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return m.IsHealthyFunc(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockClient) Close() error {
|
func (m *MockClient) Close() error {
|
||||||
@@ -40,56 +42,56 @@ func (m *MockClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
|
|||||||
return m.SyncFunc(ctx, sysInfo, msgHandler)
|
return m.SyncFunc(ctx, sysInfo, msgHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
func (m *MockClient) Register(ctx context.Context, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) error {
|
||||||
if m.GetServerPublicKeyFunc == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return m.GetServerPublicKeyFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
|
||||||
if m.RegisterFunc == nil {
|
if m.RegisterFunc == nil {
|
||||||
return nil, nil
|
return nil
|
||||||
}
|
}
|
||||||
return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey, dnsLabels)
|
return m.RegisterFunc(ctx, setupKey, jwtToken, info, sshKey, dnsLabels)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
func (m *MockClient) Login(ctx context.Context, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||||
if m.LoginFunc == nil {
|
if m.LoginFunc == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return m.LoginFunc(serverKey, info, sshKey, dnsLabels)
|
return m.LoginFunc(ctx, info, sshKey, dnsLabels)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
|
func (m *MockClient) GetDeviceAuthorizationFlow(ctx context.Context) (*proto.DeviceAuthorizationFlow, error) {
|
||||||
if m.GetDeviceAuthorizationFlowFunc == nil {
|
if m.GetDeviceAuthorizationFlowFunc == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return m.GetDeviceAuthorizationFlowFunc(serverKey)
|
return m.GetDeviceAuthorizationFlowFunc(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
|
func (m *MockClient) GetPKCEAuthorizationFlow(ctx context.Context) (*proto.PKCEAuthorizationFlow, error) {
|
||||||
if m.GetPKCEAuthorizationFlowFunc == nil {
|
if m.GetPKCEAuthorizationFlowFunc == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return m.GetPKCEAuthorizationFlow(serverKey)
|
return m.GetPKCEAuthorizationFlowFunc(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNetworkMap mock implementation of GetNetworkMap from mgm.Client interface
|
// GetNetworkMap mock implementation of GetNetworkMap from mgm.Client interface
|
||||||
func (m *MockClient) GetNetworkMap(_ *system.Info) (*proto.NetworkMap, error) {
|
func (m *MockClient) GetNetworkMap(ctx context.Context, _ *system.Info) (*proto.NetworkMap, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockClient) SyncMeta(sysInfo *system.Info) error {
|
func (m *MockClient) SyncMeta(ctx context.Context, sysInfo *system.Info) error {
|
||||||
if m.SyncMetaFunc == nil {
|
if m.SyncMetaFunc == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return m.SyncMetaFunc(sysInfo)
|
return m.SyncMetaFunc(ctx, sysInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockClient) Logout() error {
|
func (m *MockClient) HealthCheck(ctx context.Context) error {
|
||||||
|
if m.HealthCheckFunc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return m.HealthCheckFunc(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockClient) Logout(ctx context.Context) error {
|
||||||
if m.LogoutFunc == nil {
|
if m.LogoutFunc == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return m.LogoutFunc()
|
return m.LogoutFunc(ctx)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user