diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 12b59b691..393192eeb 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -86,7 +86,7 @@ func NewServer( if appMetrics != nil { // update gauge based on number of connected peers which is equal to open gRPC streams err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 { - return int64(len(peersUpdateManager.peerChannels)) + return int64(peersUpdateManager.GetChannelCount()) }) if err != nil { return nil, err diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index adf64592a..bcef2c6ee 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -11,17 +11,21 @@ import ( "github.com/netbirdio/netbird/shared/management/proto" ) -const channelBufferSize = 100 - type UpdateMessage struct { Update *proto.SyncResponse } +type peerUpdate struct { + mu sync.Mutex + message *UpdateMessage + notify chan struct{} +} + type PeersUpdateManager struct { - // peerChannels is an update channel indexed by Peer.ID - peerChannels map[string]chan *UpdateMessage - // channelsMux keeps the mutex to access peerChannels - channelsMux *sync.RWMutex + // latestUpdates stores the latest update message per peer + latestUpdates sync.Map // map[string]*peerUpdate + // activePeers tracks which peers have active sender goroutines + activePeers sync.Map // map[string]struct{} // metrics provides method to collect application metrics metrics telemetry.AppMetrics } @@ -29,87 +33,137 @@ type PeersUpdateManager struct { // NewPeersUpdateManager returns a new instance of PeersUpdateManager func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager { return &PeersUpdateManager{ - peerChannels: make(map[string]chan *UpdateMessage), - channelsMux: &sync.RWMutex{}, - metrics: metrics, + metrics: metrics, } } -// SendUpdate sends update message to the peer's channel +// SendUpdate stores the latest update message for a peer and notifies the sender goroutine func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) { start := time.Now() var found, dropped bool - p.channelsMux.RLock() - defer func() { - p.channelsMux.RUnlock() if p.metrics != nil { p.metrics.UpdateChannelMetrics().CountSendUpdateDuration(time.Since(start), found, dropped) } }() - if channel, ok := p.peerChannels[peerID]; ok { - found = true - select { - case channel <- update: - log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID) - default: - dropped = true - log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel)) - } - } else { - log.WithContext(ctx).Debugf("peer %s has no channel", peerID) + // Check if peer has an active sender goroutine + if _, ok := p.activePeers.Load(peerID); !ok { + log.WithContext(ctx).Debugf("peer %s has no active sender", peerID) + return + } + + found = true + + // Load or create peerUpdate entry + val, _ := p.latestUpdates.LoadOrStore(peerID, &peerUpdate{ + notify: make(chan struct{}, 1), + }) + + pu := val.(*peerUpdate) + + // Store the latest message (overwrites any previous unsent message) + pu.mu.Lock() + pu.message = update + pu.mu.Unlock() + + // Non-blocking notification + select { + case pu.notify <- struct{}{}: + log.WithContext(ctx).Debugf("update notification sent for peer %s", peerID) + default: + // Already notified, sender will pick up the latest message anyway + log.WithContext(ctx).Tracef("peer %s already notified, update will be picked up", peerID) } } -// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer. +// CreateChannel creates a sender goroutine for a given peer and returns a channel to receive updates func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage { start := time.Now() closed := false - p.channelsMux.Lock() defer func() { - p.channelsMux.Unlock() if p.metrics != nil { p.metrics.UpdateChannelMetrics().CountCreateChannelDuration(time.Since(start), closed) } }() - if channel, ok := p.peerChannels[peerID]; ok { + // Close existing sender if any + if _, exists := p.activePeers.LoadOrStore(peerID, struct{}{}); exists { closed = true - delete(p.peerChannels, peerID) - close(channel) + p.closeChannel(ctx, peerID) } - // mbragin: todo shouldn't it be more? or configurable? - channel := make(chan *UpdateMessage, channelBufferSize) - p.peerChannels[peerID] = channel - log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID) + // Create peerUpdate entry with notification channel + pu := &peerUpdate{ + notify: make(chan struct{}, 1), + } + p.latestUpdates.Store(peerID, pu) - return channel + // Create output channel for consumer + outChan := make(chan *UpdateMessage, 1) + + // Start sender goroutine + go func() { + defer close(outChan) + for { + select { + case <-ctx.Done(): + log.WithContext(ctx).Debugf("sender goroutine for peer %s stopped due to context cancellation", peerID) + return + case <-pu.notify: + // Check if still active + if _, ok := p.activePeers.Load(peerID); !ok { + log.WithContext(ctx).Debugf("sender goroutine for peer %s stopped", peerID) + return + } + + // Get the latest message with mutex protection + pu.mu.Lock() + msg := pu.message + pu.message = nil // Clear after reading + pu.mu.Unlock() + + if msg != nil { + select { + case outChan <- msg: + log.WithContext(ctx).Tracef("sent update to peer %s", peerID) + case <-ctx.Done(): + return + } + } + } + } + }() + + log.WithContext(ctx).Debugf("created sender goroutine for peer %s", peerID) + + return outChan } func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) { - if channel, ok := p.peerChannels[peerID]; ok { - delete(p.peerChannels, peerID) - close(channel) - - log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID) + // Mark peer as inactive to stop the sender goroutine + if _, ok := p.activePeers.LoadAndDelete(peerID); ok { + // Close notification channel + if val, ok := p.latestUpdates.Load(peerID); ok { + pu := val.(*peerUpdate) + close(pu.notify) + } + p.latestUpdates.Delete(peerID) + log.WithContext(ctx).Debugf("closed sender for peer %s", peerID) return } - log.WithContext(ctx).Debugf("closing updates channel: peer %s has no channel", peerID) + log.WithContext(ctx).Debugf("closing sender: peer %s has no active sender", peerID) } -// CloseChannels closes updates channel for each given peer +// CloseChannels closes sender goroutines for each given peer func (p *PeersUpdateManager) CloseChannels(ctx context.Context, peerIDs []string) { start := time.Now() - p.channelsMux.Lock() defer func() { - p.channelsMux.Unlock() if p.metrics != nil { p.metrics.UpdateChannelMetrics().CountCloseChannelsDuration(time.Since(start), len(peerIDs)) } @@ -120,13 +174,11 @@ func (p *PeersUpdateManager) CloseChannels(ctx context.Context, peerIDs []string } } -// CloseChannel closes updates channel of a given peer +// CloseChannel closes the sender goroutine of a given peer func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string) { start := time.Now() - p.channelsMux.Lock() defer func() { - p.channelsMux.Unlock() if p.metrics != nil { p.metrics.UpdateChannelMetrics().CountCloseChannelDuration(time.Since(start)) } @@ -139,38 +191,43 @@ func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string) { func (p *PeersUpdateManager) GetAllConnectedPeers() map[string]struct{} { start := time.Now() - p.channelsMux.RLock() - m := make(map[string]struct{}) defer func() { - p.channelsMux.RUnlock() if p.metrics != nil { p.metrics.UpdateChannelMetrics().CountGetAllConnectedPeersDuration(time.Since(start), len(m)) } }() - for ID := range p.peerChannels { - m[ID] = struct{}{} - } + p.activePeers.Range(func(key, value interface{}) bool { + m[key.(string)] = struct{}{} + return true + }) return m } -// HasChannel returns true if peers has channel in update manager, otherwise false +// HasChannel returns true if peer has an active sender goroutine, otherwise false func (p *PeersUpdateManager) HasChannel(peerID string) bool { start := time.Now() - p.channelsMux.RLock() - defer func() { - p.channelsMux.RUnlock() if p.metrics != nil { p.metrics.UpdateChannelMetrics().CountHasChannelDuration(time.Since(start)) } }() - _, ok := p.peerChannels[peerID] + _, ok := p.activePeers.Load(peerID) return ok } + +// GetChannelCount returns the number of active peer channels +func (p *PeersUpdateManager) GetChannelCount() int { + count := 0 + p.activePeers.Range(func(key, value interface{}) bool { + count++ + return true + }) + return count +}