Refactor gRPC auth and retry logic

- Centralize retry logic in auth layer
- Decouple gRPC connection logic with new Connect method
- Refactor management client to fetch server public key internally
- Add dedicated HealthCheck method for connection verification
- Simplify getServerPublicKey by removing retry logic
This commit is contained in:
Zoltán Papp
2025-12-24 11:23:51 +01:00
parent 7285fef0f0
commit 4b3e1f1b52
23 changed files with 1008 additions and 932 deletions

View File

@@ -4,8 +4,6 @@ import (
"context"
"io"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
@@ -14,13 +12,13 @@ import (
type Client interface {
io.Closer
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKey() (*wgtypes.Key, error)
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)
IsHealthy() bool
SyncMeta(sysInfo *system.Info) error
Logout() error
Register(ctx context.Context, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) error
Login(ctx context.Context, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlow(ctx context.Context) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlow(ctx context.Context) (*proto.PKCEAuthorizationFlow, error)
GetNetworkMap(ctx context.Context, sysInfo *system.Info) (*proto.NetworkMap, error)
IsHealthy(ctx context.Context) bool
HealthCheck(ctx context.Context) error
SyncMeta(ctx context.Context, sysInfo *system.Info) error
Logout(ctx context.Context) error
}

View File

@@ -193,12 +193,12 @@ func TestClient_GetServerPublicKey(t *testing.T) {
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
client := NewClient(listener.Addr().String(), testKey, false)
if err := client.Connect(ctx); err != nil {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
key, err := client.getServerPublicKey(ctx)
if err != nil {
t.Error("couldn't retrieve management public key")
}
@@ -216,16 +216,12 @@ func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
if err != nil {
client := NewClient(listener.Addr().String(), testKey, false)
if err := client.Connect(ctx); err != nil {
t.Fatal(err)
}
sysInfo := system.GetInfo(context.TODO())
_, err = client.Login(*key, sysInfo, nil, nil)
_, err = client.Login(ctx, sysInfo, nil, nil)
if err == nil {
t.Error("expecting err on unregistered login, got nil")
}
@@ -243,24 +239,16 @@ func TestClient_LoginRegistered(t *testing.T) {
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
client := NewClient(listener.Addr().String(), testKey, false)
if err := client.Connect(ctx); err != nil {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
if err != nil {
t.Error(err)
}
info := system.GetInfo(context.TODO())
resp, err := client.Register(*key, ValidKey, "", info, nil, nil)
err = client.Register(ctx, ValidKey, "", info, nil, nil)
if err != nil {
t.Error(err)
}
if resp == nil {
t.Error("expecting non nil response, got nil")
}
}
func TestClient_Sync(t *testing.T) {
@@ -272,18 +260,13 @@ func TestClient_Sync(t *testing.T) {
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
client := NewClient(listener.Addr().String(), testKey, false)
if err := client.Connect(ctx); err != nil {
t.Fatal(err)
}
serverKey, err := client.GetServerPublicKey()
if err != nil {
t.Error(err)
}
info := system.GetInfo(context.TODO())
_, err = client.Register(*serverKey, ValidKey, "", info, nil, nil)
err = client.Register(ctx, ValidKey, "", info, nil, nil)
if err != nil {
t.Error(err)
}
@@ -293,13 +276,14 @@ func TestClient_Sync(t *testing.T) {
if err != nil {
t.Error(err)
}
remoteClient, err := NewClient(context.TODO(), listener.Addr().String(), remoteKey, false)
if err != nil {
remoteClient := NewClient(listener.Addr().String(), remoteKey, false)
remoteCtx := context.TODO()
if err := remoteClient.Connect(remoteCtx); err != nil {
t.Fatal(err)
}
info = system.GetInfo(context.TODO())
_, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil, nil)
err = remoteClient.Register(remoteCtx, ValidKey, "", info, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -354,14 +338,9 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
serverAddr := lis.Addr().String()
ctx := context.Background()
testClient, err := NewClient(ctx, serverAddr, testKey, false)
if err != nil {
t.Fatalf("error while creating testClient: %v", err)
}
key, err := testClient.GetServerPublicKey()
if err != nil {
t.Fatalf("error while getting server public key from testclient, %v", err)
testClient := NewClient(serverAddr, testKey, false)
if err := testClient.Connect(ctx); err != nil {
t.Fatalf("error while connecting testClient: %v", err)
}
var actualMeta *mgmtProto.PeerSystemMeta
@@ -400,7 +379,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
}
info := system.GetInfo(context.TODO())
_, err = testClient.Register(*key, ValidKey, "", info, nil, nil)
err = testClient.Register(ctx, ValidKey, "", info, nil, nil)
if err != nil {
t.Errorf("error while trying to register client: %v", err)
}
@@ -489,9 +468,9 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) {
serverAddr := lis.Addr().String()
ctx := context.Background()
client, err := NewClient(ctx, serverAddr, testKey, false)
if err != nil {
t.Fatalf("error while creating testClient: %v", err)
client := NewClient(serverAddr, testKey, false)
if err := client.Connect(ctx); err != nil {
t.Fatalf("error while connecting testClient: %v", err)
}
expectedFlowInfo := &mgmtProto.DeviceAuthorizationFlow{
@@ -512,7 +491,7 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) {
}, nil
}
flowInfo, err := client.GetDeviceAuthorizationFlow(serverKey)
flowInfo, err := client.GetDeviceAuthorizationFlow(ctx)
if err != nil {
t.Error("error while retrieving device auth flow information")
}
@@ -533,9 +512,9 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
serverAddr := lis.Addr().String()
ctx := context.Background()
client, err := NewClient(ctx, serverAddr, testKey, false)
if err != nil {
t.Fatalf("error while creating testClient: %v", err)
client := NewClient(serverAddr, testKey, false)
if err := client.Connect(ctx); err != nil {
t.Fatalf("error while connecting testClient: %v", err)
}
expectedFlowInfo := &mgmtProto.PKCEAuthorizationFlow{
@@ -558,7 +537,7 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
}, nil
}
flowInfo, err := client.GetPKCEAuthorizationFlow(serverKey)
flowInfo, err := client.GetPKCEAuthorizationFlow(ctx)
if err != nil {
t.Error("error while retrieving pkce auth flow information")
}

View File

@@ -25,6 +25,24 @@ import (
"github.com/netbirdio/netbird/util/wsproxy"
)
// Custom management client errors that abstract away gRPC error codes
var (
// ErrPermissionDenied is returned when the server denies access to a resource
ErrPermissionDenied = errors.New("permission denied")
// ErrInvalidArgument is returned when the request contains invalid arguments
ErrInvalidArgument = errors.New("invalid argument")
// ErrUnauthenticated is returned when authentication is required
ErrUnauthenticated = errors.New("unauthenticated")
// ErrNotFound is returned when the requested resource is not found
ErrNotFound = errors.New("not found")
// ErrUnimplemented is returned when the operation is not implemented
ErrUnimplemented = errors.New("not implemented")
)
const ConnectTimeout = 10 * time.Second
const (
@@ -41,40 +59,67 @@ type ConnStateNotifier interface {
type GrpcClient struct {
key wgtypes.Key
realClient proto.ManagementServiceClient
ctx context.Context
conn *grpc.ClientConn
connStateCallback ConnStateNotifier
connStateCallbackLock sync.RWMutex
addr string
tlsEnabled bool
reconnectMutex sync.Mutex
}
// NewClient creates a new client to Management service
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
// The client is not connected after creation - call Connect to establish the connection
func NewClient(addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) *GrpcClient {
return &GrpcClient{
key: ourPrivateKey,
addr: addr,
tlsEnabled: tlsEnabled,
connStateCallbackLock: sync.RWMutex{},
reconnectMutex: sync.Mutex{},
}
}
// Connect establishes a connection to the Management Service with retry logic
// Retries connection attempts with exponential backoff on failure
func (c *GrpcClient) Connect(ctx context.Context) error {
var conn *grpc.ClientConn
operation := func() error {
var err error
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent)
conn, err = nbgrpc.CreateConnection(ctx, c.addr, c.tlsEnabled, wsproxy.ManagementComponent)
if err != nil {
return fmt.Errorf("create connection: %w", err)
log.Warnf("failed to connect to Management Service: %v", err)
return err
}
return nil
}
err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
if err != nil {
log.Errorf("failed creating connection to Management Service: %v", err)
return nil, err
if err := backoff.Retry(operation, defaultBackoff(ctx)); err != nil {
log.Errorf("failed creating connection to Management Service after retries: %v", err)
return fmt.Errorf("create connection: %w", err)
}
realClient := proto.NewManagementServiceClient(conn)
c.conn = conn
c.realClient = proto.NewManagementServiceClient(conn)
return &GrpcClient{
key: ourPrivateKey,
realClient: realClient,
ctx: ctx,
conn: conn,
connStateCallbackLock: sync.RWMutex{},
}, nil
log.Infof("connected to the Management Service at %s", c.addr)
return nil
}
// ConnectWithoutRetry establishes a connection to the Management Service without retry logic
// Performs a single connection attempt - callers should implement their own retry logic if needed
func (c *GrpcClient) ConnectWithoutRetry(ctx context.Context) error {
conn, err := nbgrpc.CreateConnection(ctx, c.addr, c.tlsEnabled, wsproxy.ManagementComponent)
if err != nil {
log.Warnf("failed to connect to Management Service: %v", err)
return fmt.Errorf("create connection: %w", err)
}
c.conn = conn
c.realClient = proto.NewManagementServiceClient(conn)
log.Debugf("connected to the Management Service at %s", c.addr)
return nil
}
// Close closes connection to the Management Service
@@ -89,19 +134,6 @@ func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) {
c.connStateCallback = notifier
}
// defaultBackoff is a basic backoff mechanism for general issues
func defaultBackoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 1,
Multiplier: 1.7,
MaxInterval: 10 * time.Second,
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// ready indicates whether the client is okay and ready to be used
// for now it just checks whether gRPC connection to the service is ready
func (c *GrpcClient) ready() bool {
@@ -122,7 +154,7 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
return fmt.Errorf("connection to management is not ready and in %s state", connState)
}
serverPubKey, err := c.GetServerPublicKey()
serverPubKey, err := c.getServerPublicKey(ctx)
if err != nil {
log.Debugf(errMsgMgmtPublicKey, err)
return err
@@ -177,15 +209,13 @@ func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key,
}
// GetNetworkMap return with the network map
func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) {
serverPubKey, err := c.GetServerPublicKey()
func (c *GrpcClient) GetNetworkMap(ctx context.Context, sysInfo *system.Info) (*proto.NetworkMap, error) {
serverPubKey, err := c.getServerPublicKey(ctx)
if err != nil {
log.Debugf("failed getting Management Service public key: %s", err)
return nil, err
}
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey, sysInfo)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
@@ -264,30 +294,32 @@ func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, se
}
}
// GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server)
func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
// getServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server)
// This is a simple operation without retry logic - callers should handle retries at the operation level
func (c *GrpcClient) getServerPublicKey(ctx context.Context) (*wgtypes.Key, error) {
if !c.ready() {
return nil, errors.New(errMsgNoMgmtConnection)
}
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
mgmCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, fmt.Errorf("failed while getting Management Service public key")
}
serverKey, err := wgtypes.ParseKey(resp.Key)
key, err := wgtypes.ParseKey(resp.Key)
if err != nil {
return nil, err
}
return &serverKey, nil
return &key, nil
}
// IsHealthy probes the gRPC connection and returns false on errors
func (c *GrpcClient) IsHealthy() bool {
func (c *GrpcClient) IsHealthy(ctx context.Context) bool {
switch c.conn.GetState() {
case connectivity.TransientFailure:
return false
@@ -299,10 +331,10 @@ func (c *GrpcClient) IsHealthy() bool {
case connectivity.Ready:
}
ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second)
healthCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()
_, err := c.realClient.GetServerKey(ctx, &proto.Empty{})
_, err := c.realClient.GetServerKey(healthCtx, &proto.Empty{})
if err != nil {
c.notifyDisconnected(err)
log.Warnf("health check returned: %s", err)
@@ -312,12 +344,26 @@ func (c *GrpcClient) IsHealthy() bool {
return true
}
func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
// HealthCheck verifies connectivity to the management server
// Returns an error if the server is not reachable
// Internally uses getServerPublicKey to verify the connection
func (c *GrpcClient) HealthCheck(ctx context.Context) error {
_, err := c.getServerPublicKey(ctx)
return err
}
func (c *GrpcClient) login(ctx context.Context, req *proto.LoginRequest) (*proto.LoginResponse, error) {
if !c.ready() {
return nil, errors.New(errMsgNoMgmtConnection)
}
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
serverKey, err := c.getServerPublicKey(ctx)
if err != nil {
log.Debugf(errMsgMgmtPublicKey, err)
return nil, err
}
loginReq, err := encryption.EncryptMessage(*serverKey, c.key, req)
if err != nil {
log.Errorf("failed to encrypt message: %s", err)
return nil, err
@@ -325,7 +371,7 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
var resp *proto.EncryptedMessage
operation := func() error {
mgmCtx, cancel := context.WithTimeout(context.Background(), ConnectTimeout)
mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout)
defer cancel()
var err error
@@ -344,14 +390,14 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
return nil
}
err = backoff.Retry(operation, nbgrpc.Backoff(c.ctx))
err = backoff.Retry(operation, nbgrpc.Backoff(ctx))
if err != nil {
log.Errorf("failed to login to Management Service: %v", err)
return nil, err
}
loginResp := &proto.LoginResponse{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, loginResp)
if err != nil {
log.Errorf("failed to decrypt login response: %s", err)
return nil, err
@@ -363,99 +409,135 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
// Takes care of encrypting and decrypting messages.
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
// Returns custom errors: ErrPermissionDenied, ErrInvalidArgument, ErrUnauthenticated
func (c *GrpcClient) Register(ctx context.Context, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) error {
keys := &proto.PeerKeys{
SshPubKey: pubSSHKey,
WgPubKey: []byte(c.key.PublicKey().String()),
}
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
_, err := c.login(ctx, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
return wrapGRPCError(err)
}
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
// Returns custom errors: ErrPermissionDenied, ErrInvalidArgument, ErrUnauthenticated
func (c *GrpcClient) Login(ctx context.Context, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
keys := &proto.PeerKeys{
SshPubKey: pubSSHKey,
WgPubKey: []byte(c.key.PublicKey().String()),
}
return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
resp, err := c.login(ctx, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
return resp, wrapGRPCError(err)
}
// GetDeviceAuthorizationFlow returns a device authorization flow information.
// It also takes care of encrypting and decrypting messages.
func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management in order to get device authorization flow")
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
defer cancel()
// It automatically retries with backoff and reconnection on connection errors.
// Returns custom errors: ErrNotFound, ErrUnimplemented
func (c *GrpcClient) GetDeviceAuthorizationFlow(ctx context.Context) (*proto.DeviceAuthorizationFlow, error) {
var flowInfoResp *proto.DeviceAuthorizationFlow
message := &proto.DeviceAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
if err != nil {
return nil, err
}
err := c.withRetry(ctx, func() error {
if !c.ready() {
return fmt.Errorf("no connection to management in order to get device authorization flow")
}
resp, err := c.realClient.GetDeviceAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encryptedMSG},
)
if err != nil {
return nil, err
}
serverKey, err := c.getServerPublicKey(ctx)
if err != nil {
log.Debugf(errMsgMgmtPublicKey, err)
return err
}
flowInfoResp := &proto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
if err != nil {
errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err)
log.Error(errWithMSG)
return nil, errWithMSG
}
mgmCtx, cancel := context.WithTimeout(ctx, time.Second*2)
defer cancel()
return flowInfoResp, nil
message := &proto.DeviceAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
if err != nil {
return err
}
resp, err := c.realClient.GetDeviceAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encryptedMSG},
)
if err != nil {
return err
}
flowInfo := &proto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfo)
if err != nil {
errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err)
log.Error(errWithMSG)
return errWithMSG
}
flowInfoResp = flowInfo
return nil
})
return flowInfoResp, wrapGRPCError(err)
}
// GetPKCEAuthorizationFlow returns a pkce authorization flow information.
// It also takes care of encrypting and decrypting messages.
func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management in order to get pkce authorization flow")
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
defer cancel()
// It automatically retries with backoff and reconnection on connection errors.
// Returns custom errors: ErrNotFound, ErrUnimplemented
func (c *GrpcClient) GetPKCEAuthorizationFlow(ctx context.Context) (*proto.PKCEAuthorizationFlow, error) {
var flowInfoResp *proto.PKCEAuthorizationFlow
message := &proto.PKCEAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
if err != nil {
return nil, err
}
err := c.withRetry(ctx, func() error {
if !c.ready() {
return fmt.Errorf("no connection to management in order to get pkce authorization flow")
}
resp, err := c.realClient.GetPKCEAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encryptedMSG,
serverKey, err := c.getServerPublicKey(ctx)
if err != nil {
log.Debugf(errMsgMgmtPublicKey, err)
return err
}
mgmCtx, cancel := context.WithTimeout(ctx, time.Second*2)
defer cancel()
message := &proto.PKCEAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
if err != nil {
return err
}
resp, err := c.realClient.GetPKCEAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encryptedMSG,
})
if err != nil {
return err
}
flowInfo := &proto.PKCEAuthorizationFlow{}
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfo)
if err != nil {
errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err)
log.Error(errWithMSG)
return errWithMSG
}
flowInfoResp = flowInfo
return nil
})
if err != nil {
return nil, err
}
flowInfoResp := &proto.PKCEAuthorizationFlow{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
if err != nil {
errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err)
log.Error(errWithMSG)
return nil, errWithMSG
}
return flowInfoResp, nil
return flowInfoResp, wrapGRPCError(err)
}
// SyncMeta sends updated system metadata to the Management Service.
// It should be used if there is changes on peer posture check after initial sync.
func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
func (c *GrpcClient) SyncMeta(ctx context.Context, sysInfo *system.Info) error {
if !c.ready() {
return errors.New(errMsgNoMgmtConnection)
}
serverPubKey, err := c.GetServerPublicKey()
serverPubKey, err := c.getServerPublicKey(ctx)
if err != nil {
log.Debugf(errMsgMgmtPublicKey, err)
return err
@@ -467,7 +549,7 @@ func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
return err
}
mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
mgmCtx, cancel := context.WithTimeout(ctx, ConnectTimeout)
defer cancel()
_, err = c.realClient.SyncMeta(mgmCtx, &proto.EncryptedMessage{
@@ -497,13 +579,13 @@ func (c *GrpcClient) notifyConnected() {
c.connStateCallback.MarkManagementConnected()
}
func (c *GrpcClient) Logout() error {
serverKey, err := c.GetServerPublicKey()
func (c *GrpcClient) Logout(ctx context.Context) error {
serverKey, err := c.getServerPublicKey(ctx)
if err != nil {
return fmt.Errorf("get server public key: %w", err)
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*15)
mgmCtx, cancel := context.WithTimeout(ctx, time.Second*15)
defer cancel()
message := &proto.Empty{}
@@ -523,6 +605,156 @@ func (c *GrpcClient) Logout() error {
return nil
}
// reconnect closes the current connection and creates a new one
func (c *GrpcClient) reconnect(ctx context.Context) error {
c.reconnectMutex.Lock()
defer c.reconnectMutex.Unlock()
// Close existing connection
if c.conn != nil {
if err := c.conn.Close(); err != nil {
log.Debugf("error closing old connection: %v", err)
}
}
// Create new connection
log.Debugf("reconnecting to Management Service %s", c.addr)
conn, err := nbgrpc.CreateConnection(ctx, c.addr, c.tlsEnabled, wsproxy.ManagementComponent)
if err != nil {
log.Errorf("failed reconnecting to Management Service %s: %v", c.addr, err)
return fmt.Errorf("reconnect: create connection: %w", err)
}
c.conn = conn
c.realClient = proto.NewManagementServiceClient(conn)
log.Debugf("successfully reconnected to Management service %s", c.addr)
return nil
}
// withRetry wraps an operation with exponential backoff retry logic
// It automatically reconnects on connection errors
func (c *GrpcClient) withRetry(ctx context.Context, operation func() error) error {
backoffSettings := &backoff.ExponentialBackOff{
InitialInterval: 500 * time.Millisecond,
RandomizationFactor: 0.5,
Multiplier: 1.5,
MaxInterval: 10 * time.Second,
MaxElapsedTime: 2 * time.Minute,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}
backoffSettings.Reset()
return backoff.RetryNotify(
func() error {
err := operation()
if err == nil {
return nil
}
// If it's a connection error, attempt reconnection
if isConnectionError(err) {
log.Warnf("connection error detected, attempting reconnection: %v", err)
if reconnectErr := c.reconnect(ctx); reconnectErr != nil {
log.Errorf("reconnection failed: %v", reconnectErr)
return reconnectErr
}
// Return the original error to trigger retry with the new connection
return err
}
// For authentication errors (InvalidArgument, PermissionDenied), don't retry
if isAuthenticationError(err) {
return backoff.Permanent(err)
}
return err
},
backoff.WithContext(backoffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("operation failed, retrying in %v: %v", duration, err)
},
)
}
// defaultBackoff is a basic backoff mechanism for general issues
func defaultBackoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 1,
Multiplier: 1.7,
MaxInterval: 10 * time.Second,
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// isConnectionError checks if the error is a connection-related error that should trigger reconnection
func isConnectionError(err error) bool {
if err == nil {
return false
}
s, ok := gstatus.FromError(err)
if !ok {
return false
}
// These error codes indicate connection issues
return s.Code() == codes.Unavailable ||
s.Code() == codes.DeadlineExceeded ||
s.Code() == codes.Canceled ||
s.Code() == codes.Internal
}
// isAuthenticationError checks if the error is an authentication-related error that should not be retried
func isAuthenticationError(err error) bool {
if err == nil {
return false
}
s, ok := gstatus.FromError(err)
if !ok {
return false
}
return s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied
}
// wrapGRPCError converts gRPC errors to custom management client errors
func wrapGRPCError(err error) error {
if err == nil {
return nil
}
// Check if it's already a custom error
if errors.Is(err, ErrPermissionDenied) ||
errors.Is(err, ErrInvalidArgument) ||
errors.Is(err, ErrUnauthenticated) ||
errors.Is(err, ErrNotFound) ||
errors.Is(err, ErrUnimplemented) {
return err
}
// Convert gRPC status errors to custom errors
s, ok := gstatus.FromError(err)
if !ok {
return err
}
switch s.Code() {
case codes.PermissionDenied:
return fmt.Errorf("%w: %s", ErrPermissionDenied, s.Message())
case codes.InvalidArgument:
return fmt.Errorf("%w: %s", ErrInvalidArgument, s.Message())
case codes.Unauthenticated:
return fmt.Errorf("%w: %s", ErrUnauthenticated, s.Message())
case codes.NotFound:
return fmt.Errorf("%w: %s", ErrNotFound, s.Message())
case codes.Unimplemented:
return fmt.Errorf("%w: %s", ErrUnimplemented, s.Message())
default:
return err
}
}
func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
if info == nil {
return nil

View File

@@ -3,8 +3,6 @@ package client
import (
"context"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
@@ -13,17 +11,21 @@ import (
type MockClient struct {
CloseFunc func() error
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
SyncMetaFunc func(sysInfo *system.Info) error
LogoutFunc func() error
RegisterFunc func(ctx context.Context, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) error
LoginFunc func(ctx context.Context, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlowFunc func(ctx context.Context) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlowFunc func(ctx context.Context) (*proto.PKCEAuthorizationFlow, error)
SyncMetaFunc func(ctx context.Context, sysInfo *system.Info) error
HealthCheckFunc func(ctx context.Context) error
LogoutFunc func(ctx context.Context) error
IsHealthyFunc func(ctx context.Context) bool
}
func (m *MockClient) IsHealthy() bool {
return true
func (m *MockClient) IsHealthy(ctx context.Context) bool {
if m.IsHealthyFunc == nil {
return true
}
return m.IsHealthyFunc(ctx)
}
func (m *MockClient) Close() error {
@@ -40,56 +42,56 @@ func (m *MockClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
return m.SyncFunc(ctx, sysInfo, msgHandler)
}
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
if m.GetServerPublicKeyFunc == nil {
return nil, nil
}
return m.GetServerPublicKeyFunc()
}
func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
func (m *MockClient) Register(ctx context.Context, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) error {
if m.RegisterFunc == nil {
return nil, nil
return nil
}
return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey, dnsLabels)
return m.RegisterFunc(ctx, setupKey, jwtToken, info, sshKey, dnsLabels)
}
func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
func (m *MockClient) Login(ctx context.Context, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
if m.LoginFunc == nil {
return nil, nil
}
return m.LoginFunc(serverKey, info, sshKey, dnsLabels)
return m.LoginFunc(ctx, info, sshKey, dnsLabels)
}
func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
func (m *MockClient) GetDeviceAuthorizationFlow(ctx context.Context) (*proto.DeviceAuthorizationFlow, error) {
if m.GetDeviceAuthorizationFlowFunc == nil {
return nil, nil
}
return m.GetDeviceAuthorizationFlowFunc(serverKey)
return m.GetDeviceAuthorizationFlowFunc(ctx)
}
func (m *MockClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
func (m *MockClient) GetPKCEAuthorizationFlow(ctx context.Context) (*proto.PKCEAuthorizationFlow, error) {
if m.GetPKCEAuthorizationFlowFunc == nil {
return nil, nil
}
return m.GetPKCEAuthorizationFlow(serverKey)
return m.GetPKCEAuthorizationFlowFunc(ctx)
}
// GetNetworkMap mock implementation of GetNetworkMap from mgm.Client interface
func (m *MockClient) GetNetworkMap(_ *system.Info) (*proto.NetworkMap, error) {
func (m *MockClient) GetNetworkMap(ctx context.Context, _ *system.Info) (*proto.NetworkMap, error) {
return nil, nil
}
func (m *MockClient) SyncMeta(sysInfo *system.Info) error {
func (m *MockClient) SyncMeta(ctx context.Context, sysInfo *system.Info) error {
if m.SyncMetaFunc == nil {
return nil
}
return m.SyncMetaFunc(sysInfo)
return m.SyncMetaFunc(ctx, sysInfo)
}
func (m *MockClient) Logout() error {
func (m *MockClient) HealthCheck(ctx context.Context) error {
if m.HealthCheckFunc == nil {
return nil
}
return m.HealthCheckFunc(ctx)
}
func (m *MockClient) Logout(ctx context.Context) error {
if m.LogoutFunc == nil {
return nil
}
return m.LogoutFunc()
return m.LogoutFunc(ctx)
}