From aa55fba5eecfc7cb1daabdfc779e86709940e3d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Papp?= Date: Sat, 29 Jun 2024 14:13:05 +0200 Subject: [PATCH] Add client side heartbeat handling --- relay/client/client.go | 90 +++++++++++++++++++++++++---------- relay/healthcheck/receiver.go | 21 ++++---- 2 files changed, 77 insertions(+), 34 deletions(-) diff --git a/relay/client/client.go b/relay/client/client.go index b5d4392b3..5d5116477 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -11,6 +11,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/relay/client/dialer/ws" + "github.com/netbirdio/netbird/relay/healthcheck" "github.com/netbirdio/netbird/relay/messages" ) @@ -23,6 +24,27 @@ var ( ErrConnAlreadyExists = fmt.Errorf("connection already exists") ) +type internalStopFlag struct { + sync.Mutex + stop bool +} + +func newInternalStopFlag() *internalStopFlag { + return &internalStopFlag{} +} + +func (isf *internalStopFlag) set() { + isf.Lock() + defer isf.Unlock() + isf.stop = true +} + +func (isf *internalStopFlag) isSet() bool { + isf.Lock() + defer isf.Unlock() + return isf.stop +} + // Msg carry the payload from the server to the client. With this sturct, the net.Conn can free the buffer. type Msg struct { Payload []byte @@ -75,7 +97,6 @@ func (cc *connContainer) close() { type Client struct { log *log.Entry parentCtx context.Context - ctxCancel context.CancelFunc serverAddress string hashedID []byte @@ -84,7 +105,7 @@ type Client struct { relayConn net.Conn conns map[string]*connContainer serviceIsRunning bool - mu sync.Mutex + mu sync.Mutex // protect serviceIsRunning and conns readLoopMutex sync.Mutex wgReadLoop sync.WaitGroup @@ -100,7 +121,6 @@ func NewClient(ctx context.Context, serverAddress, peerID string) *Client { return &Client{ log: log.WithField("client_id", hashedStringId), parentCtx: ctx, - ctxCancel: func() {}, serverAddress: serverAddress, hashedID: hashedID, bufPool: &sync.Pool{ @@ -133,15 +153,6 @@ func (c *Client) Connect() error { c.serviceIsRunning = true - var ctx context.Context - ctx, c.ctxCancel = context.WithCancel(c.parentCtx) - context.AfterFunc(ctx, func() { - cErr := c.close(false) - if cErr != nil { - log.Errorf("failed to close relay connection: %s", cErr) - } - }) - c.wgReadLoop.Add(1) go c.readLoop(c.relayConn) @@ -200,7 +211,7 @@ func (c *Client) HasConns() bool { // Close closes the connection to the relay server and all connections to other peers. func (c *Client) Close() error { - return c.close(false) + return c.close(true) } func (c *Client) connect() error { @@ -257,10 +268,13 @@ func (c *Client) handShake() error { } func (c *Client) readLoop(relayConn net.Conn) { + internallyStoppedFlag := newInternalStopFlag() + hc := healthcheck.NewReceiver() + go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag) + var ( - errExit error - n int - closedByServer bool + errExit error + n int ) for { bufPtr := c.bufPool.Get().(*[]byte) @@ -268,7 +282,7 @@ func (c *Client) readLoop(relayConn net.Conn) { n, errExit = relayConn.Read(buf) if errExit != nil { c.mu.Lock() - if c.serviceIsRunning { + if c.serviceIsRunning && !internallyStoppedFlag.isSet() { c.log.Debugf("failed to read message from relay server: %s", errExit) } c.mu.Unlock() @@ -283,15 +297,19 @@ func (c *Client) readLoop(relayConn net.Conn) { switch msgType { case messages.MsgTypeHealthCheck: + log.Debugf("on new heartbeat") msg := messages.MarshalHealthcheck() - _, err := c.relayConn.Write(msg) - if err != nil { - c.log.Errorf("failed to send heartbeat response: %s", err) + _, wErr := c.relayConn.Write(msg) + if c.serviceIsRunning && !internallyStoppedFlag.isSet() { + c.log.Errorf("failed to send heartbeat: %s", wErr) } + hc.Heartbeat() case messages.MsgTypeTransport: peerID, payload, err := messages.UnmarshalTransportMsg(buf[:n]) if err != nil { - c.log.Errorf("failed to parse transport message: %v", err) + if c.serviceIsRunning && !internallyStoppedFlag.isSet() { + c.log.Errorf("failed to parse transport message: %v", err) + } continue } stringID := messages.HashIDToString(peerID) @@ -313,16 +331,16 @@ func (c *Client) readLoop(relayConn net.Conn) { bufPtr: bufPtr, Payload: payload}) case messages.MsgTypeClose: - closedByServer = true log.Debugf("relay connection close by server") goto Exit } } Exit: + hc.Stop() c.notifyDisconnected() c.wgReadLoop.Done() - _ = c.close(closedByServer) + _ = c.close(false) } // todo check by reference too, the id is not enought because the id come from the outer conn @@ -352,6 +370,27 @@ func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) { return n, err } +func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) { + for { + select { + case _, ok := <-hc.OnTimeout: + if !ok { + return + } + c.log.Errorf("health check timeout") + internalStopFlag.set() + _ = conn.Close() // ignore the err because the readLoop will handle it + return + case <-c.parentCtx.Done(): + err := c.close(true) + if err != nil { + log.Errorf("failed to teardown connection: %s", err) + } + return + } + } +} + func (c *Client) closeAllConns() { for _, container := range c.conns { container.close() @@ -374,7 +413,7 @@ func (c *Client) closeConn(id string) error { return nil } -func (c *Client) close(byServer bool) error { +func (c *Client) close(gracefullyExit bool) error { c.readLoopMutex.Lock() defer c.readLoopMutex.Unlock() @@ -387,7 +426,7 @@ func (c *Client) close(byServer bool) error { c.serviceIsRunning = false c.closeAllConns() - if !byServer { + if gracefullyExit { c.writeCloseMsg() err = c.relayConn.Close() } @@ -395,7 +434,6 @@ func (c *Client) close(byServer bool) error { c.wgReadLoop.Wait() c.log.Infof("relay connection closed with: %s", c.serverAddress) - c.ctxCancel() return err } diff --git a/relay/healthcheck/receiver.go b/relay/healthcheck/receiver.go index 147fe2d5f..e1ef17e0e 100644 --- a/relay/healthcheck/receiver.go +++ b/relay/healthcheck/receiver.go @@ -20,7 +20,7 @@ type Receiver struct { ctx context.Context ctxCancel context.CancelFunc heartbeat chan struct{} - live bool + alive bool } // NewReceiver creates a new healthcheck receiver and start the timer in the background @@ -60,19 +60,24 @@ func (r *Receiver) waitForHealthcheck() { for { select { case <-r.heartbeat: - r.live = true + r.alive = true case <-ticker.C: - if r.live { - r.live = false + if r.alive { + r.alive = false continue } - select { - case r.OnTimeout <- struct{}{}: - default: - } + + r.notifyTimeout() return case <-r.ctx.Done(): return } } } + +func (r *Receiver) notifyTimeout() { + select { + case r.OnTimeout <- struct{}{}: + default: + } +}