Add dialWebSocket method to WASM client

This commit is contained in:
Viktor Liu
2026-04-11 20:06:23 +02:00
parent 5259e5df51
commit 009a1edcaa
4 changed files with 298 additions and 0 deletions

View File

@@ -17,6 +17,7 @@ import (
"github.com/netbirdio/netbird/client/wasm/internal/http" "github.com/netbirdio/netbird/client/wasm/internal/http"
"github.com/netbirdio/netbird/client/wasm/internal/rdp" "github.com/netbirdio/netbird/client/wasm/internal/rdp"
"github.com/netbirdio/netbird/client/wasm/internal/ssh" "github.com/netbirdio/netbird/client/wasm/internal/ssh"
nbwebsocket "github.com/netbirdio/netbird/client/wasm/internal/websocket"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -516,6 +517,7 @@ func createClientObject(client *netbird.Client) js.Value {
obj["createSSHConnection"] = createSSHMethod(client) obj["createSSHConnection"] = createSSHMethod(client)
obj["proxyRequest"] = createProxyRequestMethod(client) obj["proxyRequest"] = createProxyRequestMethod(client)
obj["createRDPProxy"] = createRDPProxyMethod(client) obj["createRDPProxy"] = createRDPProxyMethod(client)
obj["dialWebSocket"] = createDialWebSocketMethod(client)
obj["status"] = createStatusMethod(client) obj["status"] = createStatusMethod(client)
obj["statusSummary"] = createStatusSummaryMethod(client) obj["statusSummary"] = createStatusSummaryMethod(client)
obj["statusDetail"] = createStatusDetailMethod(client) obj["statusDetail"] = createStatusDetailMethod(client)
@@ -525,6 +527,31 @@ func createClientObject(client *netbird.Client) js.Value {
return js.ValueOf(obj) return js.ValueOf(obj)
} }
const dialWebSocketTimeout = 30 * time.Second
func createDialWebSocketMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 1 || args[0].Type() != js.TypeString {
return js.ValueOf("error: dialWebSocket requires a URL string argument")
}
url := args[0].String()
return createPromise(func(resolve, reject js.Value) {
ctx, cancel := context.WithTimeout(context.Background(), dialWebSocketTimeout)
defer cancel()
conn, err := nbwebsocket.Dial(ctx, client, url)
if err != nil {
reject.Invoke(js.ValueOf(fmt.Sprintf("dial websocket: %v", err)))
return
}
resolve.Invoke(nbwebsocket.NewJSInterface(conn))
})
})
}
// netBirdClientConstructor acts as a JavaScript constructor function // netBirdClientConstructor acts as a JavaScript constructor function
func netBirdClientConstructor(_ js.Value, args []js.Value) any { func netBirdClientConstructor(_ js.Value, args []js.Value) any {
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {

View File

@@ -0,0 +1,261 @@
//go:build js
package websocket
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"sync"
"syscall/js"
netbird "github.com/netbirdio/netbird/client/embed"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
log "github.com/sirupsen/logrus"
)
type closeError struct {
code uint16
reason string
}
func (e *closeError) Error() string {
return fmt.Sprintf("websocket closed: %d %s", e.code, e.reason)
}
// Conn wraps a WebSocket connection over a NetBird TCP connection.
type Conn struct {
conn net.Conn
mu sync.Mutex
closed chan struct{}
closeOnce sync.Once
closeErr error
}
// Dial establishes a WebSocket connection to the given URL through the NetBird network.
func Dial(ctx context.Context, client *netbird.Client, rawURL string) (*Conn, error) {
d := ws.Dialer{
NetDial: client.Dial,
}
conn, br, _, err := d.Dial(ctx, rawURL)
if err != nil {
return nil, fmt.Errorf("websocket dial: %w", err)
}
if br != nil {
ws.PutReader(br)
}
return &Conn{
conn: conn,
closed: make(chan struct{}),
}, nil
}
// ReadMessage reads the next WebSocket message, handling control frames automatically.
func (c *Conn) ReadMessage() (ws.OpCode, []byte, error) {
for {
msgs, err := wsutil.ReadServerMessage(c.conn, nil)
if err != nil {
return 0, nil, err
}
for _, msg := range msgs {
if msg.OpCode.IsControl() {
if err := c.handleControl(msg); err != nil {
return 0, nil, err
}
continue
}
return msg.OpCode, msg.Payload, nil
}
}
}
func (c *Conn) handleControl(msg wsutil.Message) error {
switch msg.OpCode {
case ws.OpPing:
c.mu.Lock()
defer c.mu.Unlock()
return wsutil.WriteClientMessage(c.conn, ws.OpPong, msg.Payload)
case ws.OpClose:
code, reason := parseClosePayload(msg.Payload)
return &closeError{code: code, reason: reason}
default:
return nil
}
}
// WriteText sends a text WebSocket message.
func (c *Conn) WriteText(data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
return wsutil.WriteClientMessage(c.conn, ws.OpText, data)
}
// WriteBinary sends a binary WebSocket message.
func (c *Conn) WriteBinary(data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
return wsutil.WriteClientMessage(c.conn, ws.OpBinary, data)
}
// Close sends a close frame and closes the underlying connection.
func (c *Conn) Close() error {
var first bool
c.closeOnce.Do(func() {
first = true
close(c.closed)
c.mu.Lock()
_ = wsutil.WriteClientMessage(c.conn, ws.OpClose,
ws.NewCloseFrameBody(ws.StatusNormalClosure, ""),
)
c.mu.Unlock()
c.closeErr = c.conn.Close()
})
if !first {
return net.ErrClosed
}
return c.closeErr
}
// NewJSInterface creates a JavaScript object wrapping the WebSocket connection.
// It exposes: send(string|Uint8Array), close(), and callback properties
// onmessage, onclose, onerror.
//
// Callback properties may be set from the JS thread while the read loop
// goroutine reads them. In WASM this is safe because Go and JS share a
// single thread, but the design would need synchronization on
// multi-threaded runtimes.
func NewJSInterface(conn *Conn) js.Value {
obj := js.Global().Get("Object").Call("create", js.Null())
obj.Set("send", js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf("send requires a data argument")
}
data := args[0]
switch data.Type() {
case js.TypeString:
if err := conn.WriteText([]byte(data.String())); err != nil {
log.Errorf("failed to send websocket text: %v", err)
return js.ValueOf(false)
}
default:
buf, err := jsToBytes(data)
if err != nil {
return js.ValueOf(err.Error())
}
if err := conn.WriteBinary(buf); err != nil {
log.Errorf("failed to send websocket binary: %v", err)
return js.ValueOf(false)
}
}
return js.ValueOf(true)
}))
obj.Set("close", js.FuncOf(func(_ js.Value, _ []js.Value) any {
if err := conn.Close(); err != nil {
log.Debugf("failed to close websocket: %v", err)
}
return js.Undefined()
}))
go readLoop(conn, obj)
return obj
}
func jsToBytes(data js.Value) ([]byte, error) {
var uint8Array js.Value
switch {
case data.InstanceOf(js.Global().Get("Uint8Array")):
uint8Array = data
case data.InstanceOf(js.Global().Get("ArrayBuffer")):
uint8Array = js.Global().Get("Uint8Array").New(data)
default:
return nil, fmt.Errorf("send: unsupported data type, use string or Uint8Array")
}
buf := make([]byte, uint8Array.Get("length").Int())
js.CopyBytesToGo(buf, uint8Array)
return buf, nil
}
func readLoop(conn *Conn, obj js.Value) {
var closeCode uint16
var closeReason string
var gotCloseFrame bool
defer func() {
onclose := obj.Get("onclose")
if !onclose.Truthy() {
return
}
if gotCloseFrame {
onclose.Invoke(js.ValueOf(int(closeCode)), js.ValueOf(closeReason))
return
}
onclose.Invoke()
}()
for {
select {
case <-conn.closed:
return
default:
}
op, payload, err := conn.ReadMessage()
if err != nil {
var ce *closeError
if errors.As(err, &ce) {
gotCloseFrame = true
closeCode = ce.code
closeReason = ce.reason
// Respond to server close per RFC 6455.
if err := conn.Close(); err != nil {
log.Debugf("failed to close websocket after server close frame: %v", err)
}
return
}
if err != io.EOF {
if onerror := obj.Get("onerror"); onerror.Truthy() {
onerror.Invoke(js.ValueOf(err.Error()))
}
}
return
}
onmessage := obj.Get("onmessage")
if !onmessage.Truthy() {
continue
}
switch op {
case ws.OpText:
onmessage.Invoke(js.ValueOf(string(payload)))
case ws.OpBinary:
uint8Array := js.Global().Get("Uint8Array").New(len(payload))
js.CopyBytesToJS(uint8Array, payload)
onmessage.Invoke(uint8Array)
}
}
}
func parseClosePayload(payload []byte) (uint16, string) {
if len(payload) < 2 {
return 1005, "" // RFC 6455: No Status Rcvd
}
code := binary.BigEndian.Uint16(payload[:2])
return code, string(payload[2:])
}

3
go.mod
View File

@@ -190,6 +190,9 @@ require (
github.com/go-sql-driver/mysql v1.9.3 // indirect github.com/go-sql-driver/mysql v1.9.3 // indirect
github.com/go-text/render v0.2.0 // indirect github.com/go-text/render v0.2.0 // indirect
github.com/go-text/typesetting v0.2.1 // indirect github.com/go-text/typesetting v0.2.1 // indirect
github.com/gobwas/httphead v0.1.0 // indirect
github.com/gobwas/pool v0.2.1 // indirect
github.com/gobwas/ws v1.3.2 // indirect
github.com/gogo/protobuf v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/btree v1.1.2 // indirect github.com/google/btree v1.1.2 // indirect
github.com/google/s2a-go v0.1.9 // indirect github.com/google/s2a-go v0.1.9 // indirect

7
go.sum
View File

@@ -203,6 +203,12 @@ github.com/go-text/typesetting v0.2.1 h1:x0jMOGyO3d1qFAPI0j4GSsh7M0Q3Ypjzr4+CEVg
github.com/go-text/typesetting v0.2.1/go.mod h1:mTOxEwasOFpAMBjEQDhdWRckoLLeI/+qrQeBCTGEt6M= github.com/go-text/typesetting v0.2.1/go.mod h1:mTOxEwasOFpAMBjEQDhdWRckoLLeI/+qrQeBCTGEt6M=
github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066 h1:qCuYC+94v2xrb1PoS4NIDe7DGYtLnU2wWiQe9a1B1c0= github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066 h1:qCuYC+94v2xrb1PoS4NIDe7DGYtLnU2wWiQe9a1B1c0=
github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066/go.mod h1:DDxDdQEnB70R8owOx3LVpEFvpMK9eeH1o2r0yZhFI9o= github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066/go.mod h1:DDxDdQEnB70R8owOx3LVpEFvpMK9eeH1o2r0yZhFI9o=
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
github.com/gobwas/ws v1.3.2 h1:zlnbNHxumkRvfPWgfXu8RBwyNR1x8wh9cf5PTOCqs9Q=
github.com/gobwas/ws v1.3.2/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY=
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
@@ -738,6 +744,7 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=