[client] Unexport GetServerPublicKey, add HealthCheck method (#5735)

* Unexport GetServerPublicKey, add HealthCheck method

Internalize server key fetching into Login, Register,
GetDeviceAuthorizationFlow, and GetPKCEAuthorizationFlow methods,
removing the need for callers to fetch and pass the key separately.

Replace the exported GetServerPublicKey with a HealthCheck() error
method for connection validation, keeping IsHealthy() bool for
non-blocking background monitoring.

Fix test encryption to use correct key pairs (client public key as
remotePubKey instead of server private key).

* Refactor `doMgmLogin` to return only error, removing unused response
This commit is contained in:
Zoltan Papp
2026-04-07 12:18:21 +02:00
committed by GitHub
parent 435203b13b
commit 0efef671d7
8 changed files with 106 additions and 145 deletions

View File

@@ -4,8 +4,6 @@ import (
"context"
"io"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
@@ -16,14 +14,18 @@ type Client interface {
io.Closer
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
GetServerPublicKey() (*wgtypes.Key, error)
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
Register(setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
Login(sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error)
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)
GetServerURL() string
// IsHealthy returns the current connection status without blocking.
// Used by the engine to monitor connectivity in the background.
IsHealthy() bool
// HealthCheck actively probes the management server and returns an error if unreachable.
// Used to validate connectivity before committing configuration changes.
HealthCheck() error
SyncMeta(sysInfo *system.Info) error
Logout() error
CreateExpose(ctx context.Context, req ExposeRequest) (*ExposeResponse, error)

View File

@@ -189,7 +189,7 @@ func closeManagementSilently(s *grpc.Server, listener net.Listener) {
}
}
func TestClient_GetServerPublicKey(t *testing.T) {
func TestClient_HealthCheck(t *testing.T) {
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
@@ -203,12 +203,8 @@ func TestClient_GetServerPublicKey(t *testing.T) {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
if err != nil {
t.Error("couldn't retrieve management public key")
}
if key == nil {
t.Error("got an empty management public key")
if err := client.HealthCheck(); err != nil {
t.Errorf("health check failed: %v", err)
}
}
@@ -225,12 +221,8 @@ func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
if err != nil {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
if err != nil {
t.Fatal(err)
}
sysInfo := system.GetInfo(context.TODO())
_, err = client.Login(*key, sysInfo, nil, nil)
_, err = client.Login(sysInfo, nil, nil)
if err == nil {
t.Error("expecting err on unregistered login, got nil")
}
@@ -253,12 +245,8 @@ func TestClient_LoginRegistered(t *testing.T) {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
if err != nil {
t.Error(err)
}
info := system.GetInfo(context.TODO())
resp, err := client.Register(*key, ValidKey, "", info, nil, nil)
resp, err := client.Register(ValidKey, "", info, nil, nil)
if err != nil {
t.Error(err)
}
@@ -282,13 +270,8 @@ func TestClient_Sync(t *testing.T) {
t.Fatal(err)
}
serverKey, err := client.GetServerPublicKey()
if err != nil {
t.Error(err)
}
info := system.GetInfo(context.TODO())
_, err = client.Register(*serverKey, ValidKey, "", info, nil, nil)
_, err = client.Register(ValidKey, "", info, nil, nil)
if err != nil {
t.Error(err)
}
@@ -304,7 +287,7 @@ func TestClient_Sync(t *testing.T) {
}
info = system.GetInfo(context.TODO())
_, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil, nil)
_, err = remoteClient.Register(ValidKey, "", info, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -364,11 +347,6 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
t.Fatalf("error while creating testClient: %v", err)
}
key, err := testClient.GetServerPublicKey()
if err != nil {
t.Fatalf("error while getting server public key from testclient, %v", err)
}
var actualMeta *mgmtProto.PeerSystemMeta
var actualValidKey string
var wg sync.WaitGroup
@@ -405,7 +383,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
}
info := system.GetInfo(context.TODO())
_, err = testClient.Register(*key, ValidKey, "", info, nil, nil)
_, err = testClient.Register(ValidKey, "", info, nil, nil)
if err != nil {
t.Errorf("error while trying to register client: %v", err)
}
@@ -505,7 +483,7 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) {
}
mgmtMockServer.GetDeviceAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
encryptedResp, err := encryption.EncryptMessage(client.key.PublicKey(), serverKey, expectedFlowInfo)
if err != nil {
return nil, err
}
@@ -517,7 +495,7 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) {
}, nil
}
flowInfo, err := client.GetDeviceAuthorizationFlow(serverKey)
flowInfo, err := client.GetDeviceAuthorizationFlow()
if err != nil {
t.Error("error while retrieving device auth flow information")
}
@@ -551,7 +529,7 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
}
mgmtMockServer.GetPKCEAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
encryptedResp, err := encryption.EncryptMessage(client.key.PublicKey(), serverKey, expectedFlowInfo)
if err != nil {
return nil, err
}
@@ -563,7 +541,7 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
}, nil
}
flowInfo, err := client.GetPKCEAuthorizationFlow(serverKey)
flowInfo, err := client.GetPKCEAuthorizationFlow()
if err != nil {
t.Error("error while retrieving pkce auth flow information")
}

View File

@@ -202,7 +202,7 @@ func (c *GrpcClient) withMgmtStream(
return fmt.Errorf("connection to management is not ready and in %s state", connState)
}
serverPubKey, err := c.GetServerPublicKey()
serverPubKey, err := c.getServerPublicKey()
if err != nil {
log.Debugf(errMsgMgmtPublicKey, err)
return err
@@ -404,7 +404,7 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.
// GetNetworkMap return with the network map
func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) {
serverPubKey, err := c.GetServerPublicKey()
serverPubKey, err := c.getServerPublicKey()
if err != nil {
log.Debugf("failed getting Management Service public key: %s", err)
return nil, err
@@ -490,18 +490,24 @@ func (c *GrpcClient) receiveUpdatesEvents(stream proto.ManagementService_SyncCli
}
}
// GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server)
func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
// HealthCheck actively probes the management server and returns an error if unreachable.
// Used to validate connectivity before committing configuration changes.
func (c *GrpcClient) HealthCheck() error {
if !c.ready() {
return nil, errors.New(errMsgNoMgmtConnection)
return errors.New(errMsgNoMgmtConnection)
}
_, err := c.getServerPublicKey()
return err
}
// getServerPublicKey fetches the server's WireGuard public key.
func (c *GrpcClient) getServerPublicKey() (*wgtypes.Key, error) {
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
defer cancel()
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, fmt.Errorf("failed while getting Management Service public key")
return nil, fmt.Errorf("failed getting Management Service public key: %w", err)
}
serverKey, err := wgtypes.ParseKey(resp.Key)
@@ -512,7 +518,8 @@ func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
return &serverKey, nil
}
// IsHealthy probes the gRPC connection and returns false on errors
// IsHealthy returns the current connection status without blocking.
// Used by the engine to monitor connectivity in the background.
func (c *GrpcClient) IsHealthy() bool {
switch c.conn.GetState() {
case connectivity.TransientFailure:
@@ -538,12 +545,17 @@ func (c *GrpcClient) IsHealthy() bool {
return true
}
func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
func (c *GrpcClient) login(req *proto.LoginRequest) (*proto.LoginResponse, error) {
if !c.ready() {
return nil, errors.New(errMsgNoMgmtConnection)
}
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
serverKey, err := c.getServerPublicKey()
if err != nil {
return nil, err
}
loginReq, err := encryption.EncryptMessage(*serverKey, c.key, req)
if err != nil {
log.Errorf("failed to encrypt message: %s", err)
return nil, err
@@ -577,7 +589,7 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
}
loginResp := &proto.LoginResponse{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, loginResp)
if err != nil {
log.Errorf("failed to decrypt login response: %s", err)
return nil, err
@@ -589,34 +601,40 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
// Takes care of encrypting and decrypting messages.
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
func (c *GrpcClient) Register(setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
keys := &proto.PeerKeys{
SshPubKey: pubSSHKey,
WgPubKey: []byte(c.key.PublicKey().String()),
}
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
return c.login(&proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
}
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
func (c *GrpcClient) Login(sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
keys := &proto.PeerKeys{
SshPubKey: pubSSHKey,
WgPubKey: []byte(c.key.PublicKey().String()),
}
return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
return c.login(&proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
}
// GetDeviceAuthorizationFlow returns a device authorization flow information.
// It also takes care of encrypting and decrypting messages.
func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
func (c *GrpcClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management in order to get device authorization flow")
}
serverKey, err := c.getServerPublicKey()
if err != nil {
return nil, err
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
defer cancel()
message := &proto.DeviceAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
if err != nil {
return nil, err
}
@@ -630,7 +648,7 @@ func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D
}
flowInfoResp := &proto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfoResp)
if err != nil {
errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err)
log.Error(errWithMSG)
@@ -642,15 +660,21 @@ func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D
// GetPKCEAuthorizationFlow returns a pkce authorization flow information.
// It also takes care of encrypting and decrypting messages.
func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
func (c *GrpcClient) GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management in order to get pkce authorization flow")
}
serverKey, err := c.getServerPublicKey()
if err != nil {
return nil, err
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
defer cancel()
message := &proto.PKCEAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
if err != nil {
return nil, err
}
@@ -664,7 +688,7 @@ func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKC
}
flowInfoResp := &proto.PKCEAuthorizationFlow{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfoResp)
if err != nil {
errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err)
log.Error(errWithMSG)
@@ -681,7 +705,7 @@ func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
return errors.New(errMsgNoMgmtConnection)
}
serverPubKey, err := c.GetServerPublicKey()
serverPubKey, err := c.getServerPublicKey()
if err != nil {
log.Debugf(errMsgMgmtPublicKey, err)
return err
@@ -724,7 +748,7 @@ func (c *GrpcClient) notifyConnected() {
}
func (c *GrpcClient) Logout() error {
serverKey, err := c.GetServerPublicKey()
serverKey, err := c.getServerPublicKey()
if err != nil {
return fmt.Errorf("get server public key: %w", err)
}
@@ -751,7 +775,7 @@ func (c *GrpcClient) Logout() error {
// CreateExpose calls the management server to create a new expose service.
func (c *GrpcClient) CreateExpose(ctx context.Context, req ExposeRequest) (*ExposeResponse, error) {
serverPubKey, err := c.GetServerPublicKey()
serverPubKey, err := c.getServerPublicKey()
if err != nil {
return nil, err
}
@@ -787,7 +811,7 @@ func (c *GrpcClient) CreateExpose(ctx context.Context, req ExposeRequest) (*Expo
// RenewExpose extends the TTL of an active expose session on the management server.
func (c *GrpcClient) RenewExpose(ctx context.Context, domain string) error {
serverPubKey, err := c.GetServerPublicKey()
serverPubKey, err := c.getServerPublicKey()
if err != nil {
return err
}
@@ -810,7 +834,7 @@ func (c *GrpcClient) RenewExpose(ctx context.Context, domain string) error {
// StopExpose terminates an active expose session on the management server.
func (c *GrpcClient) StopExpose(ctx context.Context, domain string) error {
serverPubKey, err := c.GetServerPublicKey()
serverPubKey, err := c.getServerPublicKey()
if err != nil {
return err
}

View File

@@ -3,8 +3,6 @@ package client
import (
"context"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
@@ -14,12 +12,12 @@ import (
type MockClient struct {
CloseFunc func() error
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
RegisterFunc func(setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
LoginFunc func(info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlowFunc func() (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlowFunc func() (*proto.PKCEAuthorizationFlow, error)
GetServerURLFunc func() string
HealthCheckFunc func() error
SyncMetaFunc func(sysInfo *system.Info) error
LogoutFunc func() error
JobFunc func(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
@@ -53,39 +51,39 @@ func (m *MockClient) Job(ctx context.Context, msgHandler func(msg *proto.JobRequ
return m.JobFunc(ctx, msgHandler)
}
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
if m.GetServerPublicKeyFunc == nil {
return nil, nil
}
return m.GetServerPublicKeyFunc()
}
func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
func (m *MockClient) Register(setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
if m.RegisterFunc == nil {
return nil, nil
}
return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey, dnsLabels)
return m.RegisterFunc(setupKey, jwtToken, info, sshKey, dnsLabels)
}
func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
func (m *MockClient) Login(info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
if m.LoginFunc == nil {
return nil, nil
}
return m.LoginFunc(serverKey, info, sshKey, dnsLabels)
return m.LoginFunc(info, sshKey, dnsLabels)
}
func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
func (m *MockClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) {
if m.GetDeviceAuthorizationFlowFunc == nil {
return nil, nil
}
return m.GetDeviceAuthorizationFlowFunc(serverKey)
return m.GetDeviceAuthorizationFlowFunc()
}
func (m *MockClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
func (m *MockClient) GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error) {
if m.GetPKCEAuthorizationFlowFunc == nil {
return nil, nil
}
return m.GetPKCEAuthorizationFlowFunc(serverKey)
return m.GetPKCEAuthorizationFlowFunc()
}
func (m *MockClient) HealthCheck() error {
if m.HealthCheckFunc == nil {
return nil
}
return m.HealthCheckFunc()
}
// GetNetworkMap mock implementation of GetNetworkMap from Client interface.