[client] Unexport GetServerPublicKey, add HealthCheck method (#5735)

* Unexport GetServerPublicKey, add HealthCheck method

Internalize server key fetching into Login, Register,
GetDeviceAuthorizationFlow, and GetPKCEAuthorizationFlow methods,
removing the need for callers to fetch and pass the key separately.

Replace the exported GetServerPublicKey with a HealthCheck() error
method for connection validation, keeping IsHealthy() bool for
non-blocking background monitoring.

Fix test encryption to use correct key pairs (client public key as
remotePubKey instead of server private key).

* Refactor `doMgmLogin` to return only error, removing unused response
This commit is contained in:
Zoltan Papp
2026-04-07 12:18:21 +02:00
committed by GitHub
parent 435203b13b
commit 0efef671d7
8 changed files with 106 additions and 145 deletions

View File

@@ -155,7 +155,7 @@ func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
var needsLogin bool var needsLogin bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error { err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey) err := a.doMgmLogin(client, ctx, pubSSHKey)
if isLoginNeeded(err) { if isLoginNeeded(err) {
needsLogin = true needsLogin = true
return nil return nil
@@ -179,8 +179,8 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
var isAuthError bool var isAuthError bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error { err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey) err := a.doMgmLogin(client, ctx, pubSSHKey)
if serverKey != nil && isRegistrationNeeded(err) { if isRegistrationNeeded(err) {
log.Debugf("peer registration required") log.Debugf("peer registration required")
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey) _, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
if err != nil { if err != nil {
@@ -201,13 +201,7 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance // getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) { func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey() protoFlow, err := client.GetPKCEAuthorizationFlow()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err) log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
@@ -246,13 +240,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance // getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) { func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey() protoFlow, err := client.GetDeviceAuthorizationFlow()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find device flow, contact admin: %v", err) log.Warnf("server couldn't find device flow, contact admin: %v", err)
@@ -292,28 +280,16 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
} }
// doMgmLogin performs the actual login operation with the management service // doMgmLogin performs the actual login operation with the management service
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) { func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) error {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, nil, err
}
sysInfo := system.GetInfo(ctx) sysInfo := system.GetInfo(ctx)
a.setSystemInfoFlags(sysInfo) a.setSystemInfoFlags(sysInfo)
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels) _, err := client.Login(sysInfo, pubSSHKey, a.config.DNSLabels)
return serverKey, loginResp, err return err
} }
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key. // registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line. // Otherwise tries to register with the provided setupKey via command line.
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) { func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
validSetupKey, err := uuid.Parse(setupKey) validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" { if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err) return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
@@ -322,7 +298,7 @@ func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKe
log.Debugf("sending peer registration request to Management Service") log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx) info := system.GetInfo(ctx)
a.setSystemInfoFlags(info) a.setSystemInfoFlags(info)
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels) loginResp, err := client.Register(validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
if err != nil { if err != nil {
log.Errorf("failed registering peer %v", err) log.Errorf("failed registering peer %v", err)
return nil, err return nil, err

View File

@@ -617,12 +617,6 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc) // loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) { func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
}
sysInfo := system.GetInfo(ctx) sysInfo := system.GetInfo(ctx)
sysInfo.SetFlags( sysInfo.SetFlags(
config.RosenpassEnabled, config.RosenpassEnabled,
@@ -641,12 +635,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.EnableSSHRemotePortForwarding, config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth, config.DisableSSHAuth,
) )
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels) return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
if err != nil {
return nil, err
}
return loginResp, nil
} }
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier { func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {

View File

@@ -828,7 +828,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
MTU: iface.DefaultMTU, MTU: iface.DefaultMTU,
}, EngineServices{ }, EngineServices{
SignalClient: &signal.MockClient{}, SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{}, MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr, RelayManager: relayMgr,
@@ -1035,7 +1035,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
MTU: iface.DefaultMTU, MTU: iface.DefaultMTU,
}, EngineServices{ }, EngineServices{
SignalClient: &signal.MockClient{}, SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{}, MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr, RelayManager: relayMgr,
@@ -1538,13 +1538,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
return nil, err return nil, err
} }
publicKey, err := mgmtClient.GetServerPublicKey()
if err != nil {
return nil, err
}
info := system.GetInfo(ctx) info := system.GetInfo(ctx)
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil) resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1566,7 +1561,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
} }
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, conf, EngineServices{ e, err := NewEngine(ctx, cancel, conf, EngineServices{
SignalClient: signalClient, SignalClient: signalClient,
MgmClient: mgmtClient, MgmClient: mgmtClient,
RelayManager: relayMgr, RelayManager: relayMgr,

View File

@@ -777,8 +777,7 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri
}() }()
// gRPC check // gRPC check
_, err = client.GetServerPublicKey() if err = client.HealthCheck(); err != nil {
if err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String()) log.Infof("couldn't switch to the new Management %s", newURL.String())
return nil, err return nil, err
} }

View File

@@ -4,8 +4,6 @@ import (
"context" "context"
"io" "io"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
@@ -16,14 +14,18 @@ type Client interface {
io.Closer io.Closer
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
GetServerPublicKey() (*wgtypes.Key, error) Register(setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) Login(sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error)
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error)
GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)
GetServerURL() string GetServerURL() string
// IsHealthy returns the current connection status without blocking.
// Used by the engine to monitor connectivity in the background.
IsHealthy() bool IsHealthy() bool
// HealthCheck actively probes the management server and returns an error if unreachable.
// Used to validate connectivity before committing configuration changes.
HealthCheck() error
SyncMeta(sysInfo *system.Info) error SyncMeta(sysInfo *system.Info) error
Logout() error Logout() error
CreateExpose(ctx context.Context, req ExposeRequest) (*ExposeResponse, error) CreateExpose(ctx context.Context, req ExposeRequest) (*ExposeResponse, error)

View File

@@ -189,7 +189,7 @@ func closeManagementSilently(s *grpc.Server, listener net.Listener) {
} }
} }
func TestClient_GetServerPublicKey(t *testing.T) { func TestClient_HealthCheck(t *testing.T) {
testKey, err := wgtypes.GenerateKey() testKey, err := wgtypes.GenerateKey()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -203,12 +203,8 @@ func TestClient_GetServerPublicKey(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
key, err := client.GetServerPublicKey() if err := client.HealthCheck(); err != nil {
if err != nil { t.Errorf("health check failed: %v", err)
t.Error("couldn't retrieve management public key")
}
if key == nil {
t.Error("got an empty management public key")
} }
} }
@@ -225,12 +221,8 @@ func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
key, err := client.GetServerPublicKey()
if err != nil {
t.Fatal(err)
}
sysInfo := system.GetInfo(context.TODO()) sysInfo := system.GetInfo(context.TODO())
_, err = client.Login(*key, sysInfo, nil, nil) _, err = client.Login(sysInfo, nil, nil)
if err == nil { if err == nil {
t.Error("expecting err on unregistered login, got nil") t.Error("expecting err on unregistered login, got nil")
} }
@@ -253,12 +245,8 @@ func TestClient_LoginRegistered(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
key, err := client.GetServerPublicKey()
if err != nil {
t.Error(err)
}
info := system.GetInfo(context.TODO()) info := system.GetInfo(context.TODO())
resp, err := client.Register(*key, ValidKey, "", info, nil, nil) resp, err := client.Register(ValidKey, "", info, nil, nil)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@@ -282,13 +270,8 @@ func TestClient_Sync(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
serverKey, err := client.GetServerPublicKey()
if err != nil {
t.Error(err)
}
info := system.GetInfo(context.TODO()) info := system.GetInfo(context.TODO())
_, err = client.Register(*serverKey, ValidKey, "", info, nil, nil) _, err = client.Register(ValidKey, "", info, nil, nil)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@@ -304,7 +287,7 @@ func TestClient_Sync(t *testing.T) {
} }
info = system.GetInfo(context.TODO()) info = system.GetInfo(context.TODO())
_, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil, nil) _, err = remoteClient.Register(ValidKey, "", info, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -364,11 +347,6 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
t.Fatalf("error while creating testClient: %v", err) t.Fatalf("error while creating testClient: %v", err)
} }
key, err := testClient.GetServerPublicKey()
if err != nil {
t.Fatalf("error while getting server public key from testclient, %v", err)
}
var actualMeta *mgmtProto.PeerSystemMeta var actualMeta *mgmtProto.PeerSystemMeta
var actualValidKey string var actualValidKey string
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -405,7 +383,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
} }
info := system.GetInfo(context.TODO()) info := system.GetInfo(context.TODO())
_, err = testClient.Register(*key, ValidKey, "", info, nil, nil) _, err = testClient.Register(ValidKey, "", info, nil, nil)
if err != nil { if err != nil {
t.Errorf("error while trying to register client: %v", err) t.Errorf("error while trying to register client: %v", err)
} }
@@ -505,7 +483,7 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) {
} }
mgmtMockServer.GetDeviceAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) { mgmtMockServer.GetDeviceAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo) encryptedResp, err := encryption.EncryptMessage(client.key.PublicKey(), serverKey, expectedFlowInfo)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -517,7 +495,7 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) {
}, nil }, nil
} }
flowInfo, err := client.GetDeviceAuthorizationFlow(serverKey) flowInfo, err := client.GetDeviceAuthorizationFlow()
if err != nil { if err != nil {
t.Error("error while retrieving device auth flow information") t.Error("error while retrieving device auth flow information")
} }
@@ -551,7 +529,7 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
} }
mgmtMockServer.GetPKCEAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) { mgmtMockServer.GetPKCEAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo) encryptedResp, err := encryption.EncryptMessage(client.key.PublicKey(), serverKey, expectedFlowInfo)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -563,7 +541,7 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
}, nil }, nil
} }
flowInfo, err := client.GetPKCEAuthorizationFlow(serverKey) flowInfo, err := client.GetPKCEAuthorizationFlow()
if err != nil { if err != nil {
t.Error("error while retrieving pkce auth flow information") t.Error("error while retrieving pkce auth flow information")
} }

View File

@@ -202,7 +202,7 @@ func (c *GrpcClient) withMgmtStream(
return fmt.Errorf("connection to management is not ready and in %s state", connState) return fmt.Errorf("connection to management is not ready and in %s state", connState)
} }
serverPubKey, err := c.GetServerPublicKey() serverPubKey, err := c.getServerPublicKey()
if err != nil { if err != nil {
log.Debugf(errMsgMgmtPublicKey, err) log.Debugf(errMsgMgmtPublicKey, err)
return err return err
@@ -404,7 +404,7 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.
// GetNetworkMap return with the network map // GetNetworkMap return with the network map
func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) { func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) {
serverPubKey, err := c.GetServerPublicKey() serverPubKey, err := c.getServerPublicKey()
if err != nil { if err != nil {
log.Debugf("failed getting Management Service public key: %s", err) log.Debugf("failed getting Management Service public key: %s", err)
return nil, err return nil, err
@@ -490,18 +490,24 @@ func (c *GrpcClient) receiveUpdatesEvents(stream proto.ManagementService_SyncCli
} }
} }
// GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server) // HealthCheck actively probes the management server and returns an error if unreachable.
func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) { // Used to validate connectivity before committing configuration changes.
func (c *GrpcClient) HealthCheck() error {
if !c.ready() { if !c.ready() {
return nil, errors.New(errMsgNoMgmtConnection) return errors.New(errMsgNoMgmtConnection)
} }
_, err := c.getServerPublicKey()
return err
}
// getServerPublicKey fetches the server's WireGuard public key.
func (c *GrpcClient) getServerPublicKey() (*wgtypes.Key, error) {
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
defer cancel() defer cancel()
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{}) resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
if err != nil { if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err) return nil, fmt.Errorf("failed getting Management Service public key: %w", err)
return nil, fmt.Errorf("failed while getting Management Service public key")
} }
serverKey, err := wgtypes.ParseKey(resp.Key) serverKey, err := wgtypes.ParseKey(resp.Key)
@@ -512,7 +518,8 @@ func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
return &serverKey, nil return &serverKey, nil
} }
// IsHealthy probes the gRPC connection and returns false on errors // IsHealthy returns the current connection status without blocking.
// Used by the engine to monitor connectivity in the background.
func (c *GrpcClient) IsHealthy() bool { func (c *GrpcClient) IsHealthy() bool {
switch c.conn.GetState() { switch c.conn.GetState() {
case connectivity.TransientFailure: case connectivity.TransientFailure:
@@ -538,12 +545,17 @@ func (c *GrpcClient) IsHealthy() bool {
return true return true
} }
func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) { func (c *GrpcClient) login(req *proto.LoginRequest) (*proto.LoginResponse, error) {
if !c.ready() { if !c.ready() {
return nil, errors.New(errMsgNoMgmtConnection) return nil, errors.New(errMsgNoMgmtConnection)
} }
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req) serverKey, err := c.getServerPublicKey()
if err != nil {
return nil, err
}
loginReq, err := encryption.EncryptMessage(*serverKey, c.key, req)
if err != nil { if err != nil {
log.Errorf("failed to encrypt message: %s", err) log.Errorf("failed to encrypt message: %s", err)
return nil, err return nil, err
@@ -577,7 +589,7 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
} }
loginResp := &proto.LoginResponse{} loginResp := &proto.LoginResponse{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp) err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, loginResp)
if err != nil { if err != nil {
log.Errorf("failed to decrypt login response: %s", err) log.Errorf("failed to decrypt login response: %s", err)
return nil, err return nil, err
@@ -589,34 +601,40 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key // Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
// Takes care of encrypting and decrypting messages. // Takes care of encrypting and decrypting messages.
// This method will also collect system info and send it with the request (e.g. hostname, os, etc) // This method will also collect system info and send it with the request (e.g. hostname, os, etc)
func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { func (c *GrpcClient) Register(setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
keys := &proto.PeerKeys{ keys := &proto.PeerKeys{
SshPubKey: pubSSHKey, SshPubKey: pubSSHKey,
WgPubKey: []byte(c.key.PublicKey().String()), WgPubKey: []byte(c.key.PublicKey().String()),
} }
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()}) return c.login(&proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
} }
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages. // Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { func (c *GrpcClient) Login(sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
keys := &proto.PeerKeys{ keys := &proto.PeerKeys{
SshPubKey: pubSSHKey, SshPubKey: pubSSHKey,
WgPubKey: []byte(c.key.PublicKey().String()), WgPubKey: []byte(c.key.PublicKey().String()),
} }
return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()}) return c.login(&proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
} }
// GetDeviceAuthorizationFlow returns a device authorization flow information. // GetDeviceAuthorizationFlow returns a device authorization flow information.
// It also takes care of encrypting and decrypting messages. // It also takes care of encrypting and decrypting messages.
func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) { func (c *GrpcClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) {
if !c.ready() { if !c.ready() {
return nil, fmt.Errorf("no connection to management in order to get device authorization flow") return nil, fmt.Errorf("no connection to management in order to get device authorization flow")
} }
serverKey, err := c.getServerPublicKey()
if err != nil {
return nil, err
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2) mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
defer cancel() defer cancel()
message := &proto.DeviceAuthorizationFlowRequest{} message := &proto.DeviceAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message) encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -630,7 +648,7 @@ func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D
} }
flowInfoResp := &proto.DeviceAuthorizationFlow{} flowInfoResp := &proto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp) err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfoResp)
if err != nil { if err != nil {
errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err) errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err)
log.Error(errWithMSG) log.Error(errWithMSG)
@@ -642,15 +660,21 @@ func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D
// GetPKCEAuthorizationFlow returns a pkce authorization flow information. // GetPKCEAuthorizationFlow returns a pkce authorization flow information.
// It also takes care of encrypting and decrypting messages. // It also takes care of encrypting and decrypting messages.
func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) { func (c *GrpcClient) GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error) {
if !c.ready() { if !c.ready() {
return nil, fmt.Errorf("no connection to management in order to get pkce authorization flow") return nil, fmt.Errorf("no connection to management in order to get pkce authorization flow")
} }
serverKey, err := c.getServerPublicKey()
if err != nil {
return nil, err
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2) mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
defer cancel() defer cancel()
message := &proto.PKCEAuthorizationFlowRequest{} message := &proto.PKCEAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message) encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -664,7 +688,7 @@ func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKC
} }
flowInfoResp := &proto.PKCEAuthorizationFlow{} flowInfoResp := &proto.PKCEAuthorizationFlow{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp) err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfoResp)
if err != nil { if err != nil {
errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err) errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err)
log.Error(errWithMSG) log.Error(errWithMSG)
@@ -681,7 +705,7 @@ func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
return errors.New(errMsgNoMgmtConnection) return errors.New(errMsgNoMgmtConnection)
} }
serverPubKey, err := c.GetServerPublicKey() serverPubKey, err := c.getServerPublicKey()
if err != nil { if err != nil {
log.Debugf(errMsgMgmtPublicKey, err) log.Debugf(errMsgMgmtPublicKey, err)
return err return err
@@ -724,7 +748,7 @@ func (c *GrpcClient) notifyConnected() {
} }
func (c *GrpcClient) Logout() error { func (c *GrpcClient) Logout() error {
serverKey, err := c.GetServerPublicKey() serverKey, err := c.getServerPublicKey()
if err != nil { if err != nil {
return fmt.Errorf("get server public key: %w", err) return fmt.Errorf("get server public key: %w", err)
} }
@@ -751,7 +775,7 @@ func (c *GrpcClient) Logout() error {
// CreateExpose calls the management server to create a new expose service. // CreateExpose calls the management server to create a new expose service.
func (c *GrpcClient) CreateExpose(ctx context.Context, req ExposeRequest) (*ExposeResponse, error) { func (c *GrpcClient) CreateExpose(ctx context.Context, req ExposeRequest) (*ExposeResponse, error) {
serverPubKey, err := c.GetServerPublicKey() serverPubKey, err := c.getServerPublicKey()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -787,7 +811,7 @@ func (c *GrpcClient) CreateExpose(ctx context.Context, req ExposeRequest) (*Expo
// RenewExpose extends the TTL of an active expose session on the management server. // RenewExpose extends the TTL of an active expose session on the management server.
func (c *GrpcClient) RenewExpose(ctx context.Context, domain string) error { func (c *GrpcClient) RenewExpose(ctx context.Context, domain string) error {
serverPubKey, err := c.GetServerPublicKey() serverPubKey, err := c.getServerPublicKey()
if err != nil { if err != nil {
return err return err
} }
@@ -810,7 +834,7 @@ func (c *GrpcClient) RenewExpose(ctx context.Context, domain string) error {
// StopExpose terminates an active expose session on the management server. // StopExpose terminates an active expose session on the management server.
func (c *GrpcClient) StopExpose(ctx context.Context, domain string) error { func (c *GrpcClient) StopExpose(ctx context.Context, domain string) error {
serverPubKey, err := c.GetServerPublicKey() serverPubKey, err := c.getServerPublicKey()
if err != nil { if err != nil {
return err return err
} }

View File

@@ -3,8 +3,6 @@ package client
import ( import (
"context" "context"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
@@ -14,12 +12,12 @@ import (
type MockClient struct { type MockClient struct {
CloseFunc func() error CloseFunc func() error
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKeyFunc func() (*wgtypes.Key, error) RegisterFunc func(setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) LoginFunc func(info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) GetDeviceAuthorizationFlowFunc func() (*proto.DeviceAuthorizationFlow, error)
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) GetPKCEAuthorizationFlowFunc func() (*proto.PKCEAuthorizationFlow, error)
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
GetServerURLFunc func() string GetServerURLFunc func() string
HealthCheckFunc func() error
SyncMetaFunc func(sysInfo *system.Info) error SyncMetaFunc func(sysInfo *system.Info) error
LogoutFunc func() error LogoutFunc func() error
JobFunc func(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error JobFunc func(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
@@ -53,39 +51,39 @@ func (m *MockClient) Job(ctx context.Context, msgHandler func(msg *proto.JobRequ
return m.JobFunc(ctx, msgHandler) return m.JobFunc(ctx, msgHandler)
} }
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) { func (m *MockClient) Register(setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
if m.GetServerPublicKeyFunc == nil {
return nil, nil
}
return m.GetServerPublicKeyFunc()
}
func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
if m.RegisterFunc == nil { if m.RegisterFunc == nil {
return nil, nil return nil, nil
} }
return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey, dnsLabels) return m.RegisterFunc(setupKey, jwtToken, info, sshKey, dnsLabels)
} }
func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) { func (m *MockClient) Login(info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
if m.LoginFunc == nil { if m.LoginFunc == nil {
return nil, nil return nil, nil
} }
return m.LoginFunc(serverKey, info, sshKey, dnsLabels) return m.LoginFunc(info, sshKey, dnsLabels)
} }
func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) { func (m *MockClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) {
if m.GetDeviceAuthorizationFlowFunc == nil { if m.GetDeviceAuthorizationFlowFunc == nil {
return nil, nil return nil, nil
} }
return m.GetDeviceAuthorizationFlowFunc(serverKey) return m.GetDeviceAuthorizationFlowFunc()
} }
func (m *MockClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) { func (m *MockClient) GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error) {
if m.GetPKCEAuthorizationFlowFunc == nil { if m.GetPKCEAuthorizationFlowFunc == nil {
return nil, nil return nil, nil
} }
return m.GetPKCEAuthorizationFlowFunc(serverKey) return m.GetPKCEAuthorizationFlowFunc()
}
func (m *MockClient) HealthCheck() error {
if m.HealthCheckFunc == nil {
return nil
}
return m.HealthCheckFunc()
} }
// GetNetworkMap mock implementation of GetNetworkMap from Client interface. // GetNetworkMap mock implementation of GetNetworkMap from Client interface.