diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go index d8e50ab6d..984f111ad 100644 --- a/client/wasm/cmd/main.go +++ b/client/wasm/cmd/main.go @@ -17,6 +17,7 @@ import ( "github.com/netbirdio/netbird/client/wasm/internal/http" "github.com/netbirdio/netbird/client/wasm/internal/rdp" "github.com/netbirdio/netbird/client/wasm/internal/ssh" + nbwebsocket "github.com/netbirdio/netbird/client/wasm/internal/websocket" "github.com/netbirdio/netbird/util" ) @@ -516,6 +517,7 @@ func createClientObject(client *netbird.Client) js.Value { obj["createSSHConnection"] = createSSHMethod(client) obj["proxyRequest"] = createProxyRequestMethod(client) obj["createRDPProxy"] = createRDPProxyMethod(client) + obj["dialWebSocket"] = createDialWebSocketMethod(client) obj["status"] = createStatusMethod(client) obj["statusSummary"] = createStatusSummaryMethod(client) obj["statusDetail"] = createStatusDetailMethod(client) @@ -525,6 +527,31 @@ func createClientObject(client *netbird.Client) js.Value { return js.ValueOf(obj) } +const dialWebSocketTimeout = 30 * time.Second + +func createDialWebSocketMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(_ js.Value, args []js.Value) any { + if len(args) < 1 || args[0].Type() != js.TypeString { + return js.ValueOf("error: dialWebSocket requires a URL string argument") + } + + url := args[0].String() + + return createPromise(func(resolve, reject js.Value) { + ctx, cancel := context.WithTimeout(context.Background(), dialWebSocketTimeout) + defer cancel() + + conn, err := nbwebsocket.Dial(ctx, client, url) + if err != nil { + reject.Invoke(js.ValueOf(fmt.Sprintf("dial websocket: %v", err))) + return + } + + resolve.Invoke(nbwebsocket.NewJSInterface(conn)) + }) + }) +} + // netBirdClientConstructor acts as a JavaScript constructor function func netBirdClientConstructor(_ js.Value, args []js.Value) any { return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { diff --git a/client/wasm/internal/websocket/websocket.go b/client/wasm/internal/websocket/websocket.go new file mode 100644 index 000000000..53b66d9d8 --- /dev/null +++ b/client/wasm/internal/websocket/websocket.go @@ -0,0 +1,261 @@ +//go:build js + +package websocket + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "sync" + "syscall/js" + + netbird "github.com/netbirdio/netbird/client/embed" + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" + log "github.com/sirupsen/logrus" +) + +type closeError struct { + code uint16 + reason string +} + +func (e *closeError) Error() string { + return fmt.Sprintf("websocket closed: %d %s", e.code, e.reason) +} + +// Conn wraps a WebSocket connection over a NetBird TCP connection. +type Conn struct { + conn net.Conn + mu sync.Mutex + closed chan struct{} + closeOnce sync.Once + closeErr error +} + +// Dial establishes a WebSocket connection to the given URL through the NetBird network. +func Dial(ctx context.Context, client *netbird.Client, rawURL string) (*Conn, error) { + d := ws.Dialer{ + NetDial: client.Dial, + } + + conn, br, _, err := d.Dial(ctx, rawURL) + if err != nil { + return nil, fmt.Errorf("websocket dial: %w", err) + } + + if br != nil { + ws.PutReader(br) + } + + return &Conn{ + conn: conn, + closed: make(chan struct{}), + }, nil +} + +// ReadMessage reads the next WebSocket message, handling control frames automatically. +func (c *Conn) ReadMessage() (ws.OpCode, []byte, error) { + for { + msgs, err := wsutil.ReadServerMessage(c.conn, nil) + if err != nil { + return 0, nil, err + } + + for _, msg := range msgs { + if msg.OpCode.IsControl() { + if err := c.handleControl(msg); err != nil { + return 0, nil, err + } + continue + } + return msg.OpCode, msg.Payload, nil + } + } +} + +func (c *Conn) handleControl(msg wsutil.Message) error { + switch msg.OpCode { + case ws.OpPing: + c.mu.Lock() + defer c.mu.Unlock() + return wsutil.WriteClientMessage(c.conn, ws.OpPong, msg.Payload) + case ws.OpClose: + code, reason := parseClosePayload(msg.Payload) + return &closeError{code: code, reason: reason} + default: + return nil + } +} + +// WriteText sends a text WebSocket message. +func (c *Conn) WriteText(data []byte) error { + c.mu.Lock() + defer c.mu.Unlock() + return wsutil.WriteClientMessage(c.conn, ws.OpText, data) +} + +// WriteBinary sends a binary WebSocket message. +func (c *Conn) WriteBinary(data []byte) error { + c.mu.Lock() + defer c.mu.Unlock() + return wsutil.WriteClientMessage(c.conn, ws.OpBinary, data) +} + +// Close sends a close frame and closes the underlying connection. +func (c *Conn) Close() error { + var first bool + c.closeOnce.Do(func() { + first = true + close(c.closed) + + c.mu.Lock() + _ = wsutil.WriteClientMessage(c.conn, ws.OpClose, + ws.NewCloseFrameBody(ws.StatusNormalClosure, ""), + ) + c.mu.Unlock() + + c.closeErr = c.conn.Close() + }) + + if !first { + return net.ErrClosed + } + return c.closeErr +} + +// NewJSInterface creates a JavaScript object wrapping the WebSocket connection. +// It exposes: send(string|Uint8Array), close(), and callback properties +// onmessage, onclose, onerror. +// +// Callback properties may be set from the JS thread while the read loop +// goroutine reads them. In WASM this is safe because Go and JS share a +// single thread, but the design would need synchronization on +// multi-threaded runtimes. +func NewJSInterface(conn *Conn) js.Value { + obj := js.Global().Get("Object").Call("create", js.Null()) + + obj.Set("send", js.FuncOf(func(_ js.Value, args []js.Value) any { + if len(args) < 1 { + return js.ValueOf("send requires a data argument") + } + + data := args[0] + switch data.Type() { + case js.TypeString: + if err := conn.WriteText([]byte(data.String())); err != nil { + log.Errorf("failed to send websocket text: %v", err) + return js.ValueOf(false) + } + default: + buf, err := jsToBytes(data) + if err != nil { + return js.ValueOf(err.Error()) + } + if err := conn.WriteBinary(buf); err != nil { + log.Errorf("failed to send websocket binary: %v", err) + return js.ValueOf(false) + } + } + return js.ValueOf(true) + })) + + obj.Set("close", js.FuncOf(func(_ js.Value, _ []js.Value) any { + if err := conn.Close(); err != nil { + log.Debugf("failed to close websocket: %v", err) + } + return js.Undefined() + })) + + go readLoop(conn, obj) + + return obj +} + +func jsToBytes(data js.Value) ([]byte, error) { + var uint8Array js.Value + switch { + case data.InstanceOf(js.Global().Get("Uint8Array")): + uint8Array = data + case data.InstanceOf(js.Global().Get("ArrayBuffer")): + uint8Array = js.Global().Get("Uint8Array").New(data) + default: + return nil, fmt.Errorf("send: unsupported data type, use string or Uint8Array") + } + + buf := make([]byte, uint8Array.Get("length").Int()) + js.CopyBytesToGo(buf, uint8Array) + return buf, nil +} + +func readLoop(conn *Conn, obj js.Value) { + var closeCode uint16 + var closeReason string + var gotCloseFrame bool + + defer func() { + onclose := obj.Get("onclose") + if !onclose.Truthy() { + return + } + if gotCloseFrame { + onclose.Invoke(js.ValueOf(int(closeCode)), js.ValueOf(closeReason)) + return + } + onclose.Invoke() + }() + + for { + select { + case <-conn.closed: + return + default: + } + + op, payload, err := conn.ReadMessage() + if err != nil { + var ce *closeError + if errors.As(err, &ce) { + gotCloseFrame = true + closeCode = ce.code + closeReason = ce.reason + // Respond to server close per RFC 6455. + if err := conn.Close(); err != nil { + log.Debugf("failed to close websocket after server close frame: %v", err) + } + return + } + if err != io.EOF { + if onerror := obj.Get("onerror"); onerror.Truthy() { + onerror.Invoke(js.ValueOf(err.Error())) + } + } + return + } + + onmessage := obj.Get("onmessage") + if !onmessage.Truthy() { + continue + } + + switch op { + case ws.OpText: + onmessage.Invoke(js.ValueOf(string(payload))) + case ws.OpBinary: + uint8Array := js.Global().Get("Uint8Array").New(len(payload)) + js.CopyBytesToJS(uint8Array, payload) + onmessage.Invoke(uint8Array) + } + } +} + +func parseClosePayload(payload []byte) (uint16, string) { + if len(payload) < 2 { + return 1005, "" // RFC 6455: No Status Rcvd + } + code := binary.BigEndian.Uint16(payload[:2]) + return code, string(payload[2:]) +} diff --git a/go.mod b/go.mod index a95192600..3a7d968e7 100644 --- a/go.mod +++ b/go.mod @@ -190,6 +190,9 @@ require ( github.com/go-sql-driver/mysql v1.9.3 // indirect github.com/go-text/render v0.2.0 // indirect github.com/go-text/typesetting v0.2.1 // indirect + github.com/gobwas/httphead v0.1.0 // indirect + github.com/gobwas/pool v0.2.1 // indirect + github.com/gobwas/ws v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/google/btree v1.1.2 // indirect github.com/google/s2a-go v0.1.9 // indirect diff --git a/go.sum b/go.sum index a1d2bb71f..4670b3785 100644 --- a/go.sum +++ b/go.sum @@ -203,6 +203,12 @@ github.com/go-text/typesetting v0.2.1 h1:x0jMOGyO3d1qFAPI0j4GSsh7M0Q3Ypjzr4+CEVg github.com/go-text/typesetting v0.2.1/go.mod h1:mTOxEwasOFpAMBjEQDhdWRckoLLeI/+qrQeBCTGEt6M= github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066 h1:qCuYC+94v2xrb1PoS4NIDe7DGYtLnU2wWiQe9a1B1c0= github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066/go.mod h1:DDxDdQEnB70R8owOx3LVpEFvpMK9eeH1o2r0yZhFI9o= +github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= +github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= +github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= +github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.3.2 h1:zlnbNHxumkRvfPWgfXu8RBwyNR1x8wh9cf5PTOCqs9Q= +github.com/gobwas/ws v1.3.2/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= @@ -738,6 +744,7 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=