diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 5660d19e3..0e83f80f4 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -45,7 +45,8 @@ type GrpcClient struct { connStateCallback ConnStateNotifier connStateCallbackLock sync.RWMutex - srvKey *wgtypes.Key + srvKey *wgtypes.Key + srvKeyMu sync.RWMutex } // NewClient creates a new client to Management service @@ -124,11 +125,14 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler return fmt.Errorf("connection to management is not ready and in %s state", connState) } - serverPubKey, err := c.GetServerPublicKey() + serverPubKey, err := c.refreshServerKey() if err != nil { log.Debugf(errMsgMgmtPublicKey, err) return err } + c.srvKeyMu.Lock() + c.srvKey = serverPubKey + c.srvKeyMu.Unlock() return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler) } @@ -272,25 +276,21 @@ func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) { return nil, errors.New(errMsgNoMgmtConnection) } + c.srvKeyMu.RLock() if c.srvKey != nil { + c.srvKeyMu.RUnlock() return c.srvKey, nil } + c.srvKeyMu.RUnlock() - mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) - defer cancel() - resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{}) - if err != nil { - log.Errorf("failed while getting Management Service public key: %v", err) - return nil, fmt.Errorf("failed while getting Management Service public key") - } - - serverKey, err := wgtypes.ParseKey(resp.Key) + srvKey, err := c.refreshServerKey() if err != nil { return nil, err } - c.srvKey = &serverKey - - return &serverKey, nil + c.srvKeyMu.Lock() + c.srvKey = srvKey + c.srvKeyMu.Unlock() + return srvKey, nil } // IsHealthy probes the gRPC connection and returns false on errors @@ -319,6 +319,26 @@ func (c *GrpcClient) IsHealthy() bool { return true } +func (c *GrpcClient) refreshServerKey() (*wgtypes.Key, error) { + if !c.ready() { + return nil, errors.New(errMsgNoMgmtConnection) + } + + mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) + defer cancel() + resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{}) + if err != nil { + log.Errorf("failed while getting Management Service public key: %v", err) + return nil, fmt.Errorf("failed while getting Management Service public key") + } + + serverKey, err := wgtypes.ParseKey(resp.Key) + if err != nil { + return nil, err + } + return &serverKey, nil +} + func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) { if !c.ready() { return nil, errors.New(errMsgNoMgmtConnection)