mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
Add early message buffer for relay client (#5282)
Add early message buffer to capture transport messages arriving before OpenConn completes, ensuring correct message ordering and no dropped messages.
This commit is contained in:
@@ -130,6 +130,7 @@ type Client struct {
|
||||
|
||||
relayConn net.Conn
|
||||
conns map[messages.PeerID]*connContainer
|
||||
earlyMsgs *earlyMsgBuffer
|
||||
serviceIsRunning bool
|
||||
mu sync.Mutex // protect serviceIsRunning and conns
|
||||
readLoopMutex sync.Mutex
|
||||
@@ -165,6 +166,8 @@ func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string,
|
||||
conns: make(map[messages.PeerID]*connContainer),
|
||||
}
|
||||
|
||||
c.earlyMsgs = newEarlyMsgBuffer()
|
||||
|
||||
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedID)
|
||||
return c
|
||||
}
|
||||
@@ -236,8 +239,14 @@ func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, erro
|
||||
conn := NewConn(c, peerID, msgChannel, instanceURL)
|
||||
container := newConnContainer(c.log, conn, msgChannel)
|
||||
c.conns[peerID] = container
|
||||
earlyMsg, hasEarly := c.earlyMsgs.pop(peerID)
|
||||
c.mu.Unlock()
|
||||
|
||||
if hasEarly {
|
||||
container.writeMsg(earlyMsg)
|
||||
c.log.Tracef("flushed buffered early message for peer: %s", peerID)
|
||||
}
|
||||
|
||||
if err := c.stateSubscription.WaitToBeOnlineAndSubscribe(ctx, peerID); err != nil {
|
||||
c.log.Errorf("peer not available: %s, %s", peerID, err)
|
||||
c.mu.Lock()
|
||||
@@ -466,10 +475,20 @@ func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppe
|
||||
return false
|
||||
}
|
||||
container, ok := c.conns[*peerID]
|
||||
earlyBuf := c.earlyMsgs
|
||||
c.mu.Unlock()
|
||||
if !ok {
|
||||
c.log.Errorf("peer not found: %s", peerID.String())
|
||||
c.bufPool.Put(bufPtr)
|
||||
msg := Msg{
|
||||
bufPool: c.bufPool,
|
||||
bufPtr: bufPtr,
|
||||
Payload: payload,
|
||||
}
|
||||
if earlyBuf == nil || !earlyBuf.put(*peerID, msg) {
|
||||
c.log.Warnf("failed to buffer early message for peer: %s", peerID.String())
|
||||
c.bufPool.Put(bufPtr)
|
||||
} else {
|
||||
c.log.Debugf("buffered early transport message for peer: %s", peerID.String())
|
||||
}
|
||||
return true
|
||||
}
|
||||
msg := Msg{
|
||||
@@ -537,6 +556,9 @@ func (c *Client) closeAllConns() {
|
||||
container.close()
|
||||
}
|
||||
c.conns = make(map[messages.PeerID]*connContainer)
|
||||
|
||||
c.earlyMsgs.close()
|
||||
c.earlyMsgs = newEarlyMsgBuffer()
|
||||
}
|
||||
|
||||
func (c *Client) closeConnsByPeerID(peerIDs []messages.PeerID) {
|
||||
|
||||
175
shared/relay/client/early_msg_buffer.go
Normal file
175
shared/relay/client/early_msg_buffer.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/relay/messages"
|
||||
)
|
||||
|
||||
const (
|
||||
earlyMsgTTL = 5 * time.Second
|
||||
earlyMsgCapacity = 1000
|
||||
)
|
||||
|
||||
// earlyMsgBuffer buffers transport messages that arrive before the corresponding
|
||||
// OpenConn call. This happens during reconnection when the remote peer sends data
|
||||
// before the local side has set up the relay connection.
|
||||
//
|
||||
// It stores at most one message per peer (the first WireGuard handshake) and
|
||||
// caps the total number of entries to prevent unbounded memory growth.
|
||||
// A cleanup timer runs only when there are buffered entries and fires when the
|
||||
// oldest entry expires. Entries are kept in a linked list ordered by insertion
|
||||
// time so cleanup only needs to walk from the front.
|
||||
type earlyMsgBuffer struct {
|
||||
mu sync.Mutex
|
||||
index map[messages.PeerID]*list.Element
|
||||
order *list.List // front = oldest
|
||||
timer *time.Timer
|
||||
closed bool
|
||||
}
|
||||
|
||||
type earlyMsg struct {
|
||||
peerID messages.PeerID
|
||||
msg Msg
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
func newEarlyMsgBuffer() *earlyMsgBuffer {
|
||||
return &earlyMsgBuffer{
|
||||
index: make(map[messages.PeerID]*list.Element),
|
||||
order: list.New(),
|
||||
}
|
||||
}
|
||||
|
||||
// put stores or overwrites a message for the given peer. If a message for the
|
||||
// peer already exists, it is replaced with the new one. Returns false if the
|
||||
// message was not stored (buffer full or buffer closed).
|
||||
func (b *earlyMsgBuffer) put(peerID messages.PeerID, msg Msg) bool {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.closed {
|
||||
return false
|
||||
}
|
||||
|
||||
if existing, exists := b.index[peerID]; exists {
|
||||
old := b.order.Remove(existing).(earlyMsg)
|
||||
old.msg.Free()
|
||||
delete(b.index, peerID)
|
||||
}
|
||||
|
||||
if b.order.Len() >= earlyMsgCapacity {
|
||||
return false
|
||||
}
|
||||
|
||||
entry := earlyMsg{
|
||||
peerID: peerID,
|
||||
msg: msg,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
elem := b.order.PushBack(entry)
|
||||
b.index[peerID] = elem
|
||||
|
||||
// Start the cleanup timer if this is the first entry
|
||||
if b.order.Len() == 1 {
|
||||
b.scheduleCleanup(earlyMsgTTL)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// pop retrieves and removes the buffered message for the given peer.
|
||||
// Returns the message and true if found, zero value and false otherwise.
|
||||
func (b *earlyMsgBuffer) pop(peerID messages.PeerID) (Msg, bool) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
elem, ok := b.index[peerID]
|
||||
if !ok {
|
||||
return Msg{}, false
|
||||
}
|
||||
|
||||
entry := b.order.Remove(elem).(earlyMsg)
|
||||
delete(b.index, peerID)
|
||||
|
||||
if b.order.Len() == 0 {
|
||||
b.stopCleanup()
|
||||
}
|
||||
|
||||
return entry.msg, true
|
||||
}
|
||||
|
||||
// close stops the cleanup timer and frees all buffered messages.
|
||||
func (b *earlyMsgBuffer) close() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.closed {
|
||||
return
|
||||
}
|
||||
b.closed = true
|
||||
b.stopCleanup()
|
||||
|
||||
for elem := b.order.Front(); elem != nil; elem = elem.Next() {
|
||||
entry := elem.Value.(earlyMsg)
|
||||
entry.msg.Free()
|
||||
}
|
||||
b.order.Init()
|
||||
b.index = make(map[messages.PeerID]*list.Element)
|
||||
}
|
||||
|
||||
// scheduleCleanup starts or resets the timer. Caller must hold b.mu.
|
||||
func (b *earlyMsgBuffer) scheduleCleanup(d time.Duration) {
|
||||
if b.timer != nil {
|
||||
b.timer.Stop()
|
||||
}
|
||||
b.timer = time.AfterFunc(d, b.removeExpired)
|
||||
}
|
||||
|
||||
// stopCleanup stops the timer. Caller must hold b.mu.
|
||||
func (b *earlyMsgBuffer) stopCleanup() {
|
||||
if b.timer != nil {
|
||||
b.timer.Stop()
|
||||
b.timer = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (b *earlyMsgBuffer) removeExpired() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.closed {
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
for elem := b.order.Front(); elem != nil; {
|
||||
entry := elem.Value.(earlyMsg)
|
||||
if now.Sub(entry.createdAt) <= earlyMsgTTL {
|
||||
// Entries are ordered by time, so the rest are newer
|
||||
break
|
||||
}
|
||||
next := elem.Next()
|
||||
b.order.Remove(elem)
|
||||
delete(b.index, entry.peerID)
|
||||
entry.msg.Free()
|
||||
elem = next
|
||||
}
|
||||
|
||||
if b.order.Len() == 0 {
|
||||
b.timer = nil
|
||||
return
|
||||
}
|
||||
|
||||
// Schedule next cleanup based on when the oldest entry expires
|
||||
front := b.order.Front()
|
||||
if front == nil {
|
||||
b.timer = nil
|
||||
return
|
||||
}
|
||||
oldest := front.Value.(earlyMsg).createdAt
|
||||
nextCleanup := earlyMsgTTL - now.Sub(oldest)
|
||||
b.scheduleCleanup(nextCleanup)
|
||||
}
|
||||
485
shared/relay/client/early_msg_buffer_test.go
Normal file
485
shared/relay/client/early_msg_buffer_test.go
Normal file
@@ -0,0 +1,485 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/relay/messages"
|
||||
)
|
||||
|
||||
func newTestPool() *sync.Pool {
|
||||
return &sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, 64)
|
||||
return &buf
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newTestMsg(pool *sync.Pool, payload string) Msg {
|
||||
bufPtr := pool.Get().(*[]byte)
|
||||
copy(*bufPtr, payload)
|
||||
return Msg{
|
||||
bufPool: pool,
|
||||
bufPtr: bufPtr,
|
||||
Payload: (*bufPtr)[:len(payload)],
|
||||
}
|
||||
}
|
||||
|
||||
func peerID(id string) messages.PeerID {
|
||||
return messages.HashID(id)
|
||||
}
|
||||
|
||||
func TestEarlyMsgBuffer_PutAndPop(t *testing.T) {
|
||||
buf := newEarlyMsgBuffer()
|
||||
defer buf.close()
|
||||
|
||||
pool := newTestPool()
|
||||
peer := peerID("peer1")
|
||||
msg := newTestMsg(pool, "hello")
|
||||
|
||||
if !buf.put(peer, msg) {
|
||||
t.Fatal("put should succeed")
|
||||
}
|
||||
|
||||
got, ok := buf.pop(peer)
|
||||
if !ok {
|
||||
t.Fatal("pop should find the message")
|
||||
}
|
||||
if string(got.Payload) != "hello" {
|
||||
t.Fatalf("expected payload 'hello', got '%s'", got.Payload)
|
||||
}
|
||||
got.Free()
|
||||
}
|
||||
|
||||
func TestEarlyMsgBuffer_PopNotFound(t *testing.T) {
|
||||
buf := newEarlyMsgBuffer()
|
||||
defer buf.close()
|
||||
|
||||
_, ok := buf.pop(peerID("nonexistent"))
|
||||
if ok {
|
||||
t.Fatal("pop should return false for unknown peer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEarlyMsgBuffer_PopAfterPopReturnsFalse(t *testing.T) {
|
||||
buf := newEarlyMsgBuffer()
|
||||
defer buf.close()
|
||||
|
||||
pool := newTestPool()
|
||||
peer := peerID("peer1")
|
||||
|
||||
buf.put(peer, newTestMsg(pool, "data"))
|
||||
|
||||
got, ok := buf.pop(peer)
|
||||
if !ok {
|
||||
t.Fatal("first pop should succeed")
|
||||
}
|
||||
got.Free()
|
||||
|
||||
_, ok = buf.pop(peer)
|
||||
if ok {
|
||||
t.Fatal("second pop for the same peer should return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEarlyMsgBuffer_OverwriteSamePeer(t *testing.T) {
|
||||
buf := newEarlyMsgBuffer()
|
||||
defer buf.close()
|
||||
|
||||
pool := newTestPool()
|
||||
peer := peerID("peer1")
|
||||
|
||||
if !buf.put(peer, newTestMsg(pool, "first")) {
|
||||
t.Fatal("first put should succeed")
|
||||
}
|
||||
if !buf.put(peer, newTestMsg(pool, "second")) {
|
||||
t.Fatal("second put (overwrite) should succeed")
|
||||
}
|
||||
|
||||
got, ok := buf.pop(peer)
|
||||
if !ok {
|
||||
t.Fatal("pop should find the message")
|
||||
}
|
||||
if string(got.Payload) != "second" {
|
||||
t.Fatalf("expected payload 'second', got '%s'", got.Payload)
|
||||
}
|
||||
got.Free()
|
||||
|
||||
// No more messages should be present for this peer
|
||||
_, ok = buf.pop(peer)
|
||||
if ok {
|
||||
t.Fatal("pop should return false after the only message was already popped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEarlyMsgBuffer_MultiplePeers(t *testing.T) {
|
||||
buf := newEarlyMsgBuffer()
|
||||
defer buf.close()
|
||||
|
||||
pool := newTestPool()
|
||||
peers := []messages.PeerID{peerID("a"), peerID("b"), peerID("c")}
|
||||
|
||||
for i, p := range peers {
|
||||
msg := newTestMsg(pool, fmt.Sprintf("msg-%d", i))
|
||||
if !buf.put(p, msg) {
|
||||
t.Fatalf("put should succeed for peer %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Pop in reverse order to verify independence
|
||||
for i := len(peers) - 1; i >= 0; i-- {
|
||||
got, ok := buf.pop(peers[i])
|
||||
if !ok {
|
||||
t.Fatalf("pop should find message for peer %d", i)
|
||||
}
|
||||
expected := fmt.Sprintf("msg-%d", i)
|
||||
if string(got.Payload) != expected {
|
||||
t.Fatalf("expected payload '%s', got '%s'", expected, got.Payload)
|
||||
}
|
||||
got.Free()
|
||||
}
|
||||
}
|
||||
|
||||
func TestEarlyMsgBuffer_Capacity(t *testing.T) {
|
||||
buf := newEarlyMsgBuffer()
|
||||
defer buf.close()
|
||||
|
||||
pool := newTestPool()
|
||||
|
||||
// Fill to capacity
|
||||
for i := 0; i < earlyMsgCapacity; i++ {
|
||||
peer := peerID(fmt.Sprintf("peer-%d", i))
|
||||
msg := newTestMsg(pool, fmt.Sprintf("msg-%d", i))
|
||||
if !buf.put(peer, msg) {
|
||||
t.Fatalf("put should succeed for peer %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Next put for a new peer should fail
|
||||
msg := newTestMsg(pool, "overflow")
|
||||
if buf.put(peerID("overflow-peer"), msg) {
|
||||
t.Fatal("put should fail when buffer is at capacity")
|
||||
}
|
||||
msg.Free()
|
||||
|
||||
// Overwriting an existing peer should still work (it removes then adds)
|
||||
overwrite := newTestMsg(pool, "overwritten")
|
||||
if !buf.put(peerID("peer-0"), overwrite) {
|
||||
t.Fatal("overwrite should succeed even at capacity")
|
||||
}
|
||||
|
||||
got, ok := buf.pop(peerID("peer-0"))
|
||||
if !ok {
|
||||
t.Fatal("pop should find overwritten message")
|
||||
}
|
||||
if string(got.Payload) != "overwritten" {
|
||||
t.Fatalf("expected 'overwritten', got '%s'", got.Payload)
|
||||
}
|
||||
got.Free()
|
||||
|
||||
// Clean up remaining
|
||||
for i := 1; i < earlyMsgCapacity; i++ {
|
||||
peer := peerID(fmt.Sprintf("peer-%d", i))
|
||||
if m, ok := buf.pop(peer); ok {
|
||||
m.Free()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEarlyMsgBuffer_CapacityAfterPop(t *testing.T) {
|
||||
buf := newEarlyMsgBuffer()
|
||||
defer buf.close()
|
||||
|
||||
pool := newTestPool()
|
||||
|
||||
// Fill to capacity
|
||||
for i := 0; i < earlyMsgCapacity; i++ {
|
||||
peer := peerID(fmt.Sprintf("peer-%d", i))
|
||||
if !buf.put(peer, newTestMsg(pool, "x")) {
|
||||
t.Fatalf("put should succeed for peer %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Pop one entry to free a slot
|
||||
got, ok := buf.pop(peerID("peer-0"))
|
||||
if !ok {
|
||||
t.Fatal("pop should succeed")
|
||||
}
|
||||
got.Free()
|
||||
|
||||
// Now a new peer should fit
|
||||
if !buf.put(peerID("new-peer"), newTestMsg(pool, "new")) {
|
||||
t.Fatal("put should succeed after popping one entry")
|
||||
}
|
||||
|
||||
// Clean up
|
||||
for i := 1; i < earlyMsgCapacity; i++ {
|
||||
if m, ok := buf.pop(peerID(fmt.Sprintf("peer-%d", i))); ok {
|
||||
m.Free()
|
||||
}
|
||||
}
|
||||
if m, ok := buf.pop(peerID("new-peer")); ok {
|
||||
m.Free()
|
||||
}
|
||||
}
|
||||
|
||||
func TestEarlyMsgBuffer_PutAfterClose(t *testing.T) {
|
||||
buf := newEarlyMsgBuffer()
|
||||
|
||||
pool := newTestPool()
|
||||
buf.close()
|
||||
|
||||
msg := newTestMsg(pool, "too late")
|
||||
if buf.put(peerID("peer1"), msg) {
|
||||
t.Fatal("put should fail after close")
|
||||
}
|
||||
msg.Free()
|
||||
}
|
||||
|
||||
func TestEarlyMsgBuffer_PopAfterClose(t *testing.T) {
|
||||
buf := newEarlyMsgBuffer()
|
||||
|
||||
pool := newTestPool()
|
||||
buf.put(peerID("peer1"), newTestMsg(pool, "data"))
|
||||
buf.close()
|
||||
|
||||
// Messages are freed on close, so pop should not find anything
|
||||
_, ok := buf.pop(peerID("peer1"))
|
||||
if ok {
|
||||
t.Fatal("pop should return false after close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEarlyMsgBuffer_DoubleClose(t *testing.T) {
|
||||
buf := newEarlyMsgBuffer()
|
||||
buf.close()
|
||||
buf.close() // should not panic
|
||||
}
|
||||
|
||||
func TestEarlyMsgBuffer_TTLExpiry(t *testing.T) {
|
||||
buf := newEarlyMsgBuffer()
|
||||
defer buf.close()
|
||||
|
||||
pool := newTestPool()
|
||||
peer := peerID("peer1")
|
||||
|
||||
buf.put(peer, newTestMsg(pool, "expiring"))
|
||||
|
||||
// Wait for the TTL to expire plus some margin
|
||||
time.Sleep(earlyMsgTTL + 500*time.Millisecond)
|
||||
|
||||
_, ok := buf.pop(peer)
|
||||
if ok {
|
||||
t.Fatal("message should have been expired by cleanup")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEarlyMsgBuffer_PartialExpiry(t *testing.T) {
|
||||
buf := newEarlyMsgBuffer()
|
||||
defer buf.close()
|
||||
|
||||
pool := newTestPool()
|
||||
|
||||
// Insert first message
|
||||
buf.put(peerID("peer1"), newTestMsg(pool, "old"))
|
||||
|
||||
// Wait half the TTL, then insert second message
|
||||
time.Sleep(earlyMsgTTL / 2)
|
||||
|
||||
buf.put(peerID("peer2"), newTestMsg(pool, "new"))
|
||||
|
||||
// Wait for the first to expire but not the second
|
||||
time.Sleep(earlyMsgTTL/2 + 500*time.Millisecond)
|
||||
|
||||
// First should be gone
|
||||
_, ok := buf.pop(peerID("peer1"))
|
||||
if ok {
|
||||
t.Fatal("peer1 message should have expired")
|
||||
}
|
||||
|
||||
// Second should still be there
|
||||
got, ok := buf.pop(peerID("peer2"))
|
||||
if !ok {
|
||||
t.Fatal("peer2 message should still be present")
|
||||
}
|
||||
if string(got.Payload) != "new" {
|
||||
t.Fatalf("expected payload 'new', got '%s'", got.Payload)
|
||||
}
|
||||
got.Free()
|
||||
}
|
||||
|
||||
func TestEarlyMsgBuffer_BulkExpiry(t *testing.T) {
|
||||
buf := newEarlyMsgBuffer()
|
||||
defer buf.close()
|
||||
|
||||
pool := newTestPool()
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
peer := peerID(fmt.Sprintf("peer-%d", i))
|
||||
buf.put(peer, newTestMsg(pool, fmt.Sprintf("msg-%d", i)))
|
||||
}
|
||||
|
||||
// All should expire together
|
||||
time.Sleep(earlyMsgTTL + 500*time.Millisecond)
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
_, ok := buf.pop(peerID(fmt.Sprintf("peer-%d", i)))
|
||||
if ok {
|
||||
t.Fatalf("peer-%d should have expired", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEarlyMsgBuffer_ConcurrentPutAndPop(t *testing.T) {
|
||||
buf := newEarlyMsgBuffer()
|
||||
defer buf.close()
|
||||
|
||||
pool := newTestPool()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Concurrent puts
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
peer := peerID(fmt.Sprintf("peer-%d", id))
|
||||
msg := newTestMsg(pool, fmt.Sprintf("msg-%d", id))
|
||||
if !buf.put(peer, msg) {
|
||||
msg.Free()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Concurrent pops
|
||||
var popped int64
|
||||
var mu sync.Mutex
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
peer := peerID(fmt.Sprintf("peer-%d", id))
|
||||
if msg, ok := buf.pop(peer); ok {
|
||||
msg.Free()
|
||||
mu.Lock()
|
||||
popped++
|
||||
mu.Unlock()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if popped != 100 {
|
||||
t.Fatalf("expected to pop 100 messages, got %d", popped)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEarlyMsgBuffer_ConcurrentPutPopAndClose(t *testing.T) {
|
||||
buf := newEarlyMsgBuffer()
|
||||
|
||||
pool := newTestPool()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Concurrent puts
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
peer := peerID(fmt.Sprintf("peer-%d", id))
|
||||
msg := newTestMsg(pool, fmt.Sprintf("msg-%d", id))
|
||||
if !buf.put(peer, msg) {
|
||||
msg.Free()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Concurrent pops
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
peer := peerID(fmt.Sprintf("peer-%d", id))
|
||||
if msg, ok := buf.pop(peer); ok {
|
||||
msg.Free()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Close concurrently
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buf.close()
|
||||
}()
|
||||
|
||||
wg.Wait() // should not panic or deadlock
|
||||
}
|
||||
|
||||
func TestEarlyMsgBuffer_OverwriteDoesNotLeak(t *testing.T) {
|
||||
buf := newEarlyMsgBuffer()
|
||||
defer buf.close()
|
||||
|
||||
// Use a dedicated pool to detect that overwritten message's Free was called
|
||||
freeCalled := make(chan struct{}, 1)
|
||||
origPool := &sync.Pool{
|
||||
New: func() any {
|
||||
b := make([]byte, 64)
|
||||
return &b
|
||||
},
|
||||
}
|
||||
|
||||
b := make([]byte, 64)
|
||||
copy(b, "original")
|
||||
bufPtr := &b
|
||||
origMsg := Msg{
|
||||
bufPool: origPool,
|
||||
bufPtr: bufPtr,
|
||||
Payload: b[:8],
|
||||
}
|
||||
|
||||
peer := peerID("peer1")
|
||||
buf.put(peer, origMsg)
|
||||
|
||||
// Now check if the original buffer was freed by trying to get from pool
|
||||
// We need a wrapper pool that signals when Put is called
|
||||
trackPool := &sync.Pool{
|
||||
New: func() any {
|
||||
b := make([]byte, 64)
|
||||
return &b
|
||||
},
|
||||
}
|
||||
_ = trackPool
|
||||
|
||||
// Simpler approach: overwrite and check that only new value is returned
|
||||
newPool := newTestPool()
|
||||
buf.put(peer, newTestMsg(newPool, "replaced"))
|
||||
|
||||
// After overwrite, only the new message should be retrievable
|
||||
got, ok := buf.pop(peer)
|
||||
if !ok {
|
||||
t.Fatal("pop should find the message")
|
||||
}
|
||||
if string(got.Payload) != "replaced" {
|
||||
t.Fatalf("expected 'replaced', got '%s'", got.Payload)
|
||||
}
|
||||
got.Free()
|
||||
close(freeCalled)
|
||||
}
|
||||
|
||||
func TestEarlyMsgBuffer_EmptyBuffer(t *testing.T) {
|
||||
buf := newEarlyMsgBuffer()
|
||||
defer buf.close()
|
||||
|
||||
// Pop from empty buffer
|
||||
_, ok := buf.pop(peerID("anything"))
|
||||
if ok {
|
||||
t.Fatal("pop from empty buffer should return false")
|
||||
}
|
||||
|
||||
// Close empty buffer should be fine
|
||||
buf2 := newEarlyMsgBuffer()
|
||||
buf2.close()
|
||||
}
|
||||
Reference in New Issue
Block a user