diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 3d782f04c..7e07bf322 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -19,6 +19,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric/noop" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" @@ -963,10 +964,14 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { b.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) + peerChannels := make(map[string]*UpdateBuffer) + metrics, err := telemetry.NewUpdateChannelMetrics(context.Background(), noop.NewMeterProvider().Meter("test")) + if err != nil { + b.Fatalf("Failed to create update channel metrics: %v", err) + } for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + peerChannels[peerID] = NewUpdateBuffer(metrics) } manager.peersUpdateManager.peerChannels = peerChannels @@ -1028,17 +1033,24 @@ func TestUpdateAccountPeers(t *testing.T) { t.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) + peerChannels := make(map[string]*UpdateBuffer) + metrics, err := telemetry.NewUpdateChannelMetrics(context.Background(), noop.NewMeterProvider().Meter("test")) + if err != nil { + t.Fatalf("Failed to create update channel metrics: %v", err) + } for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + peerChannels[peerID] = NewUpdateBuffer(metrics) } manager.peersUpdateManager.peerChannels = peerChannels manager.UpdateAccountPeers(ctx, account.Id) for _, channel := range peerChannels { - update := <-channel + update, ok := channel.Pop(context.Background()) + if !ok { + t.Fatalf("Expected update for peer, but channel is empty") + } assert.Nil(t, update.Update.NetbirdConfig) assert.Equal(t, tc.peers, len(update.NetworkMap.Peers)) assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules)) diff --git a/management/server/token_mgr_test.go b/management/server/token_mgr_test.go index b2184717d..db6f6c23c 100644 --- a/management/server/token_mgr_test.go +++ b/management/server/token_mgr_test.go @@ -79,7 +79,19 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { secret := "some_secret" peersManager := NewPeersUpdateManager(nil) peer := "some_peer" - updateChannel := peersManager.CreateChannel(context.Background(), peer) + buffer := peersManager.CreateChannel(context.Background(), peer) + resultCh := make(chan struct { + msg *UpdateMessage + ok bool + }, 1) + + go func() { + msg, ok := buffer.Pop(context.Background()) + resultCh <- struct { + msg *UpdateMessage + ok bool + }{msg, ok} + }() rc := &types.Relay{ Addresses: []string{"localhost:0"}, @@ -117,8 +129,8 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { loop: for timeout := time.After(5 * time.Second); ; { select { - case update := <-updateChannel: - updates = append(updates, update) + case update := <-resultCh: + updates = append(updates, update.msg) case <-timeout: break loop } diff --git a/management/server/update_buffer.go b/management/server/update_buffer.go index 9f4de6ed0..087341678 100644 --- a/management/server/update_buffer.go +++ b/management/server/update_buffer.go @@ -25,17 +25,15 @@ func (b *UpdateBuffer) Push(update *UpdateMessage) { b.mu.Lock() defer b.mu.Unlock() - if b.update == nil { - b.update = update - b.cond.Signal() - b.metrics.CountBufferPush() - return - } - // the equal case we need because we don't always increment the serial number - if update.NetworkMap.Network.Serial >= b.update.NetworkMap.Network.Serial { + if b.update == nil || update.Update.NetworkMap.Serial > b.update.Update.NetworkMap.Serial || b.update.Update.NetworkMap.Serial == 0 { b.update = update b.cond.Signal() + if b.update == nil { + b.metrics.CountBufferPush() + return + } + b.metrics.CountBufferOverwrite() return } @@ -50,19 +48,15 @@ func (b *UpdateBuffer) Pop(ctx context.Context) (*UpdateMessage, bool) { for b.update == nil && !b.closed { waitCh := make(chan struct{}) go func() { - b.cond.Wait() - close(waitCh) + select { + case <-ctx.Done(): + b.cond.Broadcast() + case <-waitCh: + // noop + } }() - - b.mu.Unlock() - select { - case <-ctx.Done(): - b.mu.Lock() - return nil, false - case <-waitCh: - // Wakeup due to Push() or Close() - } - b.mu.Lock() + b.cond.Wait() + close(waitCh) } if b.closed { diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index 0d0eb64fd..6bd1c6422 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -80,6 +80,7 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) * } // mbragin: todo shouldn't it be more? or configurable? buffer := NewUpdateBuffer(p.metrics.UpdateChannelMetrics()) + p.peerChannels[peerID] = buffer log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID) diff --git a/management/server/updatechannel_test.go b/management/server/updatechannel_test.go index 69f5b895c..649e73e50 100644 --- a/management/server/updatechannel_test.go +++ b/management/server/updatechannel_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/telemetry" ) // var peersUpdater *PeersUpdateManager @@ -23,7 +24,12 @@ func TestCreateChannel(t *testing.T) { func TestSendUpdate(t *testing.T) { peer := "test-sendupdate" - peersUpdater := NewPeersUpdateManager(nil) + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + if err != nil { + t.Fatalf("failed to create metrics: %v", err) + } + peersUpdater := NewPeersUpdateManager(metrics) update1 := &UpdateMessage{Update: &proto.SyncResponse{ NetworkMap: &proto.NetworkMap{ Serial: 0, @@ -33,10 +39,26 @@ func TestSendUpdate(t *testing.T) { if _, ok := peersUpdater.peerChannels[peer]; !ok { t.Error("Error creating the channel") } + + resultCh := make(chan struct { + msg *UpdateMessage + ok bool + }, 1) + + go func() { + for range [channelBufferSize]int{} { + msg, ok := peersUpdater.peerChannels[peer].Pop(context.Background()) + resultCh <- struct { + msg *UpdateMessage + ok bool + }{msg, ok} + } + }() + peersUpdater.SendUpdate(context.Background(), peer, update1) select { - case <-peersUpdater.peerChannels[peer]: - default: + case <-resultCh: + case <-time.After(1 * time.Second): t.Error("Update wasn't send") } @@ -56,8 +78,8 @@ func TestSendUpdate(t *testing.T) { select { case <-timeout: t.Error("timed out reading previously sent updates") - case updateReader := <-peersUpdater.peerChannels[peer]: - if updateReader.Update.NetworkMap.Serial == update2.Update.NetworkMap.Serial { + case updateReader := <-resultCh: + if updateReader.msg.Update.NetworkMap.Serial == update2.Update.NetworkMap.Serial { t.Error("got the update that shouldn't have been sent") } }