mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[client,signal,management] Add browser client support (#4415)
This commit is contained in:
100
client/wasm/internal/http/http.go
Normal file
100
client/wasm/internal/http/http.go
Normal file
@@ -0,0 +1,100 @@
|
||||
//go:build js
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/http"
|
||||
"strings"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
)
|
||||
|
||||
const (
|
||||
httpTimeout = 30 * time.Second
|
||||
maxResponseSize = 1024 * 1024 // 1MB
|
||||
)
|
||||
|
||||
// performRequest executes an HTTP request through NetBird and returns the response and body
|
||||
func performRequest(nbClient *netbird.Client, method, url string, headers map[string]string, body []byte) (*http.Response, []byte, error) {
|
||||
httpClient := nbClient.NewHTTPClient()
|
||||
httpClient.Timeout = httpTimeout
|
||||
|
||||
req, err := http.NewRequest(method, url, strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
for key, value := range headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
log.Errorf("failed to close response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
return resp, respBody, nil
|
||||
}
|
||||
|
||||
// ProxyRequest performs a proxied HTTP request through NetBird and returns a JavaScript object
|
||||
func ProxyRequest(nbClient *netbird.Client, request js.Value) (js.Value, error) {
|
||||
url := request.Get("url").String()
|
||||
if url == "" {
|
||||
return js.Undefined(), fmt.Errorf("URL is required")
|
||||
}
|
||||
|
||||
method := "GET"
|
||||
if methodVal := request.Get("method"); !methodVal.IsNull() && !methodVal.IsUndefined() {
|
||||
method = strings.ToUpper(methodVal.String())
|
||||
}
|
||||
|
||||
var requestBody []byte
|
||||
if bodyVal := request.Get("body"); !bodyVal.IsNull() && !bodyVal.IsUndefined() {
|
||||
requestBody = []byte(bodyVal.String())
|
||||
}
|
||||
|
||||
requestHeaders := make(map[string]string)
|
||||
if headersVal := request.Get("headers"); !headersVal.IsNull() && !headersVal.IsUndefined() && headersVal.Type() == js.TypeObject {
|
||||
headerKeys := js.Global().Get("Object").Call("keys", headersVal)
|
||||
for i := 0; i < headerKeys.Length(); i++ {
|
||||
key := headerKeys.Index(i).String()
|
||||
value := headersVal.Get(key).String()
|
||||
requestHeaders[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
resp, body, err := performRequest(nbClient, method, url, requestHeaders, requestBody)
|
||||
if err != nil {
|
||||
return js.Undefined(), err
|
||||
}
|
||||
|
||||
result := js.Global().Get("Object").New()
|
||||
result.Set("status", resp.StatusCode)
|
||||
result.Set("statusText", resp.Status)
|
||||
result.Set("body", string(body))
|
||||
|
||||
headers := js.Global().Get("Object").New()
|
||||
for key, values := range resp.Header {
|
||||
if len(values) > 0 {
|
||||
headers.Set(strings.ToLower(key), values[0])
|
||||
}
|
||||
}
|
||||
result.Set("headers", headers)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
96
client/wasm/internal/rdp/cert_validation.go
Normal file
96
client/wasm/internal/rdp/cert_validation.go
Normal file
@@ -0,0 +1,96 @@
|
||||
//go:build js
|
||||
|
||||
package rdp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
certValidationTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
|
||||
if !conn.wsHandlers.Get("onCertificateRequest").Truthy() {
|
||||
return false, fmt.Errorf("certificate validation handler not configured")
|
||||
}
|
||||
|
||||
certInfo := js.Global().Get("Object").New()
|
||||
certInfo.Set("ServerAddr", conn.destination)
|
||||
|
||||
certArray := js.Global().Get("Array").New()
|
||||
for i, certBytes := range certChain {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(certBytes))
|
||||
js.CopyBytesToJS(uint8Array, certBytes)
|
||||
certArray.SetIndex(i, uint8Array)
|
||||
}
|
||||
certInfo.Set("ServerCertChain", certArray)
|
||||
if len(certChain) > 0 {
|
||||
cert, err := x509.ParseCertificate(certChain[0])
|
||||
if err == nil {
|
||||
info := js.Global().Get("Object").New()
|
||||
info.Set("subject", cert.Subject.String())
|
||||
info.Set("issuer", cert.Issuer.String())
|
||||
info.Set("validFrom", cert.NotBefore.Format(time.RFC3339))
|
||||
info.Set("validTo", cert.NotAfter.Format(time.RFC3339))
|
||||
info.Set("serialNumber", cert.SerialNumber.String())
|
||||
certInfo.Set("CertificateInfo", info)
|
||||
}
|
||||
}
|
||||
|
||||
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
|
||||
|
||||
resultChan := make(chan bool)
|
||||
errorChan := make(chan error)
|
||||
|
||||
promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
result := args[0].Bool()
|
||||
resultChan <- result
|
||||
return nil
|
||||
})).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
errorChan <- fmt.Errorf("certificate validation failed")
|
||||
return nil
|
||||
}))
|
||||
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
if result {
|
||||
log.Info("Certificate accepted by user")
|
||||
} else {
|
||||
log.Info("Certificate rejected by user")
|
||||
}
|
||||
return result, nil
|
||||
case err := <-errorChan:
|
||||
return false, err
|
||||
case <-time.After(certValidationTimeout):
|
||||
return false, fmt.Errorf("certificate validation timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config {
|
||||
return &tls.Config{
|
||||
InsecureSkipVerify: true, // We'll validate manually after handshake
|
||||
VerifyConnection: func(cs tls.ConnectionState) error {
|
||||
var certChain [][]byte
|
||||
for _, cert := range cs.PeerCertificates {
|
||||
certChain = append(certChain, cert.Raw)
|
||||
}
|
||||
|
||||
accepted, err := p.validateCertificateWithJS(conn, certChain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !accepted {
|
||||
return fmt.Errorf("certificate rejected by user")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
271
client/wasm/internal/rdp/rdcleanpath.go
Normal file
271
client/wasm/internal/rdp/rdcleanpath.go
Normal file
@@ -0,0 +1,271 @@
|
||||
//go:build js
|
||||
|
||||
package rdp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
RDCleanPathVersion = 3390
|
||||
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
|
||||
RDCleanPathProxyScheme = "ws"
|
||||
)
|
||||
|
||||
type RDCleanPathPDU struct {
|
||||
Version int64 `asn1:"tag:0,explicit"`
|
||||
Error []byte `asn1:"tag:1,explicit,optional"`
|
||||
Destination string `asn1:"utf8,tag:2,explicit,optional"`
|
||||
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
|
||||
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
|
||||
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
|
||||
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
|
||||
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
|
||||
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
|
||||
}
|
||||
|
||||
type RDCleanPathProxy struct {
|
||||
nbClient interface {
|
||||
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
activeConnections map[string]*proxyConnection
|
||||
destinations map[string]string
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type proxyConnection struct {
|
||||
id string
|
||||
destination string
|
||||
rdpConn net.Conn
|
||||
tlsConn *tls.Conn
|
||||
wsHandlers js.Value
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewRDCleanPathProxy creates a new RDCleanPath proxy
|
||||
func NewRDCleanPathProxy(client interface {
|
||||
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}) *RDCleanPathProxy {
|
||||
return &RDCleanPathProxy{
|
||||
nbClient: client,
|
||||
activeConnections: make(map[string]*proxyConnection),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateProxy creates a new proxy endpoint for the given destination
|
||||
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
||||
destination := fmt.Sprintf("%s:%s", hostname, port)
|
||||
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
resolve := args[0]
|
||||
|
||||
go func() {
|
||||
proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections))
|
||||
|
||||
p.mu.Lock()
|
||||
if p.destinations == nil {
|
||||
p.destinations = make(map[string]string)
|
||||
}
|
||||
p.destinations[proxyID] = destination
|
||||
p.mu.Unlock()
|
||||
|
||||
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
|
||||
|
||||
// Register the WebSocket handler for this specific proxy
|
||||
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: requires WebSocket argument")
|
||||
}
|
||||
|
||||
ws := args[0]
|
||||
p.HandleWebSocketConnection(ws, proxyID)
|
||||
return nil
|
||||
}))
|
||||
|
||||
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
|
||||
resolve.Invoke(proxyURL)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
// HandleWebSocketConnection handles incoming WebSocket connections from IronRDP
|
||||
func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string) {
|
||||
p.mu.Lock()
|
||||
destination := p.destinations[proxyID]
|
||||
p.mu.Unlock()
|
||||
|
||||
if destination == "" {
|
||||
log.Errorf("No destination found for proxy ID: %s", proxyID)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// Don't defer cancel here - it will be called by cleanupConnection
|
||||
|
||||
conn := &proxyConnection{
|
||||
id: proxyID,
|
||||
destination: destination,
|
||||
wsHandlers: ws,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
p.activeConnections[proxyID] = conn
|
||||
p.mu.Unlock()
|
||||
|
||||
p.setupWebSocketHandlers(ws, conn)
|
||||
|
||||
log.Infof("RDCleanPath proxy WebSocket connection established for %s", proxyID)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
|
||||
ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
data := args[0]
|
||||
go p.handleWebSocketMessage(conn, data)
|
||||
return nil
|
||||
}))
|
||||
|
||||
ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
log.Debug("WebSocket closed by JavaScript")
|
||||
conn.cancel()
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
|
||||
if !data.InstanceOf(js.Global().Get("Uint8Array")) {
|
||||
return
|
||||
}
|
||||
|
||||
length := data.Get("length").Int()
|
||||
bytes := make([]byte, length)
|
||||
js.CopyBytesToGo(bytes, data)
|
||||
|
||||
if conn.rdpConn != nil || conn.tlsConn != nil {
|
||||
p.forwardToRDP(conn, bytes)
|
||||
return
|
||||
}
|
||||
|
||||
var pdu RDCleanPathPDU
|
||||
_, err := asn1.Unmarshal(bytes, &pdu)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to parse RDCleanPath PDU: %v", err)
|
||||
n := len(bytes)
|
||||
if n > 20 {
|
||||
n = 20
|
||||
}
|
||||
log.Warnf("First %d bytes: %x", n, bytes[:n])
|
||||
|
||||
if len(bytes) > 0 && bytes[0] == 0x03 {
|
||||
log.Debug("Received raw RDP packet instead of RDCleanPath PDU")
|
||||
go p.handleDirectRDP(conn, bytes)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
go p.processRDCleanPathPDU(conn, pdu)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) forwardToRDP(conn *proxyConnection, bytes []byte) {
|
||||
var writer io.Writer
|
||||
var connType string
|
||||
|
||||
if conn.tlsConn != nil {
|
||||
writer = conn.tlsConn
|
||||
connType = "TLS"
|
||||
} else if conn.rdpConn != nil {
|
||||
writer = conn.rdpConn
|
||||
connType = "TCP"
|
||||
} else {
|
||||
log.Error("No RDP connection available")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := writer.Write(bytes); err != nil {
|
||||
log.Errorf("Failed to write to %s: %v", connType, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []byte) {
|
||||
defer p.cleanupConnection(conn)
|
||||
|
||||
destination := conn.destination
|
||||
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||
return
|
||||
}
|
||||
conn.rdpConn = rdpConn
|
||||
|
||||
_, err = rdpConn.Write(firstPacket)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write first packet: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, 1024)
|
||||
n, err := rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.sendToWebSocket(conn, response[:n])
|
||||
|
||||
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
|
||||
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
|
||||
log.Debugf("Cleaning up connection %s", conn.id)
|
||||
conn.cancel()
|
||||
if conn.tlsConn != nil {
|
||||
log.Debug("Closing TLS connection")
|
||||
if err := conn.tlsConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TLS connection: %v", err)
|
||||
}
|
||||
conn.tlsConn = nil
|
||||
}
|
||||
if conn.rdpConn != nil {
|
||||
log.Debug("Closing TCP connection")
|
||||
if err := conn.rdpConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TCP connection: %v", err)
|
||||
}
|
||||
conn.rdpConn = nil
|
||||
}
|
||||
p.mu.Lock()
|
||||
delete(p.activeConnections, conn.id)
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
|
||||
if conn.wsHandlers.Get("receiveFromGo").Truthy() {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(data))
|
||||
js.CopyBytesToJS(uint8Array, data)
|
||||
conn.wsHandlers.Call("receiveFromGo", uint8Array.Get("buffer"))
|
||||
} else if conn.wsHandlers.Get("send").Truthy() {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(data))
|
||||
js.CopyBytesToJS(uint8Array, data)
|
||||
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
|
||||
}
|
||||
}
|
||||
251
client/wasm/internal/rdp/rdcleanpath_handlers.go
Normal file
251
client/wasm/internal/rdp/rdcleanpath_handlers.go
Normal file
@@ -0,0 +1,251 @@
|
||||
//go:build js
|
||||
|
||||
package rdp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/asn1"
|
||||
"io"
|
||||
"syscall/js"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
|
||||
|
||||
if pdu.Version != RDCleanPathVersion {
|
||||
p.sendRDCleanPathError(conn, "Unsupported version")
|
||||
return
|
||||
}
|
||||
|
||||
destination := conn.destination
|
||||
if pdu.Destination != "" {
|
||||
destination = pdu.Destination
|
||||
}
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||
p.sendRDCleanPathError(conn, "Connection failed")
|
||||
p.cleanupConnection(conn)
|
||||
return
|
||||
}
|
||||
conn.rdpConn = rdpConn
|
||||
|
||||
// RDP always starts with X.224 negotiation, then determines if TLS is needed
|
||||
// Modern RDP (since Windows Vista/2008) typically requires TLS
|
||||
// The X.224 Connection Confirm response will indicate if TLS is required
|
||||
// For now, we'll attempt TLS for all connections as it's the modern default
|
||||
p.setupTLSConnection(conn, pdu)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
var x224Response []byte
|
||||
if len(pdu.X224ConnectionPDU) > 0 {
|
||||
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
|
||||
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write X.224 PDU: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to forward X.224")
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, 1024)
|
||||
n, err := conn.rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
|
||||
return
|
||||
}
|
||||
x224Response = response[:n]
|
||||
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
|
||||
}
|
||||
|
||||
tlsConfig := p.getTLSConfigWithValidation(conn)
|
||||
|
||||
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
|
||||
conn.tlsConn = tlsConn
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
log.Errorf("TLS handshake failed: %v", err)
|
||||
p.sendRDCleanPathError(conn, "TLS handshake failed")
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("TLS handshake successful")
|
||||
|
||||
// Certificate validation happens during handshake via VerifyConnection callback
|
||||
var certChain [][]byte
|
||||
connState := tlsConn.ConnectionState()
|
||||
if len(connState.PeerCertificates) > 0 {
|
||||
for _, cert := range connState.PeerCertificates {
|
||||
certChain = append(certChain, cert.Raw)
|
||||
}
|
||||
log.Debugf("Extracted %d certificates from TLS connection", len(certChain))
|
||||
}
|
||||
|
||||
responsePDU := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
ServerAddr: conn.destination,
|
||||
ServerCertChain: certChain,
|
||||
}
|
||||
|
||||
if len(x224Response) > 0 {
|
||||
responsePDU.X224ConnectionPDU = x224Response
|
||||
}
|
||||
|
||||
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||
|
||||
log.Debug("Starting TLS forwarding")
|
||||
go p.forwardConnToWS(conn, conn.tlsConn, "TLS")
|
||||
go p.forwardWSToConn(conn, conn.tlsConn, "TLS")
|
||||
|
||||
<-conn.ctx.Done()
|
||||
log.Debug("TLS connection context done, cleaning up")
|
||||
p.cleanupConnection(conn)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
if len(pdu.X224ConnectionPDU) > 0 {
|
||||
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
|
||||
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write X.224 PDU: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to forward X.224")
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, 1024)
|
||||
n, err := conn.rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
|
||||
return
|
||||
}
|
||||
|
||||
responsePDU := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
X224ConnectionPDU: response[:n],
|
||||
ServerAddr: conn.destination,
|
||||
}
|
||||
|
||||
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||
} else {
|
||||
responsePDU := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
ServerAddr: conn.destination,
|
||||
}
|
||||
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||
}
|
||||
|
||||
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
|
||||
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
|
||||
|
||||
<-conn.ctx.Done()
|
||||
log.Debug("TCP connection context done, cleaning up")
|
||||
p.cleanupConnection(conn)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal RDCleanPath PDU: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Sending RDCleanPath PDU response (%d bytes)", len(data))
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) {
|
||||
pdu := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
Error: []byte(errorMsg),
|
||||
}
|
||||
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal error PDU: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
|
||||
msgChan := make(chan []byte)
|
||||
errChan := make(chan error)
|
||||
|
||||
handler := js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
if len(args) < 1 {
|
||||
errChan <- io.EOF
|
||||
return nil
|
||||
}
|
||||
|
||||
data := args[0]
|
||||
if data.InstanceOf(js.Global().Get("Uint8Array")) {
|
||||
length := data.Get("length").Int()
|
||||
bytes := make([]byte, length)
|
||||
js.CopyBytesToGo(bytes, data)
|
||||
msgChan <- bytes
|
||||
}
|
||||
return nil
|
||||
})
|
||||
defer handler.Release()
|
||||
|
||||
conn.wsHandlers.Set("onceGoMessage", handler)
|
||||
|
||||
select {
|
||||
case msg := <-msgChan:
|
||||
return msg, nil
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-conn.ctx.Done():
|
||||
return nil, conn.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) forwardWSToConn(conn *proxyConnection, dst io.Writer, connType string) {
|
||||
for {
|
||||
if conn.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := p.readWebSocketMessage(conn)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Errorf("Failed to read from WebSocket: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
_, err = dst.Write(msg)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write to %s: %v", connType, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) forwardConnToWS(conn *proxyConnection, src io.Reader, connType string) {
|
||||
buffer := make([]byte, 32*1024)
|
||||
|
||||
for {
|
||||
if conn.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
n, err := src.Read(buffer)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Errorf("Failed to read from %s: %v", connType, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
p.sendToWebSocket(conn, buffer[:n])
|
||||
}
|
||||
}
|
||||
}
|
||||
213
client/wasm/internal/ssh/client.go
Normal file
213
client/wasm/internal/ssh/client.go
Normal file
@@ -0,0 +1,213 @@
|
||||
//go:build js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
)
|
||||
|
||||
const (
|
||||
sshDialTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
func closeWithLog(c io.Closer, resource string) {
|
||||
if c != nil {
|
||||
if err := c.Close(); err != nil {
|
||||
logrus.Debugf("Failed to close %s: %v", resource, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
nbClient *netbird.Client
|
||||
sshClient *ssh.Client
|
||||
session *ssh.Session
|
||||
stdin io.WriteCloser
|
||||
stdout io.Reader
|
||||
stderr io.Reader
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewClient creates a new SSH client
|
||||
func NewClient(nbClient *netbird.Client) *Client {
|
||||
return &Client{
|
||||
nbClient: nbClient,
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes an SSH connection through NetBird network
|
||||
func (c *Client) Connect(host string, port int, username string) error {
|
||||
addr := fmt.Sprintf("%s:%d", host, port)
|
||||
logrus.Infof("SSH: Connecting to %s as %s", addr, username)
|
||||
|
||||
var authMethods []ssh.AuthMethod
|
||||
|
||||
nbConfig, err := c.nbClient.GetConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get NetBird config: %w", err)
|
||||
}
|
||||
if nbConfig.SSHKey == "" {
|
||||
return fmt.Errorf("no NetBird SSH key available - key should be generated during client initialization")
|
||||
}
|
||||
|
||||
signer, err := parseSSHPrivateKey([]byte(nbConfig.SSHKey))
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse NetBird SSH private key: %w", err)
|
||||
}
|
||||
|
||||
pubKey := signer.PublicKey()
|
||||
logrus.Infof("SSH: Using NetBird key authentication with public key type: %s", pubKey.Type())
|
||||
|
||||
authMethods = append(authMethods, ssh.PublicKeys(signer))
|
||||
|
||||
config := &ssh.ClientConfig{
|
||||
User: username,
|
||||
Auth: authMethods,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: sshDialTimeout,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := c.nbClient.Dial(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial %s: %w", addr, err)
|
||||
}
|
||||
|
||||
sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
|
||||
if err != nil {
|
||||
closeWithLog(conn, "connection after handshake error")
|
||||
return fmt.Errorf("SSH handshake: %w", err)
|
||||
}
|
||||
|
||||
c.sshClient = ssh.NewClient(sshConn, chans, reqs)
|
||||
logrus.Infof("SSH: Connected to %s", addr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartSession starts an SSH session with PTY
|
||||
func (c *Client) StartSession(cols, rows int) error {
|
||||
if c.sshClient == nil {
|
||||
return fmt.Errorf("SSH client not connected")
|
||||
}
|
||||
|
||||
session, err := c.sshClient.NewSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create session: %w", err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.session = session
|
||||
|
||||
modes := ssh.TerminalModes{
|
||||
ssh.ECHO: 1,
|
||||
ssh.TTY_OP_ISPEED: 14400,
|
||||
ssh.TTY_OP_OSPEED: 14400,
|
||||
ssh.VINTR: 3,
|
||||
ssh.VQUIT: 28,
|
||||
ssh.VERASE: 127,
|
||||
}
|
||||
|
||||
if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil {
|
||||
closeWithLog(session, "session after PTY error")
|
||||
return fmt.Errorf("PTY request: %w", err)
|
||||
}
|
||||
|
||||
c.stdin, err = session.StdinPipe()
|
||||
if err != nil {
|
||||
closeWithLog(session, "session after stdin error")
|
||||
return fmt.Errorf("get stdin: %w", err)
|
||||
}
|
||||
|
||||
c.stdout, err = session.StdoutPipe()
|
||||
if err != nil {
|
||||
closeWithLog(session, "session after stdout error")
|
||||
return fmt.Errorf("get stdout: %w", err)
|
||||
}
|
||||
|
||||
c.stderr, err = session.StderrPipe()
|
||||
if err != nil {
|
||||
closeWithLog(session, "session after stderr error")
|
||||
return fmt.Errorf("get stderr: %w", err)
|
||||
}
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
closeWithLog(session, "session after shell error")
|
||||
return fmt.Errorf("start shell: %w", err)
|
||||
}
|
||||
|
||||
logrus.Info("SSH: Session started with PTY")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write sends data to the SSH session
|
||||
func (c *Client) Write(data []byte) (int, error) {
|
||||
c.mu.RLock()
|
||||
stdin := c.stdin
|
||||
c.mu.RUnlock()
|
||||
|
||||
if stdin == nil {
|
||||
return 0, fmt.Errorf("SSH session not started")
|
||||
}
|
||||
return stdin.Write(data)
|
||||
}
|
||||
|
||||
// Read reads data from the SSH session
|
||||
func (c *Client) Read(buffer []byte) (int, error) {
|
||||
c.mu.RLock()
|
||||
stdout := c.stdout
|
||||
c.mu.RUnlock()
|
||||
|
||||
if stdout == nil {
|
||||
return 0, fmt.Errorf("SSH session not started")
|
||||
}
|
||||
return stdout.Read(buffer)
|
||||
}
|
||||
|
||||
// Resize updates the terminal size
|
||||
func (c *Client) Resize(cols, rows int) error {
|
||||
c.mu.RLock()
|
||||
session := c.session
|
||||
c.mu.RUnlock()
|
||||
|
||||
if session == nil {
|
||||
return fmt.Errorf("SSH session not started")
|
||||
}
|
||||
return session.WindowChange(rows, cols)
|
||||
}
|
||||
|
||||
// Close closes the SSH connection
|
||||
func (c *Client) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.session != nil {
|
||||
closeWithLog(c.session, "SSH session")
|
||||
c.session = nil
|
||||
}
|
||||
if c.stdin != nil {
|
||||
closeWithLog(c.stdin, "stdin")
|
||||
c.stdin = nil
|
||||
}
|
||||
c.stdout = nil
|
||||
c.stderr = nil
|
||||
|
||||
if c.sshClient != nil {
|
||||
err := c.sshClient.Close()
|
||||
c.sshClient = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
78
client/wasm/internal/ssh/handlers.go
Normal file
78
client/wasm/internal/ssh/handlers.go
Normal file
@@ -0,0 +1,78 @@
|
||||
//go:build js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"io"
|
||||
"syscall/js"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// CreateJSInterface creates a JavaScript interface for the SSH client
|
||||
func CreateJSInterface(client *Client) js.Value {
|
||||
jsInterface := js.Global().Get("Object").Call("create", js.Null())
|
||||
|
||||
jsInterface.Set("write", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
|
||||
data := args[0]
|
||||
var bytes []byte
|
||||
|
||||
if data.Type() == js.TypeString {
|
||||
bytes = []byte(data.String())
|
||||
} else {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(data)
|
||||
length := uint8Array.Get("length").Int()
|
||||
bytes = make([]byte, length)
|
||||
js.CopyBytesToGo(bytes, uint8Array)
|
||||
}
|
||||
|
||||
_, err := client.Write(bytes)
|
||||
return js.ValueOf(err == nil)
|
||||
}))
|
||||
|
||||
jsInterface.Set("resize", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
cols := args[0].Int()
|
||||
rows := args[1].Int()
|
||||
err := client.Resize(cols, rows)
|
||||
return js.ValueOf(err == nil)
|
||||
}))
|
||||
|
||||
jsInterface.Set("close", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
client.Close()
|
||||
return js.Undefined()
|
||||
}))
|
||||
|
||||
go readLoop(client, jsInterface)
|
||||
|
||||
return jsInterface
|
||||
}
|
||||
|
||||
func readLoop(client *Client, jsInterface js.Value) {
|
||||
buffer := make([]byte, 4096)
|
||||
for {
|
||||
n, err := client.Read(buffer)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
logrus.Debugf("SSH read error: %v", err)
|
||||
}
|
||||
if onclose := jsInterface.Get("onclose"); !onclose.IsUndefined() {
|
||||
onclose.Invoke()
|
||||
}
|
||||
client.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if ondata := jsInterface.Get("ondata"); !ondata.IsUndefined() {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(n)
|
||||
js.CopyBytesToJS(uint8Array, buffer[:n])
|
||||
ondata.Invoke(uint8Array)
|
||||
}
|
||||
}
|
||||
}
|
||||
50
client/wasm/internal/ssh/key.go
Normal file
50
client/wasm/internal/ssh/key.go
Normal file
@@ -0,0 +1,50 @@
|
||||
//go:build js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// parseSSHPrivateKey parses a private key in either SSH or PKCS8 format
|
||||
func parseSSHPrivateKey(keyPEM []byte) (ssh.Signer, error) {
|
||||
keyStr := string(keyPEM)
|
||||
if !strings.Contains(keyStr, "-----BEGIN") {
|
||||
keyPEM = []byte("-----BEGIN PRIVATE KEY-----\n" + keyStr + "\n-----END PRIVATE KEY-----")
|
||||
}
|
||||
|
||||
signer, err := ssh.ParsePrivateKey(keyPEM)
|
||||
if err == nil {
|
||||
return signer, nil
|
||||
}
|
||||
logrus.Debugf("SSH: Failed to parse as SSH format: %v", err)
|
||||
|
||||
block, _ := pem.Decode(keyPEM)
|
||||
if block == nil {
|
||||
keyPreview := string(keyPEM)
|
||||
if len(keyPreview) > 100 {
|
||||
keyPreview = keyPreview[:100]
|
||||
}
|
||||
return nil, fmt.Errorf("decode PEM block from key: %s", keyPreview)
|
||||
}
|
||||
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
logrus.Debugf("SSH: Failed to parse as PKCS8: %v", err)
|
||||
if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
|
||||
return ssh.NewSignerFromKey(rsaKey)
|
||||
}
|
||||
if ecKey, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
|
||||
return ssh.NewSignerFromKey(ecKey)
|
||||
}
|
||||
return nil, fmt.Errorf("parse private key: %w", err)
|
||||
}
|
||||
|
||||
return ssh.NewSignerFromKey(key)
|
||||
}
|
||||
Reference in New Issue
Block a user