mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Add wasm client
This commit is contained in:
0
.gitmodules
vendored
Normal file
0
.gitmodules
vendored
Normal file
@@ -2,6 +2,18 @@ version: 2
|
|||||||
|
|
||||||
project_name: netbird
|
project_name: netbird
|
||||||
builds:
|
builds:
|
||||||
|
- id: netbird-wasm
|
||||||
|
dir: client/wasm/cmd
|
||||||
|
binary: netbird.wasm
|
||||||
|
env: [GOOS=js, GOARCH=wasm, CGO_ENABLED=0]
|
||||||
|
goos:
|
||||||
|
- js
|
||||||
|
goarch:
|
||||||
|
- wasm
|
||||||
|
ldflags:
|
||||||
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
- id: netbird
|
- id: netbird
|
||||||
dir: client
|
dir: client
|
||||||
binary: netbird
|
binary: netbird
|
||||||
@@ -115,6 +127,13 @@ archives:
|
|||||||
- builds:
|
- builds:
|
||||||
- netbird
|
- netbird
|
||||||
- netbird-static
|
- netbird-static
|
||||||
|
- id: netbird-wasm
|
||||||
|
builds:
|
||||||
|
- netbird-wasm
|
||||||
|
name_template: "{{ .ProjectName }}_wasm_{{ .Version }}"
|
||||||
|
format: tar.gz
|
||||||
|
files:
|
||||||
|
- none*
|
||||||
|
|
||||||
nfpms:
|
nfpms:
|
||||||
- maintainer: Netbird <dev@netbird.io>
|
- maintainer: Netbird <dev@netbird.io>
|
||||||
|
|||||||
8
client/cmd/debug_js.go
Normal file
8
client/cmd/debug_js.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// SetupDebugHandler is a no-op for WASM
|
||||||
|
func SetupDebugHandler(context.Context, interface{}, interface{}, interface{}, string) {
|
||||||
|
// Debug handler not needed for WASM
|
||||||
|
}
|
||||||
@@ -23,23 +23,27 @@ import (
|
|||||||
|
|
||||||
var ErrClientAlreadyStarted = errors.New("client already started")
|
var ErrClientAlreadyStarted = errors.New("client already started")
|
||||||
var ErrClientNotStarted = errors.New("client not started")
|
var ErrClientNotStarted = errors.New("client not started")
|
||||||
|
var ErrConfigNotInitialized = errors.New("config not initialized")
|
||||||
|
|
||||||
// Client manages a netbird embedded client instance
|
// Client manages a netbird embedded client instance.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
deviceName string
|
deviceName string
|
||||||
config *profilemanager.Config
|
config *profilemanager.Config
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
setupKey string
|
setupKey string
|
||||||
|
jwtToken string
|
||||||
connect *internal.ConnectClient
|
connect *internal.ConnectClient
|
||||||
}
|
}
|
||||||
|
|
||||||
// Options configures a new Client
|
// Options configures a new Client.
|
||||||
type Options struct {
|
type Options struct {
|
||||||
// DeviceName is this peer's name in the network
|
// DeviceName is this peer's name in the network
|
||||||
DeviceName string
|
DeviceName string
|
||||||
// SetupKey is used for authentication
|
// SetupKey is used for authentication
|
||||||
SetupKey string
|
SetupKey string
|
||||||
|
// JWTToken is used for JWT-based authentication
|
||||||
|
JWTToken string
|
||||||
// ManagementURL overrides the default management server URL
|
// ManagementURL overrides the default management server URL
|
||||||
ManagementURL string
|
ManagementURL string
|
||||||
// PreSharedKey is the pre-shared key for the WireGuard interface
|
// PreSharedKey is the pre-shared key for the WireGuard interface
|
||||||
@@ -58,8 +62,15 @@ type Options struct {
|
|||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new netbird embedded client
|
// New creates a new netbird embedded client.
|
||||||
func New(opts Options) (*Client, error) {
|
func New(opts Options) (*Client, error) {
|
||||||
|
if opts.SetupKey == "" && opts.JWTToken == "" {
|
||||||
|
return nil, fmt.Errorf("either SetupKey or JWTToken must be provided")
|
||||||
|
}
|
||||||
|
if opts.SetupKey != "" && opts.JWTToken != "" {
|
||||||
|
return nil, fmt.Errorf("cannot specify both SetupKey and JWTToken")
|
||||||
|
}
|
||||||
|
|
||||||
if opts.LogOutput != nil {
|
if opts.LogOutput != nil {
|
||||||
logrus.SetOutput(opts.LogOutput)
|
logrus.SetOutput(opts.LogOutput)
|
||||||
}
|
}
|
||||||
@@ -110,6 +121,7 @@ func New(opts Options) (*Client, error) {
|
|||||||
return &Client{
|
return &Client{
|
||||||
deviceName: opts.DeviceName,
|
deviceName: opts.DeviceName,
|
||||||
setupKey: opts.SetupKey,
|
setupKey: opts.SetupKey,
|
||||||
|
jwtToken: opts.JWTToken,
|
||||||
config: config,
|
config: config,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -126,7 +138,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
ctx := internal.CtxInitState(context.Background())
|
ctx := internal.CtxInitState(context.Background())
|
||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||||
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
|
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
||||||
return fmt.Errorf("login: %w", err)
|
return fmt.Errorf("login: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -187,6 +199,16 @@ func (c *Client) Stop(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetConfig returns a copy of the internal client config.
|
||||||
|
func (c *Client) GetConfig() (profilemanager.Config, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
if c.config == nil {
|
||||||
|
return profilemanager.Config{}, ErrConfigNotInitialized
|
||||||
|
}
|
||||||
|
return *c.config, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Dial dials a network address in the netbird network.
|
// Dial dials a network address in the netbird network.
|
||||||
// Not applicable if the userspace networking mode is disabled.
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
@@ -211,7 +233,7 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
|
|||||||
return nsnet.DialContext(ctx, network, address)
|
return nsnet.DialContext(ctx, network, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenTCP listens on the given address in the netbird network
|
// ListenTCP listens on the given address in the netbird network.
|
||||||
// Not applicable if the userspace networking mode is disabled.
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||||
nsnet, addr, err := c.getNet()
|
nsnet, addr, err := c.getNet()
|
||||||
@@ -232,7 +254,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
|||||||
return nsnet.ListenTCP(tcpAddr)
|
return nsnet.ListenTCP(tcpAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenUDP listens on the given address in the netbird network
|
// ListenUDP listens on the given address in the netbird network.
|
||||||
// Not applicable if the userspace networking mode is disabled.
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
|
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
|
||||||
nsnet, addr, err := c.getNet()
|
nsnet, addr, err := c.getNet()
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
package bind
|
package bind
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
69
client/iface/bind/ice_bind_common.go
Normal file
69
client/iface/bind/ice_bind_common.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
//go:build js
|
||||||
|
|
||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pion/transport/v3"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RecvMessage represents a received message
|
||||||
|
type RecvMessage struct {
|
||||||
|
Endpoint *Endpoint
|
||||||
|
Buffer []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// ICEBind is a bind implementation that uses ICE candidates for connectivity
|
||||||
|
type ICEBind struct {
|
||||||
|
address wgaddr.Address
|
||||||
|
filterFn FilterFn
|
||||||
|
endpoints map[netip.Addr]net.Conn
|
||||||
|
endpointsMu sync.Mutex
|
||||||
|
udpMux *UniversalUDPMuxDefault
|
||||||
|
muUDPMux sync.Mutex
|
||||||
|
transportNet transport.Net
|
||||||
|
receiverCreated bool
|
||||||
|
activityRecorder *ActivityRecorder
|
||||||
|
RecvChan chan RecvMessage
|
||||||
|
closed bool // Flag to signal that bind is closed
|
||||||
|
closedMu sync.Mutex
|
||||||
|
mtu uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewICEBind creates a new ICEBind instance
|
||||||
|
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||||
|
return &ICEBind{
|
||||||
|
address: address,
|
||||||
|
transportNet: transportNet,
|
||||||
|
filterFn: filterFn,
|
||||||
|
endpoints: make(map[netip.Addr]net.Conn),
|
||||||
|
RecvChan: make(chan RecvMessage, 100),
|
||||||
|
activityRecorder: NewActivityRecorder(),
|
||||||
|
mtu: mtu,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFilter updates the filter function
|
||||||
|
func (s *ICEBind) SetFilter(filter FilterFn) {
|
||||||
|
s.filterFn = filter
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAddress returns the bind address
|
||||||
|
func (s *ICEBind) GetAddress() wgaddr.Address {
|
||||||
|
return s.address
|
||||||
|
}
|
||||||
|
|
||||||
|
// ActivityRecorder returns the activity recorder
|
||||||
|
func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
|
||||||
|
return s.activityRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MTU returns the maximum transmission unit
|
||||||
|
func (s *ICEBind) MTU() uint16 {
|
||||||
|
return s.mtu
|
||||||
|
}
|
||||||
141
client/iface/bind/ice_bind_js.go
Normal file
141
client/iface/bind/ice_bind_js.go
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
//go:build js
|
||||||
|
|
||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.zx2c4.com/wireguard/conn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetICEMux returns a dummy UDP mux for WASM since browsers don't support UDP.
|
||||||
|
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open creates a receive function for handling relay packets in WASM.
|
||||||
|
func (s *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||||
|
log.Debugf("Open: creating receive function for port %d", uport)
|
||||||
|
|
||||||
|
s.closedMu.Lock()
|
||||||
|
s.closed = false
|
||||||
|
s.closedMu.Unlock()
|
||||||
|
|
||||||
|
if !s.receiverCreated {
|
||||||
|
s.receiverCreated = true
|
||||||
|
log.Debugf("Open: first call, setting receiverCreated=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
receiveFn := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
|
||||||
|
s.closedMu.Lock()
|
||||||
|
if s.closed {
|
||||||
|
s.closedMu.Unlock()
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
s.closedMu.Unlock()
|
||||||
|
|
||||||
|
msg, ok := <-s.RecvChan
|
||||||
|
if !ok {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(bufs[0], msg.Buffer)
|
||||||
|
sizes[0] = len(msg.Buffer)
|
||||||
|
eps[0] = conn.Endpoint(msg.Endpoint)
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Open: receive function created, returning port %d", uport)
|
||||||
|
return []conn.ReceiveFunc{receiveFn}, uport, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMark is not applicable in WASM/browser environment.
|
||||||
|
func (s *ICEBind) SetMark(_ uint32) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send forwards packets through the relay connection for WASM.
|
||||||
|
func (s *ICEBind) Send(bufs [][]byte, ep conn.Endpoint) error {
|
||||||
|
if ep == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fakeIP := ep.DstIP()
|
||||||
|
|
||||||
|
s.endpointsMu.Lock()
|
||||||
|
relayConn, ok := s.endpoints[fakeIP]
|
||||||
|
s.endpointsMu.Unlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, buf := range bufs {
|
||||||
|
if _, err := relayConn.Write(buf); err != nil {
|
||||||
|
log.Errorf("Send: failed to write to relay: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEndpoint stores a relay endpoint for a fake IP.
|
||||||
|
func (s *ICEBind) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
|
||||||
|
s.endpointsMu.Lock()
|
||||||
|
defer s.endpointsMu.Unlock()
|
||||||
|
|
||||||
|
if oldConn, exists := s.endpoints[fakeIP]; exists {
|
||||||
|
if oldConn != conn {
|
||||||
|
log.Debugf("SetEndpoint: replacing existing connection for %s", fakeIP)
|
||||||
|
if err := oldConn.Close(); err != nil {
|
||||||
|
log.Debugf("SetEndpoint: error closing old connection: %v", err)
|
||||||
|
}
|
||||||
|
s.endpoints[fakeIP] = conn
|
||||||
|
} else {
|
||||||
|
log.Tracef("SetEndpoint: same connection already set for %s, skipping", fakeIP)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Debugf("SetEndpoint: setting new relay connection for fake IP %s", fakeIP)
|
||||||
|
s.endpoints[fakeIP] = conn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveEndpoint removes a relay endpoint.
|
||||||
|
func (s *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
|
||||||
|
s.endpointsMu.Lock()
|
||||||
|
defer s.endpointsMu.Unlock()
|
||||||
|
delete(s.endpoints, fakeIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSize returns the batch size for WASM.
|
||||||
|
func (s *ICEBind) BatchSize() int {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseEndpoint parses an endpoint string.
|
||||||
|
func (s *ICEBind) ParseEndpoint(s2 string) (conn.Endpoint, error) {
|
||||||
|
addrPort, err := netip.ParseAddrPort(s2)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("ParseEndpoint: failed to parse %s: %v", s2, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ep := &Endpoint{AddrPort: addrPort}
|
||||||
|
return ep, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the ICEBind.
|
||||||
|
func (s *ICEBind) Close() error {
|
||||||
|
log.Debugf("Close: closing ICEBind (receiverCreated=%v)", s.receiverCreated)
|
||||||
|
|
||||||
|
s.closedMu.Lock()
|
||||||
|
s.closed = true
|
||||||
|
s.closedMu.Unlock()
|
||||||
|
|
||||||
|
s.receiverCreated = false
|
||||||
|
|
||||||
|
log.Debugf("Close: returning from Close")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build linux || windows || freebsd
|
//go:build linux || windows || freebsd || js || wasip1
|
||||||
|
|
||||||
package configurer
|
package configurer
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !windows
|
//go:build !windows && !js
|
||||||
|
|
||||||
package configurer
|
package configurer
|
||||||
|
|
||||||
|
|||||||
23
client/iface/configurer/uapi_js.go
Normal file
23
client/iface/configurer/uapi_js.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package configurer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type noopListener struct{}
|
||||||
|
|
||||||
|
func (n *noopListener) Accept() (net.Conn, error) {
|
||||||
|
return nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *noopListener) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *noopListener) Addr() net.Addr {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func openUAPI(deviceName string) (net.Listener, error) {
|
||||||
|
return &noopListener{}, nil
|
||||||
|
}
|
||||||
6
client/iface/iface_destroy_js.go
Normal file
6
client/iface/iface_destroy_js.go
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
package iface
|
||||||
|
|
||||||
|
// Destroy is a no-op on WASM
|
||||||
|
func (w *WGIface) Destroy() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
27
client/iface/iface_new_js.go
Normal file
27
client/iface/iface_new_js.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode)
|
||||||
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||||
|
|
||||||
|
wgIface := &WGIface{
|
||||||
|
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
|
||||||
|
userspaceBind: true,
|
||||||
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||||
|
}
|
||||||
|
|
||||||
|
return wgIface, nil
|
||||||
|
}
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
package netstack
|
package netstack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
12
client/iface/netstack/env_js.go
Normal file
12
client/iface/netstack/env_js.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package netstack
|
||||||
|
|
||||||
|
const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE"
|
||||||
|
|
||||||
|
// IsEnabled always returns true for js since it's the only mode available
|
||||||
|
func IsEnabled() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListenAddr() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
5
client/internal/dns/server_js.go
Normal file
5
client/internal/dns/server_js.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
func (s *DefaultServer) initialize() (hostManager, error) {
|
||||||
|
return &noopHostConfigurator{}, nil
|
||||||
|
}
|
||||||
19
client/internal/dns/unclean_shutdown_js.go
Normal file
19
client/internal/dns/unclean_shutdown_js.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ShutdownState struct{}
|
||||||
|
|
||||||
|
func (s *ShutdownState) Name() string {
|
||||||
|
return "dns_state"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) Cleanup() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ShutdownState) RestoreUncleanShutdownConfigs(context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -450,14 +450,7 @@ func (e *Engine) Start() error {
|
|||||||
return fmt.Errorf("initialize dns server: %w", err)
|
return fmt.Errorf("initialize dns server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
iceCfg := icemaker.Config{
|
iceCfg := e.createICEConfig()
|
||||||
StunTurn: &e.stunTurn,
|
|
||||||
InterfaceBlackList: e.config.IFaceBlackList,
|
|
||||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
|
||||||
UDPMux: e.udpMux.UDPMuxDefault,
|
|
||||||
UDPMuxSrflx: e.udpMux,
|
|
||||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
|
||||||
}
|
|
||||||
|
|
||||||
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface)
|
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface)
|
||||||
e.connMgr.Start(e.ctx)
|
e.connMgr.Start(e.ctx)
|
||||||
@@ -1288,14 +1281,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
|||||||
Addr: e.getRosenpassAddr(),
|
Addr: e.getRosenpassAddr(),
|
||||||
PermissiveMode: e.config.RosenpassPermissive,
|
PermissiveMode: e.config.RosenpassPermissive,
|
||||||
},
|
},
|
||||||
ICEConfig: icemaker.Config{
|
ICEConfig: e.createICEConfig(),
|
||||||
StunTurn: &e.stunTurn,
|
|
||||||
InterfaceBlackList: e.config.IFaceBlackList,
|
|
||||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
|
||||||
UDPMux: e.udpMux.UDPMuxDefault,
|
|
||||||
UDPMuxSrflx: e.udpMux,
|
|
||||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
serviceDependencies := peer.ServiceDependencies{
|
serviceDependencies := peer.ServiceDependencies{
|
||||||
|
|||||||
19
client/internal/engine_generic.go
Normal file
19
client/internal/engine_generic.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createICEConfig creates ICE configuration for non-WASM environments
|
||||||
|
func (e *Engine) createICEConfig() icemaker.Config {
|
||||||
|
return icemaker.Config{
|
||||||
|
StunTurn: &e.stunTurn,
|
||||||
|
InterfaceBlackList: e.config.IFaceBlackList,
|
||||||
|
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||||
|
UDPMux: e.udpMux.UDPMuxDefault,
|
||||||
|
UDPMuxSrflx: e.udpMux,
|
||||||
|
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||||
|
}
|
||||||
|
}
|
||||||
24
client/internal/engine_js.go
Normal file
24
client/internal/engine_js.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
//go:build js
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createICEConfig creates ICE configuration for WASM environment.
|
||||||
|
func (e *Engine) createICEConfig() icemaker.Config {
|
||||||
|
cfg := icemaker.Config{
|
||||||
|
StunTurn: &e.stunTurn,
|
||||||
|
InterfaceBlackList: e.config.IFaceBlackList,
|
||||||
|
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||||
|
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.udpMux != nil {
|
||||||
|
cfg.UDPMux = e.udpMux.UDPMuxDefault
|
||||||
|
cfg.UDPMuxSrflx = e.udpMux
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
12
client/internal/networkmonitor/check_change_js.go
Normal file
12
client/internal/networkmonitor/check_change_js.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package networkmonitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
|
)
|
||||||
|
|
||||||
|
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||||
|
// No-op for WASM - network changes don't apply
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -455,6 +455,14 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if we already have a relay proxy configured
|
||||||
|
if conn.wgProxyRelay != nil {
|
||||||
|
conn.Log.Debugf("Relay proxy already configured, skipping duplicate setup")
|
||||||
|
// Update status to ensure it's connected
|
||||||
|
conn.statusRelay.SetConnected()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
conn.dumpState.RelayConnected()
|
conn.dumpState.RelayConnected()
|
||||||
conn.Log.Debugf("Relay connection has been established, setup the WireGuard")
|
conn.Log.Debugf("Relay connection has been established, setup the WireGuard")
|
||||||
|
|
||||||
@@ -472,19 +480,42 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
if conn.isICEActive() {
|
if conn.isICEActive() {
|
||||||
conn.Log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
|
conn.Log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
|
||||||
conn.setRelayedProxy(wgProxy)
|
conn.setRelayedProxy(wgProxy)
|
||||||
|
// For WASM, we still need to start the proxy and configure WireGuard
|
||||||
|
// because ICE doesn't actually work in browsers and we rely on relay
|
||||||
|
conn.Log.Infof("WASM check: runtime.GOOS=%s, should start proxy=%v", runtime.GOOS, runtime.GOOS == "js")
|
||||||
|
if runtime.GOOS == "js" {
|
||||||
|
conn.Log.Infof("WASM: starting relay proxy and configuring WireGuard despite ICE being 'active'")
|
||||||
|
wgProxy.Work()
|
||||||
|
|
||||||
|
// Configure WireGuard to use the relay proxy endpoint
|
||||||
|
endpointAddr := wgProxy.EndpointAddr()
|
||||||
|
conn.Log.Infof("WASM: Configuring WireGuard endpoint to proxy address: %v", endpointAddr)
|
||||||
|
if err := conn.configureWGEndpoint(endpointAddr, rci.rosenpassPubKey); err != nil {
|
||||||
|
conn.Log.Errorf("WASM: Failed to update WireGuard peer configuration: %v", err)
|
||||||
|
} else {
|
||||||
|
conn.Log.Infof("WASM: Successfully configured WireGuard endpoint to use proxy at %v", endpointAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update connection priority to relay for WASM
|
||||||
|
conn.currentConnPriority = conntype.Relay
|
||||||
|
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||||
|
}
|
||||||
conn.statusRelay.SetConnected()
|
conn.statusRelay.SetConnected()
|
||||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
wgProxy.Work()
|
wgProxy.Work()
|
||||||
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
|
endpointAddr := wgProxy.EndpointAddr()
|
||||||
|
conn.Log.Infof("Configuring WireGuard endpoint to proxy address: %v", endpointAddr)
|
||||||
|
if err := conn.configureWGEndpoint(endpointAddr, rci.rosenpassPubKey); err != nil {
|
||||||
if err := wgProxy.CloseConn(); err != nil {
|
if err := wgProxy.CloseConn(); err != nil {
|
||||||
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
||||||
}
|
}
|
||||||
conn.Log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
conn.Log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
conn.Log.Infof("Successfully configured WireGuard endpoint to use proxy at %v", endpointAddr)
|
||||||
|
|
||||||
conn.wgWatcherWg.Add(1)
|
conn.wgWatcherWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -663,6 +694,21 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// In WASM with forced relay, ICE is not used, so skip ICE check
|
||||||
|
if runtime.GOOS == "js" && os.Getenv("NB_FORCE_RELAY") == "true" {
|
||||||
|
// Only check relay connection status
|
||||||
|
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||||
|
relayConnected := conn.statusRelay.Get() != worker.StatusDisconnected
|
||||||
|
if !relayConnected {
|
||||||
|
conn.Log.Tracef("WASM: relay not connected for connectivity check")
|
||||||
|
}
|
||||||
|
return relayConnected
|
||||||
|
}
|
||||||
|
// If relay is not supported, consider it connected to avoid reconnect loop
|
||||||
|
conn.Log.Tracef("WASM: relay not supported, returning true to avoid reconnect")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package peer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -55,6 +54,22 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
}
|
}
|
||||||
w.relaySupportedOnRemotePeer.Store(true)
|
w.relaySupportedOnRemotePeer.Store(true)
|
||||||
|
|
||||||
|
// Check if we already have an active relay connection
|
||||||
|
w.relayLock.Lock()
|
||||||
|
existingConn := w.relayedConn
|
||||||
|
w.relayLock.Unlock()
|
||||||
|
|
||||||
|
if existingConn != nil {
|
||||||
|
w.log.Debugf("relay connection already exists for peer %s, reusing it", w.config.Key)
|
||||||
|
// Connection exists, just ensure proxy is set up if needed
|
||||||
|
go w.conn.onRelayConnectionIsReady(RelayConnInfo{
|
||||||
|
relayedConn: existingConn,
|
||||||
|
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
|
||||||
|
rosenpassAddr: remoteOfferAnswer.RosenpassAddr,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// the relayManager will return with error in case if the connection has lost with relay server
|
// the relayManager will return with error in case if the connection has lost with relay server
|
||||||
currentRelayAddress, err := w.relayManager.RelayInstanceAddress()
|
currentRelayAddress, err := w.relayManager.RelayInstanceAddress()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -66,15 +81,24 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
|
|
||||||
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key)
|
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
|
// The relay manager never actually returns ErrConnAlreadyExists - it returns
|
||||||
w.log.Debugf("handled offer by reusing existing relay connection")
|
// the existing connection with nil error. This error handling is for other failures.
|
||||||
return
|
|
||||||
}
|
|
||||||
w.log.Errorf("failed to open connection via Relay: %s", err)
|
w.log.Errorf("failed to open connection via Relay: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.relayLock.Lock()
|
w.relayLock.Lock()
|
||||||
|
// Check if we already stored this connection (might happen if OpenConn returned existing)
|
||||||
|
if w.relayedConn != nil && w.relayedConn == relayedConn {
|
||||||
|
w.relayLock.Unlock()
|
||||||
|
w.log.Debugf("OpenConn returned the same connection we already have for peer %s", w.config.Key)
|
||||||
|
go w.conn.onRelayConnectionIsReady(RelayConnInfo{
|
||||||
|
relayedConn: relayedConn,
|
||||||
|
rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
|
||||||
|
rosenpassAddr: remoteOfferAnswer.RosenpassAddr,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
w.relayedConn = relayedConn
|
w.relayedConn = relayedConn
|
||||||
w.relayLock.Unlock()
|
w.relayLock.Unlock()
|
||||||
|
|
||||||
@@ -123,11 +147,16 @@ func (w *WorkerRelay) CloseConn() {
|
|||||||
if err := w.relayedConn.Close(); err != nil {
|
if err := w.relayedConn.Close(); err != nil {
|
||||||
w.log.Warnf("failed to close relay connection: %v", err)
|
w.log.Warnf("failed to close relay connection: %v", err)
|
||||||
}
|
}
|
||||||
|
// Clear the stored connection to allow reopening
|
||||||
|
w.relayedConn = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerRelay) onWGDisconnected() {
|
func (w *WorkerRelay) onWGDisconnected() {
|
||||||
w.relayLock.Lock()
|
w.relayLock.Lock()
|
||||||
_ = w.relayedConn.Close()
|
if w.relayedConn != nil {
|
||||||
|
_ = w.relayedConn.Close()
|
||||||
|
w.relayedConn = nil
|
||||||
|
}
|
||||||
w.relayLock.Unlock()
|
w.relayLock.Unlock()
|
||||||
|
|
||||||
w.conn.onRelayDisconnected()
|
w.conn.onRelayDisconnected()
|
||||||
@@ -148,6 +177,11 @@ func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress st
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerRelay) onRelayClientDisconnected() {
|
func (w *WorkerRelay) onRelayClientDisconnected() {
|
||||||
|
// Clear the stored connection when relay disconnects
|
||||||
|
w.relayLock.Lock()
|
||||||
|
w.relayedConn = nil
|
||||||
|
w.relayLock.Unlock()
|
||||||
|
|
||||||
w.wgWatcher.DisableWgWatcher()
|
w.wgWatcher.DisableWgWatcher()
|
||||||
go w.conn.onRelayDisconnected()
|
go w.conn.onRelayDisconnected()
|
||||||
}
|
}
|
||||||
|
|||||||
48
client/internal/routemanager/systemops/systemops_js.go
Normal file
48
client/internal/routemanager/systemops/systemops_js.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package systemops
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrRouteNotSupported = errors.New("route operations not supported on js")
|
||||||
|
|
||||||
|
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||||
|
return ErrRouteNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||||
|
return ErrRouteNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetRoutesFromTable() ([]netip.Prefix, error) {
|
||||||
|
return []netip.Prefix{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||||
|
return []netip.Prefix{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDetailedRoutesFromTable returns empty routes for WASM.
|
||||||
|
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
|
||||||
|
return []DetailedRoute{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||||
|
return ErrRouteNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||||
|
return ErrRouteNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !linux && !ios
|
//go:build !linux && !ios && !js
|
||||||
|
|
||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
package ssh
|
package ssh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
package ssh
|
package ssh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
package ssh
|
package ssh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
package ssh
|
package ssh
|
||||||
|
|
||||||
import "context"
|
import "context"
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
package ssh
|
package ssh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
137
client/ssh/ssh_js.go
Normal file
137
client/ssh/ssh_js.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrSSHNotSupported = errors.New("SSH is not supported in WASM environment")
|
||||||
|
|
||||||
|
// Server is a dummy SSH server interface for WASM.
|
||||||
|
type Server interface {
|
||||||
|
Start() error
|
||||||
|
Stop() error
|
||||||
|
EnableSSH(enabled bool)
|
||||||
|
AddAuthorizedKey(peer string, key string) error
|
||||||
|
RemoveAuthorizedKey(key string)
|
||||||
|
}
|
||||||
|
|
||||||
|
type dummyServer struct{}
|
||||||
|
|
||||||
|
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
|
||||||
|
return &dummyServer{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewServer(addr string) Server {
|
||||||
|
return &dummyServer{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dummyServer) Start() error {
|
||||||
|
return ErrSSHNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dummyServer) Stop() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dummyServer) EnableSSH(enabled bool) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dummyServer) AddAuthorizedKey(peer string, key string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *dummyServer) RemoveAuthorizedKey(key string) {
|
||||||
|
}
|
||||||
|
|
||||||
|
type Client struct{}
|
||||||
|
|
||||||
|
func NewClient(ctx context.Context, addr string, config interface{}, recorder *SessionRecorder) (*Client, error) {
|
||||||
|
return nil, ErrSSHNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) Run(command []string) error {
|
||||||
|
return ErrSSHNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
type SessionRecorder struct{}
|
||||||
|
|
||||||
|
func NewSessionRecorder() *SessionRecorder {
|
||||||
|
return &SessionRecorder{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SessionRecorder) Record(session string, data []byte) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetUserShell() string {
|
||||||
|
return "/bin/sh"
|
||||||
|
}
|
||||||
|
|
||||||
|
func LookupUserInfo(username string) (string, string, error) {
|
||||||
|
return "", "", ErrSSHNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
const DefaultSSHPort = 44338
|
||||||
|
|
||||||
|
const ED25519 = "ed25519"
|
||||||
|
|
||||||
|
func isRoot() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func GeneratePrivateKey(keyType string) ([]byte, error) {
|
||||||
|
if keyType != ED25519 {
|
||||||
|
return nil, errors.New("only ED25519 keys are supported in WASM")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pkcs8Bytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pemBlock := &pem.Block{
|
||||||
|
Type: "PRIVATE KEY",
|
||||||
|
Bytes: pkcs8Bytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
pemBytes := pem.EncodeToMemory(pemBlock)
|
||||||
|
return pemBytes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GeneratePublicKey(privateKey []byte) ([]byte, error) {
|
||||||
|
signer, err := ssh.ParsePrivateKey(privateKey)
|
||||||
|
if err != nil {
|
||||||
|
block, _ := pem.Decode(privateKey)
|
||||||
|
if block != nil {
|
||||||
|
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
signer, err = ssh.NewSignerFromKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pubKeyBytes := ssh.MarshalAuthorizedKey(signer.PublicKey())
|
||||||
|
return []byte(strings.TrimSpace(string(pubKeyBytes))), nil
|
||||||
|
}
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
package ssh
|
package ssh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
234
client/system/info_js.go
Normal file
234
client/system/info_js.go
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
package system
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"syscall/js"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetInfo retrieves system information for WASM environment
|
||||||
|
func GetInfo(_ context.Context) *Info {
|
||||||
|
info := &Info{
|
||||||
|
GoOS: runtime.GOOS,
|
||||||
|
Kernel: runtime.GOARCH,
|
||||||
|
KernelVersion: runtime.GOARCH,
|
||||||
|
Platform: runtime.GOARCH,
|
||||||
|
OS: runtime.GOARCH,
|
||||||
|
Hostname: "wasm-client",
|
||||||
|
CPUs: runtime.NumCPU(),
|
||||||
|
NetbirdVersion: version.NetbirdVersion(),
|
||||||
|
}
|
||||||
|
|
||||||
|
collectBrowserInfo(info)
|
||||||
|
collectLocationInfo(info)
|
||||||
|
|
||||||
|
si := updateStaticInfo()
|
||||||
|
info.SystemSerialNumber = si.SystemSerialNumber
|
||||||
|
info.SystemProductName = si.SystemProductName
|
||||||
|
info.SystemManufacturer = si.SystemManufacturer
|
||||||
|
info.Environment = si.Environment
|
||||||
|
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectBrowserInfo(info *Info) {
|
||||||
|
navigator := js.Global().Get("navigator")
|
||||||
|
if navigator.IsUndefined() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
collectUserAgent(info, navigator)
|
||||||
|
collectPlatform(info, navigator)
|
||||||
|
collectCPUInfo(info, navigator)
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectUserAgent(info *Info, navigator js.Value) {
|
||||||
|
ua := navigator.Get("userAgent")
|
||||||
|
if ua.IsUndefined() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userAgent := ua.String()
|
||||||
|
os, osVersion := parseOSFromUserAgent(userAgent)
|
||||||
|
if os != "" {
|
||||||
|
info.OS = os
|
||||||
|
}
|
||||||
|
if osVersion != "" {
|
||||||
|
info.OSVersion = osVersion
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectPlatform(info *Info, navigator js.Value) {
|
||||||
|
// Try regular platform property
|
||||||
|
if plat := navigator.Get("platform"); !plat.IsUndefined() {
|
||||||
|
if platStr := plat.String(); platStr != "" {
|
||||||
|
info.Platform = platStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try newer userAgentData API for more accurate platform
|
||||||
|
userAgentData := navigator.Get("userAgentData")
|
||||||
|
if userAgentData.IsUndefined() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
platformInfo := userAgentData.Get("platform")
|
||||||
|
if !platformInfo.IsUndefined() {
|
||||||
|
if platStr := platformInfo.String(); platStr != "" {
|
||||||
|
info.Platform = platStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectCPUInfo(info *Info, navigator js.Value) {
|
||||||
|
hardwareConcurrency := navigator.Get("hardwareConcurrency")
|
||||||
|
if !hardwareConcurrency.IsUndefined() {
|
||||||
|
info.CPUs = hardwareConcurrency.Int()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectLocationInfo(info *Info) {
|
||||||
|
location := js.Global().Get("location")
|
||||||
|
if location.IsUndefined() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if host := location.Get("hostname"); !host.IsUndefined() {
|
||||||
|
hostnameStr := host.String()
|
||||||
|
if hostnameStr != "" && hostnameStr != "localhost" {
|
||||||
|
info.Hostname = hostnameStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkFileAndProcess(_ []string) ([]File, error) {
|
||||||
|
return []File{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateStaticInfo() *StaticInfo {
|
||||||
|
si := &StaticInfo{}
|
||||||
|
|
||||||
|
navigator := js.Global().Get("navigator")
|
||||||
|
if !navigator.IsUndefined() {
|
||||||
|
if vendor := navigator.Get("vendor"); !vendor.IsUndefined() {
|
||||||
|
si.SystemManufacturer = vendor.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
if product := navigator.Get("product"); !product.IsUndefined() {
|
||||||
|
si.SystemProductName = product.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
if userAgent := navigator.Get("userAgent"); !userAgent.IsUndefined() {
|
||||||
|
ua := userAgent.String()
|
||||||
|
si.Environment = detectEnvironmentFromUA(ua)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return si
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseOSFromUserAgent(userAgent string) (string, string) {
|
||||||
|
if userAgent == "" {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.Contains(userAgent, "Windows NT"):
|
||||||
|
return parseWindowsVersion(userAgent)
|
||||||
|
case strings.Contains(userAgent, "Mac OS X"):
|
||||||
|
return parseMacOSVersion(userAgent)
|
||||||
|
case strings.Contains(userAgent, "FreeBSD"):
|
||||||
|
return "FreeBSD", ""
|
||||||
|
case strings.Contains(userAgent, "OpenBSD"):
|
||||||
|
return "OpenBSD", ""
|
||||||
|
case strings.Contains(userAgent, "NetBSD"):
|
||||||
|
return "NetBSD", ""
|
||||||
|
case strings.Contains(userAgent, "Linux"):
|
||||||
|
return parseLinuxVersion(userAgent)
|
||||||
|
case strings.Contains(userAgent, "iPhone") || strings.Contains(userAgent, "iPad"):
|
||||||
|
return parseiOSVersion(userAgent)
|
||||||
|
case strings.Contains(userAgent, "CrOS"):
|
||||||
|
return "ChromeOS", ""
|
||||||
|
default:
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseWindowsVersion(userAgent string) (string, string) {
|
||||||
|
switch {
|
||||||
|
case strings.Contains(userAgent, "Windows NT 10.0; Win64; x64"):
|
||||||
|
return "Windows", "10/11"
|
||||||
|
case strings.Contains(userAgent, "Windows NT 10.0"):
|
||||||
|
return "Windows", "10"
|
||||||
|
case strings.Contains(userAgent, "Windows NT 6.3"):
|
||||||
|
return "Windows", "8.1"
|
||||||
|
case strings.Contains(userAgent, "Windows NT 6.2"):
|
||||||
|
return "Windows", "8"
|
||||||
|
case strings.Contains(userAgent, "Windows NT 6.1"):
|
||||||
|
return "Windows", "7"
|
||||||
|
default:
|
||||||
|
return "Windows", "Unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMacOSVersion(userAgent string) (string, string) {
|
||||||
|
idx := strings.Index(userAgent, "Mac OS X ")
|
||||||
|
if idx == -1 {
|
||||||
|
return "macOS", "Unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
versionStart := idx + len("Mac OS X ")
|
||||||
|
versionEnd := strings.Index(userAgent[versionStart:], ")")
|
||||||
|
if versionEnd <= 0 {
|
||||||
|
return "macOS", "Unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
ver := userAgent[versionStart : versionStart+versionEnd]
|
||||||
|
ver = strings.ReplaceAll(ver, "_", ".")
|
||||||
|
return "macOS", ver
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseLinuxVersion(userAgent string) (string, string) {
|
||||||
|
if strings.Contains(userAgent, "Android") {
|
||||||
|
return "Android", extractAndroidVersion(userAgent)
|
||||||
|
}
|
||||||
|
if strings.Contains(userAgent, "Ubuntu") {
|
||||||
|
return "Ubuntu", ""
|
||||||
|
}
|
||||||
|
return "Linux", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseiOSVersion(userAgent string) (string, string) {
|
||||||
|
idx := strings.Index(userAgent, "OS ")
|
||||||
|
if idx == -1 {
|
||||||
|
return "iOS", "Unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
versionStart := idx + 3
|
||||||
|
versionEnd := strings.Index(userAgent[versionStart:], " ")
|
||||||
|
if versionEnd <= 0 {
|
||||||
|
return "iOS", "Unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
ver := userAgent[versionStart : versionStart+versionEnd]
|
||||||
|
ver = strings.ReplaceAll(ver, "_", ".")
|
||||||
|
return "iOS", ver
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractAndroidVersion(userAgent string) string {
|
||||||
|
if idx := strings.Index(userAgent, "Android "); idx != -1 {
|
||||||
|
versionStart := idx + len("Android ")
|
||||||
|
versionEnd := strings.IndexAny(userAgent[versionStart:], ";)")
|
||||||
|
if versionEnd > 0 {
|
||||||
|
return userAgent[versionStart : versionStart+versionEnd]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "Unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
func detectEnvironmentFromUA(_ string) Environment {
|
||||||
|
return Environment{}
|
||||||
|
}
|
||||||
236
client/wasm/cmd/main.go
Normal file
236
client/wasm/cmd/main.go
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"syscall/js"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
netbird "github.com/netbirdio/netbird/client/embed"
|
||||||
|
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||||
|
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
||||||
|
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
clientStartTimeout = 30 * time.Second
|
||||||
|
clientStopTimeout = 10 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
js.Global().Set("NetBirdClient", js.FuncOf(netBirdClientConstructor))
|
||||||
|
|
||||||
|
select {}
|
||||||
|
}
|
||||||
|
|
||||||
|
func startClient(ctx context.Context, nbClient *netbird.Client) error {
|
||||||
|
log.Println("Starting NetBird client...")
|
||||||
|
if err := nbClient.Start(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Println("NetBird client started successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseClientOptions extracts NetBird options from JavaScript object
|
||||||
|
func parseClientOptions(jsOptions js.Value) (netbird.Options, error) {
|
||||||
|
options := netbird.Options{
|
||||||
|
DeviceName: "dashboard-client",
|
||||||
|
LogLevel: "warn",
|
||||||
|
}
|
||||||
|
|
||||||
|
if jwtToken := jsOptions.Get("jwtToken"); !jwtToken.IsNull() && !jwtToken.IsUndefined() {
|
||||||
|
options.JWTToken = jwtToken.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
if setupKey := jsOptions.Get("setupKey"); !setupKey.IsNull() && !setupKey.IsUndefined() {
|
||||||
|
options.SetupKey = setupKey.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.JWTToken == "" && options.SetupKey == "" {
|
||||||
|
return options, fmt.Errorf("either jwtToken or setupKey must be provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
if mgmtURL := jsOptions.Get("managementURL"); !mgmtURL.IsNull() && !mgmtURL.IsUndefined() {
|
||||||
|
mgmtURLStr := mgmtURL.String()
|
||||||
|
if mgmtURLStr != "" {
|
||||||
|
options.ManagementURL = mgmtURLStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if logLevel := jsOptions.Get("logLevel"); !logLevel.IsNull() && !logLevel.IsUndefined() {
|
||||||
|
options.LogLevel = logLevel.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
if deviceName := jsOptions.Get("deviceName"); !deviceName.IsNull() && !deviceName.IsUndefined() {
|
||||||
|
options.DeviceName = deviceName.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
return options, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createStartMethod creates the start method for the client
|
||||||
|
func createStartMethod(client *netbird.Client) js.Func {
|
||||||
|
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), clientStartTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := startClient(ctx, client); err != nil {
|
||||||
|
reject.Invoke(js.ValueOf(err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resolve.Invoke(js.ValueOf(true))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// createStopMethod creates the stop method for the client
|
||||||
|
func createStopMethod(client *netbird.Client) js.Func {
|
||||||
|
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), clientStopTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := client.Stop(ctx); err != nil {
|
||||||
|
log.Printf("Error stopping client: %v", err)
|
||||||
|
reject.Invoke(js.ValueOf(err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("NetBird client stopped")
|
||||||
|
resolve.Invoke(js.ValueOf(true))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// createSSHMethod creates the SSH connection method
|
||||||
|
func createSSHMethod(client *netbird.Client) js.Func {
|
||||||
|
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
|
if len(args) < 2 {
|
||||||
|
return js.ValueOf("error: requires host and port")
|
||||||
|
}
|
||||||
|
|
||||||
|
host := args[0].String()
|
||||||
|
port := args[1].Int()
|
||||||
|
username := "root"
|
||||||
|
if len(args) > 2 && args[2].String() != "" {
|
||||||
|
username = args[2].String()
|
||||||
|
}
|
||||||
|
|
||||||
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
|
sshClient := ssh.NewClient(client)
|
||||||
|
|
||||||
|
if err := sshClient.Connect(host, port, username); err != nil {
|
||||||
|
reject.Invoke(err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sshClient.StartSession(80, 24); err != nil {
|
||||||
|
if closeErr := sshClient.Close(); closeErr != nil {
|
||||||
|
log.Printf("Error closing SSH client: %v", closeErr)
|
||||||
|
}
|
||||||
|
reject.Invoke(err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
jsInterface := ssh.CreateJSInterface(sshClient)
|
||||||
|
resolve.Invoke(jsInterface)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// createProxyRequestMethod creates the proxyRequest method
|
||||||
|
func createProxyRequestMethod(client *netbird.Client) js.Func {
|
||||||
|
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
|
if len(args) < 1 {
|
||||||
|
return js.ValueOf("error: request details required")
|
||||||
|
}
|
||||||
|
|
||||||
|
request := args[0]
|
||||||
|
|
||||||
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
|
response, err := http.ProxyRequest(client, request)
|
||||||
|
if err != nil {
|
||||||
|
reject.Invoke(err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resolve.Invoke(response)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// createRDPProxyMethod creates the RDP proxy method
|
||||||
|
func createRDPProxyMethod(client *netbird.Client) js.Func {
|
||||||
|
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||||
|
if len(args) < 2 {
|
||||||
|
return js.ValueOf("error: hostname and port required")
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := rdp.NewRDCleanPathProxy(client)
|
||||||
|
return proxy.CreateProxy(args[0].String(), args[1].String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// createPromise is a helper to create JavaScript promises
|
||||||
|
func createPromise(handler func(resolve, reject js.Value)) js.Value {
|
||||||
|
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||||
|
resolve := promiseArgs[0]
|
||||||
|
reject := promiseArgs[1]
|
||||||
|
|
||||||
|
go handler(resolve, reject)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// createClientObject wraps the NetBird client in a JavaScript object
|
||||||
|
func createClientObject(client *netbird.Client) js.Value {
|
||||||
|
obj := make(map[string]interface{})
|
||||||
|
|
||||||
|
obj["start"] = createStartMethod(client)
|
||||||
|
obj["stop"] = createStopMethod(client)
|
||||||
|
obj["createSSHConnection"] = createSSHMethod(client)
|
||||||
|
obj["proxyRequest"] = createProxyRequestMethod(client)
|
||||||
|
obj["createRDPProxy"] = createRDPProxyMethod(client)
|
||||||
|
|
||||||
|
return js.ValueOf(obj)
|
||||||
|
}
|
||||||
|
|
||||||
|
// netBirdClientConstructor acts as a JavaScript constructor function
|
||||||
|
func netBirdClientConstructor(this js.Value, args []js.Value) any {
|
||||||
|
return js.Global().Get("Promise").New(js.FuncOf(func(this js.Value, promiseArgs []js.Value) any {
|
||||||
|
resolve := promiseArgs[0]
|
||||||
|
reject := promiseArgs[1]
|
||||||
|
|
||||||
|
if len(args) < 1 {
|
||||||
|
reject.Invoke(js.ValueOf("Options object required"))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
options, err := parseClientOptions(args[0])
|
||||||
|
if err != nil {
|
||||||
|
reject.Invoke(js.ValueOf(err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Creating NetBird client with options: deviceName=%s, hasJWT=%v, hasSetupKey=%v, mgmtURL=%s",
|
||||||
|
options.DeviceName, options.JWTToken != "", options.SetupKey != "", options.ManagementURL)
|
||||||
|
|
||||||
|
client, err := netbird.New(options)
|
||||||
|
if err != nil {
|
||||||
|
reject.Invoke(js.ValueOf(fmt.Sprintf("create client: %v", err)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
clientObj := createClientObject(client)
|
||||||
|
log.Println("NetBird client created successfully")
|
||||||
|
resolve.Invoke(clientObj)
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
}
|
||||||
98
client/wasm/internal/http/http.go
Normal file
98
client/wasm/internal/http/http.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package http
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"syscall/js"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
netbird "github.com/netbirdio/netbird/client/embed"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
httpTimeout = 30 * time.Second
|
||||||
|
maxResponseSize = 1024 * 1024 // 1MB
|
||||||
|
)
|
||||||
|
|
||||||
|
// performRequest executes an HTTP request through NetBird and returns the response and body
|
||||||
|
func performRequest(nbClient *netbird.Client, method, url string, headers map[string]string, body []byte) (*http.Response, []byte, error) {
|
||||||
|
httpClient := nbClient.NewHTTPClient()
|
||||||
|
httpClient.Timeout = httpTimeout
|
||||||
|
|
||||||
|
req, err := http.NewRequest(method, url, strings.NewReader(string(body)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value := range headers {
|
||||||
|
req.Header.Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := resp.Body.Close(); err != nil {
|
||||||
|
log.Printf("failed to close response body: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, respBody, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProxyRequest performs a proxied HTTP request through NetBird and returns a JavaScript object
|
||||||
|
func ProxyRequest(nbClient *netbird.Client, request js.Value) (js.Value, error) {
|
||||||
|
url := request.Get("url").String()
|
||||||
|
if url == "" {
|
||||||
|
return js.Undefined(), fmt.Errorf("URL is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
method := "GET"
|
||||||
|
if methodVal := request.Get("method"); !methodVal.IsNull() && !methodVal.IsUndefined() {
|
||||||
|
method = strings.ToUpper(methodVal.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var requestBody []byte
|
||||||
|
if bodyVal := request.Get("body"); !bodyVal.IsNull() && !bodyVal.IsUndefined() {
|
||||||
|
requestBody = []byte(bodyVal.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
requestHeaders := make(map[string]string)
|
||||||
|
if headersVal := request.Get("headers"); !headersVal.IsNull() && !headersVal.IsUndefined() && headersVal.Type() == js.TypeObject {
|
||||||
|
headerKeys := js.Global().Get("Object").Call("keys", headersVal)
|
||||||
|
for i := 0; i < headerKeys.Length(); i++ {
|
||||||
|
key := headerKeys.Index(i).String()
|
||||||
|
value := headersVal.Get(key).String()
|
||||||
|
requestHeaders[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, body, err := performRequest(nbClient, method, url, requestHeaders, requestBody)
|
||||||
|
if err != nil {
|
||||||
|
return js.Undefined(), err
|
||||||
|
}
|
||||||
|
|
||||||
|
result := js.Global().Get("Object").New()
|
||||||
|
result.Set("status", resp.StatusCode)
|
||||||
|
result.Set("statusText", resp.Status)
|
||||||
|
result.Set("body", string(body))
|
||||||
|
|
||||||
|
headers := js.Global().Get("Object").New()
|
||||||
|
for key, values := range resp.Header {
|
||||||
|
if len(values) > 0 {
|
||||||
|
headers.Set(strings.ToLower(key), values[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.Set("headers", headers)
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
94
client/wasm/internal/rdp/cert_validation.go
Normal file
94
client/wasm/internal/rdp/cert_validation.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package rdp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"syscall/js"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
certValidationTimeout = 60 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
|
||||||
|
if !conn.wsHandlers.Get("onCertificateRequest").Truthy() {
|
||||||
|
return false, fmt.Errorf("certificate validation handler not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
certInfo := js.Global().Get("Object").New()
|
||||||
|
certInfo.Set("ServerAddr", conn.destination)
|
||||||
|
|
||||||
|
certArray := js.Global().Get("Array").New()
|
||||||
|
for i, certBytes := range certChain {
|
||||||
|
uint8Array := js.Global().Get("Uint8Array").New(len(certBytes))
|
||||||
|
js.CopyBytesToJS(uint8Array, certBytes)
|
||||||
|
certArray.SetIndex(i, uint8Array)
|
||||||
|
}
|
||||||
|
certInfo.Set("ServerCertChain", certArray)
|
||||||
|
if len(certChain) > 0 {
|
||||||
|
cert, err := x509.ParseCertificate(certChain[0])
|
||||||
|
if err == nil {
|
||||||
|
info := js.Global().Get("Object").New()
|
||||||
|
info.Set("subject", cert.Subject.String())
|
||||||
|
info.Set("issuer", cert.Issuer.String())
|
||||||
|
info.Set("validFrom", cert.NotBefore.Format(time.RFC3339))
|
||||||
|
info.Set("validTo", cert.NotAfter.Format(time.RFC3339))
|
||||||
|
info.Set("serialNumber", cert.SerialNumber.String())
|
||||||
|
certInfo.Set("CertificateInfo", info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
|
||||||
|
|
||||||
|
resultChan := make(chan bool)
|
||||||
|
errorChan := make(chan error)
|
||||||
|
|
||||||
|
promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||||
|
result := args[0].Bool()
|
||||||
|
resultChan <- result
|
||||||
|
return nil
|
||||||
|
})).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||||
|
errorChan <- fmt.Errorf("certificate validation failed")
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
select {
|
||||||
|
case result := <-resultChan:
|
||||||
|
if result {
|
||||||
|
log.Info("Certificate accepted by user")
|
||||||
|
} else {
|
||||||
|
log.Info("Certificate rejected by user")
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
case err := <-errorChan:
|
||||||
|
return false, err
|
||||||
|
case <-time.After(certValidationTimeout):
|
||||||
|
return false, fmt.Errorf("certificate validation timeout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config {
|
||||||
|
return &tls.Config{
|
||||||
|
InsecureSkipVerify: true, // We'll validate manually after handshake
|
||||||
|
VerifyConnection: func(cs tls.ConnectionState) error {
|
||||||
|
var certChain [][]byte
|
||||||
|
for _, cert := range cs.PeerCertificates {
|
||||||
|
certChain = append(certChain, cert.Raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
accepted, err := p.validateCertificateWithJS(conn, certChain)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !accepted {
|
||||||
|
return fmt.Errorf("certificate rejected by user")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
269
client/wasm/internal/rdp/rdcleanpath.go
Normal file
269
client/wasm/internal/rdp/rdcleanpath.go
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
package rdp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/asn1"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"syscall/js"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
RDCleanPathVersion = 3390
|
||||||
|
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
|
||||||
|
RDCleanPathProxyScheme = "ws"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RDCleanPathPDU struct {
|
||||||
|
Version int64 `asn1:"tag:0,explicit"`
|
||||||
|
Error []byte `asn1:"tag:1,explicit,optional"`
|
||||||
|
Destination string `asn1:"utf8,tag:2,explicit,optional"`
|
||||||
|
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
|
||||||
|
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
|
||||||
|
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
|
||||||
|
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
|
||||||
|
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
|
||||||
|
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RDCleanPathProxy struct {
|
||||||
|
nbClient interface {
|
||||||
|
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
|
}
|
||||||
|
activeConnections map[string]*proxyConnection
|
||||||
|
destinations map[string]string
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type proxyConnection struct {
|
||||||
|
id string
|
||||||
|
destination string
|
||||||
|
rdpConn net.Conn
|
||||||
|
tlsConn *tls.Conn
|
||||||
|
wsHandlers js.Value
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRDCleanPathProxy creates a new RDCleanPath proxy
|
||||||
|
func NewRDCleanPathProxy(client interface {
|
||||||
|
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
|
}) *RDCleanPathProxy {
|
||||||
|
return &RDCleanPathProxy{
|
||||||
|
nbClient: client,
|
||||||
|
activeConnections: make(map[string]*proxyConnection),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateProxy creates a new proxy endpoint for the given destination
|
||||||
|
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
||||||
|
destination := fmt.Sprintf("%s:%s", hostname, port)
|
||||||
|
|
||||||
|
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||||
|
resolve := args[0]
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections))
|
||||||
|
|
||||||
|
p.mu.Lock()
|
||||||
|
if p.destinations == nil {
|
||||||
|
p.destinations = make(map[string]string)
|
||||||
|
}
|
||||||
|
p.destinations[proxyID] = destination
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
|
||||||
|
|
||||||
|
// Register the WebSocket handler for this specific proxy
|
||||||
|
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||||
|
if len(args) < 1 {
|
||||||
|
return js.ValueOf("error: requires WebSocket argument")
|
||||||
|
}
|
||||||
|
|
||||||
|
ws := args[0]
|
||||||
|
p.HandleWebSocketConnection(ws, proxyID)
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
|
||||||
|
resolve.Invoke(proxyURL)
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleWebSocketConnection handles incoming WebSocket connections from IronRDP
|
||||||
|
func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string) {
|
||||||
|
p.mu.Lock()
|
||||||
|
destination := p.destinations[proxyID]
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
if destination == "" {
|
||||||
|
log.Errorf("No destination found for proxy ID: %s", proxyID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
// Don't defer cancel here - it will be called by cleanupConnection
|
||||||
|
|
||||||
|
conn := &proxyConnection{
|
||||||
|
id: proxyID,
|
||||||
|
destination: destination,
|
||||||
|
wsHandlers: ws,
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
|
||||||
|
p.mu.Lock()
|
||||||
|
p.activeConnections[proxyID] = conn
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
p.setupWebSocketHandlers(ws, conn)
|
||||||
|
|
||||||
|
log.Infof("RDCleanPath proxy WebSocket connection established for %s", proxyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
|
||||||
|
ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
|
if len(args) < 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
data := args[0]
|
||||||
|
go p.handleWebSocketMessage(conn, data)
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||||
|
log.Debug("WebSocket closed by JavaScript")
|
||||||
|
conn.cancel()
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
|
||||||
|
if !data.InstanceOf(js.Global().Get("Uint8Array")) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
length := data.Get("length").Int()
|
||||||
|
bytes := make([]byte, length)
|
||||||
|
js.CopyBytesToGo(bytes, data)
|
||||||
|
|
||||||
|
if conn.rdpConn != nil || conn.tlsConn != nil {
|
||||||
|
p.forwardToRDP(conn, bytes)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var pdu RDCleanPathPDU
|
||||||
|
_, err := asn1.Unmarshal(bytes, &pdu)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to parse RDCleanPath PDU: %v", err)
|
||||||
|
n := len(bytes)
|
||||||
|
if n > 20 {
|
||||||
|
n = 20
|
||||||
|
}
|
||||||
|
log.Warnf("First %d bytes: %x", n, bytes[:n])
|
||||||
|
|
||||||
|
if len(bytes) > 0 && bytes[0] == 0x03 {
|
||||||
|
log.Debug("Received raw RDP packet instead of RDCleanPath PDU")
|
||||||
|
go p.handleDirectRDP(conn, bytes)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go p.processRDCleanPathPDU(conn, pdu)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RDCleanPathProxy) forwardToRDP(conn *proxyConnection, bytes []byte) {
|
||||||
|
var writer io.Writer
|
||||||
|
var connType string
|
||||||
|
|
||||||
|
if conn.tlsConn != nil {
|
||||||
|
writer = conn.tlsConn
|
||||||
|
connType = "TLS"
|
||||||
|
} else if conn.rdpConn != nil {
|
||||||
|
writer = conn.rdpConn
|
||||||
|
connType = "TCP"
|
||||||
|
} else {
|
||||||
|
log.Error("No RDP connection available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := writer.Write(bytes); err != nil {
|
||||||
|
log.Errorf("Failed to write to %s: %v", connType, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []byte) {
|
||||||
|
defer p.cleanupConnection(conn)
|
||||||
|
|
||||||
|
destination := conn.destination
|
||||||
|
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
|
||||||
|
|
||||||
|
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn.rdpConn = rdpConn
|
||||||
|
|
||||||
|
_, err = rdpConn.Write(firstPacket)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to write first packet: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response := make([]byte, 1024)
|
||||||
|
n, err := rdpConn.Read(response)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to read X.224 response: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.sendToWebSocket(conn, response[:n])
|
||||||
|
|
||||||
|
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
|
||||||
|
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
|
||||||
|
log.Debugf("Cleaning up connection %s", conn.id)
|
||||||
|
conn.cancel()
|
||||||
|
if conn.tlsConn != nil {
|
||||||
|
log.Debug("Closing TLS connection")
|
||||||
|
if err := conn.tlsConn.Close(); err != nil {
|
||||||
|
log.Debugf("Error closing TLS connection: %v", err)
|
||||||
|
}
|
||||||
|
conn.tlsConn = nil
|
||||||
|
}
|
||||||
|
if conn.rdpConn != nil {
|
||||||
|
log.Debug("Closing TCP connection")
|
||||||
|
if err := conn.rdpConn.Close(); err != nil {
|
||||||
|
log.Debugf("Error closing TCP connection: %v", err)
|
||||||
|
}
|
||||||
|
conn.rdpConn = nil
|
||||||
|
}
|
||||||
|
p.mu.Lock()
|
||||||
|
delete(p.activeConnections, conn.id)
|
||||||
|
p.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
|
||||||
|
if conn.wsHandlers.Get("receiveFromGo").Truthy() {
|
||||||
|
uint8Array := js.Global().Get("Uint8Array").New(len(data))
|
||||||
|
js.CopyBytesToJS(uint8Array, data)
|
||||||
|
conn.wsHandlers.Call("receiveFromGo", uint8Array.Get("buffer"))
|
||||||
|
} else if conn.wsHandlers.Get("send").Truthy() {
|
||||||
|
uint8Array := js.Global().Get("Uint8Array").New(len(data))
|
||||||
|
js.CopyBytesToJS(uint8Array, data)
|
||||||
|
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
|
||||||
|
}
|
||||||
|
}
|
||||||
249
client/wasm/internal/rdp/rdcleanpath_handlers.go
Normal file
249
client/wasm/internal/rdp/rdcleanpath_handlers.go
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
package rdp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/asn1"
|
||||||
|
"io"
|
||||||
|
"syscall/js"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||||
|
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
|
||||||
|
|
||||||
|
if pdu.Version != RDCleanPathVersion {
|
||||||
|
p.sendRDCleanPathError(conn, "Unsupported version")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
destination := conn.destination
|
||||||
|
if pdu.Destination != "" {
|
||||||
|
destination = pdu.Destination
|
||||||
|
}
|
||||||
|
|
||||||
|
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||||
|
p.sendRDCleanPathError(conn, "Connection failed")
|
||||||
|
p.cleanupConnection(conn)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn.rdpConn = rdpConn
|
||||||
|
|
||||||
|
// RDP always starts with X.224 negotiation, then determines if TLS is needed
|
||||||
|
// Modern RDP (since Windows Vista/2008) typically requires TLS
|
||||||
|
// The X.224 Connection Confirm response will indicate if TLS is required
|
||||||
|
// For now, we'll attempt TLS for all connections as it's the modern default
|
||||||
|
p.setupTLSConnection(conn, pdu)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||||
|
var x224Response []byte
|
||||||
|
if len(pdu.X224ConnectionPDU) > 0 {
|
||||||
|
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
|
||||||
|
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to write X.224 PDU: %v", err)
|
||||||
|
p.sendRDCleanPathError(conn, "Failed to forward X.224")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response := make([]byte, 1024)
|
||||||
|
n, err := conn.rdpConn.Read(response)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to read X.224 response: %v", err)
|
||||||
|
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
x224Response = response[:n]
|
||||||
|
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := p.getTLSConfigWithValidation(conn)
|
||||||
|
|
||||||
|
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
|
||||||
|
conn.tlsConn = tlsConn
|
||||||
|
|
||||||
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
|
log.Errorf("TLS handshake failed: %v", err)
|
||||||
|
p.sendRDCleanPathError(conn, "TLS handshake failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("TLS handshake successful")
|
||||||
|
|
||||||
|
// Certificate validation happens during handshake via VerifyConnection callback
|
||||||
|
var certChain [][]byte
|
||||||
|
connState := tlsConn.ConnectionState()
|
||||||
|
if len(connState.PeerCertificates) > 0 {
|
||||||
|
for _, cert := range connState.PeerCertificates {
|
||||||
|
certChain = append(certChain, cert.Raw)
|
||||||
|
}
|
||||||
|
log.Debugf("Extracted %d certificates from TLS connection", len(certChain))
|
||||||
|
}
|
||||||
|
|
||||||
|
responsePDU := RDCleanPathPDU{
|
||||||
|
Version: RDCleanPathVersion,
|
||||||
|
ServerAddr: conn.destination,
|
||||||
|
ServerCertChain: certChain,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(x224Response) > 0 {
|
||||||
|
responsePDU.X224ConnectionPDU = x224Response
|
||||||
|
}
|
||||||
|
|
||||||
|
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||||
|
|
||||||
|
log.Debug("Starting TLS forwarding")
|
||||||
|
go p.forwardConnToWS(conn, conn.tlsConn, "TLS")
|
||||||
|
go p.forwardWSToConn(conn, conn.tlsConn, "TLS")
|
||||||
|
|
||||||
|
<-conn.ctx.Done()
|
||||||
|
log.Debug("TLS connection context done, cleaning up")
|
||||||
|
p.cleanupConnection(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||||
|
if len(pdu.X224ConnectionPDU) > 0 {
|
||||||
|
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
|
||||||
|
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to write X.224 PDU: %v", err)
|
||||||
|
p.sendRDCleanPathError(conn, "Failed to forward X.224")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response := make([]byte, 1024)
|
||||||
|
n, err := conn.rdpConn.Read(response)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to read X.224 response: %v", err)
|
||||||
|
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
responsePDU := RDCleanPathPDU{
|
||||||
|
Version: RDCleanPathVersion,
|
||||||
|
X224ConnectionPDU: response[:n],
|
||||||
|
ServerAddr: conn.destination,
|
||||||
|
}
|
||||||
|
|
||||||
|
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||||
|
} else {
|
||||||
|
responsePDU := RDCleanPathPDU{
|
||||||
|
Version: RDCleanPathVersion,
|
||||||
|
ServerAddr: conn.destination,
|
||||||
|
}
|
||||||
|
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||||
|
}
|
||||||
|
|
||||||
|
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
|
||||||
|
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
|
||||||
|
|
||||||
|
<-conn.ctx.Done()
|
||||||
|
log.Debug("TCP connection context done, cleaning up")
|
||||||
|
p.cleanupConnection(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||||
|
data, err := asn1.Marshal(pdu)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to marshal RDCleanPath PDU: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Sending RDCleanPath PDU response (%d bytes)", len(data))
|
||||||
|
p.sendToWebSocket(conn, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) {
|
||||||
|
pdu := RDCleanPathPDU{
|
||||||
|
Version: RDCleanPathVersion,
|
||||||
|
Error: []byte(errorMsg),
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := asn1.Marshal(pdu)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to marshal error PDU: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.sendToWebSocket(conn, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
|
||||||
|
msgChan := make(chan []byte)
|
||||||
|
errChan := make(chan error)
|
||||||
|
|
||||||
|
handler := js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||||
|
if len(args) < 1 {
|
||||||
|
errChan <- io.EOF
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
data := args[0]
|
||||||
|
if data.InstanceOf(js.Global().Get("Uint8Array")) {
|
||||||
|
length := data.Get("length").Int()
|
||||||
|
bytes := make([]byte, length)
|
||||||
|
js.CopyBytesToGo(bytes, data)
|
||||||
|
msgChan <- bytes
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
defer handler.Release()
|
||||||
|
|
||||||
|
conn.wsHandlers.Set("onceGoMessage", handler)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case msg := <-msgChan:
|
||||||
|
return msg, nil
|
||||||
|
case err := <-errChan:
|
||||||
|
return nil, err
|
||||||
|
case <-conn.ctx.Done():
|
||||||
|
return nil, conn.ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RDCleanPathProxy) forwardWSToConn(conn *proxyConnection, dst io.Writer, connType string) {
|
||||||
|
for {
|
||||||
|
if conn.ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := p.readWebSocketMessage(conn)
|
||||||
|
if err != nil {
|
||||||
|
if err != io.EOF {
|
||||||
|
log.Errorf("Failed to read from WebSocket: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = dst.Write(msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to write to %s: %v", connType, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RDCleanPathProxy) forwardConnToWS(conn *proxyConnection, src io.Reader, connType string) {
|
||||||
|
buffer := make([]byte, 32*1024)
|
||||||
|
|
||||||
|
for {
|
||||||
|
if conn.ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := src.Read(buffer)
|
||||||
|
if err != nil {
|
||||||
|
if err != io.EOF {
|
||||||
|
log.Errorf("Failed to read from %s: %v", connType, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if n > 0 {
|
||||||
|
p.sendToWebSocket(conn, buffer[:n])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
211
client/wasm/internal/ssh/client.go
Normal file
211
client/wasm/internal/ssh/client.go
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
netbird "github.com/netbirdio/netbird/client/embed"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
sshDialTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
func closeWithLog(c io.Closer, resource string) {
|
||||||
|
if c != nil {
|
||||||
|
if err := c.Close(); err != nil {
|
||||||
|
logrus.Debugf("Failed to close %s: %v", resource, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
nbClient *netbird.Client
|
||||||
|
sshClient *ssh.Client
|
||||||
|
session *ssh.Session
|
||||||
|
stdin io.WriteCloser
|
||||||
|
stdout io.Reader
|
||||||
|
stderr io.Reader
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient creates a new SSH client
|
||||||
|
func NewClient(nbClient *netbird.Client) *Client {
|
||||||
|
return &Client{
|
||||||
|
nbClient: nbClient,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect establishes an SSH connection through NetBird network
|
||||||
|
func (c *Client) Connect(host string, port int, username string) error {
|
||||||
|
addr := fmt.Sprintf("%s:%d", host, port)
|
||||||
|
logrus.Infof("SSH: Connecting to %s as %s", addr, username)
|
||||||
|
|
||||||
|
var authMethods []ssh.AuthMethod
|
||||||
|
|
||||||
|
nbConfig, err := c.nbClient.GetConfig()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get NetBird config: %w", err)
|
||||||
|
}
|
||||||
|
if nbConfig.SSHKey == "" {
|
||||||
|
return fmt.Errorf("no NetBird SSH key available - key should be generated during client initialization")
|
||||||
|
}
|
||||||
|
|
||||||
|
signer, err := parseSSHPrivateKey([]byte(nbConfig.SSHKey))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse NetBird SSH private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pubKey := signer.PublicKey()
|
||||||
|
logrus.Infof("SSH: Using NetBird key authentication with public key type: %s", pubKey.Type())
|
||||||
|
|
||||||
|
authMethods = append(authMethods, ssh.PublicKeys(signer))
|
||||||
|
|
||||||
|
config := &ssh.ClientConfig{
|
||||||
|
User: username,
|
||||||
|
Auth: authMethods,
|
||||||
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||||
|
Timeout: sshDialTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := c.nbClient.Dial(ctx, "tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("dial %s: %w", addr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
|
||||||
|
if err != nil {
|
||||||
|
closeWithLog(conn, "connection after handshake error")
|
||||||
|
return fmt.Errorf("SSH handshake: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.sshClient = ssh.NewClient(sshConn, chans, reqs)
|
||||||
|
logrus.Infof("SSH: Connected to %s", addr)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartSession starts an SSH session with PTY
|
||||||
|
func (c *Client) StartSession(cols, rows int) error {
|
||||||
|
if c.sshClient == nil {
|
||||||
|
return fmt.Errorf("SSH client not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := c.sshClient.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.session = session
|
||||||
|
|
||||||
|
modes := ssh.TerminalModes{
|
||||||
|
ssh.ECHO: 1,
|
||||||
|
ssh.TTY_OP_ISPEED: 14400,
|
||||||
|
ssh.TTY_OP_OSPEED: 14400,
|
||||||
|
ssh.VINTR: 3,
|
||||||
|
ssh.VQUIT: 28,
|
||||||
|
ssh.VERASE: 127,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil {
|
||||||
|
closeWithLog(session, "session after PTY error")
|
||||||
|
return fmt.Errorf("PTY request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.stdin, err = session.StdinPipe()
|
||||||
|
if err != nil {
|
||||||
|
closeWithLog(session, "session after stdin error")
|
||||||
|
return fmt.Errorf("get stdin: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.stdout, err = session.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
closeWithLog(session, "session after stdout error")
|
||||||
|
return fmt.Errorf("get stdout: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.stderr, err = session.StderrPipe()
|
||||||
|
if err != nil {
|
||||||
|
closeWithLog(session, "session after stderr error")
|
||||||
|
return fmt.Errorf("get stderr: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := session.Shell(); err != nil {
|
||||||
|
closeWithLog(session, "session after shell error")
|
||||||
|
return fmt.Errorf("start shell: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.Info("SSH: Session started with PTY")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write sends data to the SSH session
|
||||||
|
func (c *Client) Write(data []byte) (int, error) {
|
||||||
|
c.mu.RLock()
|
||||||
|
stdin := c.stdin
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
if stdin == nil {
|
||||||
|
return 0, fmt.Errorf("SSH session not started")
|
||||||
|
}
|
||||||
|
return stdin.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read reads data from the SSH session
|
||||||
|
func (c *Client) Read(buffer []byte) (int, error) {
|
||||||
|
c.mu.RLock()
|
||||||
|
stdout := c.stdout
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
if stdout == nil {
|
||||||
|
return 0, fmt.Errorf("SSH session not started")
|
||||||
|
}
|
||||||
|
return stdout.Read(buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resize updates the terminal size
|
||||||
|
func (c *Client) Resize(cols, rows int) error {
|
||||||
|
c.mu.RLock()
|
||||||
|
session := c.session
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
if session == nil {
|
||||||
|
return fmt.Errorf("SSH session not started")
|
||||||
|
}
|
||||||
|
return session.WindowChange(rows, cols)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the SSH connection
|
||||||
|
func (c *Client) Close() error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if c.session != nil {
|
||||||
|
closeWithLog(c.session, "SSH session")
|
||||||
|
c.session = nil
|
||||||
|
}
|
||||||
|
if c.stdin != nil {
|
||||||
|
closeWithLog(c.stdin, "stdin")
|
||||||
|
c.stdin = nil
|
||||||
|
}
|
||||||
|
c.stdout = nil
|
||||||
|
c.stderr = nil
|
||||||
|
|
||||||
|
if c.sshClient != nil {
|
||||||
|
err := c.sshClient.Close()
|
||||||
|
c.sshClient = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
76
client/wasm/internal/ssh/handlers.go
Normal file
76
client/wasm/internal/ssh/handlers.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"syscall/js"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateJSInterface creates a JavaScript interface for the SSH client
|
||||||
|
func CreateJSInterface(client *Client) js.Value {
|
||||||
|
jsInterface := js.Global().Get("Object").Call("create", js.Null())
|
||||||
|
|
||||||
|
jsInterface.Set("write", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
|
if len(args) < 1 {
|
||||||
|
return js.ValueOf(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
data := args[0]
|
||||||
|
var bytes []byte
|
||||||
|
|
||||||
|
if data.Type() == js.TypeString {
|
||||||
|
bytes = []byte(data.String())
|
||||||
|
} else {
|
||||||
|
uint8Array := js.Global().Get("Uint8Array").New(data)
|
||||||
|
length := uint8Array.Get("length").Int()
|
||||||
|
bytes = make([]byte, length)
|
||||||
|
js.CopyBytesToGo(bytes, uint8Array)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := client.Write(bytes)
|
||||||
|
return js.ValueOf(err == nil)
|
||||||
|
}))
|
||||||
|
|
||||||
|
jsInterface.Set("resize", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
|
if len(args) < 2 {
|
||||||
|
return js.ValueOf(false)
|
||||||
|
}
|
||||||
|
cols := args[0].Int()
|
||||||
|
rows := args[1].Int()
|
||||||
|
err := client.Resize(cols, rows)
|
||||||
|
return js.ValueOf(err == nil)
|
||||||
|
}))
|
||||||
|
|
||||||
|
jsInterface.Set("close", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
|
client.Close()
|
||||||
|
return js.Undefined()
|
||||||
|
}))
|
||||||
|
|
||||||
|
go readLoop(client, jsInterface)
|
||||||
|
|
||||||
|
return jsInterface
|
||||||
|
}
|
||||||
|
|
||||||
|
func readLoop(client *Client, jsInterface js.Value) {
|
||||||
|
buffer := make([]byte, 4096)
|
||||||
|
for {
|
||||||
|
n, err := client.Read(buffer)
|
||||||
|
if err != nil {
|
||||||
|
if err != io.EOF {
|
||||||
|
logrus.Debugf("SSH read error: %v", err)
|
||||||
|
}
|
||||||
|
if onclose := jsInterface.Get("onclose"); !onclose.IsUndefined() {
|
||||||
|
onclose.Invoke()
|
||||||
|
}
|
||||||
|
client.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if ondata := jsInterface.Get("ondata"); !ondata.IsUndefined() {
|
||||||
|
uint8Array := js.Global().Get("Uint8Array").New(n)
|
||||||
|
js.CopyBytesToJS(uint8Array, buffer[:n])
|
||||||
|
ondata.Invoke(uint8Array)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
48
client/wasm/internal/ssh/key.go
Normal file
48
client/wasm/internal/ssh/key.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// parseSSHPrivateKey parses a private key in either SSH or PKCS8 format
|
||||||
|
func parseSSHPrivateKey(keyPEM []byte) (ssh.Signer, error) {
|
||||||
|
keyStr := string(keyPEM)
|
||||||
|
if !strings.Contains(keyStr, "-----BEGIN") {
|
||||||
|
keyPEM = []byte("-----BEGIN PRIVATE KEY-----\n" + keyStr + "\n-----END PRIVATE KEY-----")
|
||||||
|
}
|
||||||
|
|
||||||
|
signer, err := ssh.ParsePrivateKey(keyPEM)
|
||||||
|
if err == nil {
|
||||||
|
return signer, nil
|
||||||
|
}
|
||||||
|
logrus.Debugf("SSH: Failed to parse as SSH format: %v", err)
|
||||||
|
|
||||||
|
block, _ := pem.Decode(keyPEM)
|
||||||
|
if block == nil {
|
||||||
|
keyPreview := string(keyPEM)
|
||||||
|
if len(keyPreview) > 100 {
|
||||||
|
keyPreview = keyPreview[:100]
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("decode PEM block from key: %s", keyPreview)
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Debugf("SSH: Failed to parse as PKCS8: %v", err)
|
||||||
|
if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
|
||||||
|
return ssh.NewSignerFromKey(rsaKey)
|
||||||
|
}
|
||||||
|
if ecKey, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
|
||||||
|
return ssh.NewSignerFromKey(ecKey)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("parse private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ssh.NewSignerFromKey(key)
|
||||||
|
}
|
||||||
@@ -38,7 +38,8 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
|
|||||||
return nil, fmt.Errorf("parsing url: %w", err)
|
return nil, fmt.Errorf("parsing url: %w", err)
|
||||||
}
|
}
|
||||||
var opts []grpc.DialOption
|
var opts []grpc.DialOption
|
||||||
if parsedURL.Scheme == "https" {
|
tlsEnabled := parsedURL.Scheme == "https"
|
||||||
|
if tlsEnabled {
|
||||||
certPool, err := x509.SystemCertPool()
|
certPool, err := x509.SystemCertPool()
|
||||||
if err != nil || certPool == nil {
|
if err != nil || certPool == nil {
|
||||||
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
||||||
@@ -53,7 +54,7 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
|
|||||||
}
|
}
|
||||||
|
|
||||||
opts = append(opts,
|
opts = append(opts,
|
||||||
nbgrpc.WithCustomDialer(),
|
nbgrpc.WithCustomDialer(tlsEnabled),
|
||||||
grpc.WithIdleTimeout(interval*2),
|
grpc.WithIdleTimeout(interval*2),
|
||||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||||
Time: 30 * time.Second,
|
Time: 30 * time.Second,
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -38,7 +38,7 @@ require (
|
|||||||
github.com/c-robinson/iplib v1.0.3
|
github.com/c-robinson/iplib v1.0.3
|
||||||
github.com/caddyserver/certmagic v0.21.3
|
github.com/caddyserver/certmagic v0.21.3
|
||||||
github.com/cilium/ebpf v0.15.0
|
github.com/cilium/ebpf v0.15.0
|
||||||
github.com/coder/websocket v1.8.12
|
github.com/coder/websocket v1.8.13
|
||||||
github.com/coreos/go-iptables v0.7.0
|
github.com/coreos/go-iptables v0.7.0
|
||||||
github.com/creack/pty v1.1.18
|
github.com/creack/pty v1.1.18
|
||||||
github.com/eko/gocache/lib/v4 v4.2.0
|
github.com/eko/gocache/lib/v4 v4.2.0
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -140,8 +140,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk
|
|||||||
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
||||||
github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
|
github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
|
||||||
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
|
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
|
||||||
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
|
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
|
||||||
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
|
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
|
||||||
github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFquP7hsFII=
|
github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFquP7hsFII=
|
||||||
github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0=
|
github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0=
|
||||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||||
|
|||||||
@@ -73,7 +73,12 @@ func (l *Listener) Shutdown(ctx context.Context) error {
|
|||||||
|
|
||||||
func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
|
func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
|
||||||
connRemoteAddr := remoteAddr(r)
|
connRemoteAddr := remoteAddr(r)
|
||||||
wsConn, err := websocket.Accept(w, r, nil)
|
|
||||||
|
acceptOptions := &websocket.AcceptOptions{
|
||||||
|
OriginPatterns: []string{"*"},
|
||||||
|
}
|
||||||
|
|
||||||
|
wsConn, err := websocket.Accept(w, r, acceptOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to accept ws connection from %s: %s", connRemoteAddr, err)
|
log.Errorf("failed to accept ws connection from %s: %s", connRemoteAddr, err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -223,10 +223,10 @@ func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, erro
|
|||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
return nil, fmt.Errorf("relay connection is not established")
|
return nil, fmt.Errorf("relay connection is not established")
|
||||||
}
|
}
|
||||||
_, ok := c.conns[peerID]
|
existingContainer, ok := c.conns[peerID]
|
||||||
if ok {
|
if ok {
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
return nil, ErrConnAlreadyExists
|
return existingContainer.conn, nil
|
||||||
}
|
}
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
@@ -235,7 +235,6 @@ func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, erro
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID)
|
|
||||||
msgChannel := make(chan Msg, 100)
|
msgChannel := make(chan Msg, 100)
|
||||||
|
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
@@ -249,11 +248,11 @@ func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, erro
|
|||||||
c.muInstanceURL.Unlock()
|
c.muInstanceURL.Unlock()
|
||||||
conn := NewConn(c, peerID, msgChannel, instanceURL)
|
conn := NewConn(c, peerID, msgChannel, instanceURL)
|
||||||
|
|
||||||
_, ok = c.conns[peerID]
|
existingContainer, ok = c.conns[peerID]
|
||||||
if ok {
|
if ok {
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
_ = conn.Close()
|
_ = conn.Close()
|
||||||
return nil, ErrConnAlreadyExists
|
return existingContainer.conn, nil
|
||||||
}
|
}
|
||||||
c.conns[peerID] = newConnContainer(c.log, conn, msgChannel)
|
c.conns[peerID] = newConnContainer(c.log, conn, msgChannel)
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
@@ -377,7 +376,6 @@ func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internal
|
|||||||
buf := *bufPtr
|
buf := *bufPtr
|
||||||
n, errExit = relayConn.Read(buf)
|
n, errExit = relayConn.Read(buf)
|
||||||
if errExit != nil {
|
if errExit != nil {
|
||||||
c.log.Infof("start to Relay read loop exit")
|
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
||||||
c.log.Errorf("failed to read message from relay server: %s", errExit)
|
c.log.Errorf("failed to read message from relay server: %s", errExit)
|
||||||
@@ -468,12 +466,24 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe
|
|||||||
c.bufPool.Put(bufPtr)
|
c.bufPool.Put(bufPtr)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
container, ok := c.conns[*peerID]
|
container, ok := c.conns[*peerID]
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
c.log.Errorf("peer not found: %s", peerID.String())
|
// Try to create a connection for this peer to handle incoming messages
|
||||||
c.bufPool.Put(bufPtr)
|
msgChannel := make(chan Msg, 100)
|
||||||
return true
|
c.muInstanceURL.Lock()
|
||||||
|
instanceURL := c.instanceURL
|
||||||
|
c.muInstanceURL.Unlock()
|
||||||
|
conn := NewConn(c, *peerID, msgChannel, instanceURL)
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
// Check again if connection was created while we were creating it
|
||||||
|
if _, exists := c.conns[*peerID]; !exists {
|
||||||
|
c.conns[*peerID] = newConnContainer(c.log, conn, msgChannel)
|
||||||
|
}
|
||||||
|
container = c.conns[*peerID]
|
||||||
|
c.mu.Unlock()
|
||||||
}
|
}
|
||||||
msg := Msg{
|
msg := Msg{
|
||||||
bufPool: c.bufPool,
|
bufPool: c.bufPool,
|
||||||
|
|||||||
@@ -38,8 +38,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||||
err = c.Conn.Write(c.ctx, websocket.MessageBinary, b)
|
return 0, c.Conn.Write(c.ctx, websocket.MessageBinary, b)
|
||||||
return 0, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) RemoteAddr() net.Addr {
|
func (c *Conn) RemoteAddr() net.Addr {
|
||||||
|
|||||||
11
shared/relay/client/dialer/ws/dialopts_generic.go
Normal file
11
shared/relay/client/dialer/ws/dialopts_generic.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import "github.com/coder/websocket"
|
||||||
|
|
||||||
|
func createDialOptions() *websocket.DialOptions {
|
||||||
|
return &websocket.DialOptions{
|
||||||
|
HTTPClient: httpClientNbDialer(),
|
||||||
|
}
|
||||||
|
}
|
||||||
10
shared/relay/client/dialer/ws/dialopts_js.go
Normal file
10
shared/relay/client/dialer/ws/dialopts_js.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
//go:build js
|
||||||
|
|
||||||
|
package ws
|
||||||
|
|
||||||
|
import "github.com/coder/websocket"
|
||||||
|
|
||||||
|
func createDialOptions() *websocket.DialOptions {
|
||||||
|
// WASM version doesn't support HTTPClient
|
||||||
|
return &websocket.DialOptions{}
|
||||||
|
}
|
||||||
@@ -32,9 +32,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
opts := &websocket.DialOptions{
|
opts := createDialOptions()
|
||||||
HTTPClient: httpClientNbDialer(),
|
|
||||||
}
|
|
||||||
|
|
||||||
parsedURL, err := url.Parse(wsURL)
|
parsedURL, err := url.Parse(wsURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -4,15 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"os/user"
|
|
||||||
"runtime"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
@@ -21,35 +14,9 @@ import (
|
|||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func WithCustomDialer() grpc.DialOption {
|
// Backoff returns a backoff configuration for gRPC calls
|
||||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
|
||||||
if runtime.GOOS == "linux" {
|
|
||||||
currentUser, err := user.Current()
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// the custom dialer requires root permissions which are not required for use cases run as non-root
|
|
||||||
if currentUser.Uid != "0" {
|
|
||||||
log.Debug("Not running as root, using standard dialer")
|
|
||||||
dialer := &net.Dialer{}
|
|
||||||
return dialer.DialContext(ctx, "tcp", addr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to dial: %s", err)
|
|
||||||
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
|
||||||
}
|
|
||||||
return conn, nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// grpcDialBackoff is the backoff mechanism for the grpc calls
|
|
||||||
func Backoff(ctx context.Context) backoff.BackOff {
|
func Backoff(ctx context.Context) backoff.BackOff {
|
||||||
b := backoff.NewExponentialBackOff()
|
b := backoff.NewExponentialBackOff()
|
||||||
b.MaxElapsedTime = 10 * time.Second
|
b.MaxElapsedTime = 10 * time.Second
|
||||||
@@ -57,6 +24,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
|||||||
return backoff.WithContext(b, ctx)
|
return backoff.WithContext(b, ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateConnection creates a gRPC client connection with the appropriate transport options
|
||||||
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
||||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||||
if tlsEnabled {
|
if tlsEnabled {
|
||||||
@@ -78,7 +46,7 @@ func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
|||||||
connCtx,
|
connCtx,
|
||||||
addr,
|
addr,
|
||||||
transportOption,
|
transportOption,
|
||||||
WithCustomDialer(),
|
WithCustomDialer(tlsEnabled),
|
||||||
grpc.WithBlock(),
|
grpc.WithBlock(),
|
||||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||||
Time: 30 * time.Second,
|
Time: 30 * time.Second,
|
||||||
|
|||||||
44
util/grpc/dialer_generic.go
Normal file
44
util/grpc/dialer_generic.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os/user"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func WithCustomDialer(tlsEnabled bool) grpc.DialOption {
|
||||||
|
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
currentUser, err := user.Current()
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// the custom dialer requires root permissions which are not required for use cases run as non-root
|
||||||
|
if currentUser.Uid != "0" {
|
||||||
|
log.Debug("Not running as root, using standard dialer")
|
||||||
|
dialer := &net.Dialer{}
|
||||||
|
return dialer.DialContext(ctx, "tcp", addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to dial: %s", err)
|
||||||
|
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
12
util/grpc/dialer_js.go
Normal file
12
util/grpc/dialer_js.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package grpc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/util/wsproxy/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments.
|
||||||
|
func WithCustomDialer(tlsEnabled bool) grpc.DialOption {
|
||||||
|
return client.WithWebSocketDialer(tlsEnabled)
|
||||||
|
}
|
||||||
8
util/util_js.go
Normal file
8
util/util_js.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
//go:build js
|
||||||
|
|
||||||
|
package util
|
||||||
|
|
||||||
|
// IsAdmin returns false for WASM as there's no admin concept in browser
|
||||||
|
func IsAdmin() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
171
util/wsproxy/client/dialer_js.go
Normal file
171
util/wsproxy/client/dialer_js.go
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"syscall/js"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/util/wsproxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
const dialTimeout = 30 * time.Second
|
||||||
|
|
||||||
|
// websocketConn wraps a JavaScript WebSocket to implement net.Conn
|
||||||
|
type websocketConn struct {
|
||||||
|
ws js.Value
|
||||||
|
remoteAddr string
|
||||||
|
messages chan []byte
|
||||||
|
readBuf []byte
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketConn) Read(b []byte) (int, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
if len(c.readBuf) > 0 {
|
||||||
|
n := copy(b, c.readBuf)
|
||||||
|
c.readBuf = c.readBuf[n:]
|
||||||
|
c.mu.Unlock()
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case data := <-c.messages:
|
||||||
|
n := copy(b, data)
|
||||||
|
if n < len(data) {
|
||||||
|
c.mu.Lock()
|
||||||
|
c.readBuf = data[n:]
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
return 0, c.ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketConn) Write(b []byte) (int, error) {
|
||||||
|
select {
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
return 0, c.ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8Array := js.Global().Get("Uint8Array").New(len(b))
|
||||||
|
js.CopyBytesToJS(uint8Array, b)
|
||||||
|
c.ws.Call("send", uint8Array)
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketConn) Close() error {
|
||||||
|
c.cancel()
|
||||||
|
c.ws.Call("close")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketConn) LocalAddr() net.Addr {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketConn) RemoteAddr() net.Addr {
|
||||||
|
return stringAddr(c.remoteAddr)
|
||||||
|
}
|
||||||
|
func (c *websocketConn) SetDeadline(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketConn) SetReadDeadline(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// stringAddr is a simple net.Addr that returns a string
|
||||||
|
type stringAddr string
|
||||||
|
|
||||||
|
func (s stringAddr) Network() string { return "tcp" }
|
||||||
|
func (s stringAddr) String() string { return string(s) }
|
||||||
|
|
||||||
|
// WithWebSocketDialer returns a gRPC dial option that uses WebSocket transport for JS/WASM environments.
|
||||||
|
func WithWebSocketDialer(tlsEnabled bool) grpc.DialOption {
|
||||||
|
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||||
|
scheme := "wss"
|
||||||
|
if !tlsEnabled {
|
||||||
|
scheme = "ws"
|
||||||
|
}
|
||||||
|
wsURL := fmt.Sprintf("%s://%s%s", scheme, addr, wsproxy.ProxyPath)
|
||||||
|
|
||||||
|
ws := js.Global().Get("WebSocket").New(wsURL)
|
||||||
|
|
||||||
|
connCtx, connCancel := context.WithCancel(context.Background())
|
||||||
|
conn := &websocketConn{
|
||||||
|
ws: ws,
|
||||||
|
remoteAddr: addr,
|
||||||
|
messages: make(chan []byte, 100),
|
||||||
|
ctx: connCtx,
|
||||||
|
cancel: connCancel,
|
||||||
|
}
|
||||||
|
|
||||||
|
ws.Set("binaryType", "arraybuffer")
|
||||||
|
|
||||||
|
openCh := make(chan struct{})
|
||||||
|
errorCh := make(chan error, 1)
|
||||||
|
|
||||||
|
ws.Set("onopen", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
|
close(openCh)
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
ws.Set("onerror", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
|
select {
|
||||||
|
case errorCh <- wsproxy.ErrConnectionFailed:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
ws.Set("onmessage", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
|
event := args[0]
|
||||||
|
data := event.Get("data")
|
||||||
|
|
||||||
|
uint8Array := js.Global().Get("Uint8Array").New(data)
|
||||||
|
length := uint8Array.Get("length").Int()
|
||||||
|
bytes := make([]byte, length)
|
||||||
|
js.CopyBytesToGo(bytes, uint8Array)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case conn.messages <- bytes:
|
||||||
|
default:
|
||||||
|
log.Warnf("gRPC WebSocket message dropped for %s - buffer full", addr)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
ws.Set("onclose", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
|
conn.cancel()
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-openCh:
|
||||||
|
return conn, nil
|
||||||
|
case err := <-errorCh:
|
||||||
|
return nil, err
|
||||||
|
case <-ctx.Done():
|
||||||
|
ws.Call("close")
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(dialTimeout):
|
||||||
|
ws.Call("close")
|
||||||
|
return nil, wsproxy.ErrConnectionTimeout
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
13
util/wsproxy/constants.go
Normal file
13
util/wsproxy/constants.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package wsproxy
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
// ProxyPath is the standard path where the WebSocket proxy is mounted on servers.
|
||||||
|
const ProxyPath = "/ws-proxy"
|
||||||
|
|
||||||
|
// Common errors
|
||||||
|
var (
|
||||||
|
ErrConnectionTimeout = errors.New("WebSocket connection timeout")
|
||||||
|
ErrConnectionFailed = errors.New("WebSocket connection failed")
|
||||||
|
ErrBackendUnavailable = errors.New("backend unavailable")
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user