Add relayServerIP field to signal for foreign-relay fallback dial

This commit is contained in:
Viktor Liu
2026-04-27 18:01:46 +02:00
parent d6f08e4840
commit e7bd62f58c
23 changed files with 786 additions and 122 deletions

View File

@@ -23,7 +23,7 @@ func (d Dialer) Protocol() string {
return Network
}
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
quicURL, err := prepareURL(address)
if err != nil {
return nil, err
@@ -32,11 +32,14 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
// Get the base TLS config
tlsClientConfig := quictls.ClientQUICTLSConfig()
// Set ServerName to hostname if not an IP address
host, _, splitErr := net.SplitHostPort(quicURL)
if splitErr == nil && net.ParseIP(host) == nil {
// It's a hostname, not an IP - modify directly
tlsClientConfig.ServerName = host
switch {
case serverName != "" && net.ParseIP(serverName) == nil:
tlsClientConfig.ServerName = serverName
case serverName == "":
host, _, splitErr := net.SplitHostPort(quicURL)
if splitErr == nil && net.ParseIP(host) == nil {
tlsClientConfig.ServerName = host
}
}
quicConfig := &quic.Config{

View File

@@ -14,7 +14,9 @@ const (
)
type DialeFn interface {
Dial(ctx context.Context, address string) (net.Conn, error)
// Dial connects to address. serverName, when non-empty, overrides the TLS
// ServerName used for SNI/cert validation. Empty means derive from address.
Dial(ctx context.Context, address, serverName string) (net.Conn, error)
Protocol() string
}
@@ -27,6 +29,7 @@ type dialResult struct {
type RaceDial struct {
log *log.Entry
serverURL string
serverName string
dialerFns []DialeFn
connectionTimeout time.Duration
}
@@ -40,6 +43,16 @@ func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL stri
}
}
// WithServerName sets a TLS SNI/cert validation override. Used when serverURL
// contains an IP literal but the cert is issued for a different hostname.
//
// Mutates the receiver and is not safe for concurrent reconfiguration; a
// RaceDial is intended to be constructed per dial and discarded.
func (r *RaceDial) WithServerName(serverName string) *RaceDial {
r.serverName = serverName
return r
}
func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
connChan := make(chan dialResult, len(r.dialerFns))
winnerConn := make(chan net.Conn, 1)
@@ -64,7 +77,7 @@ func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dia
defer cancel()
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
conn, err := dfn.Dial(ctx, r.serverURL)
conn, err := dfn.Dial(ctx, r.serverURL, r.serverName)
connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err}
}

View File

@@ -28,7 +28,7 @@ type MockDialer struct {
protocolStr string
}
func (m *MockDialer) Dial(ctx context.Context, address string) (net.Conn, error) {
func (m *MockDialer) Dial(ctx context.Context, address, _ string) (net.Conn, error) {
return m.dialFunc(ctx, address)
}

View File

@@ -12,14 +12,24 @@ import (
type Conn struct {
ctx context.Context
*websocket.Conn
remoteAddr WebsocketAddr
remoteAddr net.Addr
}
func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn {
// NewConn builds a relay ws.Conn. underlying is the raw TCP/TLS conn captured
// from the http transport's DialContext; when set, RemoteAddr returns its
// peer address (an IP literal). When nil (e.g. wasm), RemoteAddr falls back
// to the dial-time URL.
func NewConn(wsConn *websocket.Conn, serverAddress string, underlying net.Conn) net.Conn {
var addr net.Addr = WebsocketAddr{serverAddress}
if underlying != nil {
if ra := underlying.RemoteAddr(); ra != nil {
addr = ra
}
}
return &Conn{
ctx: context.Background(),
Conn: wsConn,
remoteAddr: WebsocketAddr{serverAddress},
remoteAddr: addr,
}
}

View File

@@ -2,10 +2,14 @@
package ws
import "github.com/coder/websocket"
import (
"net"
func createDialOptions() *websocket.DialOptions {
"github.com/coder/websocket"
)
func createDialOptions(serverName string, underlyingOut *net.Conn) *websocket.DialOptions {
return &websocket.DialOptions{
HTTPClient: httpClientNbDialer(),
HTTPClient: httpClientNbDialer(serverName, underlyingOut),
}
}

View File

@@ -2,9 +2,13 @@
package ws
import "github.com/coder/websocket"
import (
"net"
func createDialOptions() *websocket.DialOptions {
// WASM version doesn't support HTTPClient
"github.com/coder/websocket"
)
func createDialOptions(_ string, _ *net.Conn) *websocket.DialOptions {
// WASM version doesn't support HTTPClient or custom TLS config.
return &websocket.DialOptions{}
}

View File

@@ -26,13 +26,14 @@ func (d Dialer) Protocol() string {
return "WS"
}
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
wsURL, err := prepareURL(address)
if err != nil {
return nil, err
}
opts := createDialOptions()
var underlying net.Conn
opts := createDialOptions(serverName, &underlying)
parsedURL, err := url.Parse(wsURL)
if err != nil {
@@ -52,7 +53,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
_ = resp.Body.Close()
}
conn := NewConn(wsConn, address)
conn := NewConn(wsConn, address, underlying)
return conn, nil
}
@@ -64,7 +65,10 @@ func prepareURL(address string) (string, error) {
return strings.Replace(address, "rel", "ws", 1), nil
}
func httpClientNbDialer() *http.Client {
// httpClientNbDialer builds the http client used by the websocket library.
// underlyingOut, when non-nil, is populated with the raw conn from the
// transport's DialContext so the caller can read its RemoteAddr.
func httpClientNbDialer(serverName string, underlyingOut *net.Conn) *http.Client {
customDialer := nbnet.NewDialer()
certPool, err := x509.SystemCertPool()
@@ -75,10 +79,15 @@ func httpClientNbDialer() *http.Client {
customTransport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return customDialer.DialContext(ctx, network, addr)
c, err := customDialer.DialContext(ctx, network, addr)
if err == nil && underlyingOut != nil {
*underlyingOut = c
}
return c, err
},
TLSClientConfig: &tls.Config{
RootCAs: certPool,
RootCAs: certPool,
ServerName: serverName,
},
}