Files
netbird/shared/management/client/client_test.go
Zoltan Papp 0efef671d7 [client] Unexport GetServerPublicKey, add HealthCheck method (#5735)
* Unexport GetServerPublicKey, add HealthCheck method

Internalize server key fetching into Login, Register,
GetDeviceAuthorizationFlow, and GetPKCEAuthorizationFlow methods,
removing the need for callers to fetch and pass the key separately.

Replace the exported GetServerPublicKey with a HealthCheck() error
method for connection validation, keeping IsHealthy() bool for
non-blocking background monitoring.

Fix test encryption to use correct key pairs (client public key as
remotePubKey instead of server private key).

* Refactor `doMgmLogin` to return only error, removing unused response
2026-04-07 12:18:21 +02:00

552 lines
16 KiB
Go

package client
import (
"context"
"net"
"os"
"sync"
"testing"
"time"
"github.com/golang/mock/gomock"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/management-integrations/integrations"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/internals/server/config"
mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
func TestMain(m *testing.M) {
_ = util.InitLog("debug", util.LogConsole)
code := m.Run()
os.Exit(code)
}
func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
t.Helper()
level, _ := log.ParseLevel("debug")
log.SetLevel(level)
config := &config.Config{}
_, err := util.ReadJson("../../../management/server/testdata/management.json", config)
if err != nil {
t.Fatal(err)
}
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
s := grpc.NewServer()
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../../management/server/testdata/store.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl)
permissionsManagerMock.
EXPECT().
ValidateUserPermissions(
gomock.Any(),
gomock.Any(),
gomock.Any(),
gomock.Any(),
gomock.Any(),
).
Return(true, nil).
AnyTimes()
peersManger := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersManger)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManger, settingsManagerMock, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.
EXPECT().
GetSettings(
gomock.Any(),
gomock.Any(),
gomock.Any(),
).
Return(&types.Settings{}, nil).
AnyTimes()
settingsMockManager.
EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManger), config)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
groupsManager := groups.NewManagerMock()
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
t.Fatal(err)
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
t.Fatal(err)
}
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {
if err := s.Serve(lis); err != nil {
t.Error(err)
return
}
}()
return s, lis
}
func startMockManagement(t *testing.T) (*grpc.Server, net.Listener, *mock_server.ManagementServiceServerMock, wgtypes.Key) {
t.Helper()
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
s := grpc.NewServer()
serverKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
mgmtMockServer := &mock_server.ManagementServiceServerMock{
GetServerKeyFunc: func(context.Context, *mgmtProto.Empty) (*mgmtProto.ServerKeyResponse, error) {
response := &mgmtProto.ServerKeyResponse{
Key: serverKey.PublicKey().String(),
}
return response, nil
},
}
mgmtProto.RegisterManagementServiceServer(s, mgmtMockServer)
go func() {
if err := s.Serve(lis); err != nil {
t.Error(err)
return
}
}()
return s, lis, mgmtMockServer, serverKey
}
func closeManagementSilently(s *grpc.Server, listener net.Listener) {
s.GracefulStop()
err := listener.Close()
if err != nil {
log.Warnf("error while closing management listener %v", err)
return
}
}
func TestClient_HealthCheck(t *testing.T) {
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}
if err := client.HealthCheck(); err != nil {
t.Errorf("health check failed: %v", err)
}
}
func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}
sysInfo := system.GetInfo(context.TODO())
_, err = client.Login(sysInfo, nil, nil)
if err == nil {
t.Error("expecting err on unregistered login, got nil")
}
if s, ok := status.FromError(err); !ok || s.Code() != codes.PermissionDenied {
t.Errorf("expecting err code %d denied on unregistered login got %d", codes.PermissionDenied, s.Code())
}
}
func TestClient_LoginRegistered(t *testing.T) {
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}
info := system.GetInfo(context.TODO())
resp, err := client.Register(ValidKey, "", info, nil, nil)
if err != nil {
t.Error(err)
}
if resp == nil {
t.Error("expecting non nil response, got nil")
}
}
func TestClient_Sync(t *testing.T) {
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}
info := system.GetInfo(context.TODO())
_, err = client.Register(ValidKey, "", info, nil, nil)
if err != nil {
t.Error(err)
}
// create and register second peer (we should receive on Sync request)
remoteKey, err := wgtypes.GenerateKey()
if err != nil {
t.Error(err)
}
remoteClient, err := NewClient(context.TODO(), listener.Addr().String(), remoteKey, false)
if err != nil {
t.Fatal(err)
}
info = system.GetInfo(context.TODO())
_, err = remoteClient.Register(ValidKey, "", info, nil, nil)
if err != nil {
t.Fatal(err)
}
ch := make(chan *mgmtProto.SyncResponse, 1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
err = client.Sync(ctx, info, func(msg *mgmtProto.SyncResponse) error {
ch <- msg
return nil
})
if err != nil {
return
}
}()
select {
case resp := <-ch:
if resp.GetPeerConfig() == nil {
t.Error("expecting non nil PeerConfig got nil")
}
if resp.GetNetbirdConfig() == nil {
t.Error("expecting non nil NetbirdConfig got nil")
}
if len(resp.GetRemotePeers()) != 1 {
t.Errorf("expecting RemotePeers size %d got %d", 1, len(resp.GetRemotePeers()))
return
}
if resp.GetRemotePeersIsEmpty() == true {
t.Error("expecting RemotePeers property to be false, got true")
}
if resp.GetRemotePeers()[0].GetWgPubKey() != remoteKey.PublicKey().String() {
t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), resp.GetRemotePeers()[0].GetWgPubKey())
}
case <-time.After(3 * time.Second):
t.Error("timeout waiting for test to finish")
}
}
func Test_SystemMetaDataFromClient(t *testing.T) {
s, lis, mgmtMockServer, serverKey := startMockManagement(t)
defer s.GracefulStop()
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
serverAddr := lis.Addr().String()
ctx := context.Background()
testClient, err := NewClient(ctx, serverAddr, testKey, false)
if err != nil {
t.Fatalf("error while creating testClient: %v", err)
}
var actualMeta *mgmtProto.PeerSystemMeta
var actualValidKey string
var wg sync.WaitGroup
wg.Add(1)
mgmtMockServer.LoginFunc = func(ctx context.Context, msg *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
peerKey, err := wgtypes.ParseKey(msg.GetWgPubKey())
if err != nil {
log.Warnf("error while parsing peer's Wireguard public key %s on Sync request.", msg.WgPubKey)
return nil, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", msg.WgPubKey)
}
loginReq := &mgmtProto.LoginRequest{}
err = encryption.DecryptMessage(peerKey, serverKey, msg.Body, loginReq)
if err != nil {
log.Fatal(err)
}
actualMeta = loginReq.GetMeta()
actualValidKey = loginReq.GetSetupKey()
wg.Done()
loginResp := &mgmtProto.LoginResponse{}
encryptedResp, err := encryption.EncryptMessage(peerKey, serverKey, loginResp)
if err != nil {
return nil, err
}
return &mgmtProto.EncryptedMessage{
WgPubKey: serverKey.PublicKey().String(),
Body: encryptedResp,
Version: 0,
}, nil
}
info := system.GetInfo(context.TODO())
_, err = testClient.Register(ValidKey, "", info, nil, nil)
if err != nil {
t.Errorf("error while trying to register client: %v", err)
}
wg.Wait()
protoNetAddr := make([]*mgmtProto.NetworkAddress, 0, len(info.NetworkAddresses))
for _, addr := range info.NetworkAddresses {
protoNetAddr = append(protoNetAddr, &mgmtProto.NetworkAddress{
NetIP: addr.NetIP.String(),
Mac: addr.Mac,
})
}
expectedMeta := &mgmtProto.PeerSystemMeta{
Hostname: info.Hostname,
GoOS: info.GoOS,
Kernel: info.Kernel,
Platform: info.Platform,
OS: info.OS,
Core: info.OSVersion,
OSVersion: info.OSVersion,
NetbirdVersion: info.NetbirdVersion,
KernelVersion: info.KernelVersion,
NetworkAddresses: protoNetAddr,
SysSerialNumber: info.SystemSerialNumber,
SysProductName: info.SystemProductName,
SysManufacturer: info.SystemManufacturer,
Environment: &mgmtProto.Environment{Cloud: info.Environment.Cloud, Platform: info.Environment.Platform},
}
assert.Equal(t, ValidKey, actualValidKey)
if !isEqual(expectedMeta, actualMeta) {
t.Errorf("expected and actual meta are not equal")
}
}
func isEqual(a, b *mgmtProto.PeerSystemMeta) bool {
if len(a.NetworkAddresses) != len(b.NetworkAddresses) {
return false
}
for _, addr := range a.GetNetworkAddresses() {
var found bool
for _, oAddr := range b.GetNetworkAddresses() {
if addr.GetMac() == oAddr.GetMac() && addr.GetNetIP() == oAddr.GetNetIP() {
found = true
continue
}
}
if !found {
return false
}
}
log.Infof("------")
return a.GetHostname() == b.GetHostname() &&
a.GetGoOS() == b.GetGoOS() &&
a.GetKernel() == b.GetKernel() &&
a.GetKernelVersion() == b.GetKernelVersion() &&
a.GetCore() == b.GetCore() &&
a.GetPlatform() == b.GetPlatform() &&
a.GetOS() == b.GetOS() &&
a.GetOSVersion() == b.GetOSVersion() &&
a.GetNetbirdVersion() == b.GetNetbirdVersion() &&
a.GetUiVersion() == b.GetUiVersion() &&
a.GetSysSerialNumber() == b.GetSysSerialNumber() &&
a.GetSysProductName() == b.GetSysProductName() &&
a.GetSysManufacturer() == b.GetSysManufacturer() &&
a.GetEnvironment().Cloud == b.GetEnvironment().Cloud &&
a.GetEnvironment().Platform == b.GetEnvironment().Platform
}
func Test_GetDeviceAuthorizationFlow(t *testing.T) {
s, lis, mgmtMockServer, serverKey := startMockManagement(t)
defer s.GracefulStop()
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
serverAddr := lis.Addr().String()
ctx := context.Background()
client, err := NewClient(ctx, serverAddr, testKey, false)
if err != nil {
t.Fatalf("error while creating testClient: %v", err)
}
expectedFlowInfo := &mgmtProto.DeviceAuthorizationFlow{
Provider: 0,
ProviderConfig: &mgmtProto.ProviderConfig{ClientID: "client"},
}
mgmtMockServer.GetDeviceAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
encryptedResp, err := encryption.EncryptMessage(client.key.PublicKey(), serverKey, expectedFlowInfo)
if err != nil {
return nil, err
}
return &mgmtProto.EncryptedMessage{
WgPubKey: serverKey.PublicKey().String(),
Body: encryptedResp,
Version: 0,
}, nil
}
flowInfo, err := client.GetDeviceAuthorizationFlow()
if err != nil {
t.Error("error while retrieving device auth flow information")
}
assert.Equal(t, expectedFlowInfo.Provider, flowInfo.Provider, "provider should match")
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match")
}
func Test_GetPKCEAuthorizationFlow(t *testing.T) {
s, lis, mgmtMockServer, serverKey := startMockManagement(t)
defer s.GracefulStop()
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
serverAddr := lis.Addr().String()
ctx := context.Background()
client, err := NewClient(ctx, serverAddr, testKey, false)
if err != nil {
t.Fatalf("error while creating testClient: %v", err)
}
expectedFlowInfo := &mgmtProto.PKCEAuthorizationFlow{
ProviderConfig: &mgmtProto.ProviderConfig{
ClientID: "client",
ClientSecret: "secret",
},
}
mgmtMockServer.GetPKCEAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
encryptedResp, err := encryption.EncryptMessage(client.key.PublicKey(), serverKey, expectedFlowInfo)
if err != nil {
return nil, err
}
return &mgmtProto.EncryptedMessage{
WgPubKey: serverKey.PublicKey().String(),
Body: encryptedResp,
Version: 0,
}, nil
}
flowInfo, err := client.GetPKCEAuthorizationFlow()
if err != nil {
t.Error("error while retrieving pkce auth flow information")
}
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match")
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match") //nolint:staticcheck
}