diff --git a/signal/loadtest/client.go b/signal/loadtest/client.go new file mode 100644 index 000000000..d4e4ac3d3 --- /dev/null +++ b/signal/loadtest/client.go @@ -0,0 +1,152 @@ +package loadtest + +import ( + "context" + "fmt" + "strings" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" + + "github.com/netbirdio/netbird/shared/signal/proto" +) + +// Client represents a signal client for load testing +type Client struct { + id string + conn *grpc.ClientConn + client proto.SignalExchangeClient + stream proto.SignalExchange_ConnectStreamClient + ctx context.Context + cancel context.CancelFunc + msgChannel chan *proto.EncryptedMessage +} + +// NewClient creates a new signal client for load testing +func NewClient(serverURL, peerID string) (*Client, error) { + addr, opts, err := parseServerURL(serverURL) + if err != nil { + return nil, fmt.Errorf("parse server URL: %w", err) + } + + conn, err := grpc.Dial(addr, opts...) + if err != nil { + return nil, fmt.Errorf("dial server: %w", err) + } + + client := proto.NewSignalExchangeClient(conn) + ctx, cancel := context.WithCancel(context.Background()) + + return &Client{ + id: peerID, + conn: conn, + client: client, + ctx: ctx, + cancel: cancel, + msgChannel: make(chan *proto.EncryptedMessage, 10), + }, nil +} + +// Connect establishes a stream connection to the signal server +func (c *Client) Connect() error { + md := metadata.New(map[string]string{proto.HeaderId: c.id}) + ctx := metadata.NewOutgoingContext(c.ctx, md) + + stream, err := c.client.ConnectStream(ctx) + if err != nil { + return fmt.Errorf("connect stream: %w", err) + } + + if _, err := stream.Header(); err != nil { + return fmt.Errorf("receive header: %w", err) + } + + c.stream = stream + + go c.receiveMessages() + + return nil +} + +// SendMessage sends an encrypted message to a remote peer using the Send RPC +func (c *Client) SendMessage(remotePeerID string, body []byte) error { + msg := &proto.EncryptedMessage{ + Key: c.id, + RemoteKey: remotePeerID, + Body: body, + } + + ctx, cancel := context.WithTimeout(c.ctx, 5*time.Second) + defer cancel() + + _, err := c.client.Send(ctx, msg) + if err != nil { + return fmt.Errorf("send message: %w", err) + } + + return nil +} + +// ReceiveMessage waits for and returns the next message +func (c *Client) ReceiveMessage() (*proto.EncryptedMessage, error) { + select { + case msg := <-c.msgChannel: + return msg, nil + case <-c.ctx.Done(): + return nil, c.ctx.Err() + } +} + +// Close closes the client connection +func (c *Client) Close() error { + c.cancel() + close(c.msgChannel) + if c.conn != nil { + return c.conn.Close() + } + return nil +} + +func (c *Client) receiveMessages() { + for { + msg, err := c.stream.Recv() + if err != nil { + return + } + + select { + case c.msgChannel <- msg: + case <-c.ctx.Done(): + return + } + } +} + +func parseServerURL(serverURL string) (string, []grpc.DialOption, error) { + serverURL = strings.TrimSpace(serverURL) + if serverURL == "" { + return "", nil, fmt.Errorf("server URL is empty") + } + + var addr string + var opts []grpc.DialOption + + if strings.HasPrefix(serverURL, "https://") { + addr = strings.TrimPrefix(serverURL, "https://") + return "", nil, fmt.Errorf("TLS support not yet implemented") + } else if strings.HasPrefix(serverURL, "http://") { + addr = strings.TrimPrefix(serverURL, "http://") + opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } else { + addr = serverURL + opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + + if !strings.Contains(addr, ":") { + return "", nil, fmt.Errorf("server URL must include port") + } + + return addr, opts, nil +} diff --git a/signal/loadtest/loadtest_test.go b/signal/loadtest/loadtest_test.go new file mode 100644 index 000000000..2080259da --- /dev/null +++ b/signal/loadtest/loadtest_test.go @@ -0,0 +1,91 @@ +package loadtest + +import ( + "context" + "fmt" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + "google.golang.org/grpc" + + "github.com/netbirdio/netbird/shared/signal/proto" + "github.com/netbirdio/netbird/signal/server" +) + +func TestSignalLoadTest_SinglePair(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + grpcServer, serverAddr := startTestSignalServer(t, ctx) + defer grpcServer.Stop() + + sender, err := NewClient(serverAddr, "sender-peer-id") + require.NoError(t, err) + defer sender.Close() + + receiver, err := NewClient(serverAddr, "receiver-peer-id") + require.NoError(t, err) + defer receiver.Close() + + err = sender.Connect() + require.NoError(t, err, "Sender should connect successfully") + + err = receiver.Connect() + require.NoError(t, err, "Receiver should connect successfully") + + time.Sleep(100 * time.Millisecond) + + testMessage := []byte("test message payload") + + t.Log("Sending message from sender to receiver") + err = sender.SendMessage("receiver-peer-id", testMessage) + require.NoError(t, err, "Sender should send message successfully") + + t.Log("Waiting for receiver to receive message") + + receiveDone := make(chan struct{}) + var msg *proto.EncryptedMessage + var receiveErr error + + go func() { + msg, receiveErr = receiver.ReceiveMessage() + close(receiveDone) + }() + + select { + case <-receiveDone: + require.NoError(t, receiveErr, "Receiver should receive message") + require.NotNil(t, msg, "Received message should not be nil") + require.Greater(t, len(msg.Body), 0, "Encrypted message body size should be greater than 0") + require.Equal(t, "sender-peer-id", msg.Key) + require.Equal(t, "receiver-peer-id", msg.RemoteKey) + t.Log("Message received successfully") + case <-time.After(5 * time.Second): + t.Fatal("Timeout waiting for message") + } +} + +func startTestSignalServer(t *testing.T, ctx context.Context) (*grpc.Server, string) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + grpcServer := grpc.NewServer() + + signalServer, err := server.NewServer(ctx, otel.Meter("test")) + require.NoError(t, err) + + proto.RegisterSignalExchangeServer(grpcServer, signalServer) + + go func() { + if err := grpcServer.Serve(listener); err != nil { + t.Logf("Server stopped: %v", err) + } + }() + + time.Sleep(100 * time.Millisecond) + + return grpcServer, fmt.Sprintf("http://%s", listener.Addr().String()) +}