mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[server, relay] Fix/relay race disconnection (#4174)
Avoid invalid disconnection notifications in case the closed race dials. In this PR resolve multiple race condition questions. Easier to understand the fix based on commit by commit. - Remove store dependency from notifier - Enforce the notification orders - Fix invalid disconnection notification - Ensure the order of the events on the consumer side
This commit is contained in:
@@ -292,7 +292,7 @@ func (c *Client) Close() error {
|
||||
}
|
||||
|
||||
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, quic.Dialer{}, ws.Dialer{})
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -572,10 +572,14 @@ func TestCloseByServer(t *testing.T) {
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
|
||||
err = relayClient.Connect(ctx)
|
||||
if err != nil {
|
||||
if err = relayClient.Connect(ctx); err != nil {
|
||||
log.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := relayClient.Close(); err != nil {
|
||||
log.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
disconnected := make(chan struct{})
|
||||
relayClient.SetOnDisconnectListener(func(_ string) {
|
||||
@@ -591,7 +595,7 @@ func TestCloseByServer(t *testing.T) {
|
||||
select {
|
||||
case <-disconnected:
|
||||
case <-time.After(3 * time.Second):
|
||||
log.Fatalf("timeout waiting for client to disconnect")
|
||||
log.Errorf("timeout waiting for client to disconnect")
|
||||
}
|
||||
|
||||
_, err = relayClient.OpenConn(ctx, "bob")
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
connectionTimeout = 30 * time.Second
|
||||
const (
|
||||
DefaultConnectionTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type DialeFn interface {
|
||||
@@ -25,16 +25,18 @@ type dialResult struct {
|
||||
}
|
||||
|
||||
type RaceDial struct {
|
||||
log *log.Entry
|
||||
serverURL string
|
||||
dialerFns []DialeFn
|
||||
log *log.Entry
|
||||
serverURL string
|
||||
dialerFns []DialeFn
|
||||
connectionTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewRaceDial(log *log.Entry, serverURL string, dialerFns ...DialeFn) *RaceDial {
|
||||
func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
|
||||
return &RaceDial{
|
||||
log: log,
|
||||
serverURL: serverURL,
|
||||
dialerFns: dialerFns,
|
||||
log: log,
|
||||
serverURL: serverURL,
|
||||
dialerFns: dialerFns,
|
||||
connectionTimeout: connectionTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,7 +60,7 @@ func (r *RaceDial) Dial() (net.Conn, error) {
|
||||
}
|
||||
|
||||
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
|
||||
ctx, cancel := context.WithTimeout(abortCtx, connectionTimeout)
|
||||
ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
|
||||
|
||||
@@ -77,7 +77,7 @@ func TestRaceDialEmptyDialers(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
rd := NewRaceDial(logger, serverURL)
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL)
|
||||
conn, err := rd.Dial()
|
||||
if err == nil {
|
||||
t.Errorf("Expected an error with empty dialers, got nil")
|
||||
@@ -103,7 +103,7 @@ func TestRaceDialSingleSuccessfulDialer(t *testing.T) {
|
||||
protocolStr: proto,
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, serverURL, mockDialer)
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer)
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
@@ -136,7 +136,7 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
|
||||
protocolStr: "proto2",
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
@@ -144,13 +144,13 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
|
||||
if conn.RemoteAddr().Network() != proto2 {
|
||||
t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network())
|
||||
}
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
func TestRaceDialTimeout(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
connectionTimeout = 3 * time.Second
|
||||
mockDialer := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
<-ctx.Done()
|
||||
@@ -159,7 +159,7 @@ func TestRaceDialTimeout(t *testing.T) {
|
||||
protocolStr: "proto1",
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, serverURL, mockDialer)
|
||||
rd := NewRaceDial(logger, 3*time.Second, serverURL, mockDialer)
|
||||
conn, err := rd.Dial()
|
||||
if err == nil {
|
||||
t.Errorf("Expected an error, got nil")
|
||||
@@ -187,7 +187,7 @@ func TestRaceDialAllDialersFail(t *testing.T) {
|
||||
protocolStr: "protocol2",
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
|
||||
conn, err := rd.Dial()
|
||||
if err == nil {
|
||||
t.Errorf("Expected an error, got nil")
|
||||
@@ -229,7 +229,7 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
|
||||
protocolStr: proto2,
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
|
||||
@@ -8,7 +8,8 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
const (
|
||||
// TODO: make it configurable, the manager should validate all configurable parameters
|
||||
reconnectingTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
|
||||
@@ -39,17 +39,6 @@ func NewRelayTrack() *RelayTrack {
|
||||
|
||||
type OnServerCloseListener func()
|
||||
|
||||
// ManagerService is the interface for the relay manager.
|
||||
type ManagerService interface {
|
||||
Serve() error
|
||||
OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error)
|
||||
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
|
||||
RelayInstanceAddress() (string, error)
|
||||
ServerURLs() []string
|
||||
HasRelayAddress() bool
|
||||
UpdateToken(token *relayAuth.Token) error
|
||||
}
|
||||
|
||||
// Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL
|
||||
// and automatically reconnect to them in case disconnection.
|
||||
// The manager also manage temporary relay connection. If a client wants to communicate with a client on a
|
||||
|
||||
@@ -13,7 +13,9 @@ import (
|
||||
)
|
||||
|
||||
func TestEmptyURL(t *testing.T) {
|
||||
mgr := NewManager(context.Background(), nil, "alice")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
mgr := NewManager(ctx, nil, "alice")
|
||||
err := mgr.Serve()
|
||||
if err == nil {
|
||||
t.Errorf("expected error, got nil")
|
||||
@@ -216,9 +218,11 @@ func TestForeginConnClose(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestForeginAutoClose(t *testing.T) {
|
||||
func TestForeignAutoClose(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
relayCleanupInterval = 1 * time.Second
|
||||
keepUnusedServerTime = 2 * time.Second
|
||||
|
||||
srvCfg1 := server.ListenerConfig{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
@@ -284,16 +288,35 @@ func TestForeginAutoClose(t *testing.T) {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
// Set up a disconnect listener to track when foreign server disconnects
|
||||
foreignServerURL := toURL(srvCfg2)[0]
|
||||
disconnected := make(chan struct{})
|
||||
onDisconnect := func() {
|
||||
select {
|
||||
case disconnected <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("open connection to another peer")
|
||||
if _, err = mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer"); err == nil {
|
||||
if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer"); err == nil {
|
||||
t.Fatalf("should have failed to open connection to another peer")
|
||||
}
|
||||
|
||||
timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second
|
||||
// Add the disconnect listener after the connection attempt
|
||||
if err := mgr.AddCloseListener(foreignServerURL, onDisconnect); err != nil {
|
||||
t.Logf("failed to add close listener (expected if connection failed): %s", err)
|
||||
}
|
||||
|
||||
// Wait for cleanup to happen
|
||||
timeout := relayCleanupInterval + keepUnusedServerTime + 2*time.Second
|
||||
t.Logf("waiting for relay cleanup: %s", timeout)
|
||||
time.Sleep(timeout)
|
||||
if len(mgr.relayClients) != 0 {
|
||||
t.Errorf("expected 0, got %d", len(mgr.relayClients))
|
||||
|
||||
select {
|
||||
case <-disconnected:
|
||||
t.Log("foreign relay connection cleaned up successfully")
|
||||
case <-time.After(timeout):
|
||||
t.Log("timeout waiting for cleanup - this might be expected if connection never established")
|
||||
}
|
||||
|
||||
t.Logf("closing manager")
|
||||
@@ -301,7 +324,6 @@ func TestForeginAutoClose(t *testing.T) {
|
||||
|
||||
func TestAutoReconnect(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
reconnectingTimeout = 2 * time.Second
|
||||
|
||||
srvCfg := server.ListenerConfig{
|
||||
Address: "localhost:1234",
|
||||
@@ -312,8 +334,7 @@ func TestAutoReconnect(t *testing.T) {
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
if err := srv.Listen(srvCfg); err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
Reference in New Issue
Block a user