[signal] fix goroutines and memory leak on forward messages between peers (#3896)

This commit is contained in:
Vlad
2025-08-27 18:30:49 +02:00
committed by GitHub
parent 7ce5507c05
commit 99bd34c02a
7 changed files with 402 additions and 126 deletions

View File

@@ -2,7 +2,9 @@ package server
import (
"context"
"errors"
"fmt"
"os"
"time"
log "github.com/sirupsen/logrus"
@@ -15,9 +17,9 @@ import (
"github.com/netbirdio/signal-dispatcher/dispatcher"
"github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/signal/peer"
"github.com/netbirdio/netbird/shared/signal/proto"
)
const (
@@ -27,6 +29,8 @@ const (
labelTypeNotRegistered = "not_registered"
labelTypeStream = "stream"
labelTypeMessage = "message"
labelTypeTimeout = "timeout"
labelTypeDisconnected = "disconnected"
labelError = "error"
labelErrorMissingId = "missing_id"
@@ -37,6 +41,12 @@ const (
labelRegistrationStatus = "status"
labelRegistrationFound = "found"
labelRegistrationNotFound = "not_found"
sendTimeout = 10 * time.Second
)
var (
ErrPeerRegisteredAgain = errors.New("peer registered again")
)
// Server an instance of a Signal server
@@ -45,6 +55,10 @@ type Server struct {
proto.UnimplementedSignalExchangeServer
dispatcher *dispatcher.Dispatcher
metrics *metrics.AppMetrics
successHeader metadata.MD
sendTimeout time.Duration
}
// NewServer creates a new Signal server
@@ -59,10 +73,19 @@ func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) {
return nil, fmt.Errorf("creating dispatcher: %v", err)
}
sTimeout := sendTimeout
to := os.Getenv("NB_SIGNAL_SEND_TIMEOUT")
if parsed, err := time.ParseDuration(to); err == nil && parsed > 0 {
log.Trace("using custom send timeout ", parsed)
sTimeout = parsed
}
s := &Server{
dispatcher: d,
registry: peer.NewRegistry(appMetrics),
metrics: appMetrics,
dispatcher: d,
registry: peer.NewRegistry(appMetrics),
metrics: appMetrics,
successHeader: metadata.Pairs(proto.HeaderRegistered, "1"),
sendTimeout: sTimeout,
}
return s, nil
@@ -82,7 +105,8 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.
// ConnectStream connects to the exchange stream
func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) error {
p, err := s.RegisterPeer(stream)
ctx, cancel := context.WithCancel(context.Background())
p, err := s.RegisterPeer(stream, cancel)
if err != nil {
return err
}
@@ -90,8 +114,7 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer)
defer s.DeregisterPeer(p)
// needed to confirm that the peer has been registered so that the client can proceed
header := metadata.Pairs(proto.HeaderRegistered, "1")
err = stream.SendHeader(header)
err = stream.SendHeader(s.successHeader)
if err != nil {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorFailedHeader)))
return err
@@ -99,27 +122,27 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer)
log.Debugf("peer connected [%s] [streamID %d] ", p.Id, p.StreamID)
<-stream.Context().Done()
log.Debugf("peer stream closing [%s] [streamID %d] ", p.Id, p.StreamID)
return nil
select {
case <-stream.Context().Done():
log.Debugf("peer stream closing [%s] [streamID %d] ", p.Id, p.StreamID)
return nil
case <-ctx.Done():
return ErrPeerRegisteredAgain
}
}
func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) {
func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) (*peer.Peer, error) {
log.Debugf("registering new peer")
meta, hasMeta := metadata.FromIncomingContext(stream.Context())
if !hasMeta {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingMeta)))
return nil, status.Errorf(codes.FailedPrecondition, "missing connection stream meta")
}
id, found := meta[proto.HeaderId]
if !found {
id := metadata.ValueFromIncomingContext(stream.Context(), proto.HeaderId)
if id == nil {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId)))
return nil, status.Errorf(codes.FailedPrecondition, "missing connection header: %s", proto.HeaderId)
}
p := peer.NewPeer(id[0], stream)
s.registry.Register(p)
p := peer.NewPeer(id[0], stream, cancel)
if err := s.registry.Register(p); err != nil {
return nil, err
}
err := s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer)
if err != nil {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorFailedRegistration)))
@@ -131,8 +154,8 @@ func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (
func (s *Server) DeregisterPeer(p *peer.Peer) {
log.Debugf("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID)
s.registry.Deregister(p)
s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds()))
s.registry.Deregister(p)
}
func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) {
@@ -145,7 +168,7 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM
if !found {
s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound)))
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected)))
log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey)
log.Tracef("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey)
// todo respond to the sender?
return
}
@@ -153,16 +176,34 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM
s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound)))
start := time.Now()
// forward the message to the target peer
if err := dstPeer.Stream.Send(msg); err != nil {
log.Tracef("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err)
// todo respond to the sender?
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
return
}
sendResultChan := make(chan error, 1)
go func() {
select {
case sendResultChan <- dstPeer.Stream.Send(msg):
return
case <-dstPeer.Stream.Context().Done():
return
}
}()
// in milliseconds
s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream)))
s.metrics.MessagesForwarded.Add(ctx, 1)
s.metrics.MessageSize.Record(ctx, int64(gproto.Size(msg)), metric.WithAttributes(attribute.String(labelType, labelTypeMessage)))
select {
case err := <-sendResultChan:
if err != nil {
log.Tracef("error while forwarding message from peer [%s] to peer [%s]: %v", msg.Key, msg.RemoteKey, err)
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
return
}
s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream)))
s.metrics.MessagesForwarded.Add(ctx, 1)
s.metrics.MessageSize.Record(ctx, int64(gproto.Size(msg)), metric.WithAttributes(attribute.String(labelType, labelTypeMessage)))
case <-dstPeer.Stream.Context().Done():
log.Tracef("failed to forward message from peer [%s] to peer [%s]: destination peer disconnected", msg.Key, msg.RemoteKey)
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeDisconnected)))
case <-time.After(s.sendTimeout):
dstPeer.Cancel() // cancel the peer context to trigger deregistration
log.Tracef("failed to forward message from peer [%s] to peer [%s]: send timeout", msg.Key, msg.RemoteKey)
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeTimeout)))
}
}