mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[client,signal,management] Add browser client support (#4415)
This commit is contained in:
171
util/wsproxy/client/dialer_js.go
Normal file
171
util/wsproxy/client/dialer_js.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/util/wsproxy"
|
||||
)
|
||||
|
||||
const dialTimeout = 30 * time.Second
|
||||
|
||||
// websocketConn wraps a JavaScript WebSocket to implement net.Conn
|
||||
type websocketConn struct {
|
||||
ws js.Value
|
||||
remoteAddr string
|
||||
messages chan []byte
|
||||
readBuf []byte
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (c *websocketConn) Read(b []byte) (int, error) {
|
||||
c.mu.Lock()
|
||||
if len(c.readBuf) > 0 {
|
||||
n := copy(b, c.readBuf)
|
||||
c.readBuf = c.readBuf[n:]
|
||||
c.mu.Unlock()
|
||||
return n, nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
select {
|
||||
case data := <-c.messages:
|
||||
n := copy(b, data)
|
||||
if n < len(data) {
|
||||
c.mu.Lock()
|
||||
c.readBuf = data[n:]
|
||||
c.mu.Unlock()
|
||||
}
|
||||
return n, nil
|
||||
case <-c.ctx.Done():
|
||||
return 0, c.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *websocketConn) Write(b []byte) (int, error) {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return 0, c.ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(b))
|
||||
js.CopyBytesToJS(uint8Array, b)
|
||||
c.ws.Call("send", uint8Array)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *websocketConn) Close() error {
|
||||
c.cancel()
|
||||
c.ws.Call("close")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *websocketConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *websocketConn) RemoteAddr() net.Addr {
|
||||
return stringAddr(c.remoteAddr)
|
||||
}
|
||||
func (c *websocketConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *websocketConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *websocketConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// stringAddr is a simple net.Addr that returns a string
|
||||
type stringAddr string
|
||||
|
||||
func (s stringAddr) Network() string { return "tcp" }
|
||||
func (s stringAddr) String() string { return string(s) }
|
||||
|
||||
// WithWebSocketDialer returns a gRPC dial option that uses WebSocket transport for JS/WASM environments.
|
||||
func WithWebSocketDialer(tlsEnabled bool) grpc.DialOption {
|
||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
scheme := "wss"
|
||||
if !tlsEnabled {
|
||||
scheme = "ws"
|
||||
}
|
||||
wsURL := fmt.Sprintf("%s://%s%s", scheme, addr, wsproxy.ProxyPath)
|
||||
|
||||
ws := js.Global().Get("WebSocket").New(wsURL)
|
||||
|
||||
connCtx, connCancel := context.WithCancel(context.Background())
|
||||
conn := &websocketConn{
|
||||
ws: ws,
|
||||
remoteAddr: addr,
|
||||
messages: make(chan []byte, 100),
|
||||
ctx: connCtx,
|
||||
cancel: connCancel,
|
||||
}
|
||||
|
||||
ws.Set("binaryType", "arraybuffer")
|
||||
|
||||
openCh := make(chan struct{})
|
||||
errorCh := make(chan error, 1)
|
||||
|
||||
ws.Set("onopen", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
close(openCh)
|
||||
return nil
|
||||
}))
|
||||
|
||||
ws.Set("onerror", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
select {
|
||||
case errorCh <- wsproxy.ErrConnectionFailed:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
}))
|
||||
|
||||
ws.Set("onmessage", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
event := args[0]
|
||||
data := event.Get("data")
|
||||
|
||||
uint8Array := js.Global().Get("Uint8Array").New(data)
|
||||
length := uint8Array.Get("length").Int()
|
||||
bytes := make([]byte, length)
|
||||
js.CopyBytesToGo(bytes, uint8Array)
|
||||
|
||||
select {
|
||||
case conn.messages <- bytes:
|
||||
default:
|
||||
log.Warnf("gRPC WebSocket message dropped for %s - buffer full", addr)
|
||||
}
|
||||
return nil
|
||||
}))
|
||||
|
||||
ws.Set("onclose", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
conn.cancel()
|
||||
return nil
|
||||
}))
|
||||
|
||||
select {
|
||||
case <-openCh:
|
||||
return conn, nil
|
||||
case err := <-errorCh:
|
||||
return nil, err
|
||||
case <-ctx.Done():
|
||||
ws.Call("close")
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(dialTimeout):
|
||||
ws.Call("close")
|
||||
return nil, wsproxy.ErrConnectionTimeout
|
||||
}
|
||||
})
|
||||
}
|
||||
13
util/wsproxy/constants.go
Normal file
13
util/wsproxy/constants.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package wsproxy
|
||||
|
||||
import "errors"
|
||||
|
||||
// ProxyPath is the standard path where the WebSocket proxy is mounted on servers.
|
||||
const ProxyPath = "/ws-proxy"
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
ErrConnectionTimeout = errors.New("WebSocket connection timeout")
|
||||
ErrConnectionFailed = errors.New("WebSocket connection failed")
|
||||
ErrBackendUnavailable = errors.New("backend unavailable")
|
||||
)
|
||||
118
util/wsproxy/server/metrics.go
Normal file
118
util/wsproxy/server/metrics.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
// MetricsRecorder defines the interface for recording proxy metrics
|
||||
type MetricsRecorder interface {
|
||||
// RecordConnection records a new connection
|
||||
RecordConnection(ctx context.Context)
|
||||
// RecordDisconnection records a connection closing
|
||||
RecordDisconnection(ctx context.Context)
|
||||
// RecordBytesTransferred records bytes transferred in a direction
|
||||
RecordBytesTransferred(ctx context.Context, direction string, bytes int64)
|
||||
// RecordError records an error
|
||||
RecordError(ctx context.Context, errorType string)
|
||||
}
|
||||
|
||||
// NoOpMetricsRecorder is a no-op implementation that does nothing
|
||||
type NoOpMetricsRecorder struct{}
|
||||
|
||||
func (n NoOpMetricsRecorder) RecordConnection(ctx context.Context) {
|
||||
// no-op
|
||||
}
|
||||
func (n NoOpMetricsRecorder) RecordDisconnection(ctx context.Context) {
|
||||
// no-op
|
||||
}
|
||||
func (n NoOpMetricsRecorder) RecordBytesTransferred(ctx context.Context, direction string, bytes int64) {
|
||||
// no-op
|
||||
}
|
||||
func (n NoOpMetricsRecorder) RecordError(ctx context.Context, errorType string) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
// Recorder implements MetricsRecorder using OpenTelemetry
|
||||
type Recorder struct {
|
||||
activeConnections metric.Int64UpDownCounter
|
||||
bytesTransferred metric.Int64Counter
|
||||
errors metric.Int64Counter
|
||||
}
|
||||
|
||||
// NewMetricsRecorder creates a new OpenTelemetry-based metrics recorder
|
||||
func NewMetricsRecorder(meter metric.Meter) (*Recorder, error) {
|
||||
activeConnections, err := meter.Int64UpDownCounter(
|
||||
"wsproxy_active_connections",
|
||||
metric.WithDescription("Number of active WebSocket proxy connections"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bytesTransferred, err := meter.Int64Counter(
|
||||
"wsproxy_bytes_transferred_total",
|
||||
metric.WithDescription("Total bytes transferred through the proxy"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
errors, err := meter.Int64Counter(
|
||||
"wsproxy_errors_total",
|
||||
metric.WithDescription("Total number of proxy errors"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Recorder{
|
||||
activeConnections: activeConnections,
|
||||
bytesTransferred: bytesTransferred,
|
||||
errors: errors,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (o *Recorder) RecordConnection(ctx context.Context) {
|
||||
o.activeConnections.Add(ctx, 1)
|
||||
}
|
||||
|
||||
func (o *Recorder) RecordDisconnection(ctx context.Context) {
|
||||
o.activeConnections.Add(ctx, -1)
|
||||
}
|
||||
|
||||
func (o *Recorder) RecordBytesTransferred(ctx context.Context, direction string, bytes int64) {
|
||||
o.bytesTransferred.Add(ctx, bytes, metric.WithAttributes(
|
||||
attribute.String("direction", direction),
|
||||
))
|
||||
}
|
||||
|
||||
func (o *Recorder) RecordError(ctx context.Context, errorType string) {
|
||||
o.errors.Add(ctx, 1, metric.WithAttributes(
|
||||
attribute.String("error_type", errorType),
|
||||
))
|
||||
}
|
||||
|
||||
// Option defines functional options for the Proxy
|
||||
type Option func(*Config)
|
||||
|
||||
// WithMetrics sets a custom metrics recorder
|
||||
func WithMetrics(recorder MetricsRecorder) Option {
|
||||
return func(c *Config) {
|
||||
c.MetricsRecorder = recorder
|
||||
}
|
||||
}
|
||||
|
||||
// WithOTelMeter creates and sets an OpenTelemetry metrics recorder
|
||||
func WithOTelMeter(meter metric.Meter) Option {
|
||||
return func(c *Config) {
|
||||
if recorder, err := NewMetricsRecorder(meter); err == nil {
|
||||
c.MetricsRecorder = recorder
|
||||
} else {
|
||||
log.Warnf("Failed to create OTel metrics recorder: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
227
util/wsproxy/server/proxy.go
Normal file
227
util/wsproxy/server/proxy.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/util/wsproxy"
|
||||
)
|
||||
|
||||
const (
|
||||
dialTimeout = 10 * time.Second
|
||||
bufferSize = 32 * 1024
|
||||
)
|
||||
|
||||
// Config contains the configuration for the WebSocket proxy.
|
||||
type Config struct {
|
||||
LocalGRPCAddr netip.AddrPort
|
||||
Path string
|
||||
MetricsRecorder MetricsRecorder
|
||||
}
|
||||
|
||||
// Proxy handles WebSocket to TCP proxying for gRPC connections.
|
||||
type Proxy struct {
|
||||
config Config
|
||||
metrics MetricsRecorder
|
||||
}
|
||||
|
||||
// New creates a new WebSocket proxy instance with optional configuration
|
||||
func New(localGRPCAddr netip.AddrPort, opts ...Option) *Proxy {
|
||||
config := Config{
|
||||
LocalGRPCAddr: localGRPCAddr,
|
||||
Path: wsproxy.ProxyPath,
|
||||
MetricsRecorder: NoOpMetricsRecorder{}, // Default to no-op
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(&config)
|
||||
}
|
||||
|
||||
return &Proxy{
|
||||
config: config,
|
||||
metrics: config.MetricsRecorder,
|
||||
}
|
||||
}
|
||||
|
||||
// Handler returns an http.Handler that proxies WebSocket connections to the local gRPC server.
|
||||
func (p *Proxy) Handler() http.Handler {
|
||||
return http.HandlerFunc(p.handleWebSocket)
|
||||
}
|
||||
|
||||
func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
p.metrics.RecordConnection(ctx)
|
||||
defer p.metrics.RecordDisconnection(ctx)
|
||||
|
||||
log.Debugf("WebSocket proxy handling connection from %s, forwarding to %s", r.RemoteAddr, p.config.LocalGRPCAddr)
|
||||
acceptOptions := &websocket.AcceptOptions{
|
||||
OriginPatterns: []string{"*"},
|
||||
}
|
||||
|
||||
wsConn, err := websocket.Accept(w, r, acceptOptions)
|
||||
if err != nil {
|
||||
p.metrics.RecordError(ctx, "websocket_accept_failed")
|
||||
log.Errorf("WebSocket upgrade failed from %s: %v", r.RemoteAddr, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := wsConn.Close(websocket.StatusNormalClosure, ""); err != nil {
|
||||
log.Debugf("Failed to close WebSocket: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Debugf("WebSocket proxy attempting to connect to local gRPC at %s", p.config.LocalGRPCAddr)
|
||||
tcpConn, err := net.DialTimeout("tcp", p.config.LocalGRPCAddr.String(), dialTimeout)
|
||||
if err != nil {
|
||||
p.metrics.RecordError(ctx, "tcp_dial_failed")
|
||||
log.Warnf("Failed to connect to local gRPC server at %s: %v", p.config.LocalGRPCAddr, err)
|
||||
if err := wsConn.Close(websocket.StatusInternalError, "Backend unavailable"); err != nil {
|
||||
log.Debugf("Failed to close WebSocket after connection failure: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := tcpConn.Close(); err != nil {
|
||||
log.Debugf("Failed to close TCP connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Debugf("WebSocket proxy established: client %s -> local gRPC %s", r.RemoteAddr, p.config.LocalGRPCAddr)
|
||||
|
||||
p.proxyData(ctx, wsConn, tcpConn)
|
||||
}
|
||||
|
||||
func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, tcpConn net.Conn) {
|
||||
proxyCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go p.wsToTCP(proxyCtx, cancel, &wg, wsConn, tcpConn)
|
||||
go p.tcpToWS(proxyCtx, cancel, &wg, wsConn, tcpConn)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
log.Tracef("Proxy data transfer completed, both goroutines terminated")
|
||||
case <-proxyCtx.Done():
|
||||
log.Tracef("Proxy data transfer cancelled, forcing connection closure")
|
||||
|
||||
if err := wsConn.Close(websocket.StatusGoingAway, "proxy cancelled"); err != nil {
|
||||
log.Tracef("Error closing WebSocket during cancellation: %v", err)
|
||||
}
|
||||
if err := tcpConn.Close(); err != nil {
|
||||
log.Tracef("Error closing TCP connection during cancellation: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
log.Tracef("Goroutines terminated after forced connection closure")
|
||||
case <-time.After(2 * time.Second):
|
||||
log.Tracef("Goroutines did not terminate within timeout after connection closure")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) {
|
||||
defer wg.Done()
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
msgType, data, err := wsConn.Read(ctx)
|
||||
if err != nil {
|
||||
switch {
|
||||
case ctx.Err() != nil:
|
||||
log.Debugf("wsToTCP goroutine terminating due to context cancellation")
|
||||
case websocket.CloseStatus(err) == websocket.StatusNormalClosure:
|
||||
log.Debugf("WebSocket closed normally")
|
||||
default:
|
||||
p.metrics.RecordError(ctx, "websocket_read_error")
|
||||
log.Errorf("WebSocket read error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if msgType != websocket.MessageBinary {
|
||||
log.Warnf("Unexpected WebSocket message type: %v", msgType)
|
||||
continue
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
log.Tracef("wsToTCP goroutine terminating due to context cancellation before TCP write")
|
||||
return
|
||||
}
|
||||
|
||||
if err := tcpConn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
log.Debugf("Failed to set TCP write deadline: %v", err)
|
||||
}
|
||||
|
||||
n, err := tcpConn.Write(data)
|
||||
if err != nil {
|
||||
p.metrics.RecordError(ctx, "tcp_write_error")
|
||||
log.Errorf("TCP write error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.metrics.RecordBytesTransferred(ctx, "ws_to_tcp", int64(n))
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) tcpToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) {
|
||||
defer wg.Done()
|
||||
defer cancel()
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
for {
|
||||
if err := tcpConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
log.Debugf("Failed to set TCP read deadline: %v", err)
|
||||
}
|
||||
n, err := tcpConn.Read(buf)
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
log.Tracef("tcpToWS goroutine terminating due to context cancellation")
|
||||
return
|
||||
}
|
||||
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
continue
|
||||
}
|
||||
|
||||
if err != io.EOF {
|
||||
log.Errorf("TCP read error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
log.Tracef("tcpToWS goroutine terminating due to context cancellation before WebSocket write")
|
||||
return
|
||||
}
|
||||
|
||||
if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil {
|
||||
p.metrics.RecordError(ctx, "websocket_write_error")
|
||||
log.Errorf("WebSocket write error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.metrics.RecordBytesTransferred(ctx, "tcp_to_ws", int64(n))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user