[client,signal,management] Adjust browser client ws proxy paths (#4565)

This commit is contained in:
Viktor Liu
2025-10-02 00:10:47 +02:00
committed by GitHub
parent b5daec3b51
commit 4d7e59f199
10 changed files with 27 additions and 14 deletions

View File

@@ -25,8 +25,9 @@ func Backoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(b, ctx) return backoff.WithContext(b, ctx)
} }
// CreateConnection creates a gRPC client connection with the appropriate transport options // CreateConnection creates a gRPC client connection with the appropriate transport options.
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) { // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled { if tlsEnabled {
certPool, err := x509.SystemCertPool() certPool, err := x509.SystemCertPool()
@@ -49,7 +50,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.
connCtx, connCtx,
addr, addr,
transportOption, transportOption,
WithCustomDialer(tlsEnabled), WithCustomDialer(tlsEnabled, component),
grpc.WithBlock(), grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{ grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second, Time: 30 * time.Second,

View File

@@ -18,7 +18,7 @@ import (
nbnet "github.com/netbirdio/netbird/client/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
func WithCustomDialer(tlsEnabled bool) grpc.DialOption { func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" {
currentUser, err := user.Current() currentUser, err := user.Current()

View File

@@ -7,6 +7,7 @@ import (
) )
// WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments. // WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments.
func WithCustomDialer(tlsEnabled bool) grpc.DialOption { // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
return client.WithWebSocketDialer(tlsEnabled) func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
return client.WithWebSocketDialer(tlsEnabled, component)
} }

View File

@@ -23,6 +23,7 @@ import (
nbgrpc "github.com/netbirdio/netbird/client/grpc" nbgrpc "github.com/netbirdio/netbird/client/grpc"
"github.com/netbirdio/netbird/flow/proto" "github.com/netbirdio/netbird/flow/proto"
"github.com/netbirdio/netbird/util/embeddedroots" "github.com/netbirdio/netbird/util/embeddedroots"
"github.com/netbirdio/netbird/util/wsproxy"
) )
type GRPCClient struct { type GRPCClient struct {
@@ -54,7 +55,7 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
} }
opts = append(opts, opts = append(opts,
nbgrpc.WithCustomDialer(tlsEnabled), nbgrpc.WithCustomDialer(tlsEnabled, wsproxy.FlowComponent),
grpc.WithIdleTimeout(interval*2), grpc.WithIdleTimeout(interval*2),
grpc.WithKeepaliveParams(keepalive.ClientParameters{ grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second, Time: 30 * time.Second,

View File

@@ -259,7 +259,7 @@ func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Hand
case request.ProtoMajor == 2 && (strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") || case request.ProtoMajor == 2 && (strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") ||
strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto")): strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto")):
gRPCHandler.ServeHTTP(writer, request) gRPCHandler.ServeHTTP(writer, request)
case request.URL.Path == wsproxy.ProxyPath: case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
wsProxy.Handler().ServeHTTP(writer, request) wsProxy.Handler().ServeHTTP(writer, request)
default: default:
httpHandler.ServeHTTP(writer, request) httpHandler.ServeHTTP(writer, request)

View File

@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util/wsproxy"
) )
const ConnectTimeout = 10 * time.Second const ConnectTimeout = 10 * time.Second
@@ -52,7 +53,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
operation := func() error { operation := func() error {
var err error var err error
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled) conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent)
if err != nil { if err != nil {
log.Printf("createConnection error: %v", err) log.Printf("createConnection error: %v", err)
return err return err

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/shared/management/client" "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/util/wsproxy"
) )
// ConnStateNotifier is a wrapper interface of the status recorder // ConnStateNotifier is a wrapper interface of the status recorder
@@ -57,7 +58,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
operation := func() error { operation := func() error {
var err error var err error
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled) conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent)
if err != nil { if err != nil {
log.Printf("createConnection error: %v", err) log.Printf("createConnection error: %v", err)
return err return err

View File

@@ -258,7 +258,7 @@ func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch { switch {
case r.URL.Path == wsproxy.ProxyPath: case r.URL.Path == wsproxy.ProxyPath+wsproxy.SignalComponent:
wsProxy.Handler().ServeHTTP(w, r) wsProxy.Handler().ServeHTTP(w, r)
default: default:
grpcServer.ServeHTTP(w, r) grpcServer.ServeHTTP(w, r)

View File

@@ -96,13 +96,14 @@ func (s stringAddr) Network() string { return "tcp" }
func (s stringAddr) String() string { return string(s) } func (s stringAddr) String() string { return string(s) }
// WithWebSocketDialer returns a gRPC dial option that uses WebSocket transport for JS/WASM environments. // WithWebSocketDialer returns a gRPC dial option that uses WebSocket transport for JS/WASM environments.
func WithWebSocketDialer(tlsEnabled bool) grpc.DialOption { // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func WithWebSocketDialer(tlsEnabled bool, component string) grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
scheme := "wss" scheme := "wss"
if !tlsEnabled { if !tlsEnabled {
scheme = "ws" scheme = "ws"
} }
wsURL := fmt.Sprintf("%s://%s%s", scheme, addr, wsproxy.ProxyPath) wsURL := fmt.Sprintf("%s://%s%s%s", scheme, addr, wsproxy.ProxyPath, component)
ws := js.Global().Get("WebSocket").New(wsURL) ws := js.Global().Get("WebSocket").New(wsURL)

View File

@@ -2,9 +2,16 @@ package wsproxy
import "errors" import "errors"
// ProxyPath is the standard path where the WebSocket proxy is mounted on servers. // ProxyPath is the base path where the WebSocket proxy is mounted on servers.
const ProxyPath = "/ws-proxy" const ProxyPath = "/ws-proxy"
// Component paths that are appended to ProxyPath
const (
ManagementComponent = "/management"
SignalComponent = "/signal"
FlowComponent = "/flow"
)
// Common errors // Common errors
var ( var (
ErrConnectionTimeout = errors.New("WebSocket connection timeout") ErrConnectionTimeout = errors.New("WebSocket connection timeout")