mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
279 lines
10 KiB
Go
279 lines
10 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"go.opentelemetry.io/otel/attribute"
|
|
"go.opentelemetry.io/otel/metric"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/metadata"
|
|
"google.golang.org/grpc/status"
|
|
gproto "google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/types/known/emptypb"
|
|
|
|
"github.com/netbirdio/signal-dispatcher/dispatcher"
|
|
|
|
"github.com/netbirdio/netbird/shared/signal/proto"
|
|
"github.com/netbirdio/netbird/signal/metrics"
|
|
"github.com/netbirdio/netbird/signal/peer"
|
|
)
|
|
|
|
var ErrPeerNotConnected = errors.New("peer not connected")
|
|
|
|
const (
|
|
labelType = "type"
|
|
labelTypeError = "error"
|
|
labelTypeNotConnected = "not_connected"
|
|
labelTypeNotRegistered = "not_registered"
|
|
labelTypeStream = "stream"
|
|
labelTypeMessage = "message"
|
|
labelTypeTimeout = "timeout"
|
|
labelTypeDisconnected = "disconnected"
|
|
|
|
labelError = "error"
|
|
labelErrorMissingId = "missing_id"
|
|
labelErrorMissingMeta = "missing_meta"
|
|
labelErrorFailedHeader = "failed_header"
|
|
labelErrorFailedRegistration = "failed_registration"
|
|
|
|
labelRegistrationStatus = "status"
|
|
labelRegistrationFound = "found"
|
|
labelRegistrationNotFound = "not_found"
|
|
|
|
sendTimeout = 10 * time.Second
|
|
)
|
|
|
|
var (
|
|
ErrPeerRegisteredAgain = errors.New("peer registered again")
|
|
)
|
|
|
|
type Options struct {
|
|
// Disable SendWithDeliveryCheck method
|
|
DisableSendWithDeliveryCheck bool
|
|
}
|
|
|
|
// Server an instance of a Signal server
|
|
type Server struct {
|
|
metrics *metrics.AppMetrics
|
|
disableSendWithDeliveryCheck bool
|
|
|
|
registry *peer.Registry
|
|
proto.UnimplementedSignalExchangeServer
|
|
dispatcher *dispatcher.Dispatcher
|
|
|
|
successHeader metadata.MD
|
|
|
|
sendTimeout time.Duration
|
|
directSendDisabled bool
|
|
}
|
|
|
|
// NewServer creates a new Signal server
|
|
func NewServer(ctx context.Context, meter metric.Meter, opts *Options) (*Server, error) {
|
|
if opts == nil {
|
|
opts = &Options{}
|
|
}
|
|
|
|
appMetrics, err := metrics.NewAppMetrics(meter)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("creating app metrics: %v", err)
|
|
}
|
|
|
|
d, err := dispatcher.NewDispatcher(ctx, meter)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("creating dispatcher: %v", err)
|
|
}
|
|
|
|
sTimeout := sendTimeout
|
|
to := os.Getenv("NB_SIGNAL_SEND_TIMEOUT")
|
|
if parsed, err := time.ParseDuration(to); err == nil && parsed > 0 {
|
|
log.Trace("using custom send timeout ", parsed)
|
|
sTimeout = parsed
|
|
}
|
|
|
|
s := &Server{
|
|
dispatcher: d,
|
|
registry: peer.NewRegistry(appMetrics),
|
|
metrics: appMetrics,
|
|
successHeader: metadata.Pairs(proto.HeaderRegistered, "1"),
|
|
sendTimeout: sTimeout,
|
|
disableSendWithDeliveryCheck: opts.DisableSendWithDeliveryCheck,
|
|
}
|
|
|
|
if directSendDisabled := os.Getenv("NB_SIGNAL_DIRECT_SEND_DISABLED"); directSendDisabled == "true" {
|
|
s.directSendDisabled = true
|
|
log.Warn("direct send to connected peers is disabled")
|
|
}
|
|
|
|
if opts.DisableSendWithDeliveryCheck {
|
|
log.Warn("SendWithDeliveryCheck method is disabled")
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
// Send forwards a message to the signal peer
|
|
func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
|
log.Tracef("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)
|
|
|
|
if _, found := s.registry.Get(msg.RemoteKey); found && !s.directSendDisabled {
|
|
_ = s.forwardMessageToPeer(ctx, msg)
|
|
return &proto.EncryptedMessage{}, nil
|
|
}
|
|
|
|
if _, err := s.dispatcher.SendMessage(ctx, msg, false); err != nil {
|
|
log.Errorf("error sending message via dispatcher: %v", err)
|
|
}
|
|
return &proto.EncryptedMessage{}, nil
|
|
}
|
|
|
|
// SendWithDeliveryCheck forwards a message to the signal peer with error handling
|
|
// When the remote peer is not connected it returns codes.NotFound error, otherwise it returns other types of errors
|
|
// that can be retried. In case codes.NotFound is returned the caller should not retry sending the message. The remote
|
|
// peer should send a new offer to re-establish the connection when it comes back online.
|
|
// Todo: double check the thread safe registry management. When both peer come online at the same time then both peers
|
|
// might not be registered yet when the first message is sent.
|
|
func (s *Server) SendWithDeliveryCheck(ctx context.Context, msg *proto.EncryptedMessage) (*emptypb.Empty, error) {
|
|
if s.disableSendWithDeliveryCheck {
|
|
log.Tracef("SendWithDeliveryCheck is disabled")
|
|
return nil, status.Errorf(codes.Unimplemented, "SendWithDeliveryCheck method is disabled")
|
|
}
|
|
|
|
log.Tracef("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)
|
|
if _, found := s.registry.Get(msg.RemoteKey); found && !s.directSendDisabled {
|
|
if err := s.forwardMessageToPeer(ctx, msg); err != nil {
|
|
if errors.Is(err, ErrPeerNotConnected) {
|
|
log.Tracef("remote peer [%s] not connected", msg.RemoteKey)
|
|
return nil, status.Errorf(codes.NotFound, "remote peer not connected")
|
|
}
|
|
log.Errorf("error sending message with delivery check to peer [%s]: %v", msg.RemoteKey, err)
|
|
return nil, status.Errorf(codes.Internal, "error forwarding message to peer: %v", err)
|
|
}
|
|
return &emptypb.Empty{}, nil
|
|
}
|
|
|
|
if _, err := s.dispatcher.SendMessage(ctx, msg, true); err != nil {
|
|
if errors.Is(err, dispatcher.ErrPeerNotConnected) {
|
|
log.Tracef("remote peer [%s] doesn't have a listener", msg.RemoteKey)
|
|
return nil, status.Errorf(codes.NotFound, "remote peer not connected")
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
return &emptypb.Empty{}, nil
|
|
}
|
|
|
|
// ConnectStream connects to the exchange stream
|
|
func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) error {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
p, err := s.RegisterPeer(stream, cancel)
|
|
if err != nil {
|
|
log.Errorf("error registering peer: %v", err)
|
|
return err
|
|
}
|
|
|
|
defer s.DeregisterPeer(p)
|
|
|
|
// needed to confirm that the peer has been registered so that the client can proceed
|
|
err = stream.SendHeader(s.successHeader)
|
|
if err != nil {
|
|
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorFailedHeader)))
|
|
log.Errorf("error sending registration header to peer [%s] [streamID %d] : %v", p.Id, p.StreamID, err)
|
|
return err
|
|
}
|
|
|
|
log.Debugf("peer connected [%s] [streamID %d] ", p.Id, p.StreamID)
|
|
|
|
select {
|
|
case <-stream.Context().Done():
|
|
log.Debugf("peer stream closing [%s] [streamID %d] ", p.Id, p.StreamID)
|
|
return nil
|
|
case <-ctx.Done():
|
|
return ErrPeerRegisteredAgain
|
|
}
|
|
}
|
|
|
|
func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) (*peer.Peer, error) {
|
|
log.Debugf("registering new peer")
|
|
id := metadata.ValueFromIncomingContext(stream.Context(), proto.HeaderId)
|
|
if id == nil {
|
|
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId)))
|
|
return nil, status.Errorf(codes.FailedPrecondition, "missing connection header: %s", proto.HeaderId)
|
|
}
|
|
|
|
p := peer.NewPeer(id[0], stream, cancel)
|
|
if err := s.registry.Register(p); err != nil {
|
|
return nil, fmt.Errorf("error adding peer to registry peer: %w", err)
|
|
}
|
|
err := s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer)
|
|
if err != nil {
|
|
s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorFailedRegistration)))
|
|
log.Errorf("error while registering message listener for peer [%s] %v", p.Id, err)
|
|
return nil, status.Errorf(codes.Internal, "error while registering message listener")
|
|
}
|
|
return p, nil
|
|
}
|
|
|
|
func (s *Server) DeregisterPeer(p *peer.Peer) {
|
|
log.Debugf("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID)
|
|
s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds()))
|
|
s.registry.Deregister(p)
|
|
}
|
|
|
|
func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) error {
|
|
log.Tracef("forwarding a new message from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)
|
|
getRegistrationStart := time.Now()
|
|
|
|
// lookup the target peer where the message is going to
|
|
dstPeer, found := s.registry.Get(msg.RemoteKey)
|
|
|
|
if !found {
|
|
s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound)))
|
|
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected)))
|
|
log.Tracef("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey)
|
|
// todo respond to the sender?
|
|
return ErrPeerNotConnected
|
|
}
|
|
|
|
s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound)))
|
|
start := time.Now()
|
|
|
|
sendResultChan := make(chan error, 1)
|
|
go func() {
|
|
select {
|
|
case sendResultChan <- dstPeer.Stream.Send(msg):
|
|
return
|
|
case <-dstPeer.Stream.Context().Done():
|
|
return
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case err := <-sendResultChan:
|
|
if err != nil {
|
|
log.Tracef("error while forwarding message from peer [%s] to peer [%s]: %v", msg.Key, msg.RemoteKey, err)
|
|
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
|
|
return fmt.Errorf("error sending message to peer: %v", err)
|
|
}
|
|
s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream)))
|
|
s.metrics.MessagesForwarded.Add(ctx, 1)
|
|
s.metrics.MessageSize.Record(ctx, int64(gproto.Size(msg)), metric.WithAttributes(attribute.String(labelType, labelTypeMessage)))
|
|
|
|
case <-dstPeer.Stream.Context().Done():
|
|
log.Tracef("failed to forward message from peer [%s] to peer [%s]: destination peer disconnected", msg.Key, msg.RemoteKey)
|
|
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeDisconnected)))
|
|
return fmt.Errorf("destination peer disconnected")
|
|
case <-time.After(s.sendTimeout):
|
|
dstPeer.Cancel() // cancel the peer context to trigger deregistration
|
|
log.Tracef("failed to forward message from peer [%s] to peer [%s]: send timeout", msg.Key, msg.RemoteKey)
|
|
s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeTimeout)))
|
|
return fmt.Errorf("sending message to peer timeout")
|
|
}
|
|
|
|
return nil
|
|
}
|