From 6b08e89c7bb318f51c24b467a34fe05e13e17fcf Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 11 May 2026 16:59:33 +0900 Subject: [PATCH] [relay] Preserve non-standard port in WS dialer URL prep (#6061) --- shared/relay/client/dialer/ws/ws.go | 32 ++++++---- shared/relay/client/dialer/ws/ws_test.go | 76 ++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 12 deletions(-) create mode 100644 shared/relay/client/dialer/ws/ws_test.go diff --git a/shared/relay/client/dialer/ws/ws.go b/shared/relay/client/dialer/ws/ws.go index 301486514..8a13ba126 100644 --- a/shared/relay/client/dialer/ws/ws.go +++ b/shared/relay/client/dialer/ws/ws.go @@ -9,7 +9,6 @@ import ( "net" "net/http" "net/url" - "strings" "github.com/coder/websocket" log "github.com/sirupsen/logrus" @@ -35,13 +34,7 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, var underlying net.Conn opts := createDialOptions(serverName, &underlying) - parsedURL, err := url.Parse(wsURL) - if err != nil { - return nil, err - } - parsedURL.Path = relay.WebSocketURLPath - - wsConn, resp, err := websocket.Dial(ctx, parsedURL.String(), opts) + wsConn, resp, err := websocket.Dial(ctx, wsURL, opts) if err != nil { if errors.Is(err, context.Canceled) { return nil, err @@ -57,12 +50,27 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, return conn, nil } +// prepareURL rewrites a rel://host[:port] or rels://host[:port] address into a +// ws://host[:port]/relay or wss://host[:port]/relay URL, preserving any +// non-standard port from the input. func prepareURL(address string) (string, error) { - if !strings.HasPrefix(address, "rel:") && !strings.HasPrefix(address, "rels:") { - return "", fmt.Errorf("unsupported scheme: %s", address) + parsed, err := url.Parse(address) + if err != nil { + return "", fmt.Errorf("parse relay address %q: %w", address, err) } - - return strings.Replace(address, "rel", "ws", 1), nil + switch parsed.Scheme { + case "rel": + parsed.Scheme = "ws" + case "rels": + parsed.Scheme = "wss" + default: + return "", fmt.Errorf("unsupported scheme: %s", parsed.Scheme) + } + if parsed.Host == "" { + return "", fmt.Errorf("missing host in relay address %q", address) + } + parsed.Path = relay.WebSocketURLPath + return parsed.String(), nil } // httpClientNbDialer builds the http client used by the websocket library. diff --git a/shared/relay/client/dialer/ws/ws_test.go b/shared/relay/client/dialer/ws/ws_test.go new file mode 100644 index 000000000..7357adbc0 --- /dev/null +++ b/shared/relay/client/dialer/ws/ws_test.go @@ -0,0 +1,76 @@ +package ws + +import ( + "testing" +) + +func TestPrepareURL(t *testing.T) { + tests := []struct { + name string + input string + want string + wantErr bool + }{ + { + name: "rel scheme with non-standard port", + input: "rel://test-domain-2:45678", + want: "ws://test-domain-2:45678/relay", + }, + { + name: "rels scheme with non-standard port", + input: "rels://test-domain-2:45678", + want: "wss://test-domain-2:45678/relay", + }, + { + name: "rel scheme without port", + input: "rel://test-domain-2", + want: "ws://test-domain-2/relay", + }, + { + name: "rels scheme without port", + input: "rels://test-domain-2", + want: "wss://test-domain-2/relay", + }, + { + name: "rel scheme with IP and port", + input: "rel://1.2.3.4:45678", + want: "ws://1.2.3.4:45678/relay", + }, + { + name: "rel scheme with hostname starting with rel", + input: "rel://relay.example.com:45678", + want: "ws://relay.example.com:45678/relay", + }, + { + name: "rel scheme with IPv6 and port", + input: "rel://[2001:db8::1]:45678", + want: "ws://[2001:db8::1]:45678/relay", + }, + { + name: "rels scheme with IPv6 loopback and port", + input: "rels://[::1]:45678", + want: "wss://[::1]:45678/relay", + }, + { + name: "unsupported scheme", + input: "http://test-domain-2:45678", + wantErr: true, + }, + { + name: "no scheme", + input: "test-domain-2:45678", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := prepareURL(tt.input) + if (err != nil) != tt.wantErr { + t.Fatalf("prepareURL(%q) err = %v, wantErr %v", tt.input, err, tt.wantErr) + } + if got != tt.want { + t.Errorf("prepareURL(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +}