mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
Add wasm client
This commit is contained in:
0
.gitmodules
vendored
Normal file
0
.gitmodules
vendored
Normal file
@@ -2,6 +2,18 @@ version: 2
|
||||
|
||||
project_name: netbird
|
||||
builds:
|
||||
- id: netbird-wasm
|
||||
dir: client/wasm/cmd
|
||||
binary: netbird.wasm
|
||||
env: [GOOS=js, GOARCH=wasm, CGO_ENABLED=0]
|
||||
goos:
|
||||
- js
|
||||
goarch:
|
||||
- wasm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird
|
||||
dir: client
|
||||
binary: netbird
|
||||
@@ -115,6 +127,13 @@ archives:
|
||||
- builds:
|
||||
- netbird
|
||||
- netbird-static
|
||||
- id: netbird-wasm
|
||||
builds:
|
||||
- netbird-wasm
|
||||
name_template: "{{ .ProjectName }}_wasm_{{ .Version }}"
|
||||
format: tar.gz
|
||||
files:
|
||||
- none*
|
||||
|
||||
nfpms:
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
|
||||
8
client/cmd/debug_js.go
Normal file
8
client/cmd/debug_js.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package cmd
|
||||
|
||||
import "context"
|
||||
|
||||
// SetupDebugHandler is a no-op for WASM
|
||||
func SetupDebugHandler(context.Context, interface{}, interface{}, interface{}, string) {
|
||||
// Debug handler not needed for WASM
|
||||
}
|
||||
@@ -23,23 +23,27 @@ import (
|
||||
|
||||
var ErrClientAlreadyStarted = errors.New("client already started")
|
||||
var ErrClientNotStarted = errors.New("client not started")
|
||||
var ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
|
||||
// Client manages a netbird embedded client instance
|
||||
// Client manages a netbird embedded client instance.
|
||||
type Client struct {
|
||||
deviceName string
|
||||
config *profilemanager.Config
|
||||
mu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
setupKey string
|
||||
jwtToken string
|
||||
connect *internal.ConnectClient
|
||||
}
|
||||
|
||||
// Options configures a new Client
|
||||
// Options configures a new Client.
|
||||
type Options struct {
|
||||
// DeviceName is this peer's name in the network
|
||||
DeviceName string
|
||||
// SetupKey is used for authentication
|
||||
SetupKey string
|
||||
// JWTToken is used for JWT-based authentication
|
||||
JWTToken string
|
||||
// ManagementURL overrides the default management server URL
|
||||
ManagementURL string
|
||||
// PreSharedKey is the pre-shared key for the WireGuard interface
|
||||
@@ -58,8 +62,15 @@ type Options struct {
|
||||
DisableClientRoutes bool
|
||||
}
|
||||
|
||||
// New creates a new netbird embedded client
|
||||
// New creates a new netbird embedded client.
|
||||
func New(opts Options) (*Client, error) {
|
||||
if opts.SetupKey == "" && opts.JWTToken == "" {
|
||||
return nil, fmt.Errorf("either SetupKey or JWTToken must be provided")
|
||||
}
|
||||
if opts.SetupKey != "" && opts.JWTToken != "" {
|
||||
return nil, fmt.Errorf("cannot specify both SetupKey and JWTToken")
|
||||
}
|
||||
|
||||
if opts.LogOutput != nil {
|
||||
logrus.SetOutput(opts.LogOutput)
|
||||
}
|
||||
@@ -110,6 +121,7 @@ func New(opts Options) (*Client, error) {
|
||||
return &Client{
|
||||
deviceName: opts.DeviceName,
|
||||
setupKey: opts.SetupKey,
|
||||
jwtToken: opts.JWTToken,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
@@ -126,7 +138,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
|
||||
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
||||
return fmt.Errorf("login: %w", err)
|
||||
}
|
||||
|
||||
@@ -187,6 +199,16 @@ func (c *Client) Stop(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfig returns a copy of the internal client config.
|
||||
func (c *Client) GetConfig() (profilemanager.Config, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.config == nil {
|
||||
return profilemanager.Config{}, ErrConfigNotInitialized
|
||||
}
|
||||
return *c.config, nil
|
||||
}
|
||||
|
||||
// Dial dials a network address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
@@ -211,7 +233,7 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
|
||||
return nsnet.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// ListenTCP listens on the given address in the netbird network
|
||||
// ListenTCP listens on the given address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||
nsnet, addr, err := c.getNet()
|
||||
@@ -232,7 +254,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||
return nsnet.ListenTCP(tcpAddr)
|
||||
}
|
||||
|
||||
// ListenUDP listens on the given address in the netbird network
|
||||
// ListenUDP listens on the given address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
|
||||
nsnet, addr, err := c.getNet()
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !js
|
||||
|
||||
package bind
|
||||
|
||||
import (
|
||||
|
||||
69
client/iface/bind/ice_bind_common.go
Normal file
69
client/iface/bind/ice_bind_common.go
Normal file
@@ -0,0 +1,69 @@
|
||||
//go:build js
|
||||
|
||||
package bind
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/transport/v3"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// RecvMessage represents a received message
|
||||
type RecvMessage struct {
|
||||
Endpoint *Endpoint
|
||||
Buffer []byte
|
||||
}
|
||||
|
||||
// ICEBind is a bind implementation that uses ICE candidates for connectivity
|
||||
type ICEBind struct {
|
||||
address wgaddr.Address
|
||||
filterFn FilterFn
|
||||
endpoints map[netip.Addr]net.Conn
|
||||
endpointsMu sync.Mutex
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
muUDPMux sync.Mutex
|
||||
transportNet transport.Net
|
||||
receiverCreated bool
|
||||
activityRecorder *ActivityRecorder
|
||||
RecvChan chan RecvMessage
|
||||
closed bool // Flag to signal that bind is closed
|
||||
closedMu sync.Mutex
|
||||
mtu uint16
|
||||
}
|
||||
|
||||
// NewICEBind creates a new ICEBind instance
|
||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||
return &ICEBind{
|
||||
address: address,
|
||||
transportNet: transportNet,
|
||||
filterFn: filterFn,
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
RecvChan: make(chan RecvMessage, 100),
|
||||
activityRecorder: NewActivityRecorder(),
|
||||
mtu: mtu,
|
||||
}
|
||||
}
|
||||
|
||||
// SetFilter updates the filter function
|
||||
func (s *ICEBind) SetFilter(filter FilterFn) {
|
||||
s.filterFn = filter
|
||||
}
|
||||
|
||||
// GetAddress returns the bind address
|
||||
func (s *ICEBind) GetAddress() wgaddr.Address {
|
||||
return s.address
|
||||
}
|
||||
|
||||
// ActivityRecorder returns the activity recorder
|
||||
func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
|
||||
return s.activityRecorder
|
||||
}
|
||||
|
||||
// MTU returns the maximum transmission unit
|
||||
func (s *ICEBind) MTU() uint16 {
|
||||
return s.mtu
|
||||
}
|
||||
141
client/iface/bind/ice_bind_js.go
Normal file
141
client/iface/bind/ice_bind_js.go
Normal file
@@ -0,0 +1,141 @@
|
||||
//go:build js
|
||||
|
||||
package bind
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
// GetICEMux returns a dummy UDP mux for WASM since browsers don't support UDP.
|
||||
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Open creates a receive function for handling relay packets in WASM.
|
||||
func (s *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||
log.Debugf("Open: creating receive function for port %d", uport)
|
||||
|
||||
s.closedMu.Lock()
|
||||
s.closed = false
|
||||
s.closedMu.Unlock()
|
||||
|
||||
if !s.receiverCreated {
|
||||
s.receiverCreated = true
|
||||
log.Debugf("Open: first call, setting receiverCreated=true")
|
||||
}
|
||||
|
||||
receiveFn := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
|
||||
s.closedMu.Lock()
|
||||
if s.closed {
|
||||
s.closedMu.Unlock()
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
s.closedMu.Unlock()
|
||||
|
||||
msg, ok := <-s.RecvChan
|
||||
if !ok {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
copy(bufs[0], msg.Buffer)
|
||||
sizes[0] = len(msg.Buffer)
|
||||
eps[0] = conn.Endpoint(msg.Endpoint)
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
log.Debugf("Open: receive function created, returning port %d", uport)
|
||||
return []conn.ReceiveFunc{receiveFn}, uport, nil
|
||||
}
|
||||
|
||||
// SetMark is not applicable in WASM/browser environment.
|
||||
func (s *ICEBind) SetMark(_ uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send forwards packets through the relay connection for WASM.
|
||||
func (s *ICEBind) Send(bufs [][]byte, ep conn.Endpoint) error {
|
||||
if ep == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
fakeIP := ep.DstIP()
|
||||
|
||||
s.endpointsMu.Lock()
|
||||
relayConn, ok := s.endpoints[fakeIP]
|
||||
s.endpointsMu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, buf := range bufs {
|
||||
if _, err := relayConn.Write(buf); err != nil {
|
||||
log.Errorf("Send: failed to write to relay: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetEndpoint stores a relay endpoint for a fake IP.
|
||||
func (s *ICEBind) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
|
||||
s.endpointsMu.Lock()
|
||||
defer s.endpointsMu.Unlock()
|
||||
|
||||
if oldConn, exists := s.endpoints[fakeIP]; exists {
|
||||
if oldConn != conn {
|
||||
log.Debugf("SetEndpoint: replacing existing connection for %s", fakeIP)
|
||||
if err := oldConn.Close(); err != nil {
|
||||
log.Debugf("SetEndpoint: error closing old connection: %v", err)
|
||||
}
|
||||
s.endpoints[fakeIP] = conn
|
||||
} else {
|
||||
log.Tracef("SetEndpoint: same connection already set for %s, skipping", fakeIP)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("SetEndpoint: setting new relay connection for fake IP %s", fakeIP)
|
||||
s.endpoints[fakeIP] = conn
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveEndpoint removes a relay endpoint.
|
||||
func (s *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
|
||||
s.endpointsMu.Lock()
|
||||
defer s.endpointsMu.Unlock()
|
||||
delete(s.endpoints, fakeIP)
|
||||
}
|
||||
|
||||
// BatchSize returns the batch size for WASM.
|
||||
func (s *ICEBind) BatchSize() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
// ParseEndpoint parses an endpoint string.
|
||||
func (s *ICEBind) ParseEndpoint(s2 string) (conn.Endpoint, error) {
|
||||
addrPort, err := netip.ParseAddrPort(s2)
|
||||
if err != nil {
|
||||
log.Errorf("ParseEndpoint: failed to parse %s: %v", s2, err)
|
||||
return nil, err
|
||||
}
|
||||
ep := &Endpoint{AddrPort: addrPort}
|
||||
return ep, nil
|
||||
}
|
||||
|
||||
// Close closes the ICEBind.
|
||||
func (s *ICEBind) Close() error {
|
||||
log.Debugf("Close: closing ICEBind (receiverCreated=%v)", s.receiverCreated)
|
||||
|
||||
s.closedMu.Lock()
|
||||
s.closed = true
|
||||
s.closedMu.Unlock()
|
||||
|
||||
s.receiverCreated = false
|
||||
|
||||
log.Debugf("Close: returning from Close")
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux || windows || freebsd
|
||||
//go:build linux || windows || freebsd || js || wasip1
|
||||
|
||||
package configurer
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !windows
|
||||
//go:build !windows && !js
|
||||
|
||||
package configurer
|
||||
|
||||
|
||||
23
client/iface/configurer/uapi_js.go
Normal file
23
client/iface/configurer/uapi_js.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package configurer
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
type noopListener struct{}
|
||||
|
||||
func (n *noopListener) Accept() (net.Conn, error) {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
|
||||
func (n *noopListener) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *noopListener) Addr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func openUAPI(deviceName string) (net.Listener, error) {
|
||||
return &noopListener{}, nil
|
||||
}
|
||||
6
client/iface/iface_destroy_js.go
Normal file
6
client/iface/iface_destroy_js.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package iface
|
||||
|
||||
// Destroy is a no-op on WASM
|
||||
func (w *WGIface) Destroy() error {
|
||||
return nil
|
||||
}
|
||||
27
client/iface/iface_new_js.go
Normal file
27
client/iface/iface_new_js.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode)
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
|
||||
wgIface := &WGIface{
|
||||
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
|
||||
userspaceBind: true,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||
}
|
||||
|
||||
return wgIface, nil
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !js
|
||||
|
||||
package netstack
|
||||
|
||||
import (
|
||||
|
||||
12
client/iface/netstack/env_js.go
Normal file
12
client/iface/netstack/env_js.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package netstack
|
||||
|
||||
const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE"
|
||||
|
||||
// IsEnabled always returns true for js since it's the only mode available
|
||||
func IsEnabled() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func ListenAddr() string {
|
||||
return ""
|
||||
}
|
||||
5
client/internal/dns/server_js.go
Normal file
5
client/internal/dns/server_js.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package dns
|
||||
|
||||
func (s *DefaultServer) initialize() (hostManager, error) {
|
||||
return &noopHostConfigurator{}, nil
|
||||
}
|
||||
19
client/internal/dns/unclean_shutdown_js.go
Normal file
19
client/internal/dns/unclean_shutdown_js.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type ShutdownState struct{}
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "dns_state"
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ShutdownState) RestoreUncleanShutdownConfigs(context.Context) error {
|
||||
return nil
|
||||
}
|
||||
@@ -450,14 +450,7 @@ func (e *Engine) Start() error {
|
||||
return fmt.Errorf("initialize dns server: %w", err)
|
||||
}
|
||||
|
||||
iceCfg := icemaker.Config{
|
||||
StunTurn: &e.stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||
UDPMux: e.udpMux.UDPMuxDefault,
|
||||
UDPMuxSrflx: e.udpMux,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
}
|
||||
iceCfg := e.createICEConfig()
|
||||
|
||||
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface)
|
||||
e.connMgr.Start(e.ctx)
|
||||
@@ -1288,14 +1281,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
Addr: e.getRosenpassAddr(),
|
||||
PermissiveMode: e.config.RosenpassPermissive,
|
||||
},
|
||||
ICEConfig: icemaker.Config{
|
||||
StunTurn: &e.stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||
UDPMux: e.udpMux.UDPMuxDefault,
|
||||
UDPMuxSrflx: e.udpMux,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
},
|
||||
ICEConfig: e.createICEConfig(),
|
||||
}
|
||||
|
||||
serviceDependencies := peer.ServiceDependencies{
|
||||
|
||||
19
client/internal/engine_generic.go
Normal file
19
client/internal/engine_generic.go
Normal file
@@ -0,0 +1,19 @@
|
||||
//go:build !js
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
)
|
||||
|
||||
// createICEConfig creates ICE configuration for non-WASM environments
|
||||
func (e *Engine) createICEConfig() icemaker.Config {
|
||||
return icemaker.Config{
|
||||
StunTurn: &e.stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||
UDPMux: e.udpMux.UDPMuxDefault,
|
||||
UDPMuxSrflx: e.udpMux,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
}
|
||||
}
|
||||
24
client/internal/engine_js.go
Normal file
24
client/internal/engine_js.go
Normal file
@@ -0,0 +1,24 @@
|
||||
//go:build js
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
)
|
||||
|
||||
// createICEConfig creates ICE configuration for WASM environment.
|
||||
func (e *Engine) createICEConfig() icemaker.Config {
|
||||
cfg := icemaker.Config{
|
||||
StunTurn: &e.stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
}
|
||||
|
||||
if e.udpMux != nil {
|
||||
cfg.UDPMux = e.udpMux.UDPMuxDefault
|
||||
cfg.UDPMuxSrflx = e.udpMux
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
12
client/internal/networkmonitor/check_change_js.go
Normal file
12
client/internal/networkmonitor/check_change_js.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
// No-op for WASM - network changes don't apply
|
||||
return nil
|
||||
}
|
||||
@@ -455,6 +455,14 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if we already have a relay proxy configured
|
||||
if conn.wgProxyRelay != nil {
|
||||
conn.Log.Debugf("Relay proxy already configured, skipping duplicate setup")
|
||||
// Update status to ensure it's connected
|
||||
conn.statusRelay.SetConnected()
|
||||
return
|
||||
}
|
||||
|
||||
conn.dumpState.RelayConnected()
|
||||
conn.Log.Debugf("Relay connection has been established, setup the WireGuard")
|
||||
|
||||
@@ -472,19 +480,42 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
if conn.isICEActive() {
|
||||
conn.Log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
|
||||
conn.setRelayedProxy(wgProxy)
|
||||
// For WASM, we still need to start the proxy and configure WireGuard
|
||||
// because ICE doesn't actually work in browsers and we rely on relay
|
||||
conn.Log.Infof("WASM check: runtime.GOOS=%s, should start proxy=%v", runtime.GOOS, runtime.GOOS == "js")
|
||||
if runtime.GOOS == "js" {
|
||||
conn.Log.Infof("WASM: starting relay proxy and configuring WireGuard despite ICE being 'active'")
|
||||
wgProxy.Work()
|
||||
|
||||
// Configure WireGuard to use the relay proxy endpoint
|
||||
endpointAddr := wgProxy.EndpointAddr()
|
||||
conn.Log.Infof("WASM: Configuring WireGuard endpoint to proxy address: %v", endpointAddr)
|
||||
if err := conn.configureWGEndpoint(endpointAddr, rci.rosenpassPubKey); err != nil {
|
||||
conn.Log.Errorf("WASM: Failed to update WireGuard peer configuration: %v", err)
|
||||
} else {
|
||||
conn.Log.Infof("WASM: Successfully configured WireGuard endpoint to use proxy at %v", endpointAddr)
|
||||
}
|
||||
|
||||
// Update connection priority to relay for WASM
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||
}
|
||||
conn.statusRelay.SetConnected()
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
return
|
||||
}
|
||||
|
||||
wgProxy.Work()
|
||||
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
|
||||
endpointAddr := wgProxy.EndpointAddr()
|
||||
conn.Log.Infof("Configuring WireGuard endpoint to proxy address: %v", endpointAddr)
|
||||
if err := conn.configureWGEndpoint(endpointAddr, rci.rosenpassPubKey); err != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
||||
}
|
||||
conn.Log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
||||
return
|
||||
}
|
||||
conn.Log.Infof("Successfully configured WireGuard endpoint to use proxy at %v", endpointAddr)
|
||||
|
||||
conn.wgWatcherWg.Add(1)
|
||||
go func() {
|
||||
@@ -663,6 +694,21 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
}
|
||||
}()
|
||||
|
||||
// In WASM with forced relay, ICE is not used, so skip ICE check
|
||||
if runtime.GOOS == "js" && os.Getenv("NB_FORCE_RELAY") == "true" {
|
||||
// Only check relay connection status
|
||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||
relayConnected := conn.statusRelay.Get() != worker.StatusDisconnected
|
||||
if !relayConnected {
|
||||
conn.Log.Tracef("WASM: relay not connected for connectivity check")
|
||||
}
|
||||
return relayConnected
|
||||
}
|
||||
// If relay is not supported, consider it connected to avoid reconnect loop
|
||||
conn.Log.Tracef("WASM: relay not supported, returning true to avoid reconnect")
|
||||
return true
|
||||
}
|
||||
|
||||
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -55,6 +54,22 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
}
|
||||
w.relaySupportedOnRemotePeer.Store(true)
|
||||
|
||||
// Check if we already have an active relay connection
|
||||
w.relayLock.Lock()
|
||||
existingConn := w.relayedConn
|
||||
w.relayLock.Unlock()
|
||||
|
||||
if existingConn != nil {
|
||||
w.log.Debugf("relay connection already exists for peer %s, reusing it", w.config.Key)
|
||||
// Connection exists, just ensure proxy is set up if needed
|
||||
go w.conn.onRelayConnectionIsReady(RelayConnInfo{
|
||||
relayedConn: existingConn,
|
||||
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
|
||||
rosenpassAddr: remoteOfferAnswer.RosenpassAddr,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// the relayManager will return with error in case if the connection has lost with relay server
|
||||
currentRelayAddress, err := w.relayManager.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
@@ -66,15 +81,24 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
|
||||
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key)
|
||||
if err != nil {
|
||||
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
|
||||
w.log.Debugf("handled offer by reusing existing relay connection")
|
||||
return
|
||||
}
|
||||
// The relay manager never actually returns ErrConnAlreadyExists - it returns
|
||||
// the existing connection with nil error. This error handling is for other failures.
|
||||
w.log.Errorf("failed to open connection via Relay: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
w.relayLock.Lock()
|
||||
// Check if we already stored this connection (might happen if OpenConn returned existing)
|
||||
if w.relayedConn != nil && w.relayedConn == relayedConn {
|
||||
w.relayLock.Unlock()
|
||||
w.log.Debugf("OpenConn returned the same connection we already have for peer %s", w.config.Key)
|
||||
go w.conn.onRelayConnectionIsReady(RelayConnInfo{
|
||||
relayedConn: relayedConn,
|
||||
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
|
||||
rosenpassAddr: remoteOfferAnswer.RosenpassAddr,
|
||||
})
|
||||
return
|
||||
}
|
||||
w.relayedConn = relayedConn
|
||||
w.relayLock.Unlock()
|
||||
|
||||
@@ -123,11 +147,16 @@ func (w *WorkerRelay) CloseConn() {
|
||||
if err := w.relayedConn.Close(); err != nil {
|
||||
w.log.Warnf("failed to close relay connection: %v", err)
|
||||
}
|
||||
// Clear the stored connection to allow reopening
|
||||
w.relayedConn = nil
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) onWGDisconnected() {
|
||||
w.relayLock.Lock()
|
||||
_ = w.relayedConn.Close()
|
||||
if w.relayedConn != nil {
|
||||
_ = w.relayedConn.Close()
|
||||
w.relayedConn = nil
|
||||
}
|
||||
w.relayLock.Unlock()
|
||||
|
||||
w.conn.onRelayDisconnected()
|
||||
@@ -148,6 +177,11 @@ func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress st
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) onRelayClientDisconnected() {
|
||||
// Clear the stored connection when relay disconnects
|
||||
w.relayLock.Lock()
|
||||
w.relayedConn = nil
|
||||
w.relayLock.Unlock()
|
||||
|
||||
w.wgWatcher.DisableWgWatcher()
|
||||
go w.conn.onRelayDisconnected()
|
||||
}
|
||||
|
||||
48
client/internal/routemanager/systemops/systemops_js.go
Normal file
48
client/internal/routemanager/systemops/systemops_js.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
var ErrRouteNotSupported = errors.New("route operations not supported on js")
|
||||
|
||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return ErrRouteNotSupported
|
||||
}
|
||||
|
||||
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return ErrRouteNotSupported
|
||||
}
|
||||
|
||||
func GetRoutesFromTable() ([]netip.Prefix, error) {
|
||||
return []netip.Prefix{}, nil
|
||||
}
|
||||
|
||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||
return []netip.Prefix{}, nil
|
||||
}
|
||||
|
||||
// GetDetailedRoutesFromTable returns empty routes for WASM.
|
||||
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
|
||||
return []DetailedRoute{}, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return ErrRouteNotSupported
|
||||
}
|
||||
|
||||
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return ErrRouteNotSupported
|
||||
}
|
||||
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux && !ios
|
||||
//go:build !linux && !ios && !js
|
||||
|
||||
package systemops
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import "context"
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
|
||||
137
client/ssh/ssh_js.go
Normal file
137
client/ssh/ssh_js.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var ErrSSHNotSupported = errors.New("SSH is not supported in WASM environment")
|
||||
|
||||
// Server is a dummy SSH server interface for WASM.
|
||||
type Server interface {
|
||||
Start() error
|
||||
Stop() error
|
||||
EnableSSH(enabled bool)
|
||||
AddAuthorizedKey(peer string, key string) error
|
||||
RemoveAuthorizedKey(key string)
|
||||
}
|
||||
|
||||
type dummyServer struct{}
|
||||
|
||||
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
|
||||
return &dummyServer{}, nil
|
||||
}
|
||||
|
||||
func NewServer(addr string) Server {
|
||||
return &dummyServer{}
|
||||
}
|
||||
|
||||
func (s *dummyServer) Start() error {
|
||||
return ErrSSHNotSupported
|
||||
}
|
||||
|
||||
func (s *dummyServer) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dummyServer) EnableSSH(enabled bool) {
|
||||
}
|
||||
|
||||
func (s *dummyServer) AddAuthorizedKey(peer string, key string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dummyServer) RemoveAuthorizedKey(key string) {
|
||||
}
|
||||
|
||||
type Client struct{}
|
||||
|
||||
func NewClient(ctx context.Context, addr string, config interface{}, recorder *SessionRecorder) (*Client, error) {
|
||||
return nil, ErrSSHNotSupported
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Run(command []string) error {
|
||||
return ErrSSHNotSupported
|
||||
}
|
||||
|
||||
type SessionRecorder struct{}
|
||||
|
||||
func NewSessionRecorder() *SessionRecorder {
|
||||
return &SessionRecorder{}
|
||||
}
|
||||
|
||||
func (r *SessionRecorder) Record(session string, data []byte) {
|
||||
}
|
||||
|
||||
func GetUserShell() string {
|
||||
return "/bin/sh"
|
||||
}
|
||||
|
||||
func LookupUserInfo(username string) (string, string, error) {
|
||||
return "", "", ErrSSHNotSupported
|
||||
}
|
||||
|
||||
const DefaultSSHPort = 44338
|
||||
|
||||
const ED25519 = "ed25519"
|
||||
|
||||
func isRoot() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func GeneratePrivateKey(keyType string) ([]byte, error) {
|
||||
if keyType != ED25519 {
|
||||
return nil, errors.New("only ED25519 keys are supported in WASM")
|
||||
}
|
||||
|
||||
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pkcs8Bytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pemBlock := &pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: pkcs8Bytes,
|
||||
}
|
||||
|
||||
pemBytes := pem.EncodeToMemory(pemBlock)
|
||||
return pemBytes, nil
|
||||
}
|
||||
|
||||
func GeneratePublicKey(privateKey []byte) ([]byte, error) {
|
||||
signer, err := ssh.ParsePrivateKey(privateKey)
|
||||
if err != nil {
|
||||
block, _ := pem.Decode(privateKey)
|
||||
if block != nil {
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
signer, err = ssh.NewSignerFromKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
pubKeyBytes := ssh.MarshalAuthorizedKey(signer.PublicKey())
|
||||
return []byte(strings.TrimSpace(string(pubKeyBytes))), nil
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
|
||||
234
client/system/info_js.go
Normal file
234
client/system/info_js.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall/js"
|
||||
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// GetInfo retrieves system information for WASM environment
|
||||
func GetInfo(_ context.Context) *Info {
|
||||
info := &Info{
|
||||
GoOS: runtime.GOOS,
|
||||
Kernel: runtime.GOARCH,
|
||||
KernelVersion: runtime.GOARCH,
|
||||
Platform: runtime.GOARCH,
|
||||
OS: runtime.GOARCH,
|
||||
Hostname: "wasm-client",
|
||||
CPUs: runtime.NumCPU(),
|
||||
NetbirdVersion: version.NetbirdVersion(),
|
||||
}
|
||||
|
||||
collectBrowserInfo(info)
|
||||
collectLocationInfo(info)
|
||||
|
||||
si := updateStaticInfo()
|
||||
info.SystemSerialNumber = si.SystemSerialNumber
|
||||
info.SystemProductName = si.SystemProductName
|
||||
info.SystemManufacturer = si.SystemManufacturer
|
||||
info.Environment = si.Environment
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
func collectBrowserInfo(info *Info) {
|
||||
navigator := js.Global().Get("navigator")
|
||||
if navigator.IsUndefined() {
|
||||
return
|
||||
}
|
||||
|
||||
collectUserAgent(info, navigator)
|
||||
collectPlatform(info, navigator)
|
||||
collectCPUInfo(info, navigator)
|
||||
}
|
||||
|
||||
func collectUserAgent(info *Info, navigator js.Value) {
|
||||
ua := navigator.Get("userAgent")
|
||||
if ua.IsUndefined() {
|
||||
return
|
||||
}
|
||||
|
||||
userAgent := ua.String()
|
||||
os, osVersion := parseOSFromUserAgent(userAgent)
|
||||
if os != "" {
|
||||
info.OS = os
|
||||
}
|
||||
if osVersion != "" {
|
||||
info.OSVersion = osVersion
|
||||
}
|
||||
}
|
||||
|
||||
func collectPlatform(info *Info, navigator js.Value) {
|
||||
// Try regular platform property
|
||||
if plat := navigator.Get("platform"); !plat.IsUndefined() {
|
||||
if platStr := plat.String(); platStr != "" {
|
||||
info.Platform = platStr
|
||||
}
|
||||
}
|
||||
|
||||
// Try newer userAgentData API for more accurate platform
|
||||
userAgentData := navigator.Get("userAgentData")
|
||||
if userAgentData.IsUndefined() {
|
||||
return
|
||||
}
|
||||
|
||||
platformInfo := userAgentData.Get("platform")
|
||||
if !platformInfo.IsUndefined() {
|
||||
if platStr := platformInfo.String(); platStr != "" {
|
||||
info.Platform = platStr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collectCPUInfo(info *Info, navigator js.Value) {
|
||||
hardwareConcurrency := navigator.Get("hardwareConcurrency")
|
||||
if !hardwareConcurrency.IsUndefined() {
|
||||
info.CPUs = hardwareConcurrency.Int()
|
||||
}
|
||||
}
|
||||
|
||||
func collectLocationInfo(info *Info) {
|
||||
location := js.Global().Get("location")
|
||||
if location.IsUndefined() {
|
||||
return
|
||||
}
|
||||
|
||||
if host := location.Get("hostname"); !host.IsUndefined() {
|
||||
hostnameStr := host.String()
|
||||
if hostnameStr != "" && hostnameStr != "localhost" {
|
||||
info.Hostname = hostnameStr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkFileAndProcess(_ []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
}
|
||||
|
||||
func updateStaticInfo() *StaticInfo {
|
||||
si := &StaticInfo{}
|
||||
|
||||
navigator := js.Global().Get("navigator")
|
||||
if !navigator.IsUndefined() {
|
||||
if vendor := navigator.Get("vendor"); !vendor.IsUndefined() {
|
||||
si.SystemManufacturer = vendor.String()
|
||||
}
|
||||
|
||||
if product := navigator.Get("product"); !product.IsUndefined() {
|
||||
si.SystemProductName = product.String()
|
||||
}
|
||||
|
||||
if userAgent := navigator.Get("userAgent"); !userAgent.IsUndefined() {
|
||||
ua := userAgent.String()
|
||||
si.Environment = detectEnvironmentFromUA(ua)
|
||||
}
|
||||
}
|
||||
|
||||
return si
|
||||
}
|
||||
|
||||
func parseOSFromUserAgent(userAgent string) (string, string) {
|
||||
if userAgent == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.Contains(userAgent, "Windows NT"):
|
||||
return parseWindowsVersion(userAgent)
|
||||
case strings.Contains(userAgent, "Mac OS X"):
|
||||
return parseMacOSVersion(userAgent)
|
||||
case strings.Contains(userAgent, "FreeBSD"):
|
||||
return "FreeBSD", ""
|
||||
case strings.Contains(userAgent, "OpenBSD"):
|
||||
return "OpenBSD", ""
|
||||
case strings.Contains(userAgent, "NetBSD"):
|
||||
return "NetBSD", ""
|
||||
case strings.Contains(userAgent, "Linux"):
|
||||
return parseLinuxVersion(userAgent)
|
||||
case strings.Contains(userAgent, "iPhone") || strings.Contains(userAgent, "iPad"):
|
||||
return parseiOSVersion(userAgent)
|
||||
case strings.Contains(userAgent, "CrOS"):
|
||||
return "ChromeOS", ""
|
||||
default:
|
||||
return "", ""
|
||||
}
|
||||
}
|
||||
|
||||
func parseWindowsVersion(userAgent string) (string, string) {
|
||||
switch {
|
||||
case strings.Contains(userAgent, "Windows NT 10.0; Win64; x64"):
|
||||
return "Windows", "10/11"
|
||||
case strings.Contains(userAgent, "Windows NT 10.0"):
|
||||
return "Windows", "10"
|
||||
case strings.Contains(userAgent, "Windows NT 6.3"):
|
||||
return "Windows", "8.1"
|
||||
case strings.Contains(userAgent, "Windows NT 6.2"):
|
||||
return "Windows", "8"
|
||||
case strings.Contains(userAgent, "Windows NT 6.1"):
|
||||
return "Windows", "7"
|
||||
default:
|
||||
return "Windows", "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func parseMacOSVersion(userAgent string) (string, string) {
|
||||
idx := strings.Index(userAgent, "Mac OS X ")
|
||||
if idx == -1 {
|
||||
return "macOS", "Unknown"
|
||||
}
|
||||
|
||||
versionStart := idx + len("Mac OS X ")
|
||||
versionEnd := strings.Index(userAgent[versionStart:], ")")
|
||||
if versionEnd <= 0 {
|
||||
return "macOS", "Unknown"
|
||||
}
|
||||
|
||||
ver := userAgent[versionStart : versionStart+versionEnd]
|
||||
ver = strings.ReplaceAll(ver, "_", ".")
|
||||
return "macOS", ver
|
||||
}
|
||||
|
||||
func parseLinuxVersion(userAgent string) (string, string) {
|
||||
if strings.Contains(userAgent, "Android") {
|
||||
return "Android", extractAndroidVersion(userAgent)
|
||||
}
|
||||
if strings.Contains(userAgent, "Ubuntu") {
|
||||
return "Ubuntu", ""
|
||||
}
|
||||
return "Linux", ""
|
||||
}
|
||||
|
||||
func parseiOSVersion(userAgent string) (string, string) {
|
||||
idx := strings.Index(userAgent, "OS ")
|
||||
if idx == -1 {
|
||||
return "iOS", "Unknown"
|
||||
}
|
||||
|
||||
versionStart := idx + 3
|
||||
versionEnd := strings.Index(userAgent[versionStart:], " ")
|
||||
if versionEnd <= 0 {
|
||||
return "iOS", "Unknown"
|
||||
}
|
||||
|
||||
ver := userAgent[versionStart : versionStart+versionEnd]
|
||||
ver = strings.ReplaceAll(ver, "_", ".")
|
||||
return "iOS", ver
|
||||
}
|
||||
|
||||
func extractAndroidVersion(userAgent string) string {
|
||||
if idx := strings.Index(userAgent, "Android "); idx != -1 {
|
||||
versionStart := idx + len("Android ")
|
||||
versionEnd := strings.IndexAny(userAgent[versionStart:], ";)")
|
||||
if versionEnd > 0 {
|
||||
return userAgent[versionStart : versionStart+versionEnd]
|
||||
}
|
||||
}
|
||||
return "Unknown"
|
||||
}
|
||||
|
||||
func detectEnvironmentFromUA(_ string) Environment {
|
||||
return Environment{}
|
||||
}
|
||||
236
client/wasm/cmd/main.go
Normal file
236
client/wasm/cmd/main.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
clientStartTimeout = 30 * time.Second
|
||||
clientStopTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
func main() {
|
||||
js.Global().Set("NetBirdClient", js.FuncOf(netBirdClientConstructor))
|
||||
|
||||
select {}
|
||||
}
|
||||
|
||||
func startClient(ctx context.Context, nbClient *netbird.Client) error {
|
||||
log.Println("Starting NetBird client...")
|
||||
if err := nbClient.Start(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Println("NetBird client started successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseClientOptions extracts NetBird options from JavaScript object
|
||||
func parseClientOptions(jsOptions js.Value) (netbird.Options, error) {
|
||||
options := netbird.Options{
|
||||
DeviceName: "dashboard-client",
|
||||
LogLevel: "warn",
|
||||
}
|
||||
|
||||
if jwtToken := jsOptions.Get("jwtToken"); !jwtToken.IsNull() && !jwtToken.IsUndefined() {
|
||||
options.JWTToken = jwtToken.String()
|
||||
}
|
||||
|
||||
if setupKey := jsOptions.Get("setupKey"); !setupKey.IsNull() && !setupKey.IsUndefined() {
|
||||
options.SetupKey = setupKey.String()
|
||||
}
|
||||
|
||||
if options.JWTToken == "" && options.SetupKey == "" {
|
||||
return options, fmt.Errorf("either jwtToken or setupKey must be provided")
|
||||
}
|
||||
|
||||
if mgmtURL := jsOptions.Get("managementURL"); !mgmtURL.IsNull() && !mgmtURL.IsUndefined() {
|
||||
mgmtURLStr := mgmtURL.String()
|
||||
if mgmtURLStr != "" {
|
||||
options.ManagementURL = mgmtURLStr
|
||||
}
|
||||
}
|
||||
|
||||
if logLevel := jsOptions.Get("logLevel"); !logLevel.IsNull() && !logLevel.IsUndefined() {
|
||||
options.LogLevel = logLevel.String()
|
||||
}
|
||||
|
||||
if deviceName := jsOptions.Get("deviceName"); !deviceName.IsNull() && !deviceName.IsUndefined() {
|
||||
options.DeviceName = deviceName.String()
|
||||
}
|
||||
|
||||
return options, nil
|
||||
}
|
||||
|
||||
// createStartMethod creates the start method for the client
|
||||
func createStartMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), clientStartTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := startClient(ctx, client); err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
resolve.Invoke(js.ValueOf(true))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createStopMethod creates the stop method for the client
|
||||
func createStopMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), clientStopTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
log.Printf("Error stopping client: %v", err)
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("NetBird client stopped")
|
||||
resolve.Invoke(js.ValueOf(true))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createSSHMethod creates the SSH connection method
|
||||
func createSSHMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf("error: requires host and port")
|
||||
}
|
||||
|
||||
host := args[0].String()
|
||||
port := args[1].Int()
|
||||
username := "root"
|
||||
if len(args) > 2 && args[2].String() != "" {
|
||||
username = args[2].String()
|
||||
}
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
sshClient := ssh.NewClient(client)
|
||||
|
||||
if err := sshClient.Connect(host, port, username); err != nil {
|
||||
reject.Invoke(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := sshClient.StartSession(80, 24); err != nil {
|
||||
if closeErr := sshClient.Close(); closeErr != nil {
|
||||
log.Printf("Error closing SSH client: %v", closeErr)
|
||||
}
|
||||
reject.Invoke(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jsInterface := ssh.CreateJSInterface(sshClient)
|
||||
resolve.Invoke(jsInterface)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createProxyRequestMethod creates the proxyRequest method
|
||||
func createProxyRequestMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: request details required")
|
||||
}
|
||||
|
||||
request := args[0]
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
response, err := http.ProxyRequest(client, request)
|
||||
if err != nil {
|
||||
reject.Invoke(err.Error())
|
||||
return
|
||||
}
|
||||
resolve.Invoke(response)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createRDPProxyMethod creates the RDP proxy method
|
||||
func createRDPProxyMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf("error: hostname and port required")
|
||||
}
|
||||
|
||||
proxy := rdp.NewRDCleanPathProxy(client)
|
||||
return proxy.CreateProxy(args[0].String(), args[1].String())
|
||||
})
|
||||
}
|
||||
|
||||
// createPromise is a helper to create JavaScript promises
|
||||
func createPromise(handler func(resolve, reject js.Value)) js.Value {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||
resolve := promiseArgs[0]
|
||||
reject := promiseArgs[1]
|
||||
|
||||
go handler(resolve, reject)
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
// createClientObject wraps the NetBird client in a JavaScript object
|
||||
func createClientObject(client *netbird.Client) js.Value {
|
||||
obj := make(map[string]interface{})
|
||||
|
||||
obj["start"] = createStartMethod(client)
|
||||
obj["stop"] = createStopMethod(client)
|
||||
obj["createSSHConnection"] = createSSHMethod(client)
|
||||
obj["proxyRequest"] = createProxyRequestMethod(client)
|
||||
obj["createRDPProxy"] = createRDPProxyMethod(client)
|
||||
|
||||
return js.ValueOf(obj)
|
||||
}
|
||||
|
||||
// netBirdClientConstructor acts as a JavaScript constructor function
|
||||
func netBirdClientConstructor(this js.Value, args []js.Value) any {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(this js.Value, promiseArgs []js.Value) any {
|
||||
resolve := promiseArgs[0]
|
||||
reject := promiseArgs[1]
|
||||
|
||||
if len(args) < 1 {
|
||||
reject.Invoke(js.ValueOf("Options object required"))
|
||||
return nil
|
||||
}
|
||||
|
||||
go func() {
|
||||
options, err := parseClientOptions(args[0])
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Creating NetBird client with options: deviceName=%s, hasJWT=%v, hasSetupKey=%v, mgmtURL=%s",
|
||||
options.DeviceName, options.JWTToken != "", options.SetupKey != "", options.ManagementURL)
|
||||
|
||||
client, err := netbird.New(options)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(fmt.Sprintf("create client: %v", err)))
|
||||
return
|
||||
}
|
||||
|
||||
clientObj := createClientObject(client)
|
||||
log.Println("NetBird client created successfully")
|
||||
resolve.Invoke(clientObj)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
98
client/wasm/internal/http/http.go
Normal file
98
client/wasm/internal/http/http.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
)
|
||||
|
||||
const (
|
||||
httpTimeout = 30 * time.Second
|
||||
maxResponseSize = 1024 * 1024 // 1MB
|
||||
)
|
||||
|
||||
// performRequest executes an HTTP request through NetBird and returns the response and body
|
||||
func performRequest(nbClient *netbird.Client, method, url string, headers map[string]string, body []byte) (*http.Response, []byte, error) {
|
||||
httpClient := nbClient.NewHTTPClient()
|
||||
httpClient.Timeout = httpTimeout
|
||||
|
||||
req, err := http.NewRequest(method, url, strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
for key, value := range headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
log.Printf("failed to close response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
return resp, respBody, nil
|
||||
}
|
||||
|
||||
// ProxyRequest performs a proxied HTTP request through NetBird and returns a JavaScript object
|
||||
func ProxyRequest(nbClient *netbird.Client, request js.Value) (js.Value, error) {
|
||||
url := request.Get("url").String()
|
||||
if url == "" {
|
||||
return js.Undefined(), fmt.Errorf("URL is required")
|
||||
}
|
||||
|
||||
method := "GET"
|
||||
if methodVal := request.Get("method"); !methodVal.IsNull() && !methodVal.IsUndefined() {
|
||||
method = strings.ToUpper(methodVal.String())
|
||||
}
|
||||
|
||||
var requestBody []byte
|
||||
if bodyVal := request.Get("body"); !bodyVal.IsNull() && !bodyVal.IsUndefined() {
|
||||
requestBody = []byte(bodyVal.String())
|
||||
}
|
||||
|
||||
requestHeaders := make(map[string]string)
|
||||
if headersVal := request.Get("headers"); !headersVal.IsNull() && !headersVal.IsUndefined() && headersVal.Type() == js.TypeObject {
|
||||
headerKeys := js.Global().Get("Object").Call("keys", headersVal)
|
||||
for i := 0; i < headerKeys.Length(); i++ {
|
||||
key := headerKeys.Index(i).String()
|
||||
value := headersVal.Get(key).String()
|
||||
requestHeaders[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
resp, body, err := performRequest(nbClient, method, url, requestHeaders, requestBody)
|
||||
if err != nil {
|
||||
return js.Undefined(), err
|
||||
}
|
||||
|
||||
result := js.Global().Get("Object").New()
|
||||
result.Set("status", resp.StatusCode)
|
||||
result.Set("statusText", resp.Status)
|
||||
result.Set("body", string(body))
|
||||
|
||||
headers := js.Global().Get("Object").New()
|
||||
for key, values := range resp.Header {
|
||||
if len(values) > 0 {
|
||||
headers.Set(strings.ToLower(key), values[0])
|
||||
}
|
||||
}
|
||||
result.Set("headers", headers)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
94
client/wasm/internal/rdp/cert_validation.go
Normal file
94
client/wasm/internal/rdp/cert_validation.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package rdp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
certValidationTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
|
||||
if !conn.wsHandlers.Get("onCertificateRequest").Truthy() {
|
||||
return false, fmt.Errorf("certificate validation handler not configured")
|
||||
}
|
||||
|
||||
certInfo := js.Global().Get("Object").New()
|
||||
certInfo.Set("ServerAddr", conn.destination)
|
||||
|
||||
certArray := js.Global().Get("Array").New()
|
||||
for i, certBytes := range certChain {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(certBytes))
|
||||
js.CopyBytesToJS(uint8Array, certBytes)
|
||||
certArray.SetIndex(i, uint8Array)
|
||||
}
|
||||
certInfo.Set("ServerCertChain", certArray)
|
||||
if len(certChain) > 0 {
|
||||
cert, err := x509.ParseCertificate(certChain[0])
|
||||
if err == nil {
|
||||
info := js.Global().Get("Object").New()
|
||||
info.Set("subject", cert.Subject.String())
|
||||
info.Set("issuer", cert.Issuer.String())
|
||||
info.Set("validFrom", cert.NotBefore.Format(time.RFC3339))
|
||||
info.Set("validTo", cert.NotAfter.Format(time.RFC3339))
|
||||
info.Set("serialNumber", cert.SerialNumber.String())
|
||||
certInfo.Set("CertificateInfo", info)
|
||||
}
|
||||
}
|
||||
|
||||
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
|
||||
|
||||
resultChan := make(chan bool)
|
||||
errorChan := make(chan error)
|
||||
|
||||
promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
result := args[0].Bool()
|
||||
resultChan <- result
|
||||
return nil
|
||||
})).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
errorChan <- fmt.Errorf("certificate validation failed")
|
||||
return nil
|
||||
}))
|
||||
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
if result {
|
||||
log.Info("Certificate accepted by user")
|
||||
} else {
|
||||
log.Info("Certificate rejected by user")
|
||||
}
|
||||
return result, nil
|
||||
case err := <-errorChan:
|
||||
return false, err
|
||||
case <-time.After(certValidationTimeout):
|
||||
return false, fmt.Errorf("certificate validation timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config {
|
||||
return &tls.Config{
|
||||
InsecureSkipVerify: true, // We'll validate manually after handshake
|
||||
VerifyConnection: func(cs tls.ConnectionState) error {
|
||||
var certChain [][]byte
|
||||
for _, cert := range cs.PeerCertificates {
|
||||
certChain = append(certChain, cert.Raw)
|
||||
}
|
||||
|
||||
accepted, err := p.validateCertificateWithJS(conn, certChain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !accepted {
|
||||
return fmt.Errorf("certificate rejected by user")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
269
client/wasm/internal/rdp/rdcleanpath.go
Normal file
269
client/wasm/internal/rdp/rdcleanpath.go
Normal file
@@ -0,0 +1,269 @@
|
||||
package rdp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
RDCleanPathVersion = 3390
|
||||
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
|
||||
RDCleanPathProxyScheme = "ws"
|
||||
)
|
||||
|
||||
type RDCleanPathPDU struct {
|
||||
Version int64 `asn1:"tag:0,explicit"`
|
||||
Error []byte `asn1:"tag:1,explicit,optional"`
|
||||
Destination string `asn1:"utf8,tag:2,explicit,optional"`
|
||||
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
|
||||
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
|
||||
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
|
||||
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
|
||||
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
|
||||
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
|
||||
}
|
||||
|
||||
type RDCleanPathProxy struct {
|
||||
nbClient interface {
|
||||
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
activeConnections map[string]*proxyConnection
|
||||
destinations map[string]string
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type proxyConnection struct {
|
||||
id string
|
||||
destination string
|
||||
rdpConn net.Conn
|
||||
tlsConn *tls.Conn
|
||||
wsHandlers js.Value
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewRDCleanPathProxy creates a new RDCleanPath proxy
|
||||
func NewRDCleanPathProxy(client interface {
|
||||
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}) *RDCleanPathProxy {
|
||||
return &RDCleanPathProxy{
|
||||
nbClient: client,
|
||||
activeConnections: make(map[string]*proxyConnection),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateProxy creates a new proxy endpoint for the given destination
|
||||
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
||||
destination := fmt.Sprintf("%s:%s", hostname, port)
|
||||
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
resolve := args[0]
|
||||
|
||||
go func() {
|
||||
proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections))
|
||||
|
||||
p.mu.Lock()
|
||||
if p.destinations == nil {
|
||||
p.destinations = make(map[string]string)
|
||||
}
|
||||
p.destinations[proxyID] = destination
|
||||
p.mu.Unlock()
|
||||
|
||||
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
|
||||
|
||||
// Register the WebSocket handler for this specific proxy
|
||||
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: requires WebSocket argument")
|
||||
}
|
||||
|
||||
ws := args[0]
|
||||
p.HandleWebSocketConnection(ws, proxyID)
|
||||
return nil
|
||||
}))
|
||||
|
||||
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
|
||||
resolve.Invoke(proxyURL)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
// HandleWebSocketConnection handles incoming WebSocket connections from IronRDP
|
||||
func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string) {
|
||||
p.mu.Lock()
|
||||
destination := p.destinations[proxyID]
|
||||
p.mu.Unlock()
|
||||
|
||||
if destination == "" {
|
||||
log.Errorf("No destination found for proxy ID: %s", proxyID)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// Don't defer cancel here - it will be called by cleanupConnection
|
||||
|
||||
conn := &proxyConnection{
|
||||
id: proxyID,
|
||||
destination: destination,
|
||||
wsHandlers: ws,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
p.activeConnections[proxyID] = conn
|
||||
p.mu.Unlock()
|
||||
|
||||
p.setupWebSocketHandlers(ws, conn)
|
||||
|
||||
log.Infof("RDCleanPath proxy WebSocket connection established for %s", proxyID)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
|
||||
ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
data := args[0]
|
||||
go p.handleWebSocketMessage(conn, data)
|
||||
return nil
|
||||
}))
|
||||
|
||||
ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
log.Debug("WebSocket closed by JavaScript")
|
||||
conn.cancel()
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
|
||||
if !data.InstanceOf(js.Global().Get("Uint8Array")) {
|
||||
return
|
||||
}
|
||||
|
||||
length := data.Get("length").Int()
|
||||
bytes := make([]byte, length)
|
||||
js.CopyBytesToGo(bytes, data)
|
||||
|
||||
if conn.rdpConn != nil || conn.tlsConn != nil {
|
||||
p.forwardToRDP(conn, bytes)
|
||||
return
|
||||
}
|
||||
|
||||
var pdu RDCleanPathPDU
|
||||
_, err := asn1.Unmarshal(bytes, &pdu)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to parse RDCleanPath PDU: %v", err)
|
||||
n := len(bytes)
|
||||
if n > 20 {
|
||||
n = 20
|
||||
}
|
||||
log.Warnf("First %d bytes: %x", n, bytes[:n])
|
||||
|
||||
if len(bytes) > 0 && bytes[0] == 0x03 {
|
||||
log.Debug("Received raw RDP packet instead of RDCleanPath PDU")
|
||||
go p.handleDirectRDP(conn, bytes)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
go p.processRDCleanPathPDU(conn, pdu)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) forwardToRDP(conn *proxyConnection, bytes []byte) {
|
||||
var writer io.Writer
|
||||
var connType string
|
||||
|
||||
if conn.tlsConn != nil {
|
||||
writer = conn.tlsConn
|
||||
connType = "TLS"
|
||||
} else if conn.rdpConn != nil {
|
||||
writer = conn.rdpConn
|
||||
connType = "TCP"
|
||||
} else {
|
||||
log.Error("No RDP connection available")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := writer.Write(bytes); err != nil {
|
||||
log.Errorf("Failed to write to %s: %v", connType, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []byte) {
|
||||
defer p.cleanupConnection(conn)
|
||||
|
||||
destination := conn.destination
|
||||
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||
return
|
||||
}
|
||||
conn.rdpConn = rdpConn
|
||||
|
||||
_, err = rdpConn.Write(firstPacket)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write first packet: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, 1024)
|
||||
n, err := rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.sendToWebSocket(conn, response[:n])
|
||||
|
||||
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
|
||||
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
|
||||
log.Debugf("Cleaning up connection %s", conn.id)
|
||||
conn.cancel()
|
||||
if conn.tlsConn != nil {
|
||||
log.Debug("Closing TLS connection")
|
||||
if err := conn.tlsConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TLS connection: %v", err)
|
||||
}
|
||||
conn.tlsConn = nil
|
||||
}
|
||||
if conn.rdpConn != nil {
|
||||
log.Debug("Closing TCP connection")
|
||||
if err := conn.rdpConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TCP connection: %v", err)
|
||||
}
|
||||
conn.rdpConn = nil
|
||||
}
|
||||
p.mu.Lock()
|
||||
delete(p.activeConnections, conn.id)
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
|
||||
if conn.wsHandlers.Get("receiveFromGo").Truthy() {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(data))
|
||||
js.CopyBytesToJS(uint8Array, data)
|
||||
conn.wsHandlers.Call("receiveFromGo", uint8Array.Get("buffer"))
|
||||
} else if conn.wsHandlers.Get("send").Truthy() {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(data))
|
||||
js.CopyBytesToJS(uint8Array, data)
|
||||
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
|
||||
}
|
||||
}
|
||||
249
client/wasm/internal/rdp/rdcleanpath_handlers.go
Normal file
249
client/wasm/internal/rdp/rdcleanpath_handlers.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package rdp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/asn1"
|
||||
"io"
|
||||
"syscall/js"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
|
||||
|
||||
if pdu.Version != RDCleanPathVersion {
|
||||
p.sendRDCleanPathError(conn, "Unsupported version")
|
||||
return
|
||||
}
|
||||
|
||||
destination := conn.destination
|
||||
if pdu.Destination != "" {
|
||||
destination = pdu.Destination
|
||||
}
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||
p.sendRDCleanPathError(conn, "Connection failed")
|
||||
p.cleanupConnection(conn)
|
||||
return
|
||||
}
|
||||
conn.rdpConn = rdpConn
|
||||
|
||||
// RDP always starts with X.224 negotiation, then determines if TLS is needed
|
||||
// Modern RDP (since Windows Vista/2008) typically requires TLS
|
||||
// The X.224 Connection Confirm response will indicate if TLS is required
|
||||
// For now, we'll attempt TLS for all connections as it's the modern default
|
||||
p.setupTLSConnection(conn, pdu)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
var x224Response []byte
|
||||
if len(pdu.X224ConnectionPDU) > 0 {
|
||||
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
|
||||
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write X.224 PDU: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to forward X.224")
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, 1024)
|
||||
n, err := conn.rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
|
||||
return
|
||||
}
|
||||
x224Response = response[:n]
|
||||
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
|
||||
}
|
||||
|
||||
tlsConfig := p.getTLSConfigWithValidation(conn)
|
||||
|
||||
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
|
||||
conn.tlsConn = tlsConn
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
log.Errorf("TLS handshake failed: %v", err)
|
||||
p.sendRDCleanPathError(conn, "TLS handshake failed")
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("TLS handshake successful")
|
||||
|
||||
// Certificate validation happens during handshake via VerifyConnection callback
|
||||
var certChain [][]byte
|
||||
connState := tlsConn.ConnectionState()
|
||||
if len(connState.PeerCertificates) > 0 {
|
||||
for _, cert := range connState.PeerCertificates {
|
||||
certChain = append(certChain, cert.Raw)
|
||||
}
|
||||
log.Debugf("Extracted %d certificates from TLS connection", len(certChain))
|
||||
}
|
||||
|
||||
responsePDU := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
ServerAddr: conn.destination,
|
||||
ServerCertChain: certChain,
|
||||
}
|
||||
|
||||
if len(x224Response) > 0 {
|
||||
responsePDU.X224ConnectionPDU = x224Response
|
||||
}
|
||||
|
||||
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||
|
||||
log.Debug("Starting TLS forwarding")
|
||||
go p.forwardConnToWS(conn, conn.tlsConn, "TLS")
|
||||
go p.forwardWSToConn(conn, conn.tlsConn, "TLS")
|
||||
|
||||
<-conn.ctx.Done()
|
||||
log.Debug("TLS connection context done, cleaning up")
|
||||
p.cleanupConnection(conn)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
if len(pdu.X224ConnectionPDU) > 0 {
|
||||
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
|
||||
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write X.224 PDU: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to forward X.224")
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, 1024)
|
||||
n, err := conn.rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
|
||||
return
|
||||
}
|
||||
|
||||
responsePDU := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
X224ConnectionPDU: response[:n],
|
||||
ServerAddr: conn.destination,
|
||||
}
|
||||
|
||||
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||
} else {
|
||||
responsePDU := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
ServerAddr: conn.destination,
|
||||
}
|
||||
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||
}
|
||||
|
||||
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
|
||||
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
|
||||
|
||||
<-conn.ctx.Done()
|
||||
log.Debug("TCP connection context done, cleaning up")
|
||||
p.cleanupConnection(conn)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal RDCleanPath PDU: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Sending RDCleanPath PDU response (%d bytes)", len(data))
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) {
|
||||
pdu := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
Error: []byte(errorMsg),
|
||||
}
|
||||
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal error PDU: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
|
||||
msgChan := make(chan []byte)
|
||||
errChan := make(chan error)
|
||||
|
||||
handler := js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
if len(args) < 1 {
|
||||
errChan <- io.EOF
|
||||
return nil
|
||||
}
|
||||
|
||||
data := args[0]
|
||||
if data.InstanceOf(js.Global().Get("Uint8Array")) {
|
||||
length := data.Get("length").Int()
|
||||
bytes := make([]byte, length)
|
||||
js.CopyBytesToGo(bytes, data)
|
||||
msgChan <- bytes
|
||||
}
|
||||
return nil
|
||||
})
|
||||
defer handler.Release()
|
||||
|
||||
conn.wsHandlers.Set("onceGoMessage", handler)
|
||||
|
||||
select {
|
||||
case msg := <-msgChan:
|
||||
return msg, nil
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-conn.ctx.Done():
|
||||
return nil, conn.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) forwardWSToConn(conn *proxyConnection, dst io.Writer, connType string) {
|
||||
for {
|
||||
if conn.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := p.readWebSocketMessage(conn)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Errorf("Failed to read from WebSocket: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
_, err = dst.Write(msg)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write to %s: %v", connType, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) forwardConnToWS(conn *proxyConnection, src io.Reader, connType string) {
|
||||
buffer := make([]byte, 32*1024)
|
||||
|
||||
for {
|
||||
if conn.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
n, err := src.Read(buffer)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Errorf("Failed to read from %s: %v", connType, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
p.sendToWebSocket(conn, buffer[:n])
|
||||
}
|
||||
}
|
||||
}
|
||||
211
client/wasm/internal/ssh/client.go
Normal file
211
client/wasm/internal/ssh/client.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
)
|
||||
|
||||
const (
|
||||
sshDialTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
func closeWithLog(c io.Closer, resource string) {
|
||||
if c != nil {
|
||||
if err := c.Close(); err != nil {
|
||||
logrus.Debugf("Failed to close %s: %v", resource, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
nbClient *netbird.Client
|
||||
sshClient *ssh.Client
|
||||
session *ssh.Session
|
||||
stdin io.WriteCloser
|
||||
stdout io.Reader
|
||||
stderr io.Reader
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewClient creates a new SSH client
|
||||
func NewClient(nbClient *netbird.Client) *Client {
|
||||
return &Client{
|
||||
nbClient: nbClient,
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes an SSH connection through NetBird network
|
||||
func (c *Client) Connect(host string, port int, username string) error {
|
||||
addr := fmt.Sprintf("%s:%d", host, port)
|
||||
logrus.Infof("SSH: Connecting to %s as %s", addr, username)
|
||||
|
||||
var authMethods []ssh.AuthMethod
|
||||
|
||||
nbConfig, err := c.nbClient.GetConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get NetBird config: %w", err)
|
||||
}
|
||||
if nbConfig.SSHKey == "" {
|
||||
return fmt.Errorf("no NetBird SSH key available - key should be generated during client initialization")
|
||||
}
|
||||
|
||||
signer, err := parseSSHPrivateKey([]byte(nbConfig.SSHKey))
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse NetBird SSH private key: %w", err)
|
||||
}
|
||||
|
||||
pubKey := signer.PublicKey()
|
||||
logrus.Infof("SSH: Using NetBird key authentication with public key type: %s", pubKey.Type())
|
||||
|
||||
authMethods = append(authMethods, ssh.PublicKeys(signer))
|
||||
|
||||
config := &ssh.ClientConfig{
|
||||
User: username,
|
||||
Auth: authMethods,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: sshDialTimeout,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := c.nbClient.Dial(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial %s: %w", addr, err)
|
||||
}
|
||||
|
||||
sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
|
||||
if err != nil {
|
||||
closeWithLog(conn, "connection after handshake error")
|
||||
return fmt.Errorf("SSH handshake: %w", err)
|
||||
}
|
||||
|
||||
c.sshClient = ssh.NewClient(sshConn, chans, reqs)
|
||||
logrus.Infof("SSH: Connected to %s", addr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartSession starts an SSH session with PTY
|
||||
func (c *Client) StartSession(cols, rows int) error {
|
||||
if c.sshClient == nil {
|
||||
return fmt.Errorf("SSH client not connected")
|
||||
}
|
||||
|
||||
session, err := c.sshClient.NewSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create session: %w", err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.session = session
|
||||
|
||||
modes := ssh.TerminalModes{
|
||||
ssh.ECHO: 1,
|
||||
ssh.TTY_OP_ISPEED: 14400,
|
||||
ssh.TTY_OP_OSPEED: 14400,
|
||||
ssh.VINTR: 3,
|
||||
ssh.VQUIT: 28,
|
||||
ssh.VERASE: 127,
|
||||
}
|
||||
|
||||
if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil {
|
||||
closeWithLog(session, "session after PTY error")
|
||||
return fmt.Errorf("PTY request: %w", err)
|
||||
}
|
||||
|
||||
c.stdin, err = session.StdinPipe()
|
||||
if err != nil {
|
||||
closeWithLog(session, "session after stdin error")
|
||||
return fmt.Errorf("get stdin: %w", err)
|
||||
}
|
||||
|
||||
c.stdout, err = session.StdoutPipe()
|
||||
if err != nil {
|
||||
closeWithLog(session, "session after stdout error")
|
||||
return fmt.Errorf("get stdout: %w", err)
|
||||
}
|
||||
|
||||
c.stderr, err = session.StderrPipe()
|
||||
if err != nil {
|
||||
closeWithLog(session, "session after stderr error")
|
||||
return fmt.Errorf("get stderr: %w", err)
|
||||
}
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
closeWithLog(session, "session after shell error")
|
||||
return fmt.Errorf("start shell: %w", err)
|
||||
}
|
||||
|
||||
logrus.Info("SSH: Session started with PTY")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write sends data to the SSH session
|
||||
func (c *Client) Write(data []byte) (int, error) {
|
||||
c.mu.RLock()
|
||||
stdin := c.stdin
|
||||
c.mu.RUnlock()
|
||||
|
||||
if stdin == nil {
|
||||
return 0, fmt.Errorf("SSH session not started")
|
||||
}
|
||||
return stdin.Write(data)
|
||||
}
|
||||
|
||||
// Read reads data from the SSH session
|
||||
func (c *Client) Read(buffer []byte) (int, error) {
|
||||
c.mu.RLock()
|
||||
stdout := c.stdout
|
||||
c.mu.RUnlock()
|
||||
|
||||
if stdout == nil {
|
||||
return 0, fmt.Errorf("SSH session not started")
|
||||
}
|
||||
return stdout.Read(buffer)
|
||||
}
|
||||
|
||||
// Resize updates the terminal size
|
||||
func (c *Client) Resize(cols, rows int) error {
|
||||
c.mu.RLock()
|
||||
session := c.session
|
||||
c.mu.RUnlock()
|
||||
|
||||
if session == nil {
|
||||
return fmt.Errorf("SSH session not started")
|
||||
}
|
||||
return session.WindowChange(rows, cols)
|
||||
}
|
||||
|
||||
// Close closes the SSH connection
|
||||
func (c *Client) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.session != nil {
|
||||
closeWithLog(c.session, "SSH session")
|
||||
c.session = nil
|
||||
}
|
||||
if c.stdin != nil {
|
||||
closeWithLog(c.stdin, "stdin")
|
||||
c.stdin = nil
|
||||
}
|
||||
c.stdout = nil
|
||||
c.stderr = nil
|
||||
|
||||
if c.sshClient != nil {
|
||||
err := c.sshClient.Close()
|
||||
c.sshClient = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
76
client/wasm/internal/ssh/handlers.go
Normal file
76
client/wasm/internal/ssh/handlers.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"io"
|
||||
"syscall/js"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// CreateJSInterface creates a JavaScript interface for the SSH client
|
||||
func CreateJSInterface(client *Client) js.Value {
|
||||
jsInterface := js.Global().Get("Object").Call("create", js.Null())
|
||||
|
||||
jsInterface.Set("write", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
|
||||
data := args[0]
|
||||
var bytes []byte
|
||||
|
||||
if data.Type() == js.TypeString {
|
||||
bytes = []byte(data.String())
|
||||
} else {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(data)
|
||||
length := uint8Array.Get("length").Int()
|
||||
bytes = make([]byte, length)
|
||||
js.CopyBytesToGo(bytes, uint8Array)
|
||||
}
|
||||
|
||||
_, err := client.Write(bytes)
|
||||
return js.ValueOf(err == nil)
|
||||
}))
|
||||
|
||||
jsInterface.Set("resize", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
cols := args[0].Int()
|
||||
rows := args[1].Int()
|
||||
err := client.Resize(cols, rows)
|
||||
return js.ValueOf(err == nil)
|
||||
}))
|
||||
|
||||
jsInterface.Set("close", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
client.Close()
|
||||
return js.Undefined()
|
||||
}))
|
||||
|
||||
go readLoop(client, jsInterface)
|
||||
|
||||
return jsInterface
|
||||
}
|
||||
|
||||
func readLoop(client *Client, jsInterface js.Value) {
|
||||
buffer := make([]byte, 4096)
|
||||
for {
|
||||
n, err := client.Read(buffer)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
logrus.Debugf("SSH read error: %v", err)
|
||||
}
|
||||
if onclose := jsInterface.Get("onclose"); !onclose.IsUndefined() {
|
||||
onclose.Invoke()
|
||||
}
|
||||
client.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if ondata := jsInterface.Get("ondata"); !ondata.IsUndefined() {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(n)
|
||||
js.CopyBytesToJS(uint8Array, buffer[:n])
|
||||
ondata.Invoke(uint8Array)
|
||||
}
|
||||
}
|
||||
}
|
||||
48
client/wasm/internal/ssh/key.go
Normal file
48
client/wasm/internal/ssh/key.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// parseSSHPrivateKey parses a private key in either SSH or PKCS8 format
|
||||
func parseSSHPrivateKey(keyPEM []byte) (ssh.Signer, error) {
|
||||
keyStr := string(keyPEM)
|
||||
if !strings.Contains(keyStr, "-----BEGIN") {
|
||||
keyPEM = []byte("-----BEGIN PRIVATE KEY-----\n" + keyStr + "\n-----END PRIVATE KEY-----")
|
||||
}
|
||||
|
||||
signer, err := ssh.ParsePrivateKey(keyPEM)
|
||||
if err == nil {
|
||||
return signer, nil
|
||||
}
|
||||
logrus.Debugf("SSH: Failed to parse as SSH format: %v", err)
|
||||
|
||||
block, _ := pem.Decode(keyPEM)
|
||||
if block == nil {
|
||||
keyPreview := string(keyPEM)
|
||||
if len(keyPreview) > 100 {
|
||||
keyPreview = keyPreview[:100]
|
||||
}
|
||||
return nil, fmt.Errorf("decode PEM block from key: %s", keyPreview)
|
||||
}
|
||||
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
logrus.Debugf("SSH: Failed to parse as PKCS8: %v", err)
|
||||
if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
|
||||
return ssh.NewSignerFromKey(rsaKey)
|
||||
}
|
||||
if ecKey, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
|
||||
return ssh.NewSignerFromKey(ecKey)
|
||||
}
|
||||
return nil, fmt.Errorf("parse private key: %w", err)
|
||||
}
|
||||
|
||||
return ssh.NewSignerFromKey(key)
|
||||
}
|
||||
@@ -38,7 +38,8 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
|
||||
return nil, fmt.Errorf("parsing url: %w", err)
|
||||
}
|
||||
var opts []grpc.DialOption
|
||||
if parsedURL.Scheme == "https" {
|
||||
tlsEnabled := parsedURL.Scheme == "https"
|
||||
if tlsEnabled {
|
||||
certPool, err := x509.SystemCertPool()
|
||||
if err != nil || certPool == nil {
|
||||
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
||||
@@ -53,7 +54,7 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
|
||||
}
|
||||
|
||||
opts = append(opts,
|
||||
nbgrpc.WithCustomDialer(),
|
||||
nbgrpc.WithCustomDialer(tlsEnabled),
|
||||
grpc.WithIdleTimeout(interval*2),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 30 * time.Second,
|
||||
|
||||
2
go.mod
2
go.mod
@@ -38,7 +38,7 @@ require (
|
||||
github.com/c-robinson/iplib v1.0.3
|
||||
github.com/caddyserver/certmagic v0.21.3
|
||||
github.com/cilium/ebpf v0.15.0
|
||||
github.com/coder/websocket v1.8.12
|
||||
github.com/coder/websocket v1.8.13
|
||||
github.com/coreos/go-iptables v0.7.0
|
||||
github.com/creack/pty v1.1.18
|
||||
github.com/eko/gocache/lib/v4 v4.2.0
|
||||
|
||||
4
go.sum
4
go.sum
@@ -140,8 +140,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk
|
||||
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
||||
github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
|
||||
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
|
||||
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
|
||||
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
|
||||
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
|
||||
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
|
||||
github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFquP7hsFII=
|
||||
github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
|
||||
@@ -73,7 +73,12 @@ func (l *Listener) Shutdown(ctx context.Context) error {
|
||||
|
||||
func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
|
||||
connRemoteAddr := remoteAddr(r)
|
||||
wsConn, err := websocket.Accept(w, r, nil)
|
||||
|
||||
acceptOptions := &websocket.AcceptOptions{
|
||||
OriginPatterns: []string{"*"},
|
||||
}
|
||||
|
||||
wsConn, err := websocket.Accept(w, r, acceptOptions)
|
||||
if err != nil {
|
||||
log.Errorf("failed to accept ws connection from %s: %s", connRemoteAddr, err)
|
||||
return
|
||||
|
||||
@@ -223,10 +223,10 @@ func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, erro
|
||||
c.mu.Unlock()
|
||||
return nil, fmt.Errorf("relay connection is not established")
|
||||
}
|
||||
_, ok := c.conns[peerID]
|
||||
existingContainer, ok := c.conns[peerID]
|
||||
if ok {
|
||||
c.mu.Unlock()
|
||||
return nil, ErrConnAlreadyExists
|
||||
return existingContainer.conn, nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
@@ -235,7 +235,6 @@ func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, erro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID)
|
||||
msgChannel := make(chan Msg, 100)
|
||||
|
||||
c.mu.Lock()
|
||||
@@ -249,11 +248,11 @@ func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, erro
|
||||
c.muInstanceURL.Unlock()
|
||||
conn := NewConn(c, peerID, msgChannel, instanceURL)
|
||||
|
||||
_, ok = c.conns[peerID]
|
||||
existingContainer, ok = c.conns[peerID]
|
||||
if ok {
|
||||
c.mu.Unlock()
|
||||
_ = conn.Close()
|
||||
return nil, ErrConnAlreadyExists
|
||||
return existingContainer.conn, nil
|
||||
}
|
||||
c.conns[peerID] = newConnContainer(c.log, conn, msgChannel)
|
||||
c.mu.Unlock()
|
||||
@@ -377,7 +376,6 @@ func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internal
|
||||
buf := *bufPtr
|
||||
n, errExit = relayConn.Read(buf)
|
||||
if errExit != nil {
|
||||
c.log.Infof("start to Relay read loop exit")
|
||||
c.mu.Lock()
|
||||
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
||||
c.log.Errorf("failed to read message from relay server: %s", errExit)
|
||||
@@ -468,12 +466,24 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe
|
||||
c.bufPool.Put(bufPtr)
|
||||
return false
|
||||
}
|
||||
|
||||
container, ok := c.conns[*peerID]
|
||||
c.mu.Unlock()
|
||||
if !ok {
|
||||
c.log.Errorf("peer not found: %s", peerID.String())
|
||||
c.bufPool.Put(bufPtr)
|
||||
return true
|
||||
// Try to create a connection for this peer to handle incoming messages
|
||||
msgChannel := make(chan Msg, 100)
|
||||
c.muInstanceURL.Lock()
|
||||
instanceURL := c.instanceURL
|
||||
c.muInstanceURL.Unlock()
|
||||
conn := NewConn(c, *peerID, msgChannel, instanceURL)
|
||||
|
||||
c.mu.Lock()
|
||||
// Check again if connection was created while we were creating it
|
||||
if _, exists := c.conns[*peerID]; !exists {
|
||||
c.conns[*peerID] = newConnContainer(c.log, conn, msgChannel)
|
||||
}
|
||||
container = c.conns[*peerID]
|
||||
c.mu.Unlock()
|
||||
}
|
||||
msg := Msg{
|
||||
bufPool: c.bufPool,
|
||||
|
||||
@@ -38,8 +38,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
err = c.Conn.Write(c.ctx, websocket.MessageBinary, b)
|
||||
return 0, err
|
||||
return 0, c.Conn.Write(c.ctx, websocket.MessageBinary, b)
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
|
||||
11
shared/relay/client/dialer/ws/dialopts_generic.go
Normal file
11
shared/relay/client/dialer/ws/dialopts_generic.go
Normal file
@@ -0,0 +1,11 @@
|
||||
//go:build !js
|
||||
|
||||
package ws
|
||||
|
||||
import "github.com/coder/websocket"
|
||||
|
||||
func createDialOptions() *websocket.DialOptions {
|
||||
return &websocket.DialOptions{
|
||||
HTTPClient: httpClientNbDialer(),
|
||||
}
|
||||
}
|
||||
10
shared/relay/client/dialer/ws/dialopts_js.go
Normal file
10
shared/relay/client/dialer/ws/dialopts_js.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build js
|
||||
|
||||
package ws
|
||||
|
||||
import "github.com/coder/websocket"
|
||||
|
||||
func createDialOptions() *websocket.DialOptions {
|
||||
// WASM version doesn't support HTTPClient
|
||||
return &websocket.DialOptions{}
|
||||
}
|
||||
@@ -32,9 +32,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts := &websocket.DialOptions{
|
||||
HTTPClient: httpClientNbDialer(),
|
||||
}
|
||||
opts := createDialOptions()
|
||||
|
||||
parsedURL, err := url.Parse(wsURL)
|
||||
if err != nil {
|
||||
|
||||
@@ -4,15 +4,8 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
@@ -21,35 +14,9 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func WithCustomDialer() grpc.DialOption {
|
||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
if runtime.GOOS == "linux" {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
// the custom dialer requires root permissions which are not required for use cases run as non-root
|
||||
if currentUser.Uid != "0" {
|
||||
log.Debug("Not running as root, using standard dialer")
|
||||
dialer := &net.Dialer{}
|
||||
return dialer.DialContext(ctx, "tcp", addr)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to dial: %s", err)
|
||||
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
})
|
||||
}
|
||||
|
||||
// grpcDialBackoff is the backoff mechanism for the grpc calls
|
||||
// Backoff returns a backoff configuration for gRPC calls
|
||||
func Backoff(ctx context.Context) backoff.BackOff {
|
||||
b := backoff.NewExponentialBackOff()
|
||||
b.MaxElapsedTime = 10 * time.Second
|
||||
@@ -57,6 +24,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
||||
return backoff.WithContext(b, ctx)
|
||||
}
|
||||
|
||||
// CreateConnection creates a gRPC client connection with the appropriate transport options
|
||||
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
if tlsEnabled {
|
||||
@@ -78,7 +46,7 @@ func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
||||
connCtx,
|
||||
addr,
|
||||
transportOption,
|
||||
WithCustomDialer(),
|
||||
WithCustomDialer(tlsEnabled),
|
||||
grpc.WithBlock(),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 30 * time.Second,
|
||||
|
||||
44
util/grpc/dialer_generic.go
Normal file
44
util/grpc/dialer_generic.go
Normal file
@@ -0,0 +1,44 @@
|
||||
//go:build !js
|
||||
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/user"
|
||||
"runtime"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func WithCustomDialer(tlsEnabled bool) grpc.DialOption {
|
||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
if runtime.GOOS == "linux" {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
// the custom dialer requires root permissions which are not required for use cases run as non-root
|
||||
if currentUser.Uid != "0" {
|
||||
log.Debug("Not running as root, using standard dialer")
|
||||
dialer := &net.Dialer{}
|
||||
return dialer.DialContext(ctx, "tcp", addr)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to dial: %s", err)
|
||||
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
})
|
||||
}
|
||||
12
util/grpc/dialer_js.go
Normal file
12
util/grpc/dialer_js.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/util/wsproxy/client"
|
||||
)
|
||||
|
||||
// WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments.
|
||||
func WithCustomDialer(tlsEnabled bool) grpc.DialOption {
|
||||
return client.WithWebSocketDialer(tlsEnabled)
|
||||
}
|
||||
8
util/util_js.go
Normal file
8
util/util_js.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build js
|
||||
|
||||
package util
|
||||
|
||||
// IsAdmin returns false for WASM as there's no admin concept in browser
|
||||
func IsAdmin() bool {
|
||||
return false
|
||||
}
|
||||
171
util/wsproxy/client/dialer_js.go
Normal file
171
util/wsproxy/client/dialer_js.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/util/wsproxy"
|
||||
)
|
||||
|
||||
const dialTimeout = 30 * time.Second
|
||||
|
||||
// websocketConn wraps a JavaScript WebSocket to implement net.Conn
|
||||
type websocketConn struct {
|
||||
ws js.Value
|
||||
remoteAddr string
|
||||
messages chan []byte
|
||||
readBuf []byte
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (c *websocketConn) Read(b []byte) (int, error) {
|
||||
c.mu.Lock()
|
||||
if len(c.readBuf) > 0 {
|
||||
n := copy(b, c.readBuf)
|
||||
c.readBuf = c.readBuf[n:]
|
||||
c.mu.Unlock()
|
||||
return n, nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
select {
|
||||
case data := <-c.messages:
|
||||
n := copy(b, data)
|
||||
if n < len(data) {
|
||||
c.mu.Lock()
|
||||
c.readBuf = data[n:]
|
||||
c.mu.Unlock()
|
||||
}
|
||||
return n, nil
|
||||
case <-c.ctx.Done():
|
||||
return 0, c.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *websocketConn) Write(b []byte) (int, error) {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return 0, c.ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(b))
|
||||
js.CopyBytesToJS(uint8Array, b)
|
||||
c.ws.Call("send", uint8Array)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *websocketConn) Close() error {
|
||||
c.cancel()
|
||||
c.ws.Call("close")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *websocketConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *websocketConn) RemoteAddr() net.Addr {
|
||||
return stringAddr(c.remoteAddr)
|
||||
}
|
||||
func (c *websocketConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *websocketConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *websocketConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// stringAddr is a simple net.Addr that returns a string
|
||||
type stringAddr string
|
||||
|
||||
func (s stringAddr) Network() string { return "tcp" }
|
||||
func (s stringAddr) String() string { return string(s) }
|
||||
|
||||
// WithWebSocketDialer returns a gRPC dial option that uses WebSocket transport for JS/WASM environments.
|
||||
func WithWebSocketDialer(tlsEnabled bool) grpc.DialOption {
|
||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
scheme := "wss"
|
||||
if !tlsEnabled {
|
||||
scheme = "ws"
|
||||
}
|
||||
wsURL := fmt.Sprintf("%s://%s%s", scheme, addr, wsproxy.ProxyPath)
|
||||
|
||||
ws := js.Global().Get("WebSocket").New(wsURL)
|
||||
|
||||
connCtx, connCancel := context.WithCancel(context.Background())
|
||||
conn := &websocketConn{
|
||||
ws: ws,
|
||||
remoteAddr: addr,
|
||||
messages: make(chan []byte, 100),
|
||||
ctx: connCtx,
|
||||
cancel: connCancel,
|
||||
}
|
||||
|
||||
ws.Set("binaryType", "arraybuffer")
|
||||
|
||||
openCh := make(chan struct{})
|
||||
errorCh := make(chan error, 1)
|
||||
|
||||
ws.Set("onopen", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
close(openCh)
|
||||
return nil
|
||||
}))
|
||||
|
||||
ws.Set("onerror", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
select {
|
||||
case errorCh <- wsproxy.ErrConnectionFailed:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
}))
|
||||
|
||||
ws.Set("onmessage", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
event := args[0]
|
||||
data := event.Get("data")
|
||||
|
||||
uint8Array := js.Global().Get("Uint8Array").New(data)
|
||||
length := uint8Array.Get("length").Int()
|
||||
bytes := make([]byte, length)
|
||||
js.CopyBytesToGo(bytes, uint8Array)
|
||||
|
||||
select {
|
||||
case conn.messages <- bytes:
|
||||
default:
|
||||
log.Warnf("gRPC WebSocket message dropped for %s - buffer full", addr)
|
||||
}
|
||||
return nil
|
||||
}))
|
||||
|
||||
ws.Set("onclose", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
conn.cancel()
|
||||
return nil
|
||||
}))
|
||||
|
||||
select {
|
||||
case <-openCh:
|
||||
return conn, nil
|
||||
case err := <-errorCh:
|
||||
return nil, err
|
||||
case <-ctx.Done():
|
||||
ws.Call("close")
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(dialTimeout):
|
||||
ws.Call("close")
|
||||
return nil, wsproxy.ErrConnectionTimeout
|
||||
}
|
||||
})
|
||||
}
|
||||
13
util/wsproxy/constants.go
Normal file
13
util/wsproxy/constants.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package wsproxy
|
||||
|
||||
import "errors"
|
||||
|
||||
// ProxyPath is the standard path where the WebSocket proxy is mounted on servers.
|
||||
const ProxyPath = "/ws-proxy"
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
ErrConnectionTimeout = errors.New("WebSocket connection timeout")
|
||||
ErrConnectionFailed = errors.New("WebSocket connection failed")
|
||||
ErrBackendUnavailable = errors.New("backend unavailable")
|
||||
)
|
||||
Reference in New Issue
Block a user