mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[client] Unexport GetServerPublicKey, add HealthCheck method (#5735)
* Unexport GetServerPublicKey, add HealthCheck method Internalize server key fetching into Login, Register, GetDeviceAuthorizationFlow, and GetPKCEAuthorizationFlow methods, removing the need for callers to fetch and pass the key separately. Replace the exported GetServerPublicKey with a HealthCheck() error method for connection validation, keeping IsHealthy() bool for non-blocking background monitoring. Fix test encryption to use correct key pairs (client public key as remotePubKey instead of server private key). * Refactor `doMgmLogin` to return only error, removing unused response
This commit is contained in:
@@ -155,7 +155,7 @@ func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
|
|||||||
var needsLogin bool
|
var needsLogin bool
|
||||||
|
|
||||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||||
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||||
if isLoginNeeded(err) {
|
if isLoginNeeded(err) {
|
||||||
needsLogin = true
|
needsLogin = true
|
||||||
return nil
|
return nil
|
||||||
@@ -179,8 +179,8 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
|
|||||||
var isAuthError bool
|
var isAuthError bool
|
||||||
|
|
||||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||||
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||||
if serverKey != nil && isRegistrationNeeded(err) {
|
if isRegistrationNeeded(err) {
|
||||||
log.Debugf("peer registration required")
|
log.Debugf("peer registration required")
|
||||||
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -201,13 +201,7 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
|
|||||||
|
|
||||||
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
||||||
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
||||||
serverKey, err := client.GetServerPublicKey()
|
protoFlow, err := client.GetPKCEAuthorizationFlow()
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||||
@@ -246,13 +240,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
|
|||||||
|
|
||||||
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
||||||
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
||||||
serverKey, err := client.GetServerPublicKey()
|
protoFlow, err := client.GetDeviceAuthorizationFlow()
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||||
@@ -292,28 +280,16 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// doMgmLogin performs the actual login operation with the management service
|
// doMgmLogin performs the actual login operation with the management service
|
||||||
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) error {
|
||||||
serverKey, err := client.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
sysInfo := system.GetInfo(ctx)
|
sysInfo := system.GetInfo(ctx)
|
||||||
a.setSystemInfoFlags(sysInfo)
|
a.setSystemInfoFlags(sysInfo)
|
||||||
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
|
_, err := client.Login(sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||||
return serverKey, loginResp, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||||
// Otherwise tries to register with the provided setupKey via command line.
|
// Otherwise tries to register with the provided setupKey via command line.
|
||||||
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||||
serverPublicKey, err := client.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
validSetupKey, err := uuid.Parse(setupKey)
|
validSetupKey, err := uuid.Parse(setupKey)
|
||||||
if err != nil && jwtToken == "" {
|
if err != nil && jwtToken == "" {
|
||||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||||
@@ -322,7 +298,7 @@ func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKe
|
|||||||
log.Debugf("sending peer registration request to Management Service")
|
log.Debugf("sending peer registration request to Management Service")
|
||||||
info := system.GetInfo(ctx)
|
info := system.GetInfo(ctx)
|
||||||
a.setSystemInfoFlags(info)
|
a.setSystemInfoFlags(info)
|
||||||
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
loginResp, err := client.Register(validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed registering peer %v", err)
|
log.Errorf("failed registering peer %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -617,12 +617,6 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
|
|||||||
|
|
||||||
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
||||||
|
|
||||||
serverPublicKey, err := client.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sysInfo := system.GetInfo(ctx)
|
sysInfo := system.GetInfo(ctx)
|
||||||
sysInfo.SetFlags(
|
sysInfo.SetFlags(
|
||||||
config.RosenpassEnabled,
|
config.RosenpassEnabled,
|
||||||
@@ -641,12 +635,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
|||||||
config.EnableSSHRemotePortForwarding,
|
config.EnableSSHRemotePortForwarding,
|
||||||
config.DisableSSHAuth,
|
config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return loginResp, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
|
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
|
||||||
|
|||||||
@@ -828,7 +828,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
}, EngineServices{
|
}, EngineServices{
|
||||||
SignalClient: &signal.MockClient{},
|
SignalClient: &signal.MockClient{},
|
||||||
MgmClient: &mgmt.MockClient{},
|
MgmClient: &mgmt.MockClient{},
|
||||||
RelayManager: relayMgr,
|
RelayManager: relayMgr,
|
||||||
@@ -1035,7 +1035,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
}, EngineServices{
|
}, EngineServices{
|
||||||
SignalClient: &signal.MockClient{},
|
SignalClient: &signal.MockClient{},
|
||||||
MgmClient: &mgmt.MockClient{},
|
MgmClient: &mgmt.MockClient{},
|
||||||
RelayManager: relayMgr,
|
RelayManager: relayMgr,
|
||||||
@@ -1538,13 +1538,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
publicKey, err := mgmtClient.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
info := system.GetInfo(ctx)
|
info := system.GetInfo(ctx)
|
||||||
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil)
|
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1566,7 +1561,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||||
SignalClient: signalClient,
|
SignalClient: signalClient,
|
||||||
MgmClient: mgmtClient,
|
MgmClient: mgmtClient,
|
||||||
RelayManager: relayMgr,
|
RelayManager: relayMgr,
|
||||||
|
|||||||
@@ -777,8 +777,7 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// gRPC check
|
// gRPC check
|
||||||
_, err = client.GetServerPublicKey()
|
if err = client.HealthCheck(); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
@@ -16,14 +14,18 @@ type Client interface {
|
|||||||
io.Closer
|
io.Closer
|
||||||
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
|
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
|
||||||
Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
|
Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
|
||||||
GetServerPublicKey() (*wgtypes.Key, error)
|
Register(setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||||
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
Login(sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||||
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error)
|
||||||
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
|
GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error)
|
||||||
GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
|
|
||||||
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)
|
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)
|
||||||
GetServerURL() string
|
GetServerURL() string
|
||||||
|
// IsHealthy returns the current connection status without blocking.
|
||||||
|
// Used by the engine to monitor connectivity in the background.
|
||||||
IsHealthy() bool
|
IsHealthy() bool
|
||||||
|
// HealthCheck actively probes the management server and returns an error if unreachable.
|
||||||
|
// Used to validate connectivity before committing configuration changes.
|
||||||
|
HealthCheck() error
|
||||||
SyncMeta(sysInfo *system.Info) error
|
SyncMeta(sysInfo *system.Info) error
|
||||||
Logout() error
|
Logout() error
|
||||||
CreateExpose(ctx context.Context, req ExposeRequest) (*ExposeResponse, error)
|
CreateExpose(ctx context.Context, req ExposeRequest) (*ExposeResponse, error)
|
||||||
|
|||||||
@@ -189,7 +189,7 @@ func closeManagementSilently(s *grpc.Server, listener net.Listener) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClient_GetServerPublicKey(t *testing.T) {
|
func TestClient_HealthCheck(t *testing.T) {
|
||||||
testKey, err := wgtypes.GenerateKey()
|
testKey, err := wgtypes.GenerateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -203,12 +203,8 @@ func TestClient_GetServerPublicKey(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := client.GetServerPublicKey()
|
if err := client.HealthCheck(); err != nil {
|
||||||
if err != nil {
|
t.Errorf("health check failed: %v", err)
|
||||||
t.Error("couldn't retrieve management public key")
|
|
||||||
}
|
|
||||||
if key == nil {
|
|
||||||
t.Error("got an empty management public key")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -225,12 +221,8 @@ func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
key, err := client.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
sysInfo := system.GetInfo(context.TODO())
|
sysInfo := system.GetInfo(context.TODO())
|
||||||
_, err = client.Login(*key, sysInfo, nil, nil)
|
_, err = client.Login(sysInfo, nil, nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expecting err on unregistered login, got nil")
|
t.Error("expecting err on unregistered login, got nil")
|
||||||
}
|
}
|
||||||
@@ -253,12 +245,8 @@ func TestClient_LoginRegistered(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := client.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
info := system.GetInfo(context.TODO())
|
info := system.GetInfo(context.TODO())
|
||||||
resp, err := client.Register(*key, ValidKey, "", info, nil, nil)
|
resp, err := client.Register(ValidKey, "", info, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -282,13 +270,8 @@ func TestClient_Sync(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
serverKey, err := client.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
info := system.GetInfo(context.TODO())
|
info := system.GetInfo(context.TODO())
|
||||||
_, err = client.Register(*serverKey, ValidKey, "", info, nil, nil)
|
_, err = client.Register(ValidKey, "", info, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -304,7 +287,7 @@ func TestClient_Sync(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
info = system.GetInfo(context.TODO())
|
info = system.GetInfo(context.TODO())
|
||||||
_, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil, nil)
|
_, err = remoteClient.Register(ValidKey, "", info, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -364,11 +347,6 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
|
|||||||
t.Fatalf("error while creating testClient: %v", err)
|
t.Fatalf("error while creating testClient: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
key, err := testClient.GetServerPublicKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("error while getting server public key from testclient, %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var actualMeta *mgmtProto.PeerSystemMeta
|
var actualMeta *mgmtProto.PeerSystemMeta
|
||||||
var actualValidKey string
|
var actualValidKey string
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
@@ -405,7 +383,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
info := system.GetInfo(context.TODO())
|
info := system.GetInfo(context.TODO())
|
||||||
_, err = testClient.Register(*key, ValidKey, "", info, nil, nil)
|
_, err = testClient.Register(ValidKey, "", info, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("error while trying to register client: %v", err)
|
t.Errorf("error while trying to register client: %v", err)
|
||||||
}
|
}
|
||||||
@@ -505,7 +483,7 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mgmtMockServer.GetDeviceAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
|
mgmtMockServer.GetDeviceAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
|
||||||
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
|
encryptedResp, err := encryption.EncryptMessage(client.key.PublicKey(), serverKey, expectedFlowInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -517,7 +495,7 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
flowInfo, err := client.GetDeviceAuthorizationFlow(serverKey)
|
flowInfo, err := client.GetDeviceAuthorizationFlow()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("error while retrieving device auth flow information")
|
t.Error("error while retrieving device auth flow information")
|
||||||
}
|
}
|
||||||
@@ -551,7 +529,7 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mgmtMockServer.GetPKCEAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
|
mgmtMockServer.GetPKCEAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
|
||||||
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
|
encryptedResp, err := encryption.EncryptMessage(client.key.PublicKey(), serverKey, expectedFlowInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -563,7 +541,7 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
flowInfo, err := client.GetPKCEAuthorizationFlow(serverKey)
|
flowInfo, err := client.GetPKCEAuthorizationFlow()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error("error while retrieving pkce auth flow information")
|
t.Error("error while retrieving pkce auth flow information")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ func (c *GrpcClient) withMgmtStream(
|
|||||||
return fmt.Errorf("connection to management is not ready and in %s state", connState)
|
return fmt.Errorf("connection to management is not ready and in %s state", connState)
|
||||||
}
|
}
|
||||||
|
|
||||||
serverPubKey, err := c.GetServerPublicKey()
|
serverPubKey, err := c.getServerPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf(errMsgMgmtPublicKey, err)
|
log.Debugf(errMsgMgmtPublicKey, err)
|
||||||
return err
|
return err
|
||||||
@@ -404,7 +404,7 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.
|
|||||||
|
|
||||||
// GetNetworkMap return with the network map
|
// GetNetworkMap return with the network map
|
||||||
func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) {
|
func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) {
|
||||||
serverPubKey, err := c.GetServerPublicKey()
|
serverPubKey, err := c.getServerPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed getting Management Service public key: %s", err)
|
log.Debugf("failed getting Management Service public key: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -490,18 +490,24 @@ func (c *GrpcClient) receiveUpdatesEvents(stream proto.ManagementService_SyncCli
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server)
|
// HealthCheck actively probes the management server and returns an error if unreachable.
|
||||||
func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
// Used to validate connectivity before committing configuration changes.
|
||||||
|
func (c *GrpcClient) HealthCheck() error {
|
||||||
if !c.ready() {
|
if !c.ready() {
|
||||||
return nil, errors.New(errMsgNoMgmtConnection)
|
return errors.New(errMsgNoMgmtConnection)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err := c.getServerPublicKey()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// getServerPublicKey fetches the server's WireGuard public key.
|
||||||
|
func (c *GrpcClient) getServerPublicKey() (*wgtypes.Key, error) {
|
||||||
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
|
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
|
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
return nil, fmt.Errorf("failed getting Management Service public key: %w", err)
|
||||||
return nil, fmt.Errorf("failed while getting Management Service public key")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
serverKey, err := wgtypes.ParseKey(resp.Key)
|
serverKey, err := wgtypes.ParseKey(resp.Key)
|
||||||
@@ -512,7 +518,8 @@ func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
|||||||
return &serverKey, nil
|
return &serverKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsHealthy probes the gRPC connection and returns false on errors
|
// IsHealthy returns the current connection status without blocking.
|
||||||
|
// Used by the engine to monitor connectivity in the background.
|
||||||
func (c *GrpcClient) IsHealthy() bool {
|
func (c *GrpcClient) IsHealthy() bool {
|
||||||
switch c.conn.GetState() {
|
switch c.conn.GetState() {
|
||||||
case connectivity.TransientFailure:
|
case connectivity.TransientFailure:
|
||||||
@@ -538,12 +545,17 @@ func (c *GrpcClient) IsHealthy() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
|
func (c *GrpcClient) login(req *proto.LoginRequest) (*proto.LoginResponse, error) {
|
||||||
if !c.ready() {
|
if !c.ready() {
|
||||||
return nil, errors.New(errMsgNoMgmtConnection)
|
return nil, errors.New(errMsgNoMgmtConnection)
|
||||||
}
|
}
|
||||||
|
|
||||||
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
|
serverKey, err := c.getServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
loginReq, err := encryption.EncryptMessage(*serverKey, c.key, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to encrypt message: %s", err)
|
log.Errorf("failed to encrypt message: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -577,7 +589,7 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
|
|||||||
}
|
}
|
||||||
|
|
||||||
loginResp := &proto.LoginResponse{}
|
loginResp := &proto.LoginResponse{}
|
||||||
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
|
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, loginResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to decrypt login response: %s", err)
|
log.Errorf("failed to decrypt login response: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -589,34 +601,40 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
|
|||||||
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
|
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
|
||||||
// Takes care of encrypting and decrypting messages.
|
// Takes care of encrypting and decrypting messages.
|
||||||
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
|
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
|
||||||
func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
func (c *GrpcClient) Register(setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||||
keys := &proto.PeerKeys{
|
keys := &proto.PeerKeys{
|
||||||
SshPubKey: pubSSHKey,
|
SshPubKey: pubSSHKey,
|
||||||
WgPubKey: []byte(c.key.PublicKey().String()),
|
WgPubKey: []byte(c.key.PublicKey().String()),
|
||||||
}
|
}
|
||||||
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
|
return c.login(&proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
|
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
|
||||||
func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
func (c *GrpcClient) Login(sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||||
keys := &proto.PeerKeys{
|
keys := &proto.PeerKeys{
|
||||||
SshPubKey: pubSSHKey,
|
SshPubKey: pubSSHKey,
|
||||||
WgPubKey: []byte(c.key.PublicKey().String()),
|
WgPubKey: []byte(c.key.PublicKey().String()),
|
||||||
}
|
}
|
||||||
return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
|
return c.login(&proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDeviceAuthorizationFlow returns a device authorization flow information.
|
// GetDeviceAuthorizationFlow returns a device authorization flow information.
|
||||||
// It also takes care of encrypting and decrypting messages.
|
// It also takes care of encrypting and decrypting messages.
|
||||||
func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
|
func (c *GrpcClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) {
|
||||||
if !c.ready() {
|
if !c.ready() {
|
||||||
return nil, fmt.Errorf("no connection to management in order to get device authorization flow")
|
return nil, fmt.Errorf("no connection to management in order to get device authorization flow")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
serverKey, err := c.getServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
|
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
message := &proto.DeviceAuthorizationFlowRequest{}
|
message := &proto.DeviceAuthorizationFlowRequest{}
|
||||||
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
|
encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -630,7 +648,7 @@ func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D
|
|||||||
}
|
}
|
||||||
|
|
||||||
flowInfoResp := &proto.DeviceAuthorizationFlow{}
|
flowInfoResp := &proto.DeviceAuthorizationFlow{}
|
||||||
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
|
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfoResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err)
|
errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err)
|
||||||
log.Error(errWithMSG)
|
log.Error(errWithMSG)
|
||||||
@@ -642,15 +660,21 @@ func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D
|
|||||||
|
|
||||||
// GetPKCEAuthorizationFlow returns a pkce authorization flow information.
|
// GetPKCEAuthorizationFlow returns a pkce authorization flow information.
|
||||||
// It also takes care of encrypting and decrypting messages.
|
// It also takes care of encrypting and decrypting messages.
|
||||||
func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
|
func (c *GrpcClient) GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error) {
|
||||||
if !c.ready() {
|
if !c.ready() {
|
||||||
return nil, fmt.Errorf("no connection to management in order to get pkce authorization flow")
|
return nil, fmt.Errorf("no connection to management in order to get pkce authorization flow")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
serverKey, err := c.getServerPublicKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
|
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
message := &proto.PKCEAuthorizationFlowRequest{}
|
message := &proto.PKCEAuthorizationFlowRequest{}
|
||||||
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
|
encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -664,7 +688,7 @@ func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKC
|
|||||||
}
|
}
|
||||||
|
|
||||||
flowInfoResp := &proto.PKCEAuthorizationFlow{}
|
flowInfoResp := &proto.PKCEAuthorizationFlow{}
|
||||||
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
|
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfoResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err)
|
errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err)
|
||||||
log.Error(errWithMSG)
|
log.Error(errWithMSG)
|
||||||
@@ -681,7 +705,7 @@ func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
|
|||||||
return errors.New(errMsgNoMgmtConnection)
|
return errors.New(errMsgNoMgmtConnection)
|
||||||
}
|
}
|
||||||
|
|
||||||
serverPubKey, err := c.GetServerPublicKey()
|
serverPubKey, err := c.getServerPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf(errMsgMgmtPublicKey, err)
|
log.Debugf(errMsgMgmtPublicKey, err)
|
||||||
return err
|
return err
|
||||||
@@ -724,7 +748,7 @@ func (c *GrpcClient) notifyConnected() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *GrpcClient) Logout() error {
|
func (c *GrpcClient) Logout() error {
|
||||||
serverKey, err := c.GetServerPublicKey()
|
serverKey, err := c.getServerPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get server public key: %w", err)
|
return fmt.Errorf("get server public key: %w", err)
|
||||||
}
|
}
|
||||||
@@ -751,7 +775,7 @@ func (c *GrpcClient) Logout() error {
|
|||||||
|
|
||||||
// CreateExpose calls the management server to create a new expose service.
|
// CreateExpose calls the management server to create a new expose service.
|
||||||
func (c *GrpcClient) CreateExpose(ctx context.Context, req ExposeRequest) (*ExposeResponse, error) {
|
func (c *GrpcClient) CreateExpose(ctx context.Context, req ExposeRequest) (*ExposeResponse, error) {
|
||||||
serverPubKey, err := c.GetServerPublicKey()
|
serverPubKey, err := c.getServerPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -787,7 +811,7 @@ func (c *GrpcClient) CreateExpose(ctx context.Context, req ExposeRequest) (*Expo
|
|||||||
|
|
||||||
// RenewExpose extends the TTL of an active expose session on the management server.
|
// RenewExpose extends the TTL of an active expose session on the management server.
|
||||||
func (c *GrpcClient) RenewExpose(ctx context.Context, domain string) error {
|
func (c *GrpcClient) RenewExpose(ctx context.Context, domain string) error {
|
||||||
serverPubKey, err := c.GetServerPublicKey()
|
serverPubKey, err := c.getServerPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -810,7 +834,7 @@ func (c *GrpcClient) RenewExpose(ctx context.Context, domain string) error {
|
|||||||
|
|
||||||
// StopExpose terminates an active expose session on the management server.
|
// StopExpose terminates an active expose session on the management server.
|
||||||
func (c *GrpcClient) StopExpose(ctx context.Context, domain string) error {
|
func (c *GrpcClient) StopExpose(ctx context.Context, domain string) error {
|
||||||
serverPubKey, err := c.GetServerPublicKey()
|
serverPubKey, err := c.getServerPublicKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ package client
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
@@ -14,12 +12,12 @@ import (
|
|||||||
type MockClient struct {
|
type MockClient struct {
|
||||||
CloseFunc func() error
|
CloseFunc func() error
|
||||||
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
|
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
|
||||||
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
|
RegisterFunc func(setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||||
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
LoginFunc func(info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||||
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
GetDeviceAuthorizationFlowFunc func() (*proto.DeviceAuthorizationFlow, error)
|
||||||
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
|
GetPKCEAuthorizationFlowFunc func() (*proto.PKCEAuthorizationFlow, error)
|
||||||
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
|
|
||||||
GetServerURLFunc func() string
|
GetServerURLFunc func() string
|
||||||
|
HealthCheckFunc func() error
|
||||||
SyncMetaFunc func(sysInfo *system.Info) error
|
SyncMetaFunc func(sysInfo *system.Info) error
|
||||||
LogoutFunc func() error
|
LogoutFunc func() error
|
||||||
JobFunc func(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
|
JobFunc func(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
|
||||||
@@ -53,39 +51,39 @@ func (m *MockClient) Job(ctx context.Context, msgHandler func(msg *proto.JobRequ
|
|||||||
return m.JobFunc(ctx, msgHandler)
|
return m.JobFunc(ctx, msgHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
func (m *MockClient) Register(setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||||
if m.GetServerPublicKeyFunc == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return m.GetServerPublicKeyFunc()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
|
||||||
if m.RegisterFunc == nil {
|
if m.RegisterFunc == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey, dnsLabels)
|
return m.RegisterFunc(setupKey, jwtToken, info, sshKey, dnsLabels)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
func (m *MockClient) Login(info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||||
if m.LoginFunc == nil {
|
if m.LoginFunc == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return m.LoginFunc(serverKey, info, sshKey, dnsLabels)
|
return m.LoginFunc(info, sshKey, dnsLabels)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
|
func (m *MockClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) {
|
||||||
if m.GetDeviceAuthorizationFlowFunc == nil {
|
if m.GetDeviceAuthorizationFlowFunc == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return m.GetDeviceAuthorizationFlowFunc(serverKey)
|
return m.GetDeviceAuthorizationFlowFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
|
func (m *MockClient) GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error) {
|
||||||
if m.GetPKCEAuthorizationFlowFunc == nil {
|
if m.GetPKCEAuthorizationFlowFunc == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return m.GetPKCEAuthorizationFlowFunc(serverKey)
|
return m.GetPKCEAuthorizationFlowFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockClient) HealthCheck() error {
|
||||||
|
if m.HealthCheckFunc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return m.HealthCheckFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNetworkMap mock implementation of GetNetworkMap from Client interface.
|
// GetNetworkMap mock implementation of GetNetworkMap from Client interface.
|
||||||
|
|||||||
Reference in New Issue
Block a user