From ebd78e01220bc63f20329058cba29a961f2c3c8c Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 10 Apr 2026 20:51:04 +0200 Subject: [PATCH] [client] Update `RaceDial` to accept context for improved cancellation handling (#5849) --- shared/relay/client/client.go | 2 +- shared/relay/client/dialer/race_dialer.go | 4 ++-- shared/relay/client/dialer/race_dialer_test.go | 12 ++++++------ 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/shared/relay/client/client.go b/shared/relay/client/client.go index ed1b63435..b10b05617 100644 --- a/shared/relay/client/client.go +++ b/shared/relay/client/client.go @@ -333,7 +333,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) { dialers := c.getDialers() rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...) - conn, err := rd.Dial() + conn, err := rd.Dial(ctx) if err != nil { return nil, err } diff --git a/shared/relay/client/dialer/race_dialer.go b/shared/relay/client/dialer/race_dialer.go index 0550fc63e..34359d17e 100644 --- a/shared/relay/client/dialer/race_dialer.go +++ b/shared/relay/client/dialer/race_dialer.go @@ -40,10 +40,10 @@ func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL stri } } -func (r *RaceDial) Dial() (net.Conn, error) { +func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) { connChan := make(chan dialResult, len(r.dialerFns)) winnerConn := make(chan net.Conn, 1) - abortCtx, abort := context.WithCancel(context.Background()) + abortCtx, abort := context.WithCancel(ctx) defer abort() for _, dfn := range r.dialerFns { diff --git a/shared/relay/client/dialer/race_dialer_test.go b/shared/relay/client/dialer/race_dialer_test.go index d216ec5e7..aa18df578 100644 --- a/shared/relay/client/dialer/race_dialer_test.go +++ b/shared/relay/client/dialer/race_dialer_test.go @@ -78,7 +78,7 @@ func TestRaceDialEmptyDialers(t *testing.T) { serverURL := "test.server.com" rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err == nil { t.Errorf("Expected an error with empty dialers, got nil") } @@ -104,7 +104,7 @@ func TestRaceDialSingleSuccessfulDialer(t *testing.T) { } rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -137,7 +137,7 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) { } rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err != nil { t.Errorf("Expected no error, got %v", err) } @@ -160,7 +160,7 @@ func TestRaceDialTimeout(t *testing.T) { } rd := NewRaceDial(logger, 3*time.Second, serverURL, mockDialer) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err == nil { t.Errorf("Expected an error, got nil") } @@ -188,7 +188,7 @@ func TestRaceDialAllDialersFail(t *testing.T) { } rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err == nil { t.Errorf("Expected an error, got nil") } @@ -230,7 +230,7 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) { } rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2) - conn, err := rd.Dial() + conn, err := rd.Dial(context.Background()) if err != nil { t.Errorf("Expected no error, got %v", err) }