mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-15 23:06:38 +00:00
[client] Unexport GetServerPublicKey, add HealthCheck method (#5735)
* Unexport GetServerPublicKey, add HealthCheck method Internalize server key fetching into Login, Register, GetDeviceAuthorizationFlow, and GetPKCEAuthorizationFlow methods, removing the need for callers to fetch and pass the key separately. Replace the exported GetServerPublicKey with a HealthCheck() error method for connection validation, keeping IsHealthy() bool for non-blocking background monitoring. Fix test encryption to use correct key pairs (client public key as remotePubKey instead of server private key). * Refactor `doMgmLogin` to return only error, removing unused response
This commit is contained in:
@@ -155,7 +155,7 @@ func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
|
||||
var needsLogin bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if isLoginNeeded(err) {
|
||||
needsLogin = true
|
||||
return nil
|
||||
@@ -179,8 +179,8 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
|
||||
var isAuthError bool
|
||||
|
||||
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
|
||||
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if serverKey != nil && isRegistrationNeeded(err) {
|
||||
err := a.doMgmLogin(client, ctx, pubSSHKey)
|
||||
if isRegistrationNeeded(err) {
|
||||
log.Debugf("peer registration required")
|
||||
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
|
||||
if err != nil {
|
||||
@@ -201,13 +201,7 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
|
||||
|
||||
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
|
||||
protoFlow, err := client.GetPKCEAuthorizationFlow()
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
|
||||
@@ -246,13 +240,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
|
||||
|
||||
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
|
||||
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
|
||||
protoFlow, err := client.GetDeviceAuthorizationFlow()
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
|
||||
log.Warnf("server couldn't find device flow, contact admin: %v", err)
|
||||
@@ -292,28 +280,16 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
|
||||
}
|
||||
|
||||
// doMgmLogin performs the actual login operation with the management service
|
||||
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) error {
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(sysInfo)
|
||||
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||
return serverKey, loginResp, err
|
||||
_, err := client.Login(sysInfo, pubSSHKey, a.config.DNSLabels)
|
||||
return err
|
||||
}
|
||||
|
||||
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
|
||||
// Otherwise tries to register with the provided setupKey via command line.
|
||||
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
validSetupKey, err := uuid.Parse(setupKey)
|
||||
if err != nil && jwtToken == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
|
||||
@@ -322,7 +298,7 @@ func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKe
|
||||
log.Debugf("sending peer registration request to Management Service")
|
||||
info := system.GetInfo(ctx)
|
||||
a.setSystemInfoFlags(info)
|
||||
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||
loginResp, err := client.Register(validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
|
||||
if err != nil {
|
||||
log.Errorf("failed registering peer %v", err)
|
||||
return nil, err
|
||||
|
||||
@@ -617,12 +617,6 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
|
||||
|
||||
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
|
||||
|
||||
serverPublicKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
|
||||
}
|
||||
|
||||
sysInfo := system.GetInfo(ctx)
|
||||
sysInfo.SetFlags(
|
||||
config.RosenpassEnabled,
|
||||
@@ -641,12 +635,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return loginResp, nil
|
||||
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
|
||||
}
|
||||
|
||||
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {
|
||||
|
||||
@@ -828,7 +828,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
@@ -1035,7 +1035,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, EngineServices{
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
@@ -1538,13 +1538,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
return nil, err
|
||||
}
|
||||
|
||||
publicKey, err := mgmtClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
info := system.GetInfo(ctx)
|
||||
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil)
|
||||
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1566,7 +1561,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||
SignalClient: signalClient,
|
||||
MgmClient: mgmtClient,
|
||||
RelayManager: relayMgr,
|
||||
|
||||
@@ -777,8 +777,7 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri
|
||||
}()
|
||||
|
||||
// gRPC check
|
||||
_, err = client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
if err = client.HealthCheck(); err != nil {
|
||||
log.Infof("couldn't switch to the new Management %s", newURL.String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
@@ -16,14 +14,18 @@ type Client interface {
|
||||
io.Closer
|
||||
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
|
||||
Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
|
||||
GetServerPublicKey() (*wgtypes.Key, error)
|
||||
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
|
||||
GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
|
||||
Register(setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
Login(sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error)
|
||||
GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error)
|
||||
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)
|
||||
GetServerURL() string
|
||||
// IsHealthy returns the current connection status without blocking.
|
||||
// Used by the engine to monitor connectivity in the background.
|
||||
IsHealthy() bool
|
||||
// HealthCheck actively probes the management server and returns an error if unreachable.
|
||||
// Used to validate connectivity before committing configuration changes.
|
||||
HealthCheck() error
|
||||
SyncMeta(sysInfo *system.Info) error
|
||||
Logout() error
|
||||
CreateExpose(ctx context.Context, req ExposeRequest) (*ExposeResponse, error)
|
||||
|
||||
@@ -189,7 +189,7 @@ func closeManagementSilently(s *grpc.Server, listener net.Listener) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_GetServerPublicKey(t *testing.T) {
|
||||
func TestClient_HealthCheck(t *testing.T) {
|
||||
testKey, err := wgtypes.GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -203,12 +203,8 @@ func TestClient_GetServerPublicKey(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
key, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
t.Error("couldn't retrieve management public key")
|
||||
}
|
||||
if key == nil {
|
||||
t.Error("got an empty management public key")
|
||||
if err := client.HealthCheck(); err != nil {
|
||||
t.Errorf("health check failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -225,12 +221,8 @@ func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
key, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sysInfo := system.GetInfo(context.TODO())
|
||||
_, err = client.Login(*key, sysInfo, nil, nil)
|
||||
_, err = client.Login(sysInfo, nil, nil)
|
||||
if err == nil {
|
||||
t.Error("expecting err on unregistered login, got nil")
|
||||
}
|
||||
@@ -253,12 +245,8 @@ func TestClient_LoginRegistered(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
key, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
info := system.GetInfo(context.TODO())
|
||||
resp, err := client.Register(*key, ValidKey, "", info, nil, nil)
|
||||
resp, err := client.Register(ValidKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -282,13 +270,8 @@ func TestClient_Sync(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
info := system.GetInfo(context.TODO())
|
||||
_, err = client.Register(*serverKey, ValidKey, "", info, nil, nil)
|
||||
_, err = client.Register(ValidKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -304,7 +287,7 @@ func TestClient_Sync(t *testing.T) {
|
||||
}
|
||||
|
||||
info = system.GetInfo(context.TODO())
|
||||
_, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil, nil)
|
||||
_, err = remoteClient.Register(ValidKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -364,11 +347,6 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
|
||||
t.Fatalf("error while creating testClient: %v", err)
|
||||
}
|
||||
|
||||
key, err := testClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
t.Fatalf("error while getting server public key from testclient, %v", err)
|
||||
}
|
||||
|
||||
var actualMeta *mgmtProto.PeerSystemMeta
|
||||
var actualValidKey string
|
||||
var wg sync.WaitGroup
|
||||
@@ -405,7 +383,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
|
||||
}
|
||||
|
||||
info := system.GetInfo(context.TODO())
|
||||
_, err = testClient.Register(*key, ValidKey, "", info, nil, nil)
|
||||
_, err = testClient.Register(ValidKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
t.Errorf("error while trying to register client: %v", err)
|
||||
}
|
||||
@@ -505,7 +483,7 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
}
|
||||
|
||||
mgmtMockServer.GetDeviceAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
|
||||
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
|
||||
encryptedResp, err := encryption.EncryptMessage(client.key.PublicKey(), serverKey, expectedFlowInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -517,7 +495,7 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
flowInfo, err := client.GetDeviceAuthorizationFlow(serverKey)
|
||||
flowInfo, err := client.GetDeviceAuthorizationFlow()
|
||||
if err != nil {
|
||||
t.Error("error while retrieving device auth flow information")
|
||||
}
|
||||
@@ -551,7 +529,7 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
|
||||
}
|
||||
|
||||
mgmtMockServer.GetPKCEAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
|
||||
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
|
||||
encryptedResp, err := encryption.EncryptMessage(client.key.PublicKey(), serverKey, expectedFlowInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -563,7 +541,7 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
flowInfo, err := client.GetPKCEAuthorizationFlow(serverKey)
|
||||
flowInfo, err := client.GetPKCEAuthorizationFlow()
|
||||
if err != nil {
|
||||
t.Error("error while retrieving pkce auth flow information")
|
||||
}
|
||||
|
||||
@@ -202,7 +202,7 @@ func (c *GrpcClient) withMgmtStream(
|
||||
return fmt.Errorf("connection to management is not ready and in %s state", connState)
|
||||
}
|
||||
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
serverPubKey, err := c.getServerPublicKey()
|
||||
if err != nil {
|
||||
log.Debugf(errMsgMgmtPublicKey, err)
|
||||
return err
|
||||
@@ -404,7 +404,7 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.
|
||||
|
||||
// GetNetworkMap return with the network map
|
||||
func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) {
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
serverPubKey, err := c.getServerPublicKey()
|
||||
if err != nil {
|
||||
log.Debugf("failed getting Management Service public key: %s", err)
|
||||
return nil, err
|
||||
@@ -490,18 +490,24 @@ func (c *GrpcClient) receiveUpdatesEvents(stream proto.ManagementService_SyncCli
|
||||
}
|
||||
}
|
||||
|
||||
// GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server)
|
||||
func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||
// HealthCheck actively probes the management server and returns an error if unreachable.
|
||||
// Used to validate connectivity before committing configuration changes.
|
||||
func (c *GrpcClient) HealthCheck() error {
|
||||
if !c.ready() {
|
||||
return nil, errors.New(errMsgNoMgmtConnection)
|
||||
return errors.New(errMsgNoMgmtConnection)
|
||||
}
|
||||
|
||||
_, err := c.getServerPublicKey()
|
||||
return err
|
||||
}
|
||||
|
||||
// getServerPublicKey fetches the server's WireGuard public key.
|
||||
func (c *GrpcClient) getServerPublicKey() (*wgtypes.Key, error) {
|
||||
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, fmt.Errorf("failed while getting Management Service public key")
|
||||
return nil, fmt.Errorf("failed getting Management Service public key: %w", err)
|
||||
}
|
||||
|
||||
serverKey, err := wgtypes.ParseKey(resp.Key)
|
||||
@@ -512,7 +518,8 @@ func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||
return &serverKey, nil
|
||||
}
|
||||
|
||||
// IsHealthy probes the gRPC connection and returns false on errors
|
||||
// IsHealthy returns the current connection status without blocking.
|
||||
// Used by the engine to monitor connectivity in the background.
|
||||
func (c *GrpcClient) IsHealthy() bool {
|
||||
switch c.conn.GetState() {
|
||||
case connectivity.TransientFailure:
|
||||
@@ -538,12 +545,17 @@ func (c *GrpcClient) IsHealthy() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
|
||||
func (c *GrpcClient) login(req *proto.LoginRequest) (*proto.LoginResponse, error) {
|
||||
if !c.ready() {
|
||||
return nil, errors.New(errMsgNoMgmtConnection)
|
||||
}
|
||||
|
||||
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
|
||||
serverKey, err := c.getServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loginReq, err := encryption.EncryptMessage(*serverKey, c.key, req)
|
||||
if err != nil {
|
||||
log.Errorf("failed to encrypt message: %s", err)
|
||||
return nil, err
|
||||
@@ -577,7 +589,7 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
|
||||
}
|
||||
|
||||
loginResp := &proto.LoginResponse{}
|
||||
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
|
||||
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, loginResp)
|
||||
if err != nil {
|
||||
log.Errorf("failed to decrypt login response: %s", err)
|
||||
return nil, err
|
||||
@@ -589,34 +601,40 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
|
||||
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
|
||||
// Takes care of encrypting and decrypting messages.
|
||||
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
|
||||
func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||
func (c *GrpcClient) Register(setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||
keys := &proto.PeerKeys{
|
||||
SshPubKey: pubSSHKey,
|
||||
WgPubKey: []byte(c.key.PublicKey().String()),
|
||||
}
|
||||
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
|
||||
return c.login(&proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
|
||||
}
|
||||
|
||||
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
|
||||
func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||
func (c *GrpcClient) Login(sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||
keys := &proto.PeerKeys{
|
||||
SshPubKey: pubSSHKey,
|
||||
WgPubKey: []byte(c.key.PublicKey().String()),
|
||||
}
|
||||
return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
|
||||
return c.login(&proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationFlow returns a device authorization flow information.
|
||||
// It also takes care of encrypting and decrypting messages.
|
||||
func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
|
||||
func (c *GrpcClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) {
|
||||
if !c.ready() {
|
||||
return nil, fmt.Errorf("no connection to management in order to get device authorization flow")
|
||||
}
|
||||
|
||||
serverKey, err := c.getServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
|
||||
defer cancel()
|
||||
|
||||
message := &proto.DeviceAuthorizationFlowRequest{}
|
||||
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
|
||||
encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -630,7 +648,7 @@ func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D
|
||||
}
|
||||
|
||||
flowInfoResp := &proto.DeviceAuthorizationFlow{}
|
||||
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
|
||||
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfoResp)
|
||||
if err != nil {
|
||||
errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err)
|
||||
log.Error(errWithMSG)
|
||||
@@ -642,15 +660,21 @@ func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D
|
||||
|
||||
// GetPKCEAuthorizationFlow returns a pkce authorization flow information.
|
||||
// It also takes care of encrypting and decrypting messages.
|
||||
func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
|
||||
func (c *GrpcClient) GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error) {
|
||||
if !c.ready() {
|
||||
return nil, fmt.Errorf("no connection to management in order to get pkce authorization flow")
|
||||
}
|
||||
|
||||
serverKey, err := c.getServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
|
||||
defer cancel()
|
||||
|
||||
message := &proto.PKCEAuthorizationFlowRequest{}
|
||||
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
|
||||
encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -664,7 +688,7 @@ func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKC
|
||||
}
|
||||
|
||||
flowInfoResp := &proto.PKCEAuthorizationFlow{}
|
||||
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
|
||||
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfoResp)
|
||||
if err != nil {
|
||||
errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err)
|
||||
log.Error(errWithMSG)
|
||||
@@ -681,7 +705,7 @@ func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
|
||||
return errors.New(errMsgNoMgmtConnection)
|
||||
}
|
||||
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
serverPubKey, err := c.getServerPublicKey()
|
||||
if err != nil {
|
||||
log.Debugf(errMsgMgmtPublicKey, err)
|
||||
return err
|
||||
@@ -724,7 +748,7 @@ func (c *GrpcClient) notifyConnected() {
|
||||
}
|
||||
|
||||
func (c *GrpcClient) Logout() error {
|
||||
serverKey, err := c.GetServerPublicKey()
|
||||
serverKey, err := c.getServerPublicKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get server public key: %w", err)
|
||||
}
|
||||
@@ -751,7 +775,7 @@ func (c *GrpcClient) Logout() error {
|
||||
|
||||
// CreateExpose calls the management server to create a new expose service.
|
||||
func (c *GrpcClient) CreateExpose(ctx context.Context, req ExposeRequest) (*ExposeResponse, error) {
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
serverPubKey, err := c.getServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -787,7 +811,7 @@ func (c *GrpcClient) CreateExpose(ctx context.Context, req ExposeRequest) (*Expo
|
||||
|
||||
// RenewExpose extends the TTL of an active expose session on the management server.
|
||||
func (c *GrpcClient) RenewExpose(ctx context.Context, domain string) error {
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
serverPubKey, err := c.getServerPublicKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -810,7 +834,7 @@ func (c *GrpcClient) RenewExpose(ctx context.Context, domain string) error {
|
||||
|
||||
// StopExpose terminates an active expose session on the management server.
|
||||
func (c *GrpcClient) StopExpose(ctx context.Context, domain string) error {
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
serverPubKey, err := c.getServerPublicKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -3,8 +3,6 @@ package client
|
||||
import (
|
||||
"context"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
@@ -14,12 +12,12 @@ import (
|
||||
type MockClient struct {
|
||||
CloseFunc func() error
|
||||
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
|
||||
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
|
||||
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
|
||||
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
|
||||
RegisterFunc func(setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
LoginFunc func(info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
GetDeviceAuthorizationFlowFunc func() (*proto.DeviceAuthorizationFlow, error)
|
||||
GetPKCEAuthorizationFlowFunc func() (*proto.PKCEAuthorizationFlow, error)
|
||||
GetServerURLFunc func() string
|
||||
HealthCheckFunc func() error
|
||||
SyncMetaFunc func(sysInfo *system.Info) error
|
||||
LogoutFunc func() error
|
||||
JobFunc func(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
|
||||
@@ -53,39 +51,39 @@ func (m *MockClient) Job(ctx context.Context, msgHandler func(msg *proto.JobRequ
|
||||
return m.JobFunc(ctx, msgHandler)
|
||||
}
|
||||
|
||||
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||
if m.GetServerPublicKeyFunc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.GetServerPublicKeyFunc()
|
||||
}
|
||||
|
||||
func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||
func (m *MockClient) Register(setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||
if m.RegisterFunc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey, dnsLabels)
|
||||
return m.RegisterFunc(setupKey, jwtToken, info, sshKey, dnsLabels)
|
||||
}
|
||||
|
||||
func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||
func (m *MockClient) Login(info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||
if m.LoginFunc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.LoginFunc(serverKey, info, sshKey, dnsLabels)
|
||||
return m.LoginFunc(info, sshKey, dnsLabels)
|
||||
}
|
||||
|
||||
func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
|
||||
func (m *MockClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) {
|
||||
if m.GetDeviceAuthorizationFlowFunc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.GetDeviceAuthorizationFlowFunc(serverKey)
|
||||
return m.GetDeviceAuthorizationFlowFunc()
|
||||
}
|
||||
|
||||
func (m *MockClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
|
||||
func (m *MockClient) GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error) {
|
||||
if m.GetPKCEAuthorizationFlowFunc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.GetPKCEAuthorizationFlowFunc(serverKey)
|
||||
return m.GetPKCEAuthorizationFlowFunc()
|
||||
}
|
||||
|
||||
func (m *MockClient) HealthCheck() error {
|
||||
if m.HealthCheckFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return m.HealthCheckFunc()
|
||||
}
|
||||
|
||||
// GetNetworkMap mock implementation of GetNetworkMap from Client interface.
|
||||
|
||||
Reference in New Issue
Block a user