Files
netbird/signal/server/signal.go
Pascal Fischer 6ce67f383e add logging
2025-10-09 10:14:44 +02:00

279 lines
10 KiB
Go

package server
import (
"context"
"errors"
"fmt"
"os"
"time"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
gproto "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/netbirdio/signal-dispatcher/dispatcher"
"github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/signal/peer"
)
var ErrPeerNotConnected = errors.New("peer not connected")
const (
labelType = "type"
labelTypeError = "error"
labelTypeNotConnected = "not_connected"
labelTypeNotRegistered = "not_registered"
labelTypeStream = "stream"
labelTypeMessage = "message"
labelTypeTimeout = "timeout"
labelTypeDisconnected = "disconnected"
labelError = "error"
labelErrorMissingId = "missing_id"
labelErrorMissingMeta = "missing_meta"
labelErrorFailedHeader = "failed_header"
labelErrorFailedRegistration = "failed_registration"
labelRegistrationStatus = "status"
labelRegistrationFound = "found"
labelRegistrationNotFound = "not_found"
sendTimeout = 10 * time.Second
)
var (
ErrPeerRegisteredAgain = errors.New("peer registered again")
)
type Options struct {
// Disable SendWithDeliveryCheck method
DisableSendWithDeliveryCheck bool
}
// Server an instance of a Signal server
type Server struct {
metrics *metrics.AppMetrics
disableSendWithDeliveryCheck bool
registry *peer.Registry
proto.UnimplementedSignalExchangeServer
dispatcher *dispatcher.Dispatcher
successHeader metadata.MD
sendTimeout time.Duration
directSendDisabled bool
}
// NewServer creates a new Signal server
func NewServer(ctx context.Context, meter metric.Meter, opts *Options) (*Server, error) {
if opts == nil {
opts = &Options{}
}
appMetrics, err := metrics.NewAppMetrics(meter)
if err != nil {
return nil, fmt.Errorf("creating app metrics: %v", err)
}
d, err := dispatcher.NewDispatcher(ctx, meter)
if err != nil {
return nil, fmt.Errorf("creating dispatcher: %v", err)
}
sTimeout := sendTimeout
to := os.Getenv("NB_SIGNAL_SEND_TIMEOUT")
if parsed, err := time.ParseDuration(to); err == nil && parsed > 0 {
log.Trace("using custom send timeout ", parsed)
sTimeout = parsed
}
s := &Server{
dispatcher: d,
registry: peer.NewRegistry(appMetrics),
metrics: appMetrics,
successHeader: metadata.Pairs(proto.HeaderRegistered, "1"),
sendTimeout: sTimeout,
disableSendWithDeliveryCheck: opts.DisableSendWithDeliveryCheck,
}
if directSendDisabled := os.Getenv("NB_SIGNAL_DIRECT_SEND_DISABLED"); directSendDisabled == "true" {
s.directSendDisabled = true
log.Warn("direct send to connected peers is disabled")
}
if opts.DisableSendWithDeliveryCheck {
log.Warn("SendWithDeliveryCheck method is disabled")
}
return s, nil
}
// Send forwards a message to the signal peer
func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.Tracef("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)
if _, found := s.registry.Get(msg.RemoteKey); found && !s.directSendDisabled {
_ = s.forwardMessageToPeer(ctx, msg)
return &proto.EncryptedMessage{}, nil
}
if _, err := s.dispatcher.SendMessage(ctx, msg, false); err != nil {
log.Errorf("error sending message via dispatcher: %v", err)
}
return &proto.EncryptedMessage{}, nil
}
// SendWithDeliveryCheck forwards a message to the signal peer with error handling
// When the remote peer is not connected it returns codes.NotFound error, otherwise it returns other types of errors
// that can be retried. In case codes.NotFound is returned the caller should not retry sending the message. The remote
// peer should send a new offer to re-establish the connection when it comes back online.
// Todo: double check the thread safe registry management. When both peer come online at the same time then both peers
// might not be registered yet when the first message is sent.
func (s *Server) SendWithDeliveryCheck(ctx context.Context, msg *proto.EncryptedMessage) (*emptypb.Empty, error) {
if s.disableSendWithDeliveryCheck {
log.Tracef("SendWithDeliveryCheck is disabled")
return nil, status.Errorf(codes.Unimplemented, "SendWithDeliveryCheck method is disabled")
}
log.Tracef("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)
if _, found := s.registry.Get(msg.RemoteKey); found && !s.directSendDisabled {
if err := s.forwardMessageToPeer(ctx, msg); err != nil {
if errors.Is(err, ErrPeerNotConnected) {
log.Tracef("remote peer [%s] not connected", msg.RemoteKey)
return nil, status.Errorf(codes.NotFound, "remote peer not connected")
}
log.Errorf("error sending message with delivery check to peer [%s]: %v", msg.RemoteKey, err)
return nil, status.Errorf(codes.Internal, "error forwarding message to peer: %v", err)
}
return &emptypb.Empty{}, nil
}
if _, err := s.dispatcher.SendMessage(ctx, msg, true); err != nil {
if errors.Is(err, dispatcher.ErrPeerNotConnected) {
log.Tracef("remote peer [%s] doesn't have a listener", msg.RemoteKey)
return nil, status.Errorf(codes.NotFound, "remote peer not connected")
}
return nil, err
}
return &emptypb.Empty{}, nil
}
// ConnectStream connects to the exchange stream
func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) error {
ctx, cancel := context.WithCancel(context.Background())
p, err := s.RegisterPeer(stream, cancel)
if err != nil {
log.Errorf("error registering peer: %v", err)
return err
}
defer s.DeregisterPeer(p)
// needed to confirm that the peer has been registered so that the client can proceed
err = stream.SendHeader(s.successHeader)
if err != nil {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorFailedHeader)))
log.Errorf("error sending registration header to peer [%s] [streamID %d] : %v", p.Id, p.StreamID, err)
return err
}
log.Debugf("peer connected [%s] [streamID %d] ", p.Id, p.StreamID)
select {
case <-stream.Context().Done():
log.Debugf("peer stream closing [%s] [streamID %d] ", p.Id, p.StreamID)
return nil
case <-ctx.Done():
return ErrPeerRegisteredAgain
}
}
func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) (*peer.Peer, error) {
log.Debugf("registering new peer")
id := metadata.ValueFromIncomingContext(stream.Context(), proto.HeaderId)
if id == nil {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId)))
return nil, status.Errorf(codes.FailedPrecondition, "missing connection header: %s", proto.HeaderId)
}
p := peer.NewPeer(id[0], stream, cancel)
if err := s.registry.Register(p); err != nil {
return nil, fmt.Errorf("error adding peer to registry peer: %w", err)
}
err := s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer)
if err != nil {
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorFailedRegistration)))
log.Errorf("error while registering message listener for peer [%s] %v", p.Id, err)
return nil, status.Errorf(codes.Internal, "error while registering message listener")
}
return p, nil
}
func (s *Server) DeregisterPeer(p *peer.Peer) {
log.Debugf("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID)
s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds()))
s.registry.Deregister(p)
}
func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) error {
log.Tracef("forwarding a new message from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)
getRegistrationStart := time.Now()
// lookup the target peer where the message is going to
dstPeer, found := s.registry.Get(msg.RemoteKey)
if !found {
s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound)))
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected)))
log.Tracef("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey)
// todo respond to the sender?
return ErrPeerNotConnected
}
s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound)))
start := time.Now()
sendResultChan := make(chan error, 1)
go func() {
select {
case sendResultChan <- dstPeer.Stream.Send(msg):
return
case <-dstPeer.Stream.Context().Done():
return
}
}()
select {
case err := <-sendResultChan:
if err != nil {
log.Tracef("error while forwarding message from peer [%s] to peer [%s]: %v", msg.Key, msg.RemoteKey, err)
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
return fmt.Errorf("error sending message to peer: %v", err)
}
s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream)))
s.metrics.MessagesForwarded.Add(ctx, 1)
s.metrics.MessageSize.Record(ctx, int64(gproto.Size(msg)), metric.WithAttributes(attribute.String(labelType, labelTypeMessage)))
case <-dstPeer.Stream.Context().Done():
log.Tracef("failed to forward message from peer [%s] to peer [%s]: destination peer disconnected", msg.Key, msg.RemoteKey)
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeDisconnected)))
return fmt.Errorf("destination peer disconnected")
case <-time.After(s.sendTimeout):
dstPeer.Cancel() // cancel the peer context to trigger deregistration
log.Tracef("failed to forward message from peer [%s] to peer [%s]: send timeout", msg.Key, msg.RemoteKey)
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeTimeout)))
return fmt.Errorf("sending message to peer timeout")
}
return nil
}