Files
netbird/shared/relay/client/early_msg_buffer.go

176 lines
3.9 KiB
Go

package client
import (
"container/list"
"sync"
"time"
"github.com/netbirdio/netbird/shared/relay/messages"
)
const (
earlyMsgTTL = 5 * time.Second
earlyMsgCapacity = 1000
)
// earlyMsgBuffer buffers transport messages that arrive before the corresponding
// OpenConn call. This happens during reconnection when the remote peer sends data
// before the local side has set up the relay connection.
//
// It stores at most one message per peer (the first WireGuard handshake) and
// caps the total number of entries to prevent unbounded memory growth.
// A cleanup timer runs only when there are buffered entries and fires when the
// oldest entry expires. Entries are kept in a linked list ordered by insertion
// time so cleanup only needs to walk from the front.
type earlyMsgBuffer struct {
mu sync.Mutex
index map[messages.PeerID]*list.Element
order *list.List // front = oldest
timer *time.Timer
closed bool
}
type earlyMsg struct {
peerID messages.PeerID
msg Msg
createdAt time.Time
}
func newEarlyMsgBuffer() *earlyMsgBuffer {
return &earlyMsgBuffer{
index: make(map[messages.PeerID]*list.Element),
order: list.New(),
}
}
// put stores or overwrites a message for the given peer. If a message for the
// peer already exists, it is replaced with the new one. Returns false if the
// message was not stored (buffer full or buffer closed).
func (b *earlyMsgBuffer) put(peerID messages.PeerID, msg Msg) bool {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return false
}
if existing, exists := b.index[peerID]; exists {
old := b.order.Remove(existing).(earlyMsg)
old.msg.Free()
delete(b.index, peerID)
}
if b.order.Len() >= earlyMsgCapacity {
return false
}
entry := earlyMsg{
peerID: peerID,
msg: msg,
createdAt: time.Now(),
}
elem := b.order.PushBack(entry)
b.index[peerID] = elem
// Start the cleanup timer if this is the first entry
if b.order.Len() == 1 {
b.scheduleCleanup(earlyMsgTTL)
}
return true
}
// pop retrieves and removes the buffered message for the given peer.
// Returns the message and true if found, zero value and false otherwise.
func (b *earlyMsgBuffer) pop(peerID messages.PeerID) (Msg, bool) {
b.mu.Lock()
defer b.mu.Unlock()
elem, ok := b.index[peerID]
if !ok {
return Msg{}, false
}
entry := b.order.Remove(elem).(earlyMsg)
delete(b.index, peerID)
if b.order.Len() == 0 {
b.stopCleanup()
}
return entry.msg, true
}
// close stops the cleanup timer and frees all buffered messages.
func (b *earlyMsgBuffer) close() {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return
}
b.closed = true
b.stopCleanup()
for elem := b.order.Front(); elem != nil; elem = elem.Next() {
entry := elem.Value.(earlyMsg)
entry.msg.Free()
}
b.order.Init()
b.index = make(map[messages.PeerID]*list.Element)
}
// scheduleCleanup starts or resets the timer. Caller must hold b.mu.
func (b *earlyMsgBuffer) scheduleCleanup(d time.Duration) {
if b.timer != nil {
b.timer.Stop()
}
b.timer = time.AfterFunc(d, b.removeExpired)
}
// stopCleanup stops the timer. Caller must hold b.mu.
func (b *earlyMsgBuffer) stopCleanup() {
if b.timer != nil {
b.timer.Stop()
b.timer = nil
}
}
func (b *earlyMsgBuffer) removeExpired() {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return
}
now := time.Now()
for elem := b.order.Front(); elem != nil; {
entry := elem.Value.(earlyMsg)
if now.Sub(entry.createdAt) <= earlyMsgTTL {
// Entries are ordered by time, so the rest are newer
break
}
next := elem.Next()
b.order.Remove(elem)
delete(b.index, entry.peerID)
entry.msg.Free()
elem = next
}
if b.order.Len() == 0 {
b.timer = nil
return
}
// Schedule next cleanup based on when the oldest entry expires
front := b.order.Front()
if front == nil {
b.timer = nil
return
}
oldest := front.Value.(earlyMsg).createdAt
nextCleanup := earlyMsgTTL - now.Sub(oldest)
b.scheduleCleanup(nextCleanup)
}