mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-15 23:06:38 +00:00
[signal] fix goroutines and memory leak on forward messages between peers (#3896)
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
// nolint:gosec
|
||||
_ "net/http/pprof"
|
||||
"strings"
|
||||
|
||||
@@ -5,10 +5,16 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"errors"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/signal/metrics"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
"github.com/netbirdio/netbird/signal/metrics"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPeerAlreadyRegistered = errors.New("peer already registered")
|
||||
)
|
||||
|
||||
// Peer representation of a connected Peer
|
||||
@@ -23,15 +29,18 @@ type Peer struct {
|
||||
|
||||
// registration time
|
||||
RegisteredAt time.Time
|
||||
|
||||
Cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewPeer creates a new instance of a connected Peer
|
||||
func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer) *Peer {
|
||||
func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) *Peer {
|
||||
return &Peer{
|
||||
Id: id,
|
||||
Stream: stream,
|
||||
StreamID: time.Now().UnixNano(),
|
||||
RegisteredAt: time.Now(),
|
||||
Cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,20 +78,24 @@ func (registry *Registry) IsPeerRegistered(peerId string) bool {
|
||||
}
|
||||
|
||||
// Register registers peer in the registry
|
||||
func (registry *Registry) Register(peer *Peer) {
|
||||
func (registry *Registry) Register(peer *Peer) error {
|
||||
start := time.Now()
|
||||
|
||||
registry.regMutex.Lock()
|
||||
defer registry.regMutex.Unlock()
|
||||
|
||||
// can be that peer already exists, but it is fine (e.g. reconnect)
|
||||
p, loaded := registry.Peers.LoadOrStore(peer.Id, peer)
|
||||
if loaded {
|
||||
pp := p.(*Peer)
|
||||
log.Tracef("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.",
|
||||
peer.Id, peer.StreamID, pp.StreamID)
|
||||
registry.Peers.Store(peer.Id, peer)
|
||||
return
|
||||
if peer.StreamID > pp.StreamID {
|
||||
log.Tracef("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.",
|
||||
peer.Id, peer.StreamID, pp.StreamID)
|
||||
if swapped := registry.Peers.CompareAndSwap(peer.Id, pp, peer); !swapped {
|
||||
return registry.Register(peer)
|
||||
}
|
||||
pp.Cancel()
|
||||
log.Debugf("peer re-registered [%s]", peer.Id)
|
||||
return nil
|
||||
}
|
||||
return ErrPeerAlreadyRegistered
|
||||
}
|
||||
|
||||
log.Debugf("peer registered [%s]", peer.Id)
|
||||
@@ -92,22 +105,13 @@ func (registry *Registry) Register(peer *Peer) {
|
||||
registry.metrics.RegistrationDelay.Record(context.Background(), float64(time.Since(start).Nanoseconds())/1e6)
|
||||
|
||||
registry.metrics.Registrations.Add(context.Background(), 1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deregister Peer from the Registry (usually once it disconnects)
|
||||
func (registry *Registry) Deregister(peer *Peer) {
|
||||
registry.regMutex.Lock()
|
||||
defer registry.regMutex.Unlock()
|
||||
|
||||
p, loaded := registry.Peers.LoadAndDelete(peer.Id)
|
||||
if loaded {
|
||||
pp := p.(*Peer)
|
||||
if peer.StreamID < pp.StreamID {
|
||||
registry.Peers.Store(peer.Id, p)
|
||||
log.Debugf("attempted to remove newer registered stream of a peer [%s] [newer streamID %d, previous StreamID %d]. Ignoring.",
|
||||
peer.Id, pp.StreamID, peer.StreamID)
|
||||
return
|
||||
}
|
||||
if deleted := registry.Peers.CompareAndDelete(peer.Id, peer); deleted {
|
||||
registry.metrics.ActivePeers.Add(context.Background(), -1)
|
||||
log.Debugf("peer deregistered [%s]", peer.Id)
|
||||
registry.metrics.Deregistrations.Add(context.Background(), 1)
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
"github.com/netbirdio/netbird/signal/metrics"
|
||||
)
|
||||
|
||||
@@ -19,12 +24,16 @@ func TestRegistry_ShouldNotDeregisterWhenHasNewerStreamRegistered(t *testing.T)
|
||||
|
||||
peerID := "peer"
|
||||
|
||||
olderPeer := NewPeer(peerID, nil)
|
||||
r.Register(olderPeer)
|
||||
_, cancel1 := context.WithCancel(context.Background())
|
||||
olderPeer := NewPeer(peerID, nil, cancel1)
|
||||
err = r.Register(olderPeer)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Nanosecond)
|
||||
|
||||
newerPeer := NewPeer(peerID, nil)
|
||||
r.Register(newerPeer)
|
||||
_, cancel2 := context.WithCancel(context.Background())
|
||||
newerPeer := NewPeer(peerID, nil, cancel2)
|
||||
err = r.Register(newerPeer)
|
||||
require.NoError(t, err)
|
||||
registered, _ := r.Get(olderPeer.Id)
|
||||
|
||||
assert.NotNil(t, registered, "peer can't be nil")
|
||||
@@ -59,10 +68,14 @@ func TestRegistry_Register(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
r := NewRegistry(metrics)
|
||||
peer1 := NewPeer("test_peer_1", nil)
|
||||
peer2 := NewPeer("test_peer_2", nil)
|
||||
r.Register(peer1)
|
||||
r.Register(peer2)
|
||||
_, cancel1 := context.WithCancel(context.Background())
|
||||
peer1 := NewPeer("test_peer_1", nil, cancel1)
|
||||
_, cancel2 := context.WithCancel(context.Background())
|
||||
peer2 := NewPeer("test_peer_2", nil, cancel2)
|
||||
err = r.Register(peer1)
|
||||
require.NoError(t, err)
|
||||
err = r.Register(peer2)
|
||||
require.NoError(t, err)
|
||||
|
||||
if _, ok := r.Get("test_peer_1"); !ok {
|
||||
t.Errorf("expected test_peer_1 not found in the registry")
|
||||
@@ -78,10 +91,14 @@ func TestRegistry_Deregister(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
r := NewRegistry(metrics)
|
||||
peer1 := NewPeer("test_peer_1", nil)
|
||||
peer2 := NewPeer("test_peer_2", nil)
|
||||
r.Register(peer1)
|
||||
r.Register(peer2)
|
||||
_, cancel1 := context.WithCancel(context.Background())
|
||||
peer1 := NewPeer("test_peer_1", nil, cancel1)
|
||||
_, cancel2 := context.WithCancel(context.Background())
|
||||
peer2 := NewPeer("test_peer_2", nil, cancel2)
|
||||
err = r.Register(peer1)
|
||||
require.NoError(t, err)
|
||||
err = r.Register(peer2)
|
||||
require.NoError(t, err)
|
||||
|
||||
r.Deregister(peer1)
|
||||
|
||||
@@ -94,3 +111,213 @@ func TestRegistry_Deregister(t *testing.T) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestRegistry_MultipleRegister_Concurrency(t *testing.T) {
|
||||
|
||||
metrics, err := metrics.NewAppMetrics(otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
registry := NewRegistry(metrics)
|
||||
|
||||
numGoroutines := 1000
|
||||
|
||||
ids := make(chan int64, numGoroutines)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
peerID := "peer-concurrent"
|
||||
for i := range numGoroutines {
|
||||
go func(routineIndex int) {
|
||||
defer wg.Done()
|
||||
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
peer := NewPeer(peerID, nil, cancel)
|
||||
_ = registry.Register(peer)
|
||||
ids <- peer.StreamID
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(ids)
|
||||
maxId := int64(0)
|
||||
for id := range ids {
|
||||
maxId = max(maxId, id)
|
||||
}
|
||||
|
||||
peer, ok := registry.Get(peerID)
|
||||
require.True(t, ok, "expected peer to be registered")
|
||||
require.Equal(t, maxId, peer.StreamID, "expected the highest StreamID to be registered")
|
||||
}
|
||||
|
||||
func Benchmark_MultipleRegister_Concurrency(b *testing.B) {
|
||||
|
||||
metrics, err := metrics.NewAppMetrics(otel.Meter(""))
|
||||
require.NoError(b, err)
|
||||
|
||||
numGoroutines := 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
peerID := "peer-concurrent"
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
b.Run("multiple-register", func(b *testing.B) {
|
||||
registry := NewRegistry(metrics)
|
||||
b.ResetTimer()
|
||||
for j := 0; j < b.N; j++ {
|
||||
wg.Add(numGoroutines)
|
||||
for i := range numGoroutines {
|
||||
go func(routineIndex int) {
|
||||
defer wg.Done()
|
||||
|
||||
peer := NewPeer(peerID, nil, cancel)
|
||||
_ = registry.Register(peer)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegistry_MultipleDeregister_Concurrency(t *testing.T) {
|
||||
|
||||
metrics, err := metrics.NewAppMetrics(otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
registry := NewRegistry(metrics)
|
||||
|
||||
numGoroutines := 1000
|
||||
|
||||
ids := make(chan int64, numGoroutines)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
peerID := "peer-concurrent"
|
||||
for i := range numGoroutines {
|
||||
go func(routineIndex int) {
|
||||
defer wg.Done()
|
||||
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
peer := NewPeer(peerID, nil, cancel)
|
||||
_ = registry.Register(peer)
|
||||
ids <- peer.StreamID
|
||||
registry.Deregister(peer)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(ids)
|
||||
maxId := int64(0)
|
||||
for id := range ids {
|
||||
maxId = max(maxId, id)
|
||||
}
|
||||
|
||||
_, ok := registry.Get(peerID)
|
||||
require.False(t, ok, "expected peer to be deregistered")
|
||||
}
|
||||
|
||||
func Benchmark_MultipleDeregister_Concurrency(b *testing.B) {
|
||||
|
||||
metrics, err := metrics.NewAppMetrics(otel.Meter(""))
|
||||
require.NoError(b, err)
|
||||
|
||||
numGoroutines := 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
peerID := "peer-concurrent"
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
b.Run("register-deregister", func(b *testing.B) {
|
||||
registry := NewRegistry(metrics)
|
||||
b.ResetTimer()
|
||||
for j := 0; j < b.N; j++ {
|
||||
wg.Add(numGoroutines)
|
||||
for i := range numGoroutines {
|
||||
go func(routineIndex int) {
|
||||
defer wg.Done()
|
||||
|
||||
peer := NewPeer(peerID, nil, cancel)
|
||||
_ = registry.Register(peer)
|
||||
time.Sleep(time.Nanosecond)
|
||||
registry.Deregister(peer)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type mockConnectStreamServer struct {
|
||||
grpc.ServerStream
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (m *mockConnectStreamServer) Context() context.Context {
|
||||
return m.ctx
|
||||
}
|
||||
|
||||
func (m *mockConnectStreamServer) SendHeader(md metadata.MD) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConnectStreamServer) Send(msg *proto.EncryptedMessage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConnectStreamServer) Recv() (*proto.EncryptedMessage, error) {
|
||||
<-m.ctx.Done()
|
||||
return nil, m.ctx.Err()
|
||||
}
|
||||
|
||||
func TestReconnectHandling(t *testing.T) {
|
||||
metrics, err := metrics.NewAppMetrics(otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
registry := NewRegistry(metrics)
|
||||
peerID := "test-peer-reconnect"
|
||||
|
||||
ctx1, cancel1 := context.WithCancel(context.Background())
|
||||
defer cancel1()
|
||||
stream1 := &mockConnectStreamServer{ctx: ctx1}
|
||||
peer1 := NewPeer(peerID, stream1, cancel1)
|
||||
|
||||
err = registry.Register(peer1)
|
||||
require.NoError(t, err, "first registration should succeed")
|
||||
|
||||
p, found := registry.Get(peerID)
|
||||
require.True(t, found, "peer should be found in the registry")
|
||||
require.Equal(t, peer1.StreamID, p.StreamID, "StreamID of registered peer should match")
|
||||
|
||||
time.Sleep(time.Nanosecond)
|
||||
ctx2, cancel2 := context.WithCancel(context.Background())
|
||||
defer cancel2()
|
||||
stream2 := &mockConnectStreamServer{ctx: ctx2}
|
||||
peer2 := NewPeer(peerID, stream2, cancel2)
|
||||
|
||||
err = registry.Register(peer2)
|
||||
require.NoError(t, err, "reconnect registration should succeed")
|
||||
|
||||
select {
|
||||
case <-ctx1.Done():
|
||||
require.ErrorIs(t, ctx1.Err(), context.Canceled, "context of old stream should be canceled after successful reconnection")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("context of old stream was not canceled after reconnection")
|
||||
}
|
||||
|
||||
p, found = registry.Get(peerID)
|
||||
require.True(t, found)
|
||||
require.Equal(t, peer2.StreamID, p.StreamID, "registered peer should have the new StreamID after reconnection")
|
||||
|
||||
ctx3, cancel3 := context.WithCancel(context.Background())
|
||||
defer cancel3()
|
||||
stream3 := &mockConnectStreamServer{ctx: ctx3}
|
||||
stalePeer := NewPeer(peerID, stream3, cancel3)
|
||||
stalePeer.StreamID = peer1.StreamID
|
||||
|
||||
err = registry.Register(stalePeer)
|
||||
require.ErrorIs(t, err, ErrPeerAlreadyRegistered, "reconnecting with an old StreamID should return an error")
|
||||
|
||||
p, found = registry.Get(peerID)
|
||||
require.True(t, found)
|
||||
require.Equal(t, peer2.StreamID, p.StreamID, "active peer should still be the one with the latest StreamID")
|
||||
|
||||
select {
|
||||
case <-ctx2.Done():
|
||||
t.Fatal("context of the new stream should not be canceled after trying to register with an old StreamID")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -15,9 +17,9 @@ import (
|
||||
|
||||
"github.com/netbirdio/signal-dispatcher/dispatcher"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
"github.com/netbirdio/netbird/signal/metrics"
|
||||
"github.com/netbirdio/netbird/signal/peer"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -27,6 +29,8 @@ const (
|
||||
labelTypeNotRegistered = "not_registered"
|
||||
labelTypeStream = "stream"
|
||||
labelTypeMessage = "message"
|
||||
labelTypeTimeout = "timeout"
|
||||
labelTypeDisconnected = "disconnected"
|
||||
|
||||
labelError = "error"
|
||||
labelErrorMissingId = "missing_id"
|
||||
@@ -37,6 +41,12 @@ const (
|
||||
labelRegistrationStatus = "status"
|
||||
labelRegistrationFound = "found"
|
||||
labelRegistrationNotFound = "not_found"
|
||||
|
||||
sendTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPeerRegisteredAgain = errors.New("peer registered again")
|
||||
)
|
||||
|
||||
// Server an instance of a Signal server
|
||||
@@ -45,6 +55,10 @@ type Server struct {
|
||||
proto.UnimplementedSignalExchangeServer
|
||||
dispatcher *dispatcher.Dispatcher
|
||||
metrics *metrics.AppMetrics
|
||||
|
||||
successHeader metadata.MD
|
||||
|
||||
sendTimeout time.Duration
|
||||
}
|
||||
|
||||
// NewServer creates a new Signal server
|
||||
@@ -59,10 +73,19 @@ func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) {
|
||||
return nil, fmt.Errorf("creating dispatcher: %v", err)
|
||||
}
|
||||
|
||||
sTimeout := sendTimeout
|
||||
to := os.Getenv("NB_SIGNAL_SEND_TIMEOUT")
|
||||
if parsed, err := time.ParseDuration(to); err == nil && parsed > 0 {
|
||||
log.Trace("using custom send timeout ", parsed)
|
||||
sTimeout = parsed
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
dispatcher: d,
|
||||
registry: peer.NewRegistry(appMetrics),
|
||||
metrics: appMetrics,
|
||||
dispatcher: d,
|
||||
registry: peer.NewRegistry(appMetrics),
|
||||
metrics: appMetrics,
|
||||
successHeader: metadata.Pairs(proto.HeaderRegistered, "1"),
|
||||
sendTimeout: sTimeout,
|
||||
}
|
||||
|
||||
return s, nil
|
||||
@@ -82,7 +105,8 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.
|
||||
|
||||
// ConnectStream connects to the exchange stream
|
||||
func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) error {
|
||||
p, err := s.RegisterPeer(stream)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
p, err := s.RegisterPeer(stream, cancel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -90,8 +114,7 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer)
|
||||
defer s.DeregisterPeer(p)
|
||||
|
||||
// needed to confirm that the peer has been registered so that the client can proceed
|
||||
header := metadata.Pairs(proto.HeaderRegistered, "1")
|
||||
err = stream.SendHeader(header)
|
||||
err = stream.SendHeader(s.successHeader)
|
||||
if err != nil {
|
||||
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorFailedHeader)))
|
||||
return err
|
||||
@@ -99,27 +122,27 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer)
|
||||
|
||||
log.Debugf("peer connected [%s] [streamID %d] ", p.Id, p.StreamID)
|
||||
|
||||
<-stream.Context().Done()
|
||||
log.Debugf("peer stream closing [%s] [streamID %d] ", p.Id, p.StreamID)
|
||||
return nil
|
||||
select {
|
||||
case <-stream.Context().Done():
|
||||
log.Debugf("peer stream closing [%s] [streamID %d] ", p.Id, p.StreamID)
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ErrPeerRegisteredAgain
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) {
|
||||
func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) (*peer.Peer, error) {
|
||||
log.Debugf("registering new peer")
|
||||
meta, hasMeta := metadata.FromIncomingContext(stream.Context())
|
||||
if !hasMeta {
|
||||
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingMeta)))
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "missing connection stream meta")
|
||||
}
|
||||
|
||||
id, found := meta[proto.HeaderId]
|
||||
if !found {
|
||||
id := metadata.ValueFromIncomingContext(stream.Context(), proto.HeaderId)
|
||||
if id == nil {
|
||||
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId)))
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "missing connection header: %s", proto.HeaderId)
|
||||
}
|
||||
|
||||
p := peer.NewPeer(id[0], stream)
|
||||
s.registry.Register(p)
|
||||
p := peer.NewPeer(id[0], stream, cancel)
|
||||
if err := s.registry.Register(p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err := s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer)
|
||||
if err != nil {
|
||||
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorFailedRegistration)))
|
||||
@@ -131,8 +154,8 @@ func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (
|
||||
|
||||
func (s *Server) DeregisterPeer(p *peer.Peer) {
|
||||
log.Debugf("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID)
|
||||
s.registry.Deregister(p)
|
||||
s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds()))
|
||||
s.registry.Deregister(p)
|
||||
}
|
||||
|
||||
func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) {
|
||||
@@ -145,7 +168,7 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM
|
||||
if !found {
|
||||
s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound)))
|
||||
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected)))
|
||||
log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey)
|
||||
log.Tracef("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey)
|
||||
// todo respond to the sender?
|
||||
return
|
||||
}
|
||||
@@ -153,16 +176,34 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM
|
||||
s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound)))
|
||||
start := time.Now()
|
||||
|
||||
// forward the message to the target peer
|
||||
if err := dstPeer.Stream.Send(msg); err != nil {
|
||||
log.Tracef("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err)
|
||||
// todo respond to the sender?
|
||||
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
|
||||
return
|
||||
}
|
||||
sendResultChan := make(chan error, 1)
|
||||
go func() {
|
||||
select {
|
||||
case sendResultChan <- dstPeer.Stream.Send(msg):
|
||||
return
|
||||
case <-dstPeer.Stream.Context().Done():
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
// in milliseconds
|
||||
s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream)))
|
||||
s.metrics.MessagesForwarded.Add(ctx, 1)
|
||||
s.metrics.MessageSize.Record(ctx, int64(gproto.Size(msg)), metric.WithAttributes(attribute.String(labelType, labelTypeMessage)))
|
||||
select {
|
||||
case err := <-sendResultChan:
|
||||
if err != nil {
|
||||
log.Tracef("error while forwarding message from peer [%s] to peer [%s]: %v", msg.Key, msg.RemoteKey, err)
|
||||
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
|
||||
return
|
||||
}
|
||||
s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream)))
|
||||
s.metrics.MessagesForwarded.Add(ctx, 1)
|
||||
s.metrics.MessageSize.Record(ctx, int64(gproto.Size(msg)), metric.WithAttributes(attribute.String(labelType, labelTypeMessage)))
|
||||
|
||||
case <-dstPeer.Stream.Context().Done():
|
||||
log.Tracef("failed to forward message from peer [%s] to peer [%s]: destination peer disconnected", msg.Key, msg.RemoteKey)
|
||||
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeDisconnected)))
|
||||
|
||||
case <-time.After(s.sendTimeout):
|
||||
dstPeer.Cancel() // cancel the peer context to trigger deregistration
|
||||
log.Tracef("failed to forward message from peer [%s] to peer [%s]: send timeout", msg.Key, msg.RemoteKey)
|
||||
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeTimeout)))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user