diff --git a/relay/server/relay.go b/relay/server/relay.go index d86684937..771eea4fd 100644 --- a/relay/server/relay.go +++ b/relay/server/relay.go @@ -132,6 +132,7 @@ func (r *Relay) Accept(conn net.Conn) { storeTime := time.Now() if isReconnection := r.store.AddPeer(peer); isReconnection { r.metrics.RecordPeerReconnection() + r.notifier.PeerWentOffline(peer.ID()) } r.notifier.PeerCameOnline(peer.ID()) diff --git a/shared/relay/client/client_test.go b/shared/relay/client/client_test.go index 8fe5f04f4..0f9997340 100644 --- a/shared/relay/client/client_test.go +++ b/shared/relay/client/client_test.go @@ -11,11 +11,11 @@ import ( "go.opentelemetry.io/otel" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/shared/relay/auth/allow" "github.com/netbirdio/netbird/shared/relay/auth/hmac" + "github.com/netbirdio/netbird/shared/relay/messages" "github.com/netbirdio/netbird/util" - - "github.com/netbirdio/netbird/relay/server" ) var ( @@ -427,6 +427,105 @@ func TestBindReconnect(t *testing.T) { } } +func TestBindReconnectRace(t *testing.T) { + ctx := context.Background() + + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv, err := server.NewServer(serverCfg) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + err := srv.Listen(srvCfg) + if err != nil { + errChan <- err + } + }() + + defer func() { + log.Infof("closing server") + err := srv.Shutdown(ctx) + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + // wait for servers to start + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + clientBob := NewClient(serverURL, hmacTokenStore, "bob", iface.DefaultMTU) + err = clientBob.Connect(ctx) + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + defer clientBob.Close() + + // Run the reconnection scenario multiple times to expose the race + failures := 0 + iterations := 1000 + + for i := 0; i < iterations; i++ { + log.Infof("Iteration %d/%d", i+1, iterations) + + // Alice connects + clientAlice := NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) + err = clientAlice.Connect(ctx) + if err != nil { + t.Fatalf("iteration %d: failed to connect alice: %s", i, err) + } + + // Bob opens connection to Alice + _, err = clientBob.OpenConn(ctx, "alice") + if err != nil { + t.Fatalf("iteration %d: failed to open conn from bob: %s", i, err) + } + + // Close Alice immediately + err = clientAlice.Close() + if err != nil { + t.Errorf("iteration %d: failed to close alice: %s", i, err) + } + + // Reconnect Alice immediately (this is where the race occurs) + clientAlice = NewClient(serverURL, hmacTokenStore, "alice", iface.DefaultMTU) + err = clientAlice.Connect(ctx) + if err != nil { + t.Fatalf("iteration %d: failed to reconnect alice: %s", i, err) + } + + // Bob tries to open a new connection to the reconnected Alice + // Without the fix, this will sometimes fail with "connection already exists" + // because Bob still has the old connection in its map + _, err = clientBob.OpenConn(ctx, "alice") + if err != nil { + log.Errorf("iteration %d: RACE DETECTED - failed to open new conn after reconnect: %s", i, err) + failures++ + } + + // Clean up + clientAlice.Close() + + // Close Bob's connection to Alice to prepare for next iteration + clientBob.mu.Lock() + aliceID := messages.HashID("alice") + if container, ok := clientBob.conns[aliceID]; ok { + container.close() + delete(clientBob.conns, aliceID) + } + clientBob.mu.Unlock() + } + + if failures > 0 { + t.Errorf("Race condition detected in %d out of %d iterations (%.1f%%)", + failures, iterations, float64(failures)/float64(iterations)*100) + } else { + log.Infof("No race detected in %d iterations (fix is working or race didn't trigger)", iterations) + } +} + func TestCloseConn(t *testing.T) { ctx := context.Background()