mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
Add early message buffer to capture transport messages arriving before OpenConn completes, ensuring correct message ordering and no dropped messages.
176 lines
3.9 KiB
Go
176 lines
3.9 KiB
Go
package client
|
|
|
|
import (
|
|
"container/list"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/netbirdio/netbird/shared/relay/messages"
|
|
)
|
|
|
|
const (
|
|
earlyMsgTTL = 5 * time.Second
|
|
earlyMsgCapacity = 1000
|
|
)
|
|
|
|
// earlyMsgBuffer buffers transport messages that arrive before the corresponding
|
|
// OpenConn call. This happens during reconnection when the remote peer sends data
|
|
// before the local side has set up the relay connection.
|
|
//
|
|
// It stores at most one message per peer (the first WireGuard handshake) and
|
|
// caps the total number of entries to prevent unbounded memory growth.
|
|
// A cleanup timer runs only when there are buffered entries and fires when the
|
|
// oldest entry expires. Entries are kept in a linked list ordered by insertion
|
|
// time so cleanup only needs to walk from the front.
|
|
type earlyMsgBuffer struct {
|
|
mu sync.Mutex
|
|
index map[messages.PeerID]*list.Element
|
|
order *list.List // front = oldest
|
|
timer *time.Timer
|
|
closed bool
|
|
}
|
|
|
|
type earlyMsg struct {
|
|
peerID messages.PeerID
|
|
msg Msg
|
|
createdAt time.Time
|
|
}
|
|
|
|
func newEarlyMsgBuffer() *earlyMsgBuffer {
|
|
return &earlyMsgBuffer{
|
|
index: make(map[messages.PeerID]*list.Element),
|
|
order: list.New(),
|
|
}
|
|
}
|
|
|
|
// put stores or overwrites a message for the given peer. If a message for the
|
|
// peer already exists, it is replaced with the new one. Returns false if the
|
|
// message was not stored (buffer full or buffer closed).
|
|
func (b *earlyMsgBuffer) put(peerID messages.PeerID, msg Msg) bool {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
|
|
if b.closed {
|
|
return false
|
|
}
|
|
|
|
if existing, exists := b.index[peerID]; exists {
|
|
old := b.order.Remove(existing).(earlyMsg)
|
|
old.msg.Free()
|
|
delete(b.index, peerID)
|
|
}
|
|
|
|
if b.order.Len() >= earlyMsgCapacity {
|
|
return false
|
|
}
|
|
|
|
entry := earlyMsg{
|
|
peerID: peerID,
|
|
msg: msg,
|
|
createdAt: time.Now(),
|
|
}
|
|
elem := b.order.PushBack(entry)
|
|
b.index[peerID] = elem
|
|
|
|
// Start the cleanup timer if this is the first entry
|
|
if b.order.Len() == 1 {
|
|
b.scheduleCleanup(earlyMsgTTL)
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// pop retrieves and removes the buffered message for the given peer.
|
|
// Returns the message and true if found, zero value and false otherwise.
|
|
func (b *earlyMsgBuffer) pop(peerID messages.PeerID) (Msg, bool) {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
|
|
elem, ok := b.index[peerID]
|
|
if !ok {
|
|
return Msg{}, false
|
|
}
|
|
|
|
entry := b.order.Remove(elem).(earlyMsg)
|
|
delete(b.index, peerID)
|
|
|
|
if b.order.Len() == 0 {
|
|
b.stopCleanup()
|
|
}
|
|
|
|
return entry.msg, true
|
|
}
|
|
|
|
// close stops the cleanup timer and frees all buffered messages.
|
|
func (b *earlyMsgBuffer) close() {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
|
|
if b.closed {
|
|
return
|
|
}
|
|
b.closed = true
|
|
b.stopCleanup()
|
|
|
|
for elem := b.order.Front(); elem != nil; elem = elem.Next() {
|
|
entry := elem.Value.(earlyMsg)
|
|
entry.msg.Free()
|
|
}
|
|
b.order.Init()
|
|
b.index = make(map[messages.PeerID]*list.Element)
|
|
}
|
|
|
|
// scheduleCleanup starts or resets the timer. Caller must hold b.mu.
|
|
func (b *earlyMsgBuffer) scheduleCleanup(d time.Duration) {
|
|
if b.timer != nil {
|
|
b.timer.Stop()
|
|
}
|
|
b.timer = time.AfterFunc(d, b.removeExpired)
|
|
}
|
|
|
|
// stopCleanup stops the timer. Caller must hold b.mu.
|
|
func (b *earlyMsgBuffer) stopCleanup() {
|
|
if b.timer != nil {
|
|
b.timer.Stop()
|
|
b.timer = nil
|
|
}
|
|
}
|
|
|
|
func (b *earlyMsgBuffer) removeExpired() {
|
|
b.mu.Lock()
|
|
defer b.mu.Unlock()
|
|
|
|
if b.closed {
|
|
return
|
|
}
|
|
|
|
now := time.Now()
|
|
for elem := b.order.Front(); elem != nil; {
|
|
entry := elem.Value.(earlyMsg)
|
|
if now.Sub(entry.createdAt) <= earlyMsgTTL {
|
|
// Entries are ordered by time, so the rest are newer
|
|
break
|
|
}
|
|
next := elem.Next()
|
|
b.order.Remove(elem)
|
|
delete(b.index, entry.peerID)
|
|
entry.msg.Free()
|
|
elem = next
|
|
}
|
|
|
|
if b.order.Len() == 0 {
|
|
b.timer = nil
|
|
return
|
|
}
|
|
|
|
// Schedule next cleanup based on when the oldest entry expires
|
|
front := b.order.Front()
|
|
if front == nil {
|
|
b.timer = nil
|
|
return
|
|
}
|
|
oldest := front.Value.(earlyMsg).createdAt
|
|
nextCleanup := earlyMsgTTL - now.Sub(oldest)
|
|
b.scheduleCleanup(nextCleanup)
|
|
}
|