mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
Add early message buffer to capture transport messages arriving before OpenConn completes, ensuring correct message ordering and no dropped messages.
486 lines
10 KiB
Go
486 lines
10 KiB
Go
package client
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/netbirdio/netbird/shared/relay/messages"
|
|
)
|
|
|
|
func newTestPool() *sync.Pool {
|
|
return &sync.Pool{
|
|
New: func() any {
|
|
buf := make([]byte, 64)
|
|
return &buf
|
|
},
|
|
}
|
|
}
|
|
|
|
func newTestMsg(pool *sync.Pool, payload string) Msg {
|
|
bufPtr := pool.Get().(*[]byte)
|
|
copy(*bufPtr, payload)
|
|
return Msg{
|
|
bufPool: pool,
|
|
bufPtr: bufPtr,
|
|
Payload: (*bufPtr)[:len(payload)],
|
|
}
|
|
}
|
|
|
|
func peerID(id string) messages.PeerID {
|
|
return messages.HashID(id)
|
|
}
|
|
|
|
func TestEarlyMsgBuffer_PutAndPop(t *testing.T) {
|
|
buf := newEarlyMsgBuffer()
|
|
defer buf.close()
|
|
|
|
pool := newTestPool()
|
|
peer := peerID("peer1")
|
|
msg := newTestMsg(pool, "hello")
|
|
|
|
if !buf.put(peer, msg) {
|
|
t.Fatal("put should succeed")
|
|
}
|
|
|
|
got, ok := buf.pop(peer)
|
|
if !ok {
|
|
t.Fatal("pop should find the message")
|
|
}
|
|
if string(got.Payload) != "hello" {
|
|
t.Fatalf("expected payload 'hello', got '%s'", got.Payload)
|
|
}
|
|
got.Free()
|
|
}
|
|
|
|
func TestEarlyMsgBuffer_PopNotFound(t *testing.T) {
|
|
buf := newEarlyMsgBuffer()
|
|
defer buf.close()
|
|
|
|
_, ok := buf.pop(peerID("nonexistent"))
|
|
if ok {
|
|
t.Fatal("pop should return false for unknown peer")
|
|
}
|
|
}
|
|
|
|
func TestEarlyMsgBuffer_PopAfterPopReturnsFalse(t *testing.T) {
|
|
buf := newEarlyMsgBuffer()
|
|
defer buf.close()
|
|
|
|
pool := newTestPool()
|
|
peer := peerID("peer1")
|
|
|
|
buf.put(peer, newTestMsg(pool, "data"))
|
|
|
|
got, ok := buf.pop(peer)
|
|
if !ok {
|
|
t.Fatal("first pop should succeed")
|
|
}
|
|
got.Free()
|
|
|
|
_, ok = buf.pop(peer)
|
|
if ok {
|
|
t.Fatal("second pop for the same peer should return false")
|
|
}
|
|
}
|
|
|
|
func TestEarlyMsgBuffer_OverwriteSamePeer(t *testing.T) {
|
|
buf := newEarlyMsgBuffer()
|
|
defer buf.close()
|
|
|
|
pool := newTestPool()
|
|
peer := peerID("peer1")
|
|
|
|
if !buf.put(peer, newTestMsg(pool, "first")) {
|
|
t.Fatal("first put should succeed")
|
|
}
|
|
if !buf.put(peer, newTestMsg(pool, "second")) {
|
|
t.Fatal("second put (overwrite) should succeed")
|
|
}
|
|
|
|
got, ok := buf.pop(peer)
|
|
if !ok {
|
|
t.Fatal("pop should find the message")
|
|
}
|
|
if string(got.Payload) != "second" {
|
|
t.Fatalf("expected payload 'second', got '%s'", got.Payload)
|
|
}
|
|
got.Free()
|
|
|
|
// No more messages should be present for this peer
|
|
_, ok = buf.pop(peer)
|
|
if ok {
|
|
t.Fatal("pop should return false after the only message was already popped")
|
|
}
|
|
}
|
|
|
|
func TestEarlyMsgBuffer_MultiplePeers(t *testing.T) {
|
|
buf := newEarlyMsgBuffer()
|
|
defer buf.close()
|
|
|
|
pool := newTestPool()
|
|
peers := []messages.PeerID{peerID("a"), peerID("b"), peerID("c")}
|
|
|
|
for i, p := range peers {
|
|
msg := newTestMsg(pool, fmt.Sprintf("msg-%d", i))
|
|
if !buf.put(p, msg) {
|
|
t.Fatalf("put should succeed for peer %d", i)
|
|
}
|
|
}
|
|
|
|
// Pop in reverse order to verify independence
|
|
for i := len(peers) - 1; i >= 0; i-- {
|
|
got, ok := buf.pop(peers[i])
|
|
if !ok {
|
|
t.Fatalf("pop should find message for peer %d", i)
|
|
}
|
|
expected := fmt.Sprintf("msg-%d", i)
|
|
if string(got.Payload) != expected {
|
|
t.Fatalf("expected payload '%s', got '%s'", expected, got.Payload)
|
|
}
|
|
got.Free()
|
|
}
|
|
}
|
|
|
|
func TestEarlyMsgBuffer_Capacity(t *testing.T) {
|
|
buf := newEarlyMsgBuffer()
|
|
defer buf.close()
|
|
|
|
pool := newTestPool()
|
|
|
|
// Fill to capacity
|
|
for i := 0; i < earlyMsgCapacity; i++ {
|
|
peer := peerID(fmt.Sprintf("peer-%d", i))
|
|
msg := newTestMsg(pool, fmt.Sprintf("msg-%d", i))
|
|
if !buf.put(peer, msg) {
|
|
t.Fatalf("put should succeed for peer %d", i)
|
|
}
|
|
}
|
|
|
|
// Next put for a new peer should fail
|
|
msg := newTestMsg(pool, "overflow")
|
|
if buf.put(peerID("overflow-peer"), msg) {
|
|
t.Fatal("put should fail when buffer is at capacity")
|
|
}
|
|
msg.Free()
|
|
|
|
// Overwriting an existing peer should still work (it removes then adds)
|
|
overwrite := newTestMsg(pool, "overwritten")
|
|
if !buf.put(peerID("peer-0"), overwrite) {
|
|
t.Fatal("overwrite should succeed even at capacity")
|
|
}
|
|
|
|
got, ok := buf.pop(peerID("peer-0"))
|
|
if !ok {
|
|
t.Fatal("pop should find overwritten message")
|
|
}
|
|
if string(got.Payload) != "overwritten" {
|
|
t.Fatalf("expected 'overwritten', got '%s'", got.Payload)
|
|
}
|
|
got.Free()
|
|
|
|
// Clean up remaining
|
|
for i := 1; i < earlyMsgCapacity; i++ {
|
|
peer := peerID(fmt.Sprintf("peer-%d", i))
|
|
if m, ok := buf.pop(peer); ok {
|
|
m.Free()
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestEarlyMsgBuffer_CapacityAfterPop(t *testing.T) {
|
|
buf := newEarlyMsgBuffer()
|
|
defer buf.close()
|
|
|
|
pool := newTestPool()
|
|
|
|
// Fill to capacity
|
|
for i := 0; i < earlyMsgCapacity; i++ {
|
|
peer := peerID(fmt.Sprintf("peer-%d", i))
|
|
if !buf.put(peer, newTestMsg(pool, "x")) {
|
|
t.Fatalf("put should succeed for peer %d", i)
|
|
}
|
|
}
|
|
|
|
// Pop one entry to free a slot
|
|
got, ok := buf.pop(peerID("peer-0"))
|
|
if !ok {
|
|
t.Fatal("pop should succeed")
|
|
}
|
|
got.Free()
|
|
|
|
// Now a new peer should fit
|
|
if !buf.put(peerID("new-peer"), newTestMsg(pool, "new")) {
|
|
t.Fatal("put should succeed after popping one entry")
|
|
}
|
|
|
|
// Clean up
|
|
for i := 1; i < earlyMsgCapacity; i++ {
|
|
if m, ok := buf.pop(peerID(fmt.Sprintf("peer-%d", i))); ok {
|
|
m.Free()
|
|
}
|
|
}
|
|
if m, ok := buf.pop(peerID("new-peer")); ok {
|
|
m.Free()
|
|
}
|
|
}
|
|
|
|
func TestEarlyMsgBuffer_PutAfterClose(t *testing.T) {
|
|
buf := newEarlyMsgBuffer()
|
|
|
|
pool := newTestPool()
|
|
buf.close()
|
|
|
|
msg := newTestMsg(pool, "too late")
|
|
if buf.put(peerID("peer1"), msg) {
|
|
t.Fatal("put should fail after close")
|
|
}
|
|
msg.Free()
|
|
}
|
|
|
|
func TestEarlyMsgBuffer_PopAfterClose(t *testing.T) {
|
|
buf := newEarlyMsgBuffer()
|
|
|
|
pool := newTestPool()
|
|
buf.put(peerID("peer1"), newTestMsg(pool, "data"))
|
|
buf.close()
|
|
|
|
// Messages are freed on close, so pop should not find anything
|
|
_, ok := buf.pop(peerID("peer1"))
|
|
if ok {
|
|
t.Fatal("pop should return false after close")
|
|
}
|
|
}
|
|
|
|
func TestEarlyMsgBuffer_DoubleClose(t *testing.T) {
|
|
buf := newEarlyMsgBuffer()
|
|
buf.close()
|
|
buf.close() // should not panic
|
|
}
|
|
|
|
func TestEarlyMsgBuffer_TTLExpiry(t *testing.T) {
|
|
buf := newEarlyMsgBuffer()
|
|
defer buf.close()
|
|
|
|
pool := newTestPool()
|
|
peer := peerID("peer1")
|
|
|
|
buf.put(peer, newTestMsg(pool, "expiring"))
|
|
|
|
// Wait for the TTL to expire plus some margin
|
|
time.Sleep(earlyMsgTTL + 500*time.Millisecond)
|
|
|
|
_, ok := buf.pop(peer)
|
|
if ok {
|
|
t.Fatal("message should have been expired by cleanup")
|
|
}
|
|
}
|
|
|
|
func TestEarlyMsgBuffer_PartialExpiry(t *testing.T) {
|
|
buf := newEarlyMsgBuffer()
|
|
defer buf.close()
|
|
|
|
pool := newTestPool()
|
|
|
|
// Insert first message
|
|
buf.put(peerID("peer1"), newTestMsg(pool, "old"))
|
|
|
|
// Wait half the TTL, then insert second message
|
|
time.Sleep(earlyMsgTTL / 2)
|
|
|
|
buf.put(peerID("peer2"), newTestMsg(pool, "new"))
|
|
|
|
// Wait for the first to expire but not the second
|
|
time.Sleep(earlyMsgTTL/2 + 500*time.Millisecond)
|
|
|
|
// First should be gone
|
|
_, ok := buf.pop(peerID("peer1"))
|
|
if ok {
|
|
t.Fatal("peer1 message should have expired")
|
|
}
|
|
|
|
// Second should still be there
|
|
got, ok := buf.pop(peerID("peer2"))
|
|
if !ok {
|
|
t.Fatal("peer2 message should still be present")
|
|
}
|
|
if string(got.Payload) != "new" {
|
|
t.Fatalf("expected payload 'new', got '%s'", got.Payload)
|
|
}
|
|
got.Free()
|
|
}
|
|
|
|
func TestEarlyMsgBuffer_BulkExpiry(t *testing.T) {
|
|
buf := newEarlyMsgBuffer()
|
|
defer buf.close()
|
|
|
|
pool := newTestPool()
|
|
|
|
for i := 0; i < 50; i++ {
|
|
peer := peerID(fmt.Sprintf("peer-%d", i))
|
|
buf.put(peer, newTestMsg(pool, fmt.Sprintf("msg-%d", i)))
|
|
}
|
|
|
|
// All should expire together
|
|
time.Sleep(earlyMsgTTL + 500*time.Millisecond)
|
|
|
|
for i := 0; i < 50; i++ {
|
|
_, ok := buf.pop(peerID(fmt.Sprintf("peer-%d", i)))
|
|
if ok {
|
|
t.Fatalf("peer-%d should have expired", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestEarlyMsgBuffer_ConcurrentPutAndPop(t *testing.T) {
|
|
buf := newEarlyMsgBuffer()
|
|
defer buf.close()
|
|
|
|
pool := newTestPool()
|
|
var wg sync.WaitGroup
|
|
|
|
// Concurrent puts
|
|
for i := 0; i < 100; i++ {
|
|
wg.Add(1)
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
peer := peerID(fmt.Sprintf("peer-%d", id))
|
|
msg := newTestMsg(pool, fmt.Sprintf("msg-%d", id))
|
|
if !buf.put(peer, msg) {
|
|
msg.Free()
|
|
}
|
|
}(i)
|
|
}
|
|
wg.Wait()
|
|
|
|
// Concurrent pops
|
|
var popped int64
|
|
var mu sync.Mutex
|
|
for i := 0; i < 100; i++ {
|
|
wg.Add(1)
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
peer := peerID(fmt.Sprintf("peer-%d", id))
|
|
if msg, ok := buf.pop(peer); ok {
|
|
msg.Free()
|
|
mu.Lock()
|
|
popped++
|
|
mu.Unlock()
|
|
}
|
|
}(i)
|
|
}
|
|
wg.Wait()
|
|
|
|
if popped != 100 {
|
|
t.Fatalf("expected to pop 100 messages, got %d", popped)
|
|
}
|
|
}
|
|
|
|
func TestEarlyMsgBuffer_ConcurrentPutPopAndClose(t *testing.T) {
|
|
buf := newEarlyMsgBuffer()
|
|
|
|
pool := newTestPool()
|
|
var wg sync.WaitGroup
|
|
|
|
// Concurrent puts
|
|
for i := 0; i < 50; i++ {
|
|
wg.Add(1)
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
peer := peerID(fmt.Sprintf("peer-%d", id))
|
|
msg := newTestMsg(pool, fmt.Sprintf("msg-%d", id))
|
|
if !buf.put(peer, msg) {
|
|
msg.Free()
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
// Concurrent pops
|
|
for i := 0; i < 50; i++ {
|
|
wg.Add(1)
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
peer := peerID(fmt.Sprintf("peer-%d", id))
|
|
if msg, ok := buf.pop(peer); ok {
|
|
msg.Free()
|
|
}
|
|
}(i)
|
|
}
|
|
|
|
// Close concurrently
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
buf.close()
|
|
}()
|
|
|
|
wg.Wait() // should not panic or deadlock
|
|
}
|
|
|
|
func TestEarlyMsgBuffer_OverwriteDoesNotLeak(t *testing.T) {
|
|
buf := newEarlyMsgBuffer()
|
|
defer buf.close()
|
|
|
|
// Use a dedicated pool to detect that overwritten message's Free was called
|
|
freeCalled := make(chan struct{}, 1)
|
|
origPool := &sync.Pool{
|
|
New: func() any {
|
|
b := make([]byte, 64)
|
|
return &b
|
|
},
|
|
}
|
|
|
|
b := make([]byte, 64)
|
|
copy(b, "original")
|
|
bufPtr := &b
|
|
origMsg := Msg{
|
|
bufPool: origPool,
|
|
bufPtr: bufPtr,
|
|
Payload: b[:8],
|
|
}
|
|
|
|
peer := peerID("peer1")
|
|
buf.put(peer, origMsg)
|
|
|
|
// Now check if the original buffer was freed by trying to get from pool
|
|
// We need a wrapper pool that signals when Put is called
|
|
trackPool := &sync.Pool{
|
|
New: func() any {
|
|
b := make([]byte, 64)
|
|
return &b
|
|
},
|
|
}
|
|
_ = trackPool
|
|
|
|
// Simpler approach: overwrite and check that only new value is returned
|
|
newPool := newTestPool()
|
|
buf.put(peer, newTestMsg(newPool, "replaced"))
|
|
|
|
// After overwrite, only the new message should be retrievable
|
|
got, ok := buf.pop(peer)
|
|
if !ok {
|
|
t.Fatal("pop should find the message")
|
|
}
|
|
if string(got.Payload) != "replaced" {
|
|
t.Fatalf("expected 'replaced', got '%s'", got.Payload)
|
|
}
|
|
got.Free()
|
|
close(freeCalled)
|
|
}
|
|
|
|
func TestEarlyMsgBuffer_EmptyBuffer(t *testing.T) {
|
|
buf := newEarlyMsgBuffer()
|
|
defer buf.close()
|
|
|
|
// Pop from empty buffer
|
|
_, ok := buf.pop(peerID("anything"))
|
|
if ok {
|
|
t.Fatal("pop from empty buffer should return false")
|
|
}
|
|
|
|
// Close empty buffer should be fine
|
|
buf2 := newEarlyMsgBuffer()
|
|
buf2.close()
|
|
}
|