Add wasm client

This commit is contained in:
Viktor Liu
2025-08-16 16:44:16 +02:00
parent dbefa8bd9f
commit 6d99d451d6
55 changed files with 2525 additions and 87 deletions

0
.gitmodules vendored Normal file
View File

View File

@@ -2,6 +2,18 @@ version: 2
project_name: netbird
builds:
- id: netbird-wasm
dir: client/wasm/cmd
binary: netbird.wasm
env: [GOOS=js, GOARCH=wasm, CGO_ENABLED=0]
goos:
- js
goarch:
- wasm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird
dir: client
binary: netbird
@@ -115,6 +127,13 @@ archives:
- builds:
- netbird
- netbird-static
- id: netbird-wasm
builds:
- netbird-wasm
name_template: "{{ .ProjectName }}_wasm_{{ .Version }}"
format: tar.gz
files:
- none*
nfpms:
- maintainer: Netbird <dev@netbird.io>

8
client/cmd/debug_js.go Normal file
View File

@@ -0,0 +1,8 @@
package cmd
import "context"
// SetupDebugHandler is a no-op for WASM
func SetupDebugHandler(context.Context, interface{}, interface{}, interface{}, string) {
// Debug handler not needed for WASM
}

View File

@@ -23,23 +23,27 @@ import (
var ErrClientAlreadyStarted = errors.New("client already started")
var ErrClientNotStarted = errors.New("client not started")
var ErrConfigNotInitialized = errors.New("config not initialized")
// Client manages a netbird embedded client instance
// Client manages a netbird embedded client instance.
type Client struct {
deviceName string
config *profilemanager.Config
mu sync.Mutex
cancel context.CancelFunc
setupKey string
jwtToken string
connect *internal.ConnectClient
}
// Options configures a new Client
// Options configures a new Client.
type Options struct {
// DeviceName is this peer's name in the network
DeviceName string
// SetupKey is used for authentication
SetupKey string
// JWTToken is used for JWT-based authentication
JWTToken string
// ManagementURL overrides the default management server URL
ManagementURL string
// PreSharedKey is the pre-shared key for the WireGuard interface
@@ -58,8 +62,15 @@ type Options struct {
DisableClientRoutes bool
}
// New creates a new netbird embedded client
// New creates a new netbird embedded client.
func New(opts Options) (*Client, error) {
if opts.SetupKey == "" && opts.JWTToken == "" {
return nil, fmt.Errorf("either SetupKey or JWTToken must be provided")
}
if opts.SetupKey != "" && opts.JWTToken != "" {
return nil, fmt.Errorf("cannot specify both SetupKey and JWTToken")
}
if opts.LogOutput != nil {
logrus.SetOutput(opts.LogOutput)
}
@@ -110,6 +121,7 @@ func New(opts Options) (*Client, error) {
return &Client{
deviceName: opts.DeviceName,
setupKey: opts.SetupKey,
jwtToken: opts.JWTToken,
config: config,
}, nil
}
@@ -126,7 +138,7 @@ func (c *Client) Start(startCtx context.Context) error {
ctx := internal.CtxInitState(context.Background())
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err)
}
@@ -187,6 +199,16 @@ func (c *Client) Stop(ctx context.Context) error {
}
}
// GetConfig returns a copy of the internal client config.
func (c *Client) GetConfig() (profilemanager.Config, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.config == nil {
return profilemanager.Config{}, ErrConfigNotInitialized
}
return *c.config, nil
}
// Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
@@ -211,7 +233,7 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
return nsnet.DialContext(ctx, network, address)
}
// ListenTCP listens on the given address in the netbird network
// ListenTCP listens on the given address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenTCP(address string) (net.Listener, error) {
nsnet, addr, err := c.getNet()
@@ -232,7 +254,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
return nsnet.ListenTCP(tcpAddr)
}
// ListenUDP listens on the given address in the netbird network
// ListenUDP listens on the given address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
nsnet, addr, err := c.getNet()

View File

@@ -1,3 +1,5 @@
//go:build !js
package bind
import (

View File

@@ -0,0 +1,69 @@
//go:build js
package bind
import (
"net"
"net/netip"
"sync"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// RecvMessage represents a received message
type RecvMessage struct {
Endpoint *Endpoint
Buffer []byte
}
// ICEBind is a bind implementation that uses ICE candidates for connectivity
type ICEBind struct {
address wgaddr.Address
filterFn FilterFn
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
udpMux *UniversalUDPMuxDefault
muUDPMux sync.Mutex
transportNet transport.Net
receiverCreated bool
activityRecorder *ActivityRecorder
RecvChan chan RecvMessage
closed bool // Flag to signal that bind is closed
closedMu sync.Mutex
mtu uint16
}
// NewICEBind creates a new ICEBind instance
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
return &ICEBind{
address: address,
transportNet: transportNet,
filterFn: filterFn,
endpoints: make(map[netip.Addr]net.Conn),
RecvChan: make(chan RecvMessage, 100),
activityRecorder: NewActivityRecorder(),
mtu: mtu,
}
}
// SetFilter updates the filter function
func (s *ICEBind) SetFilter(filter FilterFn) {
s.filterFn = filter
}
// GetAddress returns the bind address
func (s *ICEBind) GetAddress() wgaddr.Address {
return s.address
}
// ActivityRecorder returns the activity recorder
func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
return s.activityRecorder
}
// MTU returns the maximum transmission unit
func (s *ICEBind) MTU() uint16 {
return s.mtu
}

View File

@@ -0,0 +1,141 @@
//go:build js
package bind
import (
"net"
"net/netip"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
)
// GetICEMux returns a dummy UDP mux for WASM since browsers don't support UDP.
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
return nil, nil
}
// Open creates a receive function for handling relay packets in WASM.
func (s *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
log.Debugf("Open: creating receive function for port %d", uport)
s.closedMu.Lock()
s.closed = false
s.closedMu.Unlock()
if !s.receiverCreated {
s.receiverCreated = true
log.Debugf("Open: first call, setting receiverCreated=true")
}
receiveFn := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
s.closedMu.Lock()
if s.closed {
s.closedMu.Unlock()
return 0, net.ErrClosed
}
s.closedMu.Unlock()
msg, ok := <-s.RecvChan
if !ok {
return 0, net.ErrClosed
}
copy(bufs[0], msg.Buffer)
sizes[0] = len(msg.Buffer)
eps[0] = conn.Endpoint(msg.Endpoint)
return 1, nil
}
log.Debugf("Open: receive function created, returning port %d", uport)
return []conn.ReceiveFunc{receiveFn}, uport, nil
}
// SetMark is not applicable in WASM/browser environment.
func (s *ICEBind) SetMark(_ uint32) error {
return nil
}
// Send forwards packets through the relay connection for WASM.
func (s *ICEBind) Send(bufs [][]byte, ep conn.Endpoint) error {
if ep == nil {
return nil
}
fakeIP := ep.DstIP()
s.endpointsMu.Lock()
relayConn, ok := s.endpoints[fakeIP]
s.endpointsMu.Unlock()
if !ok {
return nil
}
for _, buf := range bufs {
if _, err := relayConn.Write(buf); err != nil {
log.Errorf("Send: failed to write to relay: %v", err)
return err
}
}
return nil
}
// SetEndpoint stores a relay endpoint for a fake IP.
func (s *ICEBind) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
s.endpointsMu.Lock()
defer s.endpointsMu.Unlock()
if oldConn, exists := s.endpoints[fakeIP]; exists {
if oldConn != conn {
log.Debugf("SetEndpoint: replacing existing connection for %s", fakeIP)
if err := oldConn.Close(); err != nil {
log.Debugf("SetEndpoint: error closing old connection: %v", err)
}
s.endpoints[fakeIP] = conn
} else {
log.Tracef("SetEndpoint: same connection already set for %s, skipping", fakeIP)
}
} else {
log.Debugf("SetEndpoint: setting new relay connection for fake IP %s", fakeIP)
s.endpoints[fakeIP] = conn
}
}
// RemoveEndpoint removes a relay endpoint.
func (s *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
s.endpointsMu.Lock()
defer s.endpointsMu.Unlock()
delete(s.endpoints, fakeIP)
}
// BatchSize returns the batch size for WASM.
func (s *ICEBind) BatchSize() int {
return 1
}
// ParseEndpoint parses an endpoint string.
func (s *ICEBind) ParseEndpoint(s2 string) (conn.Endpoint, error) {
addrPort, err := netip.ParseAddrPort(s2)
if err != nil {
log.Errorf("ParseEndpoint: failed to parse %s: %v", s2, err)
return nil, err
}
ep := &Endpoint{AddrPort: addrPort}
return ep, nil
}
// Close closes the ICEBind.
func (s *ICEBind) Close() error {
log.Debugf("Close: closing ICEBind (receiverCreated=%v)", s.receiverCreated)
s.closedMu.Lock()
s.closed = true
s.closedMu.Unlock()
s.receiverCreated = false
log.Debugf("Close: returning from Close")
return nil
}

View File

@@ -1,4 +1,4 @@
//go:build linux || windows || freebsd
//go:build linux || windows || freebsd || js || wasip1
package configurer

View File

@@ -1,4 +1,4 @@
//go:build !windows
//go:build !windows && !js
package configurer

View File

@@ -0,0 +1,23 @@
package configurer
import (
"net"
)
type noopListener struct{}
func (n *noopListener) Accept() (net.Conn, error) {
return nil, net.ErrClosed
}
func (n *noopListener) Close() error {
return nil
}
func (n *noopListener) Addr() net.Addr {
return nil
}
func openUAPI(deviceName string) (net.Listener, error) {
return &noopListener{}, nil
}

View File

@@ -0,0 +1,6 @@
package iface
// Destroy is a no-op on WASM
func (w *WGIface) Destroy() error {
return nil
}

View File

@@ -0,0 +1,27 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode)
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
wgIface := &WGIface{
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIface, nil
}

View File

@@ -1,3 +1,5 @@
//go:build !js
package netstack
import (

View File

@@ -0,0 +1,12 @@
package netstack
const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE"
// IsEnabled always returns true for js since it's the only mode available
func IsEnabled() bool {
return true
}
func ListenAddr() string {
return ""
}

View File

@@ -0,0 +1,5 @@
package dns
func (s *DefaultServer) initialize() (hostManager, error) {
return &noopHostConfigurator{}, nil
}

View File

@@ -0,0 +1,19 @@
package dns
import (
"context"
)
type ShutdownState struct{}
func (s *ShutdownState) Name() string {
return "dns_state"
}
func (s *ShutdownState) Cleanup() error {
return nil
}
func (s *ShutdownState) RestoreUncleanShutdownConfigs(context.Context) error {
return nil
}

View File

@@ -450,14 +450,7 @@ func (e *Engine) Start() error {
return fmt.Errorf("initialize dns server: %w", err)
}
iceCfg := icemaker.Config{
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.UDPMuxDefault,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
}
iceCfg := e.createICEConfig()
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface)
e.connMgr.Start(e.ctx)
@@ -1288,14 +1281,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
Addr: e.getRosenpassAddr(),
PermissiveMode: e.config.RosenpassPermissive,
},
ICEConfig: icemaker.Config{
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.UDPMuxDefault,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
},
ICEConfig: e.createICEConfig(),
}
serviceDependencies := peer.ServiceDependencies{

View File

@@ -0,0 +1,19 @@
//go:build !js
package internal
import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
)
// createICEConfig creates ICE configuration for non-WASM environments
func (e *Engine) createICEConfig() icemaker.Config {
return icemaker.Config{
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.UDPMuxDefault,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
}
}

View File

@@ -0,0 +1,24 @@
//go:build js
package internal
import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
)
// createICEConfig creates ICE configuration for WASM environment.
func (e *Engine) createICEConfig() icemaker.Config {
cfg := icemaker.Config{
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
NATExternalIPs: e.parseNATExternalIPMappings(),
}
if e.udpMux != nil {
cfg.UDPMux = e.udpMux.UDPMuxDefault
cfg.UDPMuxSrflx = e.udpMux
}
return cfg
}

View File

@@ -0,0 +1,12 @@
package networkmonitor
import (
"context"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
// No-op for WASM - network changes don't apply
return nil
}

View File

@@ -455,6 +455,14 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return
}
// Check if we already have a relay proxy configured
if conn.wgProxyRelay != nil {
conn.Log.Debugf("Relay proxy already configured, skipping duplicate setup")
// Update status to ensure it's connected
conn.statusRelay.SetConnected()
return
}
conn.dumpState.RelayConnected()
conn.Log.Debugf("Relay connection has been established, setup the WireGuard")
@@ -472,19 +480,42 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
if conn.isICEActive() {
conn.Log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
conn.setRelayedProxy(wgProxy)
// For WASM, we still need to start the proxy and configure WireGuard
// because ICE doesn't actually work in browsers and we rely on relay
conn.Log.Infof("WASM check: runtime.GOOS=%s, should start proxy=%v", runtime.GOOS, runtime.GOOS == "js")
if runtime.GOOS == "js" {
conn.Log.Infof("WASM: starting relay proxy and configuring WireGuard despite ICE being 'active'")
wgProxy.Work()
// Configure WireGuard to use the relay proxy endpoint
endpointAddr := wgProxy.EndpointAddr()
conn.Log.Infof("WASM: Configuring WireGuard endpoint to proxy address: %v", endpointAddr)
if err := conn.configureWGEndpoint(endpointAddr, rci.rosenpassPubKey); err != nil {
conn.Log.Errorf("WASM: Failed to update WireGuard peer configuration: %v", err)
} else {
conn.Log.Infof("WASM: Successfully configured WireGuard endpoint to use proxy at %v", endpointAddr)
}
// Update connection priority to relay for WASM
conn.currentConnPriority = conntype.Relay
conn.rosenpassRemoteKey = rci.rosenpassPubKey
}
conn.statusRelay.SetConnected()
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
return
}
wgProxy.Work()
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
endpointAddr := wgProxy.EndpointAddr()
conn.Log.Infof("Configuring WireGuard endpoint to proxy address: %v", endpointAddr)
if err := conn.configureWGEndpoint(endpointAddr, rci.rosenpassPubKey); err != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.Log.Warnf("Failed to close relay connection: %v", err)
}
conn.Log.Errorf("Failed to update WireGuard peer configuration: %v", err)
return
}
conn.Log.Infof("Successfully configured WireGuard endpoint to use proxy at %v", endpointAddr)
conn.wgWatcherWg.Add(1)
go func() {
@@ -663,6 +694,21 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
}
}()
// In WASM with forced relay, ICE is not used, so skip ICE check
if runtime.GOOS == "js" && os.Getenv("NB_FORCE_RELAY") == "true" {
// Only check relay connection status
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
relayConnected := conn.statusRelay.Get() != worker.StatusDisconnected
if !relayConnected {
conn.Log.Tracef("WASM: relay not connected for connectivity check")
}
return relayConnected
}
// If relay is not supported, consider it connected to avoid reconnect loop
conn.Log.Tracef("WASM: relay not supported, returning true to avoid reconnect")
return true
}
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
return false
}

View File

@@ -2,7 +2,6 @@ package peer
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
@@ -55,6 +54,22 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
}
w.relaySupportedOnRemotePeer.Store(true)
// Check if we already have an active relay connection
w.relayLock.Lock()
existingConn := w.relayedConn
w.relayLock.Unlock()
if existingConn != nil {
w.log.Debugf("relay connection already exists for peer %s, reusing it", w.config.Key)
// Connection exists, just ensure proxy is set up if needed
go w.conn.onRelayConnectionIsReady(RelayConnInfo{
relayedConn: existingConn,
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
rosenpassAddr: remoteOfferAnswer.RosenpassAddr,
})
return
}
// the relayManager will return with error in case if the connection has lost with relay server
currentRelayAddress, err := w.relayManager.RelayInstanceAddress()
if err != nil {
@@ -66,15 +81,24 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key)
if err != nil {
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
w.log.Debugf("handled offer by reusing existing relay connection")
return
}
// The relay manager never actually returns ErrConnAlreadyExists - it returns
// the existing connection with nil error. This error handling is for other failures.
w.log.Errorf("failed to open connection via Relay: %s", err)
return
}
w.relayLock.Lock()
// Check if we already stored this connection (might happen if OpenConn returned existing)
if w.relayedConn != nil && w.relayedConn == relayedConn {
w.relayLock.Unlock()
w.log.Debugf("OpenConn returned the same connection we already have for peer %s", w.config.Key)
go w.conn.onRelayConnectionIsReady(RelayConnInfo{
relayedConn: relayedConn,
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
rosenpassAddr: remoteOfferAnswer.RosenpassAddr,
})
return
}
w.relayedConn = relayedConn
w.relayLock.Unlock()
@@ -123,11 +147,16 @@ func (w *WorkerRelay) CloseConn() {
if err := w.relayedConn.Close(); err != nil {
w.log.Warnf("failed to close relay connection: %v", err)
}
// Clear the stored connection to allow reopening
w.relayedConn = nil
}
func (w *WorkerRelay) onWGDisconnected() {
w.relayLock.Lock()
_ = w.relayedConn.Close()
if w.relayedConn != nil {
_ = w.relayedConn.Close()
w.relayedConn = nil
}
w.relayLock.Unlock()
w.conn.onRelayDisconnected()
@@ -148,6 +177,11 @@ func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress st
}
func (w *WorkerRelay) onRelayClientDisconnected() {
// Clear the stored connection when relay disconnects
w.relayLock.Lock()
w.relayedConn = nil
w.relayLock.Unlock()
w.wgWatcher.DisableWgWatcher()
go w.conn.onRelayDisconnected()
}

View File

@@ -0,0 +1,48 @@
package systemops
import (
"errors"
"net"
"net/netip"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
var ErrRouteNotSupported = errors.New("route operations not supported on js")
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return ErrRouteNotSupported
}
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return ErrRouteNotSupported
}
func GetRoutesFromTable() ([]netip.Prefix, error) {
return []netip.Prefix{}, nil
}
func hasSeparateRouting() ([]netip.Prefix, error) {
return []netip.Prefix{}, nil
}
// GetDetailedRoutesFromTable returns empty routes for WASM.
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
return []DetailedRoute{}, nil
}
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return ErrRouteNotSupported
}
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
return ErrRouteNotSupported
}
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
return nil
}
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
return nil
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !ios
//go:build !linux && !ios && !js
package systemops

View File

@@ -1,3 +1,5 @@
//go:build !js
package ssh
import (

View File

@@ -1,3 +1,5 @@
//go:build !js
package ssh
import (

View File

@@ -1,3 +1,5 @@
//go:build !js
package ssh
import (

View File

@@ -1,3 +1,5 @@
//go:build !js
package ssh
import "context"

View File

@@ -1,3 +1,5 @@
//go:build !js
package ssh
import (

137
client/ssh/ssh_js.go Normal file
View File

@@ -0,0 +1,137 @@
package ssh
import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/x509"
"encoding/pem"
"errors"
"strings"
"golang.org/x/crypto/ssh"
)
var ErrSSHNotSupported = errors.New("SSH is not supported in WASM environment")
// Server is a dummy SSH server interface for WASM.
type Server interface {
Start() error
Stop() error
EnableSSH(enabled bool)
AddAuthorizedKey(peer string, key string) error
RemoveAuthorizedKey(key string)
}
type dummyServer struct{}
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
return &dummyServer{}, nil
}
func NewServer(addr string) Server {
return &dummyServer{}
}
func (s *dummyServer) Start() error {
return ErrSSHNotSupported
}
func (s *dummyServer) Stop() error {
return nil
}
func (s *dummyServer) EnableSSH(enabled bool) {
}
func (s *dummyServer) AddAuthorizedKey(peer string, key string) error {
return nil
}
func (s *dummyServer) RemoveAuthorizedKey(key string) {
}
type Client struct{}
func NewClient(ctx context.Context, addr string, config interface{}, recorder *SessionRecorder) (*Client, error) {
return nil, ErrSSHNotSupported
}
func (c *Client) Close() error {
return nil
}
func (c *Client) Run(command []string) error {
return ErrSSHNotSupported
}
type SessionRecorder struct{}
func NewSessionRecorder() *SessionRecorder {
return &SessionRecorder{}
}
func (r *SessionRecorder) Record(session string, data []byte) {
}
func GetUserShell() string {
return "/bin/sh"
}
func LookupUserInfo(username string) (string, string, error) {
return "", "", ErrSSHNotSupported
}
const DefaultSSHPort = 44338
const ED25519 = "ed25519"
func isRoot() bool {
return false
}
func GeneratePrivateKey(keyType string) ([]byte, error) {
if keyType != ED25519 {
return nil, errors.New("only ED25519 keys are supported in WASM")
}
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
pkcs8Bytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
if err != nil {
return nil, err
}
pemBlock := &pem.Block{
Type: "PRIVATE KEY",
Bytes: pkcs8Bytes,
}
pemBytes := pem.EncodeToMemory(pemBlock)
return pemBytes, nil
}
func GeneratePublicKey(privateKey []byte) ([]byte, error) {
signer, err := ssh.ParsePrivateKey(privateKey)
if err != nil {
block, _ := pem.Decode(privateKey)
if block != nil {
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, err
}
signer, err = ssh.NewSignerFromKey(key)
if err != nil {
return nil, err
}
} else {
return nil, err
}
}
pubKeyBytes := ssh.MarshalAuthorizedKey(signer.PublicKey())
return []byte(strings.TrimSpace(string(pubKeyBytes))), nil
}

View File

@@ -1,3 +1,5 @@
//go:build !js
package ssh
import (

234
client/system/info_js.go Normal file
View File

@@ -0,0 +1,234 @@
package system
import (
"context"
"runtime"
"strings"
"syscall/js"
"github.com/netbirdio/netbird/version"
)
// GetInfo retrieves system information for WASM environment
func GetInfo(_ context.Context) *Info {
info := &Info{
GoOS: runtime.GOOS,
Kernel: runtime.GOARCH,
KernelVersion: runtime.GOARCH,
Platform: runtime.GOARCH,
OS: runtime.GOARCH,
Hostname: "wasm-client",
CPUs: runtime.NumCPU(),
NetbirdVersion: version.NetbirdVersion(),
}
collectBrowserInfo(info)
collectLocationInfo(info)
si := updateStaticInfo()
info.SystemSerialNumber = si.SystemSerialNumber
info.SystemProductName = si.SystemProductName
info.SystemManufacturer = si.SystemManufacturer
info.Environment = si.Environment
return info
}
func collectBrowserInfo(info *Info) {
navigator := js.Global().Get("navigator")
if navigator.IsUndefined() {
return
}
collectUserAgent(info, navigator)
collectPlatform(info, navigator)
collectCPUInfo(info, navigator)
}
func collectUserAgent(info *Info, navigator js.Value) {
ua := navigator.Get("userAgent")
if ua.IsUndefined() {
return
}
userAgent := ua.String()
os, osVersion := parseOSFromUserAgent(userAgent)
if os != "" {
info.OS = os
}
if osVersion != "" {
info.OSVersion = osVersion
}
}
func collectPlatform(info *Info, navigator js.Value) {
// Try regular platform property
if plat := navigator.Get("platform"); !plat.IsUndefined() {
if platStr := plat.String(); platStr != "" {
info.Platform = platStr
}
}
// Try newer userAgentData API for more accurate platform
userAgentData := navigator.Get("userAgentData")
if userAgentData.IsUndefined() {
return
}
platformInfo := userAgentData.Get("platform")
if !platformInfo.IsUndefined() {
if platStr := platformInfo.String(); platStr != "" {
info.Platform = platStr
}
}
}
func collectCPUInfo(info *Info, navigator js.Value) {
hardwareConcurrency := navigator.Get("hardwareConcurrency")
if !hardwareConcurrency.IsUndefined() {
info.CPUs = hardwareConcurrency.Int()
}
}
func collectLocationInfo(info *Info) {
location := js.Global().Get("location")
if location.IsUndefined() {
return
}
if host := location.Get("hostname"); !host.IsUndefined() {
hostnameStr := host.String()
if hostnameStr != "" && hostnameStr != "localhost" {
info.Hostname = hostnameStr
}
}
}
func checkFileAndProcess(_ []string) ([]File, error) {
return []File{}, nil
}
func updateStaticInfo() *StaticInfo {
si := &StaticInfo{}
navigator := js.Global().Get("navigator")
if !navigator.IsUndefined() {
if vendor := navigator.Get("vendor"); !vendor.IsUndefined() {
si.SystemManufacturer = vendor.String()
}
if product := navigator.Get("product"); !product.IsUndefined() {
si.SystemProductName = product.String()
}
if userAgent := navigator.Get("userAgent"); !userAgent.IsUndefined() {
ua := userAgent.String()
si.Environment = detectEnvironmentFromUA(ua)
}
}
return si
}
func parseOSFromUserAgent(userAgent string) (string, string) {
if userAgent == "" {
return "", ""
}
switch {
case strings.Contains(userAgent, "Windows NT"):
return parseWindowsVersion(userAgent)
case strings.Contains(userAgent, "Mac OS X"):
return parseMacOSVersion(userAgent)
case strings.Contains(userAgent, "FreeBSD"):
return "FreeBSD", ""
case strings.Contains(userAgent, "OpenBSD"):
return "OpenBSD", ""
case strings.Contains(userAgent, "NetBSD"):
return "NetBSD", ""
case strings.Contains(userAgent, "Linux"):
return parseLinuxVersion(userAgent)
case strings.Contains(userAgent, "iPhone") || strings.Contains(userAgent, "iPad"):
return parseiOSVersion(userAgent)
case strings.Contains(userAgent, "CrOS"):
return "ChromeOS", ""
default:
return "", ""
}
}
func parseWindowsVersion(userAgent string) (string, string) {
switch {
case strings.Contains(userAgent, "Windows NT 10.0; Win64; x64"):
return "Windows", "10/11"
case strings.Contains(userAgent, "Windows NT 10.0"):
return "Windows", "10"
case strings.Contains(userAgent, "Windows NT 6.3"):
return "Windows", "8.1"
case strings.Contains(userAgent, "Windows NT 6.2"):
return "Windows", "8"
case strings.Contains(userAgent, "Windows NT 6.1"):
return "Windows", "7"
default:
return "Windows", "Unknown"
}
}
func parseMacOSVersion(userAgent string) (string, string) {
idx := strings.Index(userAgent, "Mac OS X ")
if idx == -1 {
return "macOS", "Unknown"
}
versionStart := idx + len("Mac OS X ")
versionEnd := strings.Index(userAgent[versionStart:], ")")
if versionEnd <= 0 {
return "macOS", "Unknown"
}
ver := userAgent[versionStart : versionStart+versionEnd]
ver = strings.ReplaceAll(ver, "_", ".")
return "macOS", ver
}
func parseLinuxVersion(userAgent string) (string, string) {
if strings.Contains(userAgent, "Android") {
return "Android", extractAndroidVersion(userAgent)
}
if strings.Contains(userAgent, "Ubuntu") {
return "Ubuntu", ""
}
return "Linux", ""
}
func parseiOSVersion(userAgent string) (string, string) {
idx := strings.Index(userAgent, "OS ")
if idx == -1 {
return "iOS", "Unknown"
}
versionStart := idx + 3
versionEnd := strings.Index(userAgent[versionStart:], " ")
if versionEnd <= 0 {
return "iOS", "Unknown"
}
ver := userAgent[versionStart : versionStart+versionEnd]
ver = strings.ReplaceAll(ver, "_", ".")
return "iOS", ver
}
func extractAndroidVersion(userAgent string) string {
if idx := strings.Index(userAgent, "Android "); idx != -1 {
versionStart := idx + len("Android ")
versionEnd := strings.IndexAny(userAgent[versionStart:], ";)")
if versionEnd > 0 {
return userAgent[versionStart : versionStart+versionEnd]
}
}
return "Unknown"
}
func detectEnvironmentFromUA(_ string) Environment {
return Environment{}
}

236
client/wasm/cmd/main.go Normal file
View File

@@ -0,0 +1,236 @@
package main
import (
"context"
"fmt"
"log"
"syscall/js"
"time"
netbird "github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/client/wasm/internal/http"
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
)
const (
clientStartTimeout = 30 * time.Second
clientStopTimeout = 10 * time.Second
)
func main() {
js.Global().Set("NetBirdClient", js.FuncOf(netBirdClientConstructor))
select {}
}
func startClient(ctx context.Context, nbClient *netbird.Client) error {
log.Println("Starting NetBird client...")
if err := nbClient.Start(ctx); err != nil {
return err
}
log.Println("NetBird client started successfully")
return nil
}
// parseClientOptions extracts NetBird options from JavaScript object
func parseClientOptions(jsOptions js.Value) (netbird.Options, error) {
options := netbird.Options{
DeviceName: "dashboard-client",
LogLevel: "warn",
}
if jwtToken := jsOptions.Get("jwtToken"); !jwtToken.IsNull() && !jwtToken.IsUndefined() {
options.JWTToken = jwtToken.String()
}
if setupKey := jsOptions.Get("setupKey"); !setupKey.IsNull() && !setupKey.IsUndefined() {
options.SetupKey = setupKey.String()
}
if options.JWTToken == "" && options.SetupKey == "" {
return options, fmt.Errorf("either jwtToken or setupKey must be provided")
}
if mgmtURL := jsOptions.Get("managementURL"); !mgmtURL.IsNull() && !mgmtURL.IsUndefined() {
mgmtURLStr := mgmtURL.String()
if mgmtURLStr != "" {
options.ManagementURL = mgmtURLStr
}
}
if logLevel := jsOptions.Get("logLevel"); !logLevel.IsNull() && !logLevel.IsUndefined() {
options.LogLevel = logLevel.String()
}
if deviceName := jsOptions.Get("deviceName"); !deviceName.IsNull() && !deviceName.IsUndefined() {
options.DeviceName = deviceName.String()
}
return options, nil
}
// createStartMethod creates the start method for the client
func createStartMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
ctx, cancel := context.WithTimeout(context.Background(), clientStartTimeout)
defer cancel()
if err := startClient(ctx, client); err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
resolve.Invoke(js.ValueOf(true))
})
})
}
// createStopMethod creates the stop method for the client
func createStopMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
ctx, cancel := context.WithTimeout(context.Background(), clientStopTimeout)
defer cancel()
if err := client.Stop(ctx); err != nil {
log.Printf("Error stopping client: %v", err)
reject.Invoke(js.ValueOf(err.Error()))
return
}
log.Println("NetBird client stopped")
resolve.Invoke(js.ValueOf(true))
})
})
}
// createSSHMethod creates the SSH connection method
func createSSHMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 2 {
return js.ValueOf("error: requires host and port")
}
host := args[0].String()
port := args[1].Int()
username := "root"
if len(args) > 2 && args[2].String() != "" {
username = args[2].String()
}
return createPromise(func(resolve, reject js.Value) {
sshClient := ssh.NewClient(client)
if err := sshClient.Connect(host, port, username); err != nil {
reject.Invoke(err.Error())
return
}
if err := sshClient.StartSession(80, 24); err != nil {
if closeErr := sshClient.Close(); closeErr != nil {
log.Printf("Error closing SSH client: %v", closeErr)
}
reject.Invoke(err.Error())
return
}
jsInterface := ssh.CreateJSInterface(sshClient)
resolve.Invoke(jsInterface)
})
})
}
// createProxyRequestMethod creates the proxyRequest method
func createProxyRequestMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf("error: request details required")
}
request := args[0]
return createPromise(func(resolve, reject js.Value) {
response, err := http.ProxyRequest(client, request)
if err != nil {
reject.Invoke(err.Error())
return
}
resolve.Invoke(response)
})
})
}
// createRDPProxyMethod creates the RDP proxy method
func createRDPProxyMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 2 {
return js.ValueOf("error: hostname and port required")
}
proxy := rdp.NewRDCleanPathProxy(client)
return proxy.CreateProxy(args[0].String(), args[1].String())
})
}
// createPromise is a helper to create JavaScript promises
func createPromise(handler func(resolve, reject js.Value)) js.Value {
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
resolve := promiseArgs[0]
reject := promiseArgs[1]
go handler(resolve, reject)
return nil
}))
}
// createClientObject wraps the NetBird client in a JavaScript object
func createClientObject(client *netbird.Client) js.Value {
obj := make(map[string]interface{})
obj["start"] = createStartMethod(client)
obj["stop"] = createStopMethod(client)
obj["createSSHConnection"] = createSSHMethod(client)
obj["proxyRequest"] = createProxyRequestMethod(client)
obj["createRDPProxy"] = createRDPProxyMethod(client)
return js.ValueOf(obj)
}
// netBirdClientConstructor acts as a JavaScript constructor function
func netBirdClientConstructor(this js.Value, args []js.Value) any {
return js.Global().Get("Promise").New(js.FuncOf(func(this js.Value, promiseArgs []js.Value) any {
resolve := promiseArgs[0]
reject := promiseArgs[1]
if len(args) < 1 {
reject.Invoke(js.ValueOf("Options object required"))
return nil
}
go func() {
options, err := parseClientOptions(args[0])
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
log.Printf("Creating NetBird client with options: deviceName=%s, hasJWT=%v, hasSetupKey=%v, mgmtURL=%s",
options.DeviceName, options.JWTToken != "", options.SetupKey != "", options.ManagementURL)
client, err := netbird.New(options)
if err != nil {
reject.Invoke(js.ValueOf(fmt.Sprintf("create client: %v", err)))
return
}
clientObj := createClientObject(client)
log.Println("NetBird client created successfully")
resolve.Invoke(clientObj)
}()
return nil
}))
}

View File

@@ -0,0 +1,98 @@
package http
import (
"fmt"
"io"
"log"
"net/http"
"strings"
"syscall/js"
"time"
netbird "github.com/netbirdio/netbird/client/embed"
)
const (
httpTimeout = 30 * time.Second
maxResponseSize = 1024 * 1024 // 1MB
)
// performRequest executes an HTTP request through NetBird and returns the response and body
func performRequest(nbClient *netbird.Client, method, url string, headers map[string]string, body []byte) (*http.Response, []byte, error) {
httpClient := nbClient.NewHTTPClient()
httpClient.Timeout = httpTimeout
req, err := http.NewRequest(method, url, strings.NewReader(string(body)))
if err != nil {
return nil, nil, fmt.Errorf("create request: %w", err)
}
for key, value := range headers {
req.Header.Set(key, value)
}
resp, err := httpClient.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("request failed: %w", err)
}
defer func() {
if err := resp.Body.Close(); err != nil {
log.Printf("failed to close response body: %v", err)
}
}()
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
if err != nil {
return nil, nil, fmt.Errorf("read response: %w", err)
}
return resp, respBody, nil
}
// ProxyRequest performs a proxied HTTP request through NetBird and returns a JavaScript object
func ProxyRequest(nbClient *netbird.Client, request js.Value) (js.Value, error) {
url := request.Get("url").String()
if url == "" {
return js.Undefined(), fmt.Errorf("URL is required")
}
method := "GET"
if methodVal := request.Get("method"); !methodVal.IsNull() && !methodVal.IsUndefined() {
method = strings.ToUpper(methodVal.String())
}
var requestBody []byte
if bodyVal := request.Get("body"); !bodyVal.IsNull() && !bodyVal.IsUndefined() {
requestBody = []byte(bodyVal.String())
}
requestHeaders := make(map[string]string)
if headersVal := request.Get("headers"); !headersVal.IsNull() && !headersVal.IsUndefined() && headersVal.Type() == js.TypeObject {
headerKeys := js.Global().Get("Object").Call("keys", headersVal)
for i := 0; i < headerKeys.Length(); i++ {
key := headerKeys.Index(i).String()
value := headersVal.Get(key).String()
requestHeaders[key] = value
}
}
resp, body, err := performRequest(nbClient, method, url, requestHeaders, requestBody)
if err != nil {
return js.Undefined(), err
}
result := js.Global().Get("Object").New()
result.Set("status", resp.StatusCode)
result.Set("statusText", resp.Status)
result.Set("body", string(body))
headers := js.Global().Get("Object").New()
for key, values := range resp.Header {
if len(values) > 0 {
headers.Set(strings.ToLower(key), values[0])
}
}
result.Set("headers", headers)
return result, nil
}

View File

@@ -0,0 +1,94 @@
package rdp
import (
"crypto/tls"
"crypto/x509"
"fmt"
"syscall/js"
"time"
log "github.com/sirupsen/logrus"
)
const (
certValidationTimeout = 60 * time.Second
)
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
if !conn.wsHandlers.Get("onCertificateRequest").Truthy() {
return false, fmt.Errorf("certificate validation handler not configured")
}
certInfo := js.Global().Get("Object").New()
certInfo.Set("ServerAddr", conn.destination)
certArray := js.Global().Get("Array").New()
for i, certBytes := range certChain {
uint8Array := js.Global().Get("Uint8Array").New(len(certBytes))
js.CopyBytesToJS(uint8Array, certBytes)
certArray.SetIndex(i, uint8Array)
}
certInfo.Set("ServerCertChain", certArray)
if len(certChain) > 0 {
cert, err := x509.ParseCertificate(certChain[0])
if err == nil {
info := js.Global().Get("Object").New()
info.Set("subject", cert.Subject.String())
info.Set("issuer", cert.Issuer.String())
info.Set("validFrom", cert.NotBefore.Format(time.RFC3339))
info.Set("validTo", cert.NotAfter.Format(time.RFC3339))
info.Set("serialNumber", cert.SerialNumber.String())
certInfo.Set("CertificateInfo", info)
}
}
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
resultChan := make(chan bool)
errorChan := make(chan error)
promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
result := args[0].Bool()
resultChan <- result
return nil
})).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
errorChan <- fmt.Errorf("certificate validation failed")
return nil
}))
select {
case result := <-resultChan:
if result {
log.Info("Certificate accepted by user")
} else {
log.Info("Certificate rejected by user")
}
return result, nil
case err := <-errorChan:
return false, err
case <-time.After(certValidationTimeout):
return false, fmt.Errorf("certificate validation timeout")
}
}
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config {
return &tls.Config{
InsecureSkipVerify: true, // We'll validate manually after handshake
VerifyConnection: func(cs tls.ConnectionState) error {
var certChain [][]byte
for _, cert := range cs.PeerCertificates {
certChain = append(certChain, cert.Raw)
}
accepted, err := p.validateCertificateWithJS(conn, certChain)
if err != nil {
return err
}
if !accepted {
return fmt.Errorf("certificate rejected by user")
}
return nil
},
}
}

View File

@@ -0,0 +1,269 @@
package rdp
import (
"context"
"crypto/tls"
"encoding/asn1"
"fmt"
"io"
"net"
"sync"
"syscall/js"
log "github.com/sirupsen/logrus"
)
const (
RDCleanPathVersion = 3390
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
RDCleanPathProxyScheme = "ws"
)
type RDCleanPathPDU struct {
Version int64 `asn1:"tag:0,explicit"`
Error []byte `asn1:"tag:1,explicit,optional"`
Destination string `asn1:"utf8,tag:2,explicit,optional"`
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
}
type RDCleanPathProxy struct {
nbClient interface {
Dial(ctx context.Context, network, address string) (net.Conn, error)
}
activeConnections map[string]*proxyConnection
destinations map[string]string
mu sync.Mutex
}
type proxyConnection struct {
id string
destination string
rdpConn net.Conn
tlsConn *tls.Conn
wsHandlers js.Value
ctx context.Context
cancel context.CancelFunc
}
// NewRDCleanPathProxy creates a new RDCleanPath proxy
func NewRDCleanPathProxy(client interface {
Dial(ctx context.Context, network, address string) (net.Conn, error)
}) *RDCleanPathProxy {
return &RDCleanPathProxy{
nbClient: client,
activeConnections: make(map[string]*proxyConnection),
}
}
// CreateProxy creates a new proxy endpoint for the given destination
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
destination := fmt.Sprintf("%s:%s", hostname, port)
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, args []js.Value) any {
resolve := args[0]
go func() {
proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections))
p.mu.Lock()
if p.destinations == nil {
p.destinations = make(map[string]string)
}
p.destinations[proxyID] = destination
p.mu.Unlock()
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
// Register the WebSocket handler for this specific proxy
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf("error: requires WebSocket argument")
}
ws := args[0]
p.HandleWebSocketConnection(ws, proxyID)
return nil
}))
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
resolve.Invoke(proxyURL)
}()
return nil
}))
}
// HandleWebSocketConnection handles incoming WebSocket connections from IronRDP
func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string) {
p.mu.Lock()
destination := p.destinations[proxyID]
p.mu.Unlock()
if destination == "" {
log.Errorf("No destination found for proxy ID: %s", proxyID)
return
}
ctx, cancel := context.WithCancel(context.Background())
// Don't defer cancel here - it will be called by cleanupConnection
conn := &proxyConnection{
id: proxyID,
destination: destination,
wsHandlers: ws,
ctx: ctx,
cancel: cancel,
}
p.mu.Lock()
p.activeConnections[proxyID] = conn
p.mu.Unlock()
p.setupWebSocketHandlers(ws, conn)
log.Infof("RDCleanPath proxy WebSocket connection established for %s", proxyID)
}
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 1 {
return nil
}
data := args[0]
go p.handleWebSocketMessage(conn, data)
return nil
}))
ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any {
log.Debug("WebSocket closed by JavaScript")
conn.cancel()
return nil
}))
}
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
if !data.InstanceOf(js.Global().Get("Uint8Array")) {
return
}
length := data.Get("length").Int()
bytes := make([]byte, length)
js.CopyBytesToGo(bytes, data)
if conn.rdpConn != nil || conn.tlsConn != nil {
p.forwardToRDP(conn, bytes)
return
}
var pdu RDCleanPathPDU
_, err := asn1.Unmarshal(bytes, &pdu)
if err != nil {
log.Warnf("Failed to parse RDCleanPath PDU: %v", err)
n := len(bytes)
if n > 20 {
n = 20
}
log.Warnf("First %d bytes: %x", n, bytes[:n])
if len(bytes) > 0 && bytes[0] == 0x03 {
log.Debug("Received raw RDP packet instead of RDCleanPath PDU")
go p.handleDirectRDP(conn, bytes)
return
}
return
}
go p.processRDCleanPathPDU(conn, pdu)
}
func (p *RDCleanPathProxy) forwardToRDP(conn *proxyConnection, bytes []byte) {
var writer io.Writer
var connType string
if conn.tlsConn != nil {
writer = conn.tlsConn
connType = "TLS"
} else if conn.rdpConn != nil {
writer = conn.rdpConn
connType = "TCP"
} else {
log.Error("No RDP connection available")
return
}
if _, err := writer.Write(bytes); err != nil {
log.Errorf("Failed to write to %s: %v", connType, err)
}
}
func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []byte) {
defer p.cleanupConnection(conn)
destination := conn.destination
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err)
return
}
conn.rdpConn = rdpConn
_, err = rdpConn.Write(firstPacket)
if err != nil {
log.Errorf("Failed to write first packet: %v", err)
return
}
response := make([]byte, 1024)
n, err := rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
return
}
p.sendToWebSocket(conn, response[:n])
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
}
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
log.Debugf("Cleaning up connection %s", conn.id)
conn.cancel()
if conn.tlsConn != nil {
log.Debug("Closing TLS connection")
if err := conn.tlsConn.Close(); err != nil {
log.Debugf("Error closing TLS connection: %v", err)
}
conn.tlsConn = nil
}
if conn.rdpConn != nil {
log.Debug("Closing TCP connection")
if err := conn.rdpConn.Close(); err != nil {
log.Debugf("Error closing TCP connection: %v", err)
}
conn.rdpConn = nil
}
p.mu.Lock()
delete(p.activeConnections, conn.id)
p.mu.Unlock()
}
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
if conn.wsHandlers.Get("receiveFromGo").Truthy() {
uint8Array := js.Global().Get("Uint8Array").New(len(data))
js.CopyBytesToJS(uint8Array, data)
conn.wsHandlers.Call("receiveFromGo", uint8Array.Get("buffer"))
} else if conn.wsHandlers.Get("send").Truthy() {
uint8Array := js.Global().Get("Uint8Array").New(len(data))
js.CopyBytesToJS(uint8Array, data)
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
}
}

View File

@@ -0,0 +1,249 @@
package rdp
import (
"crypto/tls"
"encoding/asn1"
"io"
"syscall/js"
log "github.com/sirupsen/logrus"
)
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
if pdu.Version != RDCleanPathVersion {
p.sendRDCleanPathError(conn, "Unsupported version")
return
}
destination := conn.destination
if pdu.Destination != "" {
destination = pdu.Destination
}
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err)
p.sendRDCleanPathError(conn, "Connection failed")
p.cleanupConnection(conn)
return
}
conn.rdpConn = rdpConn
// RDP always starts with X.224 negotiation, then determines if TLS is needed
// Modern RDP (since Windows Vista/2008) typically requires TLS
// The X.224 Connection Confirm response will indicate if TLS is required
// For now, we'll attempt TLS for all connections as it's the modern default
p.setupTLSConnection(conn, pdu)
}
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
var x224Response []byte
if len(pdu.X224ConnectionPDU) > 0 {
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
if err != nil {
log.Errorf("Failed to write X.224 PDU: %v", err)
p.sendRDCleanPathError(conn, "Failed to forward X.224")
return
}
response := make([]byte, 1024)
n, err := conn.rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
return
}
x224Response = response[:n]
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
}
tlsConfig := p.getTLSConfigWithValidation(conn)
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
conn.tlsConn = tlsConn
if err := tlsConn.Handshake(); err != nil {
log.Errorf("TLS handshake failed: %v", err)
p.sendRDCleanPathError(conn, "TLS handshake failed")
return
}
log.Info("TLS handshake successful")
// Certificate validation happens during handshake via VerifyConnection callback
var certChain [][]byte
connState := tlsConn.ConnectionState()
if len(connState.PeerCertificates) > 0 {
for _, cert := range connState.PeerCertificates {
certChain = append(certChain, cert.Raw)
}
log.Debugf("Extracted %d certificates from TLS connection", len(certChain))
}
responsePDU := RDCleanPathPDU{
Version: RDCleanPathVersion,
ServerAddr: conn.destination,
ServerCertChain: certChain,
}
if len(x224Response) > 0 {
responsePDU.X224ConnectionPDU = x224Response
}
p.sendRDCleanPathPDU(conn, responsePDU)
log.Debug("Starting TLS forwarding")
go p.forwardConnToWS(conn, conn.tlsConn, "TLS")
go p.forwardWSToConn(conn, conn.tlsConn, "TLS")
<-conn.ctx.Done()
log.Debug("TLS connection context done, cleaning up")
p.cleanupConnection(conn)
}
func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
if len(pdu.X224ConnectionPDU) > 0 {
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
if err != nil {
log.Errorf("Failed to write X.224 PDU: %v", err)
p.sendRDCleanPathError(conn, "Failed to forward X.224")
return
}
response := make([]byte, 1024)
n, err := conn.rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
return
}
responsePDU := RDCleanPathPDU{
Version: RDCleanPathVersion,
X224ConnectionPDU: response[:n],
ServerAddr: conn.destination,
}
p.sendRDCleanPathPDU(conn, responsePDU)
} else {
responsePDU := RDCleanPathPDU{
Version: RDCleanPathVersion,
ServerAddr: conn.destination,
}
p.sendRDCleanPathPDU(conn, responsePDU)
}
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
<-conn.ctx.Done()
log.Debug("TCP connection context done, cleaning up")
p.cleanupConnection(conn)
}
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
data, err := asn1.Marshal(pdu)
if err != nil {
log.Errorf("Failed to marshal RDCleanPath PDU: %v", err)
return
}
log.Debugf("Sending RDCleanPath PDU response (%d bytes)", len(data))
p.sendToWebSocket(conn, data)
}
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) {
pdu := RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: []byte(errorMsg),
}
data, err := asn1.Marshal(pdu)
if err != nil {
log.Errorf("Failed to marshal error PDU: %v", err)
return
}
p.sendToWebSocket(conn, data)
}
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
msgChan := make(chan []byte)
errChan := make(chan error)
handler := js.FuncOf(func(this js.Value, args []js.Value) interface{} {
if len(args) < 1 {
errChan <- io.EOF
return nil
}
data := args[0]
if data.InstanceOf(js.Global().Get("Uint8Array")) {
length := data.Get("length").Int()
bytes := make([]byte, length)
js.CopyBytesToGo(bytes, data)
msgChan <- bytes
}
return nil
})
defer handler.Release()
conn.wsHandlers.Set("onceGoMessage", handler)
select {
case msg := <-msgChan:
return msg, nil
case err := <-errChan:
return nil, err
case <-conn.ctx.Done():
return nil, conn.ctx.Err()
}
}
func (p *RDCleanPathProxy) forwardWSToConn(conn *proxyConnection, dst io.Writer, connType string) {
for {
if conn.ctx.Err() != nil {
return
}
msg, err := p.readWebSocketMessage(conn)
if err != nil {
if err != io.EOF {
log.Errorf("Failed to read from WebSocket: %v", err)
}
return
}
_, err = dst.Write(msg)
if err != nil {
log.Errorf("Failed to write to %s: %v", connType, err)
return
}
}
}
func (p *RDCleanPathProxy) forwardConnToWS(conn *proxyConnection, src io.Reader, connType string) {
buffer := make([]byte, 32*1024)
for {
if conn.ctx.Err() != nil {
return
}
n, err := src.Read(buffer)
if err != nil {
if err != io.EOF {
log.Errorf("Failed to read from %s: %v", connType, err)
}
return
}
if n > 0 {
p.sendToWebSocket(conn, buffer[:n])
}
}
}

View File

@@ -0,0 +1,211 @@
package ssh
import (
"context"
"fmt"
"io"
"sync"
"time"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
netbird "github.com/netbirdio/netbird/client/embed"
)
const (
sshDialTimeout = 30 * time.Second
)
func closeWithLog(c io.Closer, resource string) {
if c != nil {
if err := c.Close(); err != nil {
logrus.Debugf("Failed to close %s: %v", resource, err)
}
}
}
type Client struct {
nbClient *netbird.Client
sshClient *ssh.Client
session *ssh.Session
stdin io.WriteCloser
stdout io.Reader
stderr io.Reader
mu sync.RWMutex
}
// NewClient creates a new SSH client
func NewClient(nbClient *netbird.Client) *Client {
return &Client{
nbClient: nbClient,
}
}
// Connect establishes an SSH connection through NetBird network
func (c *Client) Connect(host string, port int, username string) error {
addr := fmt.Sprintf("%s:%d", host, port)
logrus.Infof("SSH: Connecting to %s as %s", addr, username)
var authMethods []ssh.AuthMethod
nbConfig, err := c.nbClient.GetConfig()
if err != nil {
return fmt.Errorf("get NetBird config: %w", err)
}
if nbConfig.SSHKey == "" {
return fmt.Errorf("no NetBird SSH key available - key should be generated during client initialization")
}
signer, err := parseSSHPrivateKey([]byte(nbConfig.SSHKey))
if err != nil {
return fmt.Errorf("parse NetBird SSH private key: %w", err)
}
pubKey := signer.PublicKey()
logrus.Infof("SSH: Using NetBird key authentication with public key type: %s", pubKey.Type())
authMethods = append(authMethods, ssh.PublicKeys(signer))
config := &ssh.ClientConfig{
User: username,
Auth: authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: sshDialTimeout,
}
ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout)
defer cancel()
conn, err := c.nbClient.Dial(ctx, "tcp", addr)
if err != nil {
return fmt.Errorf("dial %s: %w", addr, err)
}
sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
if err != nil {
closeWithLog(conn, "connection after handshake error")
return fmt.Errorf("SSH handshake: %w", err)
}
c.sshClient = ssh.NewClient(sshConn, chans, reqs)
logrus.Infof("SSH: Connected to %s", addr)
return nil
}
// StartSession starts an SSH session with PTY
func (c *Client) StartSession(cols, rows int) error {
if c.sshClient == nil {
return fmt.Errorf("SSH client not connected")
}
session, err := c.sshClient.NewSession()
if err != nil {
return fmt.Errorf("create session: %w", err)
}
c.mu.Lock()
defer c.mu.Unlock()
c.session = session
modes := ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
ssh.VINTR: 3,
ssh.VQUIT: 28,
ssh.VERASE: 127,
}
if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil {
closeWithLog(session, "session after PTY error")
return fmt.Errorf("PTY request: %w", err)
}
c.stdin, err = session.StdinPipe()
if err != nil {
closeWithLog(session, "session after stdin error")
return fmt.Errorf("get stdin: %w", err)
}
c.stdout, err = session.StdoutPipe()
if err != nil {
closeWithLog(session, "session after stdout error")
return fmt.Errorf("get stdout: %w", err)
}
c.stderr, err = session.StderrPipe()
if err != nil {
closeWithLog(session, "session after stderr error")
return fmt.Errorf("get stderr: %w", err)
}
if err := session.Shell(); err != nil {
closeWithLog(session, "session after shell error")
return fmt.Errorf("start shell: %w", err)
}
logrus.Info("SSH: Session started with PTY")
return nil
}
// Write sends data to the SSH session
func (c *Client) Write(data []byte) (int, error) {
c.mu.RLock()
stdin := c.stdin
c.mu.RUnlock()
if stdin == nil {
return 0, fmt.Errorf("SSH session not started")
}
return stdin.Write(data)
}
// Read reads data from the SSH session
func (c *Client) Read(buffer []byte) (int, error) {
c.mu.RLock()
stdout := c.stdout
c.mu.RUnlock()
if stdout == nil {
return 0, fmt.Errorf("SSH session not started")
}
return stdout.Read(buffer)
}
// Resize updates the terminal size
func (c *Client) Resize(cols, rows int) error {
c.mu.RLock()
session := c.session
c.mu.RUnlock()
if session == nil {
return fmt.Errorf("SSH session not started")
}
return session.WindowChange(rows, cols)
}
// Close closes the SSH connection
func (c *Client) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.session != nil {
closeWithLog(c.session, "SSH session")
c.session = nil
}
if c.stdin != nil {
closeWithLog(c.stdin, "stdin")
c.stdin = nil
}
c.stdout = nil
c.stderr = nil
if c.sshClient != nil {
err := c.sshClient.Close()
c.sshClient = nil
return err
}
return nil
}

View File

@@ -0,0 +1,76 @@
package ssh
import (
"io"
"syscall/js"
"github.com/sirupsen/logrus"
)
// CreateJSInterface creates a JavaScript interface for the SSH client
func CreateJSInterface(client *Client) js.Value {
jsInterface := js.Global().Get("Object").Call("create", js.Null())
jsInterface.Set("write", js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf(false)
}
data := args[0]
var bytes []byte
if data.Type() == js.TypeString {
bytes = []byte(data.String())
} else {
uint8Array := js.Global().Get("Uint8Array").New(data)
length := uint8Array.Get("length").Int()
bytes = make([]byte, length)
js.CopyBytesToGo(bytes, uint8Array)
}
_, err := client.Write(bytes)
return js.ValueOf(err == nil)
}))
jsInterface.Set("resize", js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 2 {
return js.ValueOf(false)
}
cols := args[0].Int()
rows := args[1].Int()
err := client.Resize(cols, rows)
return js.ValueOf(err == nil)
}))
jsInterface.Set("close", js.FuncOf(func(this js.Value, args []js.Value) any {
client.Close()
return js.Undefined()
}))
go readLoop(client, jsInterface)
return jsInterface
}
func readLoop(client *Client, jsInterface js.Value) {
buffer := make([]byte, 4096)
for {
n, err := client.Read(buffer)
if err != nil {
if err != io.EOF {
logrus.Debugf("SSH read error: %v", err)
}
if onclose := jsInterface.Get("onclose"); !onclose.IsUndefined() {
onclose.Invoke()
}
client.Close()
return
}
if ondata := jsInterface.Get("ondata"); !ondata.IsUndefined() {
uint8Array := js.Global().Get("Uint8Array").New(n)
js.CopyBytesToJS(uint8Array, buffer[:n])
ondata.Invoke(uint8Array)
}
}
}

View File

@@ -0,0 +1,48 @@
package ssh
import (
"crypto/x509"
"encoding/pem"
"fmt"
"strings"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)
// parseSSHPrivateKey parses a private key in either SSH or PKCS8 format
func parseSSHPrivateKey(keyPEM []byte) (ssh.Signer, error) {
keyStr := string(keyPEM)
if !strings.Contains(keyStr, "-----BEGIN") {
keyPEM = []byte("-----BEGIN PRIVATE KEY-----\n" + keyStr + "\n-----END PRIVATE KEY-----")
}
signer, err := ssh.ParsePrivateKey(keyPEM)
if err == nil {
return signer, nil
}
logrus.Debugf("SSH: Failed to parse as SSH format: %v", err)
block, _ := pem.Decode(keyPEM)
if block == nil {
keyPreview := string(keyPEM)
if len(keyPreview) > 100 {
keyPreview = keyPreview[:100]
}
return nil, fmt.Errorf("decode PEM block from key: %s", keyPreview)
}
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
logrus.Debugf("SSH: Failed to parse as PKCS8: %v", err)
if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
return ssh.NewSignerFromKey(rsaKey)
}
if ecKey, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
return ssh.NewSignerFromKey(ecKey)
}
return nil, fmt.Errorf("parse private key: %w", err)
}
return ssh.NewSignerFromKey(key)
}

View File

@@ -38,7 +38,8 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
return nil, fmt.Errorf("parsing url: %w", err)
}
var opts []grpc.DialOption
if parsedURL.Scheme == "https" {
tlsEnabled := parsedURL.Scheme == "https"
if tlsEnabled {
certPool, err := x509.SystemCertPool()
if err != nil || certPool == nil {
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
@@ -53,7 +54,7 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
}
opts = append(opts,
nbgrpc.WithCustomDialer(),
nbgrpc.WithCustomDialer(tlsEnabled),
grpc.WithIdleTimeout(interval*2),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,

2
go.mod
View File

@@ -38,7 +38,7 @@ require (
github.com/c-robinson/iplib v1.0.3
github.com/caddyserver/certmagic v0.21.3
github.com/cilium/ebpf v0.15.0
github.com/coder/websocket v1.8.12
github.com/coder/websocket v1.8.13
github.com/coreos/go-iptables v0.7.0
github.com/creack/pty v1.1.18
github.com/eko/gocache/lib/v4 v4.2.0

4
go.sum
View File

@@ -140,8 +140,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFquP7hsFII=
github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=

View File

@@ -73,7 +73,12 @@ func (l *Listener) Shutdown(ctx context.Context) error {
func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
connRemoteAddr := remoteAddr(r)
wsConn, err := websocket.Accept(w, r, nil)
acceptOptions := &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
}
wsConn, err := websocket.Accept(w, r, acceptOptions)
if err != nil {
log.Errorf("failed to accept ws connection from %s: %s", connRemoteAddr, err)
return

View File

@@ -223,10 +223,10 @@ func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, erro
c.mu.Unlock()
return nil, fmt.Errorf("relay connection is not established")
}
_, ok := c.conns[peerID]
existingContainer, ok := c.conns[peerID]
if ok {
c.mu.Unlock()
return nil, ErrConnAlreadyExists
return existingContainer.conn, nil
}
c.mu.Unlock()
@@ -235,7 +235,6 @@ func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, erro
return nil, err
}
c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID)
msgChannel := make(chan Msg, 100)
c.mu.Lock()
@@ -249,11 +248,11 @@ func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, erro
c.muInstanceURL.Unlock()
conn := NewConn(c, peerID, msgChannel, instanceURL)
_, ok = c.conns[peerID]
existingContainer, ok = c.conns[peerID]
if ok {
c.mu.Unlock()
_ = conn.Close()
return nil, ErrConnAlreadyExists
return existingContainer.conn, nil
}
c.conns[peerID] = newConnContainer(c.log, conn, msgChannel)
c.mu.Unlock()
@@ -377,7 +376,6 @@ func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internal
buf := *bufPtr
n, errExit = relayConn.Read(buf)
if errExit != nil {
c.log.Infof("start to Relay read loop exit")
c.mu.Lock()
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Errorf("failed to read message from relay server: %s", errExit)
@@ -468,12 +466,24 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe
c.bufPool.Put(bufPtr)
return false
}
container, ok := c.conns[*peerID]
c.mu.Unlock()
if !ok {
c.log.Errorf("peer not found: %s", peerID.String())
c.bufPool.Put(bufPtr)
return true
// Try to create a connection for this peer to handle incoming messages
msgChannel := make(chan Msg, 100)
c.muInstanceURL.Lock()
instanceURL := c.instanceURL
c.muInstanceURL.Unlock()
conn := NewConn(c, *peerID, msgChannel, instanceURL)
c.mu.Lock()
// Check again if connection was created while we were creating it
if _, exists := c.conns[*peerID]; !exists {
c.conns[*peerID] = newConnContainer(c.log, conn, msgChannel)
}
container = c.conns[*peerID]
c.mu.Unlock()
}
msg := Msg{
bufPool: c.bufPool,

View File

@@ -38,8 +38,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
}
func (c *Conn) Write(b []byte) (n int, err error) {
err = c.Conn.Write(c.ctx, websocket.MessageBinary, b)
return 0, err
return 0, c.Conn.Write(c.ctx, websocket.MessageBinary, b)
}
func (c *Conn) RemoteAddr() net.Addr {

View File

@@ -0,0 +1,11 @@
//go:build !js
package ws
import "github.com/coder/websocket"
func createDialOptions() *websocket.DialOptions {
return &websocket.DialOptions{
HTTPClient: httpClientNbDialer(),
}
}

View File

@@ -0,0 +1,10 @@
//go:build js
package ws
import "github.com/coder/websocket"
func createDialOptions() *websocket.DialOptions {
// WASM version doesn't support HTTPClient
return &websocket.DialOptions{}
}

View File

@@ -32,9 +32,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
return nil, err
}
opts := &websocket.DialOptions{
HTTPClient: httpClientNbDialer(),
}
opts := createDialOptions()
parsedURL, err := url.Parse(wsURL)
if err != nil {

View File

@@ -4,15 +4,8 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"os/user"
"runtime"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
@@ -21,35 +14,9 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/util/embeddedroots"
nbnet "github.com/netbirdio/netbird/util/net"
)
func WithCustomDialer() grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
if err != nil {
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
}
// the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" {
log.Debug("Not running as root, using standard dialer")
dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr)
}
}
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
}
return conn, nil
})
}
// grpcDialBackoff is the backoff mechanism for the grpc calls
// Backoff returns a backoff configuration for gRPC calls
func Backoff(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff()
b.MaxElapsedTime = 10 * time.Second
@@ -57,6 +24,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(b, ctx)
}
// CreateConnection creates a gRPC client connection with the appropriate transport options
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled {
@@ -78,7 +46,7 @@ func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
connCtx,
addr,
transportOption,
WithCustomDialer(),
WithCustomDialer(tlsEnabled),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,

View File

@@ -0,0 +1,44 @@
//go:build !js
package grpc
import (
"context"
"fmt"
"net"
"os/user"
"runtime"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
nbnet "github.com/netbirdio/netbird/util/net"
)
func WithCustomDialer(tlsEnabled bool) grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
if err != nil {
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
}
// the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" {
log.Debug("Not running as root, using standard dialer")
dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr)
}
}
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
}
return conn, nil
})
}

12
util/grpc/dialer_js.go Normal file
View File

@@ -0,0 +1,12 @@
package grpc
import (
"google.golang.org/grpc"
"github.com/netbirdio/netbird/util/wsproxy/client"
)
// WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments.
func WithCustomDialer(tlsEnabled bool) grpc.DialOption {
return client.WithWebSocketDialer(tlsEnabled)
}

8
util/util_js.go Normal file
View File

@@ -0,0 +1,8 @@
//go:build js
package util
// IsAdmin returns false for WASM as there's no admin concept in browser
func IsAdmin() bool {
return false
}

View File

@@ -0,0 +1,171 @@
package client
import (
"context"
"fmt"
"net"
"sync"
"syscall/js"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/util/wsproxy"
)
const dialTimeout = 30 * time.Second
// websocketConn wraps a JavaScript WebSocket to implement net.Conn
type websocketConn struct {
ws js.Value
remoteAddr string
messages chan []byte
readBuf []byte
ctx context.Context
cancel context.CancelFunc
mu sync.Mutex
}
func (c *websocketConn) Read(b []byte) (int, error) {
c.mu.Lock()
if len(c.readBuf) > 0 {
n := copy(b, c.readBuf)
c.readBuf = c.readBuf[n:]
c.mu.Unlock()
return n, nil
}
c.mu.Unlock()
select {
case data := <-c.messages:
n := copy(b, data)
if n < len(data) {
c.mu.Lock()
c.readBuf = data[n:]
c.mu.Unlock()
}
return n, nil
case <-c.ctx.Done():
return 0, c.ctx.Err()
}
}
func (c *websocketConn) Write(b []byte) (int, error) {
select {
case <-c.ctx.Done():
return 0, c.ctx.Err()
default:
}
uint8Array := js.Global().Get("Uint8Array").New(len(b))
js.CopyBytesToJS(uint8Array, b)
c.ws.Call("send", uint8Array)
return len(b), nil
}
func (c *websocketConn) Close() error {
c.cancel()
c.ws.Call("close")
return nil
}
func (c *websocketConn) LocalAddr() net.Addr {
return nil
}
func (c *websocketConn) RemoteAddr() net.Addr {
return stringAddr(c.remoteAddr)
}
func (c *websocketConn) SetDeadline(t time.Time) error {
return nil
}
func (c *websocketConn) SetReadDeadline(t time.Time) error {
return nil
}
func (c *websocketConn) SetWriteDeadline(t time.Time) error {
return nil
}
// stringAddr is a simple net.Addr that returns a string
type stringAddr string
func (s stringAddr) Network() string { return "tcp" }
func (s stringAddr) String() string { return string(s) }
// WithWebSocketDialer returns a gRPC dial option that uses WebSocket transport for JS/WASM environments.
func WithWebSocketDialer(tlsEnabled bool) grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
scheme := "wss"
if !tlsEnabled {
scheme = "ws"
}
wsURL := fmt.Sprintf("%s://%s%s", scheme, addr, wsproxy.ProxyPath)
ws := js.Global().Get("WebSocket").New(wsURL)
connCtx, connCancel := context.WithCancel(context.Background())
conn := &websocketConn{
ws: ws,
remoteAddr: addr,
messages: make(chan []byte, 100),
ctx: connCtx,
cancel: connCancel,
}
ws.Set("binaryType", "arraybuffer")
openCh := make(chan struct{})
errorCh := make(chan error, 1)
ws.Set("onopen", js.FuncOf(func(this js.Value, args []js.Value) any {
close(openCh)
return nil
}))
ws.Set("onerror", js.FuncOf(func(this js.Value, args []js.Value) any {
select {
case errorCh <- wsproxy.ErrConnectionFailed:
default:
}
return nil
}))
ws.Set("onmessage", js.FuncOf(func(this js.Value, args []js.Value) any {
event := args[0]
data := event.Get("data")
uint8Array := js.Global().Get("Uint8Array").New(data)
length := uint8Array.Get("length").Int()
bytes := make([]byte, length)
js.CopyBytesToGo(bytes, uint8Array)
select {
case conn.messages <- bytes:
default:
log.Warnf("gRPC WebSocket message dropped for %s - buffer full", addr)
}
return nil
}))
ws.Set("onclose", js.FuncOf(func(this js.Value, args []js.Value) any {
conn.cancel()
return nil
}))
select {
case <-openCh:
return conn, nil
case err := <-errorCh:
return nil, err
case <-ctx.Done():
ws.Call("close")
return nil, ctx.Err()
case <-time.After(dialTimeout):
ws.Call("close")
return nil, wsproxy.ErrConnectionTimeout
}
})
}

13
util/wsproxy/constants.go Normal file
View File

@@ -0,0 +1,13 @@
package wsproxy
import "errors"
// ProxyPath is the standard path where the WebSocket proxy is mounted on servers.
const ProxyPath = "/ws-proxy"
// Common errors
var (
ErrConnectionTimeout = errors.New("WebSocket connection timeout")
ErrConnectionFailed = errors.New("WebSocket connection failed")
ErrBackendUnavailable = errors.New("backend unavailable")
)