mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[client] Update RaceDial to accept context for improved cancellation handling (#5849)
This commit is contained in:
@@ -333,7 +333,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
|||||||
dialers := c.getDialers()
|
dialers := c.getDialers()
|
||||||
|
|
||||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
|
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
|
||||||
conn, err := rd.Dial()
|
conn, err := rd.Dial(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,10 +40,10 @@ func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL stri
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RaceDial) Dial() (net.Conn, error) {
|
func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
|
||||||
connChan := make(chan dialResult, len(r.dialerFns))
|
connChan := make(chan dialResult, len(r.dialerFns))
|
||||||
winnerConn := make(chan net.Conn, 1)
|
winnerConn := make(chan net.Conn, 1)
|
||||||
abortCtx, abort := context.WithCancel(context.Background())
|
abortCtx, abort := context.WithCancel(ctx)
|
||||||
defer abort()
|
defer abort()
|
||||||
|
|
||||||
for _, dfn := range r.dialerFns {
|
for _, dfn := range r.dialerFns {
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ func TestRaceDialEmptyDialers(t *testing.T) {
|
|||||||
serverURL := "test.server.com"
|
serverURL := "test.server.com"
|
||||||
|
|
||||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL)
|
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL)
|
||||||
conn, err := rd.Dial()
|
conn, err := rd.Dial(context.Background())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Expected an error with empty dialers, got nil")
|
t.Errorf("Expected an error with empty dialers, got nil")
|
||||||
}
|
}
|
||||||
@@ -104,7 +104,7 @@ func TestRaceDialSingleSuccessfulDialer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer)
|
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer)
|
||||||
conn, err := rd.Dial()
|
conn, err := rd.Dial(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Expected no error, got %v", err)
|
t.Errorf("Expected no error, got %v", err)
|
||||||
}
|
}
|
||||||
@@ -137,7 +137,7 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
|
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
|
||||||
conn, err := rd.Dial()
|
conn, err := rd.Dial(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Expected no error, got %v", err)
|
t.Errorf("Expected no error, got %v", err)
|
||||||
}
|
}
|
||||||
@@ -160,7 +160,7 @@ func TestRaceDialTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
rd := NewRaceDial(logger, 3*time.Second, serverURL, mockDialer)
|
rd := NewRaceDial(logger, 3*time.Second, serverURL, mockDialer)
|
||||||
conn, err := rd.Dial()
|
conn, err := rd.Dial(context.Background())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Expected an error, got nil")
|
t.Errorf("Expected an error, got nil")
|
||||||
}
|
}
|
||||||
@@ -188,7 +188,7 @@ func TestRaceDialAllDialersFail(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
|
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
|
||||||
conn, err := rd.Dial()
|
conn, err := rd.Dial(context.Background())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Expected an error, got nil")
|
t.Errorf("Expected an error, got nil")
|
||||||
}
|
}
|
||||||
@@ -230,7 +230,7 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
|
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
|
||||||
conn, err := rd.Dial()
|
conn, err := rd.Dial(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Expected no error, got %v", err)
|
t.Errorf("Expected no error, got %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user