diff --git a/shared/relay/client/dialer/ws/ws.go b/shared/relay/client/dialer/ws/ws.go index 37b189e05..8fd6a7f01 100644 --- a/shared/relay/client/dialer/ws/ws.go +++ b/shared/relay/client/dialer/ws/ws.go @@ -9,7 +9,6 @@ import ( "net" "net/http" "net/url" - "strings" "github.com/coder/websocket" log "github.com/sirupsen/logrus" @@ -34,13 +33,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { opts := createDialOptions() - parsedURL, err := url.Parse(wsURL) - if err != nil { - return nil, err - } - parsedURL.Path = relay.WebSocketURLPath - - wsConn, resp, err := websocket.Dial(ctx, parsedURL.String(), opts) + wsConn, resp, err := websocket.Dial(ctx, wsURL, opts) if err != nil { if errors.Is(err, context.Canceled) { return nil, err @@ -56,12 +49,24 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { return conn, nil } +// prepareURL rewrites a rel://host[:port] or rels://host[:port] address into a +// ws://host[:port]/relay or wss://host[:port]/relay URL, preserving any +// non-standard port from the input. func prepareURL(address string) (string, error) { - if !strings.HasPrefix(address, "rel:") && !strings.HasPrefix(address, "rels:") { - return "", fmt.Errorf("unsupported scheme: %s", address) + parsed, err := url.Parse(address) + if err != nil { + return "", fmt.Errorf("parse relay address %q: %w", address, err) } - - return strings.Replace(address, "rel", "ws", 1), nil + switch parsed.Scheme { + case "rel": + parsed.Scheme = "ws" + case "rels": + parsed.Scheme = "wss" + default: + return "", fmt.Errorf("unsupported scheme: %s", parsed.Scheme) + } + parsed.Path = relay.WebSocketURLPath + return parsed.String(), nil } func httpClientNbDialer() *http.Client { diff --git a/shared/relay/client/dialer/ws/ws_test.go b/shared/relay/client/dialer/ws/ws_test.go new file mode 100644 index 000000000..67e3cd227 --- /dev/null +++ b/shared/relay/client/dialer/ws/ws_test.go @@ -0,0 +1,66 @@ +package ws + +import ( + "testing" +) + +func TestPrepareURL(t *testing.T) { + tests := []struct { + name string + input string + want string + wantErr bool + }{ + { + name: "rel scheme with non-standard port", + input: "rel://test-domain-2:45678", + want: "ws://test-domain-2:45678/relay", + }, + { + name: "rels scheme with non-standard port", + input: "rels://test-domain-2:45678", + want: "wss://test-domain-2:45678/relay", + }, + { + name: "rel scheme without port", + input: "rel://test-domain-2", + want: "ws://test-domain-2/relay", + }, + { + name: "rels scheme without port", + input: "rels://test-domain-2", + want: "wss://test-domain-2/relay", + }, + { + name: "rel scheme with IP and port", + input: "rel://1.2.3.4:45678", + want: "ws://1.2.3.4:45678/relay", + }, + { + name: "rel scheme with hostname starting with rel", + input: "rel://relay.example.com:45678", + want: "ws://relay.example.com:45678/relay", + }, + { + name: "unsupported scheme", + input: "http://test-domain-2:45678", + wantErr: true, + }, + { + name: "no scheme", + input: "test-domain-2:45678", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := prepareURL(tt.input) + if (err != nil) != tt.wantErr { + t.Fatalf("prepareURL(%q) err = %v, wantErr %v", tt.input, err, tt.wantErr) + } + if got != tt.want { + t.Errorf("prepareURL(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +}