diff --git a/client/cmd/up.go b/client/cmd/up.go index c4a94197c..afe69b68a 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -94,7 +94,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { var cancel context.CancelFunc ctx, cancel = context.WithCancel(ctx) SetupCloseHandler(ctx, cancel) - return internal.RunClient(ctx, config, peer.NewRecorder()) + return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String())) } func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { diff --git a/client/internal/connect.go b/client/internal/connect.go index 511201f84..4a3d052b7 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -58,8 +58,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) return err } - managementURL := config.ManagementURL.String() - statusRecorder.MarkManagementDisconnected(managementURL) + statusRecorder.MarkManagementDisconnected() operation := func() error { // if context cancelled we not start new backoff cycle @@ -73,13 +72,16 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) engineCtx, cancel := context.WithCancel(ctx) defer func() { - statusRecorder.MarkManagementDisconnected(managementURL) + statusRecorder.MarkManagementDisconnected() statusRecorder.CleanLocalPeerState() cancel() }() log.Debugf("conecting to the Management service %s", config.ManagementURL.Host) mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled) + mgmNotifier := statusRecorderToMgmConnStateNotifier(statusRecorder) + mgmClient.SetConnStateListener(mgmNotifier) + if err != nil { return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err)) } @@ -101,7 +103,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) } return wrapErr(err) } - statusRecorder.MarkManagementConnected(managementURL) + statusRecorder.MarkManagementConnected() localPeerState := peer.LocalPeerState{ IP: loginResp.GetPeerConfig().GetAddress(), @@ -117,8 +119,10 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) loginResp.GetWiretrusteeConfig().GetSignal().GetUri(), ) - statusRecorder.MarkSignalDisconnected(signalURL) - defer statusRecorder.MarkSignalDisconnected(signalURL) + statusRecorder.UpdateSignalAddress(signalURL) + + statusRecorder.MarkSignalDisconnected() + defer statusRecorder.MarkSignalDisconnected() // with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal signalClient, err := connectToSignal(engineCtx, loginResp.GetWiretrusteeConfig(), myPrivateKey) @@ -133,7 +137,10 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) } }() - statusRecorder.MarkSignalConnected(signalURL) + signalNotifier := statusRecorderToSignalConnStateNotifier(statusRecorder) + signalClient.SetConnStateListener(signalNotifier) + + statusRecorder.MarkSignalConnected() peerConfig := loginResp.GetPeerConfig() @@ -320,3 +327,15 @@ func UpdateOldManagementPort(ctx context.Context, config *Config, configPath str return config, nil } + +func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier { + var sri interface{} = statusRecorder + mgmNotifier, _ := sri.(mgm.ConnStateNotifier) + return mgmNotifier +} + +func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal.ConnStateNotifier { + var sri interface{} = statusRecorder + notifier, _ := sri.(signal.ConnStateNotifier) + return notifier +} diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index bac263684..2e57a16c2 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -72,7 +72,7 @@ func TestEngine_SSH(t *testing.T) { WgAddr: "100.64.0.1/24", WgPrivateKey: key, WgPort: 33100, - }, peer.NewRecorder()) + }, peer.NewRecorder("https://mgm")) engine.dnsServer = &dns.MockServer{ UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, @@ -206,7 +206,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { WgAddr: "100.64.0.1/24", WgPrivateKey: key, WgPort: 33100, - }, peer.NewRecorder()) + }, peer.NewRecorder("https://mgm")) engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU) engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), engine.wgInterface, engine.statusRecorder) engine.dnsServer = &dns.MockServer{ @@ -390,7 +390,7 @@ func TestEngine_Sync(t *testing.T) { WgAddr: "100.64.0.1/24", WgPrivateKey: key, WgPort: 33100, - }, peer.NewRecorder()) + }, peer.NewRecorder("https://mgm")) engine.dnsServer = &dns.MockServer{ UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, @@ -548,7 +548,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { WgAddr: wgAddr, WgPrivateKey: key, WgPort: 33100, - }, peer.NewRecorder()) + }, peer.NewRecorder("https://mgm")) engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU) assert.NoError(t, err, "shouldn't return error") input := struct { @@ -713,7 +713,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { WgAddr: wgAddr, WgPrivateKey: key, WgPort: 33100, - }, peer.NewRecorder()) + }, peer.NewRecorder("https://mgm")) engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, iface.DefaultMTU) assert.NoError(t, err, "shouldn't return error") @@ -978,7 +978,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin WgPort: wgPort, } - return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, peer.NewRecorder()), nil + return NewEngine(ctx, cancel, signalClient, mgmtClient, conf, peer.NewRecorder("https://mgm")), nil } func startSignal() (*grpc.Server, string, error) { diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index ac4447956..7f9b263e4 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -49,7 +49,7 @@ func TestConn_GetKey(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) { - conn, err := NewConn(connConf, NewRecorder()) + conn, err := NewConn(connConf, NewRecorder("https://mgm")) if err != nil { return } @@ -83,7 +83,7 @@ func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) { - conn, err := NewConn(connConf, NewRecorder()) + conn, err := NewConn(connConf, NewRecorder("https://mgm")) if err != nil { return } @@ -116,7 +116,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) { } func TestConn_Status(t *testing.T) { - conn, err := NewConn(connConf, NewRecorder()) + conn, err := NewConn(connConf, NewRecorder("https://mgm")) if err != nil { return } @@ -143,7 +143,7 @@ func TestConn_Status(t *testing.T) { func TestConn_Close(t *testing.T) { - conn, err := NewConn(connConf, NewRecorder()) + conn, err := NewConn(connConf, NewRecorder("https://mgm")) if err != nil { return } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 3726a9d4a..6515cceb2 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -49,21 +49,24 @@ type FullStatus struct { // Status holds a state of peers, signal and management connections type Status struct { - mux sync.Mutex - peers map[string]State - changeNotify map[string]chan struct{} - signal SignalState - management ManagementState - localPeer LocalPeerState - offlinePeers []State + mux sync.Mutex + peers map[string]State + changeNotify map[string]chan struct{} + signalState bool + managementState bool + localPeer LocalPeerState + offlinePeers []State + mgmAddress string + signalAddress string } // NewRecorder returns a new Status instance -func NewRecorder() *Status { +func NewRecorder(mgmAddress string) *Status { return &Status{ peers: make(map[string]State), changeNotify: make(map[string]chan struct{}), offlinePeers: make([]State, 0), + mgmAddress: mgmAddress, } } @@ -193,43 +196,45 @@ func (d *Status) CleanLocalPeerState() { } // MarkManagementDisconnected sets ManagementState to disconnected -func (d *Status) MarkManagementDisconnected(managementURL string) { +func (d *Status) MarkManagementDisconnected() { d.mux.Lock() defer d.mux.Unlock() - d.management = ManagementState{ - URL: managementURL, - Connected: false, - } + d.managementState = false } // MarkManagementConnected sets ManagementState to connected -func (d *Status) MarkManagementConnected(managementURL string) { +func (d *Status) MarkManagementConnected() { d.mux.Lock() defer d.mux.Unlock() - d.management = ManagementState{ - URL: managementURL, - Connected: true, - } + d.managementState = true +} + +// UpdateSignalAddress update the address of the signal server +func (d *Status) UpdateSignalAddress(signalURL string) { + d.mux.Lock() + defer d.mux.Unlock() + d.signalAddress = signalURL +} + +// UpdateManagementAddress update the address of the management server +func (d *Status) UpdateManagementAddress(mgmAddress string) { + d.mux.Lock() + defer d.mux.Unlock() + d.mgmAddress = mgmAddress } // MarkSignalDisconnected sets SignalState to disconnected -func (d *Status) MarkSignalDisconnected(signalURL string) { +func (d *Status) MarkSignalDisconnected() { d.mux.Lock() defer d.mux.Unlock() - d.signal = SignalState{ - signalURL, - false, - } + d.signalState = false } // MarkSignalConnected sets SignalState to connected -func (d *Status) MarkSignalConnected(signalURL string) { +func (d *Status) MarkSignalConnected() { d.mux.Lock() defer d.mux.Unlock() - d.signal = SignalState{ - signalURL, - true, - } + d.signalState = true } // GetFullStatus gets full status @@ -238,9 +243,15 @@ func (d *Status) GetFullStatus() FullStatus { defer d.mux.Unlock() fullStatus := FullStatus{ - ManagementState: d.management, - SignalState: d.signal, - LocalPeerState: d.localPeer, + ManagementState: ManagementState{ + d.mgmAddress, + d.managementState, + }, + SignalState: SignalState{ + d.signalAddress, + d.signalState, + }, + LocalPeerState: d.localPeer, } for _, status := range d.peers { diff --git a/client/internal/peer/status_test.go b/client/internal/peer/status_test.go index a23620492..86ee933f6 100644 --- a/client/internal/peer/status_test.go +++ b/client/internal/peer/status_test.go @@ -8,7 +8,7 @@ import ( func TestAddPeer(t *testing.T) { key := "abc" - status := NewRecorder() + status := NewRecorder("https://mgm") err := status.AddPeer(key) assert.NoError(t, err, "shouldn't return error") @@ -22,7 +22,7 @@ func TestAddPeer(t *testing.T) { func TestGetPeer(t *testing.T) { key := "abc" - status := NewRecorder() + status := NewRecorder("https://mgm") err := status.AddPeer(key) assert.NoError(t, err, "shouldn't return error") @@ -38,7 +38,7 @@ func TestGetPeer(t *testing.T) { func TestUpdatePeerState(t *testing.T) { key := "abc" ip := "10.10.10.10" - status := NewRecorder() + status := NewRecorder("https://mgm") peerState := State{ PubKey: key, } @@ -58,7 +58,7 @@ func TestUpdatePeerState(t *testing.T) { func TestStatus_UpdatePeerFQDN(t *testing.T) { key := "abc" fqdn := "peer-a.netbird.local" - status := NewRecorder() + status := NewRecorder("https://mgm") peerState := State{ PubKey: key, } @@ -76,7 +76,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) { func TestGetPeerStateChangeNotifierLogic(t *testing.T) { key := "abc" ip := "10.10.10.10" - status := NewRecorder() + status := NewRecorder("https://mgm") peerState := State{ PubKey: key, } @@ -100,7 +100,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) { func TestRemovePeer(t *testing.T) { key := "abc" - status := NewRecorder() + status := NewRecorder("https://mgm") peerState := State{ PubKey: key, } @@ -123,7 +123,7 @@ func TestUpdateLocalPeerState(t *testing.T) { PubKey: "abc", KernelInterface: false, } - status := NewRecorder() + status := NewRecorder("https://mgm") status.UpdateLocalPeerState(localPeerState) @@ -137,7 +137,7 @@ func TestCleanLocalPeerState(t *testing.T) { PubKey: "abc", KernelInterface: false, } - status := NewRecorder() + status := NewRecorder("https://mgm") status.localPeer = localPeerState @@ -151,29 +151,23 @@ func TestUpdateSignalState(t *testing.T) { var tests = []struct { name string connected bool - want SignalState + want bool }{ - {"should mark as connected", true, SignalState{ - - URL: url, - Connected: true, - }}, - {"should mark as disconnected", false, SignalState{ - URL: url, - Connected: false, - }}, + {"should mark as connected", true, true}, + {"should mark as disconnected", false, false}, } - status := NewRecorder() + status := NewRecorder("https://mgm") + status.UpdateSignalAddress(url) for _, test := range tests { t.Run(test.name, func(t *testing.T) { if test.connected { - status.MarkSignalConnected(url) + status.MarkSignalConnected() } else { - status.MarkSignalDisconnected(url) + status.MarkSignalDisconnected() } - assert.Equal(t, test.want, status.signal, "signal status should be equal") + assert.Equal(t, test.want, status.signalState, "signal status should be equal") }) } } @@ -183,29 +177,22 @@ func TestUpdateManagementState(t *testing.T) { var tests = []struct { name string connected bool - want ManagementState + want bool }{ - {"should mark as connected", true, ManagementState{ - - URL: url, - Connected: true, - }}, - {"should mark as disconnected", false, ManagementState{ - URL: url, - Connected: false, - }}, + {"should mark as connected", true, true}, + {"should mark as disconnected", false, false}, } - status := NewRecorder() + status := NewRecorder(url) for _, test := range tests { t.Run(test.name, func(t *testing.T) { if test.connected { - status.MarkManagementConnected(url) + status.MarkManagementConnected() } else { - status.MarkManagementDisconnected(url) + status.MarkManagementDisconnected() } - assert.Equal(t, test.want, status.management, "signal status should be equal") + assert.Equal(t, test.want, status.managementState, "signalState status should be equal") }) } } @@ -213,12 +200,13 @@ func TestUpdateManagementState(t *testing.T) { func TestGetFullStatus(t *testing.T) { key1 := "abc" key2 := "def" + signalAddr := "https://signal" managementState := ManagementState{ - URL: "https://signal", + URL: "https://mgm", Connected: true, } signalState := SignalState{ - URL: "https://signal", + URL: signalAddr, Connected: true, } peerState1 := State{ @@ -229,10 +217,11 @@ func TestGetFullStatus(t *testing.T) { PubKey: key2, } - status := NewRecorder() + status := NewRecorder("https://mgm") + status.UpdateSignalAddress(signalAddr) - status.management = managementState - status.signal = signalState + status.managementState = managementState.Connected + status.signalState = signalState.Connected status.peers[key1] = peerState1 status.peers[key2] = peerState2 diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 570b6e1a3..9d3b5f5ff 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -398,7 +398,7 @@ func TestManagerUpdateRoutes(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") - statusRecorder := peer.NewRecorder() + statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder) defer routeManager.Stop() diff --git a/client/server/server.go b/client/server/server.go index 2531dc97e..e1adc0239 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -96,7 +96,9 @@ func (s *Server) Start() error { s.config = config if s.statusRecorder == nil { - s.statusRecorder = peer.NewRecorder() + s.statusRecorder = peer.NewRecorder(config.ManagementURL.String()) + } else { + s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String()) } go func() { @@ -386,7 +388,9 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes } if s.statusRecorder == nil { - s.statusRecorder = peer.NewRecorder() + s.statusRecorder = peer.NewRecorder(s.config.ManagementURL.String()) + } else { + s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) } go func() { @@ -430,7 +434,9 @@ func (s *Server) Status( statusResponse := proto.StatusResponse{Status: string(status), DaemonVersion: version.NetbirdVersion()} if s.statusRecorder == nil { - s.statusRecorder = peer.NewRecorder() + s.statusRecorder = peer.NewRecorder(s.config.ManagementURL.String()) + } else { + s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) } if msg.GetFullPeerStatus { diff --git a/management/client/grpc.go b/management/client/grpc.go index 0da261b98..ad66aa3d6 100644 --- a/management/client/grpc.go +++ b/management/client/grpc.go @@ -4,15 +4,13 @@ import ( "context" "crypto/tls" "fmt" - "google.golang.org/grpc/codes" - gstatus "google.golang.org/grpc/status" "io" + "sync" "time" - "github.com/cenkalti/backoff/v4" - "github.com/netbirdio/netbird/client/system" - "github.com/netbirdio/netbird/encryption" - "github.com/netbirdio/netbird/management/proto" + "google.golang.org/grpc/codes" + gstatus "google.golang.org/grpc/status" + log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" @@ -20,13 +18,26 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" + + "github.com/cenkalti/backoff/v4" + "github.com/netbirdio/netbird/client/system" + "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/management/proto" ) +// ConnStateNotifier is a wrapper interface of the status recorders +type ConnStateNotifier interface { + MarkManagementDisconnected() + MarkManagementConnected() +} + type GrpcClient struct { - key wgtypes.Key - realClient proto.ManagementServiceClient - ctx context.Context - conn *grpc.ClientConn + key wgtypes.Key + realClient proto.ManagementServiceClient + ctx context.Context + conn *grpc.ClientConn + connStateCallback ConnStateNotifier + connStateCallbackLock sync.RWMutex } // NewClient creates a new client to Management service @@ -56,10 +67,11 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE realClient := proto.NewManagementServiceClient(conn) return &GrpcClient{ - key: ourPrivateKey, - realClient: realClient, - ctx: ctx, - conn: conn, + key: ourPrivateKey, + realClient: realClient, + ctx: ctx, + conn: conn, + connStateCallbackLock: sync.RWMutex{}, }, nil } @@ -68,6 +80,13 @@ func (c *GrpcClient) Close() error { return c.conn.Close() } +// SetConnStateListener set the ConnStateNotifier +func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) { + c.connStateCallbackLock.Lock() + defer c.connStateCallbackLock.Unlock() + c.connStateCallback = notifier +} + // defaultBackoff is a basic backoff mechanism for general issues func defaultBackoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(&backoff.ExponentialBackOff{ @@ -121,7 +140,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error } log.Infof("connected to the Management Service stream") - + c.notifyConnected() // blocking until error err = c.receiveEvents(stream, *serverPubKey, msgHandler) if err != nil { @@ -131,6 +150,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error // we need this reset because after a successful connection and a consequent error, backoff lib doesn't // reset times and next try will start with a long delay backOff.Reset() + c.notifyDisconnected() log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err) return err } @@ -298,6 +318,26 @@ func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D return flowInfoResp, nil } +func (c *GrpcClient) notifyDisconnected() { + c.connStateCallbackLock.RLock() + defer c.connStateCallbackLock.RUnlock() + + if c.connStateCallback == nil { + return + } + c.connStateCallback.MarkManagementDisconnected() +} + +func (c *GrpcClient) notifyConnected() { + c.connStateCallbackLock.RLock() + defer c.connStateCallbackLock.RUnlock() + + if c.connStateCallback == nil { + return + } + c.connStateCallback.MarkManagementConnected() +} + func infoToMetaData(info *system.Info) *proto.PeerSystemMeta { if info == nil { return nil diff --git a/signal/client/grpc.go b/signal/client/grpc.go index 36b4833fb..5b0ae39ef 100644 --- a/signal/client/grpc.go +++ b/signal/client/grpc.go @@ -22,6 +22,12 @@ import ( "time" ) +// ConnStateNotifier is a wrapper interface of the status recorder +type ConnStateNotifier interface { + MarkSignalDisconnected() + MarkSignalConnected() +} + // GrpcClient Wraps the Signal Exchange Service gRpc client type GrpcClient struct { key wgtypes.Key @@ -34,6 +40,9 @@ type GrpcClient struct { mux sync.Mutex // StreamConnected indicates whether this client is StreamConnected to the Signal stream status Status + + connStateCallback ConnStateNotifier + connStateCallbackLock sync.RWMutex } func (c *GrpcClient) StreamConnected() bool { @@ -78,15 +87,23 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo log.Debugf("connected to Signal Service: %v", conn.Target()) return &GrpcClient{ - realClient: proto.NewSignalExchangeClient(conn), - ctx: ctx, - signalConn: conn, - key: key, - mux: sync.Mutex{}, - status: StreamDisconnected, + realClient: proto.NewSignalExchangeClient(conn), + ctx: ctx, + signalConn: conn, + key: key, + mux: sync.Mutex{}, + status: StreamDisconnected, + connStateCallbackLock: sync.RWMutex{}, }, nil } +// SetConnStateListener set the ConnStateNotifier +func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) { + c.connStateCallbackLock.Lock() + defer c.connStateCallbackLock.Unlock() + c.connStateCallback = notifier +} + // defaultBackoff is a basic backoff mechanism for general issues func defaultBackoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(&backoff.ExponentialBackOff{ @@ -134,13 +151,14 @@ func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error { c.notifyStreamConnected() log.Infof("connected to the Signal Service stream") - + c.notifyConnected() // start receiving messages from the Signal stream (from other peers through signal) err = c.receive(stream, msgHandler) if err != nil { // we need this reset because after a successful connection and a consequent error, backoff lib doesn't // reset times and next try will start with a long delay backOff.Reset() + c.notifyDisconnected() log.Warnf("disconnected from the Signal service but will retry silently. Reason: %v", err) return err } @@ -341,3 +359,23 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient, } } } + +func (c *GrpcClient) notifyDisconnected() { + c.connStateCallbackLock.RLock() + defer c.connStateCallbackLock.RUnlock() + + if c.connStateCallback == nil { + return + } + c.connStateCallback.MarkSignalDisconnected() +} + +func (c *GrpcClient) notifyConnected() { + c.connStateCallbackLock.RLock() + defer c.connStateCallbackLock.RUnlock() + + if c.connStateCallback == nil { + return + } + c.connStateCallback.MarkSignalConnected() +}