Rewrite relay WS dialer URL prep using net/url to preserve non-standard ports

This commit is contained in:
Viktor Liu
2026-05-04 11:30:45 +02:00
parent c4b2da4c92
commit a409678fe5
2 changed files with 83 additions and 12 deletions

View File

@@ -9,7 +9,6 @@ import (
"net"
"net/http"
"net/url"
"strings"
"github.com/coder/websocket"
log "github.com/sirupsen/logrus"
@@ -34,13 +33,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
opts := createDialOptions()
parsedURL, err := url.Parse(wsURL)
if err != nil {
return nil, err
}
parsedURL.Path = relay.WebSocketURLPath
wsConn, resp, err := websocket.Dial(ctx, parsedURL.String(), opts)
wsConn, resp, err := websocket.Dial(ctx, wsURL, opts)
if err != nil {
if errors.Is(err, context.Canceled) {
return nil, err
@@ -56,12 +49,24 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
return conn, nil
}
// prepareURL rewrites a rel://host[:port] or rels://host[:port] address into a
// ws://host[:port]/relay or wss://host[:port]/relay URL, preserving any
// non-standard port from the input.
func prepareURL(address string) (string, error) {
if !strings.HasPrefix(address, "rel:") && !strings.HasPrefix(address, "rels:") {
return "", fmt.Errorf("unsupported scheme: %s", address)
parsed, err := url.Parse(address)
if err != nil {
return "", fmt.Errorf("parse relay address %q: %w", address, err)
}
return strings.Replace(address, "rel", "ws", 1), nil
switch parsed.Scheme {
case "rel":
parsed.Scheme = "ws"
case "rels":
parsed.Scheme = "wss"
default:
return "", fmt.Errorf("unsupported scheme: %s", parsed.Scheme)
}
parsed.Path = relay.WebSocketURLPath
return parsed.String(), nil
}
func httpClientNbDialer() *http.Client {

View File

@@ -0,0 +1,66 @@
package ws
import (
"testing"
)
func TestPrepareURL(t *testing.T) {
tests := []struct {
name string
input string
want string
wantErr bool
}{
{
name: "rel scheme with non-standard port",
input: "rel://test-domain-2:45678",
want: "ws://test-domain-2:45678/relay",
},
{
name: "rels scheme with non-standard port",
input: "rels://test-domain-2:45678",
want: "wss://test-domain-2:45678/relay",
},
{
name: "rel scheme without port",
input: "rel://test-domain-2",
want: "ws://test-domain-2/relay",
},
{
name: "rels scheme without port",
input: "rels://test-domain-2",
want: "wss://test-domain-2/relay",
},
{
name: "rel scheme with IP and port",
input: "rel://1.2.3.4:45678",
want: "ws://1.2.3.4:45678/relay",
},
{
name: "rel scheme with hostname starting with rel",
input: "rel://relay.example.com:45678",
want: "ws://relay.example.com:45678/relay",
},
{
name: "unsupported scheme",
input: "http://test-domain-2:45678",
wantErr: true,
},
{
name: "no scheme",
input: "test-domain-2:45678",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := prepareURL(tt.input)
if (err != nil) != tt.wantErr {
t.Fatalf("prepareURL(%q) err = %v, wantErr %v", tt.input, err, tt.wantErr)
}
if got != tt.want {
t.Errorf("prepareURL(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}