diff --git a/client/internal/peer/signaler.go b/client/internal/peer/signaler.go index 58be304cc..7153d986a 100644 --- a/client/internal/peer/signaler.go +++ b/client/internal/peer/signaler.go @@ -2,6 +2,7 @@ package peer import ( "errors" + "sync/atomic" "github.com/pion/ice/v4" log "github.com/sirupsen/logrus" @@ -11,11 +12,16 @@ import ( sProto "github.com/netbirdio/netbird/shared/signal/proto" ) -var ErrPeerNotAvailable = signal.ErrPeerNotAvailable +var ( + ErrPeerNotAvailable = signal.ErrPeerNotAvailable + ErrSignalNotSupportDeliveryCheck = errors.New("the signal client does not support SendWithDeliveryCheck") +) type Signaler struct { signal signal.Client wgPrivateKey wgtypes.Key + + deliveryCheckNotSupported atomic.Bool } func NewSignaler(signal signal.Client, wgPrivateKey wgtypes.Key) *Signaler { @@ -71,13 +77,21 @@ func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, return err } - if err = s.signal.SendWithDeliveryCheck(msg); err != nil { - if errors.Is(err, signal.ErrPeerNotAvailable) { - return ErrPeerNotAvailable - } - return err + if s.deliveryCheckNotSupported.Load() { + return s.signal.Send(msg) } + if err = s.signal.SendWithDeliveryCheck(msg); err != nil { + switch { + case errors.Is(err, signal.ErrPeerNotAvailable): + return ErrPeerNotAvailable + case errors.Is(err, signal.ErrUnimplementedMethod): + s.deliveryCheckNotSupported.Store(true) + return ErrSignalNotSupportDeliveryCheck + default: + return err + } + } return nil } diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index eeb7f9c0c..d8ca4bb55 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -24,7 +24,8 @@ import ( ) var ( - ErrPeerNotAvailable = errors.New("peer not available") + ErrPeerNotAvailable = errors.New("peer not available") + ErrUnimplementedMethod = errors.New("the signal client does not support SendWithDeliveryCheck") ) // ConnStateNotifier is a wrapper interface of the status recorder @@ -420,6 +421,8 @@ func (c *GrpcClient) SendWithDeliveryCheck(msg *proto.Message) error { switch st.Code() { case codes.NotFound: return ErrPeerNotAvailable + case codes.Unimplemented: + return ErrUnimplementedMethod default: return fmt.Errorf("grpc error %s: %w", st.Code(), err) } diff --git a/signal/cmd/env.go b/signal/cmd/env.go index 3c15ebe1f..b04cb6188 100644 --- a/signal/cmd/env.go +++ b/signal/cmd/env.go @@ -2,6 +2,7 @@ package cmd import ( "os" + "strconv" "strings" log "github.com/sirupsen/logrus" @@ -9,6 +10,19 @@ import ( "github.com/spf13/pflag" ) +func EnvDisableSendWithDeliveryCheck() bool { + envVar := "NB_DISABLE_SEND_WITH_DELIVERY_CHECK" + value, present := os.LookupEnv(envVar) + if !present { + return false + } + + if parsed, err := strconv.ParseBool(value); err == nil { + return parsed + } + return false +} + // setFlagsFromEnvVars reads and updates flag values from environment variables with prefix NB_ func setFlagsFromEnvVars(cmd *cobra.Command) { flags := cmd.PersistentFlags() diff --git a/signal/cmd/run.go b/signal/cmd/run.go index dea90ddc3..c528fbb8b 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -109,7 +109,11 @@ var ( } }() - srv, err := server.NewServer(cmd.Context(), metricsServer.Meter) + optsSignal := &server.Options{ + DisableSendWithDeliveryCheck: EnvDisableSendWithDeliveryCheck(), + } + + srv, err := server.NewServer(cmd.Context(), metricsServer.Meter, optsSignal) if err != nil { return fmt.Errorf("creating signal server: %v", err) } diff --git a/signal/server/signal.go b/signal/server/signal.go index 47beda29b..d884ab157 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -50,12 +50,19 @@ var ( ErrPeerRegisteredAgain = errors.New("peer registered again") ) +type Options struct { + // Disable SendWithDeliveryCheck method + DisableSendWithDeliveryCheck bool +} + // Server an instance of a Signal server type Server struct { + metrics *metrics.AppMetrics + disableSendWithDeliveryCheck bool + registry *peer.Registry proto.UnimplementedSignalExchangeServer dispatcher *dispatcher.Dispatcher - metrics *metrics.AppMetrics successHeader metadata.MD @@ -63,7 +70,11 @@ type Server struct { } // NewServer creates a new Signal server -func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) { +func NewServer(ctx context.Context, meter metric.Meter, opts *Options) (*Server, error) { + if opts == nil { + opts = &Options{} + } + appMetrics, err := metrics.NewAppMetrics(meter) if err != nil { return nil, fmt.Errorf("creating app metrics: %v", err) @@ -82,11 +93,12 @@ func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) { } s := &Server{ - dispatcher: d, - registry: peer.NewRegistry(appMetrics), - metrics: appMetrics, - successHeader: metadata.Pairs(proto.HeaderRegistered, "1"), - sendTimeout: sTimeout, + dispatcher: d, + registry: peer.NewRegistry(appMetrics), + metrics: appMetrics, + successHeader: metadata.Pairs(proto.HeaderRegistered, "1"), + sendTimeout: sTimeout, + disableSendWithDeliveryCheck: opts.DisableSendWithDeliveryCheck, } return s, nil @@ -111,8 +123,12 @@ func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto. // Todo: double check the thread safe registry management. When both peer come online at the same time then both peers // might not be registered yet when the first message is sent. func (s *Server) SendWithDeliveryCheck(ctx context.Context, msg *proto.EncryptedMessage) (*emptypb.Empty, error) { - log.Tracef("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey) + if s.disableSendWithDeliveryCheck { + log.Tracef("SendWithDeliveryCheck is disabled") + return nil, status.Errorf(codes.Unimplemented, "SendWithDeliveryCheck method is disabled") + } + log.Tracef("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey) if _, found := s.registry.Get(msg.RemoteKey); found { // todo error handling here too err := s.forwardMessageToPeer(ctx, msg)