Compare commits

...

4 Commits

Author SHA1 Message Date
pascal
37b9905b68 Merge branch 'main' into feature/log-most-busy-peers 2026-04-23 15:41:26 +02:00
pascal
92e53d6319 generic settings overrider 2026-04-21 18:31:53 +02:00
pascal
8a7d78ddf3 make configurable via env 2026-04-21 15:42:44 +02:00
pascal
ea83cbf917 log the most busy peers 2026-04-21 15:27:05 +02:00
8 changed files with 501 additions and 2 deletions

View File

@@ -19,6 +19,7 @@ import (
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
cachestore "github.com/eko/gocache/lib/v4/store" cachestore "github.com/eko/gocache/lib/v4/store"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
@@ -34,6 +35,7 @@ import (
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto" mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/settingoverrider"
"github.com/netbirdio/netbird/util/crypt" "github.com/netbirdio/netbird/util/crypt"
) )
@@ -73,6 +75,23 @@ func (s *BaseServer) CacheStore() cachestore.StoreInterface {
}) })
} }
// SettingOverrider returns a shared setting overrider backed by Redis.
// Returns a no-op overrider if no Redis address is configured.
func (s *BaseServer) SettingOverrider() *settingoverrider.Overrider {
return Create(s, func() *settingoverrider.Overrider {
redisAddr := nbcache.GetAddrFromEnv()
if redisAddr == "" {
return settingoverrider.NewNoop()
}
o, err := settingoverrider.New(context.Background(), redisAddr)
if err != nil {
log.Fatalf("failed to create setting overrider: %v", err)
}
return o
})
}
func (s *BaseServer) Store() store.Store { func (s *BaseServer) Store() store.Store {
return Create(s, func() store.Store { return Create(s, func() store.Store {
store, err := store.NewStore(context.Background(), s.Config.StoreConfig.Engine, s.Config.Datadir, s.Metrics(), false) store, err := store.NewStore(context.Background(), s.Config.StoreConfig.Engine, s.Config.Datadir, s.Metrics(), false)

View File

@@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/metrics" "github.com/netbirdio/netbird/management/server/metrics"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/settingoverrider"
"github.com/netbirdio/netbird/util/wsproxy" "github.com/netbirdio/netbird/util/wsproxy"
wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server" wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
@@ -123,6 +124,15 @@ func (s *BaseServer) Start(ctx context.Context) error {
s.PeersManager() s.PeersManager()
s.GeoLocationManager() s.GeoLocationManager()
s.SettingOverrider().Poll(settingoverrider.DefaultInterval, "managementLogLevel", func(value string) error {
level, err := log.ParseLevel(value)
if err != nil {
return fmt.Errorf("parsing log level %q: %w", value, err)
}
log.SetLevel(level)
return nil
})
err := s.Metrics().Expose(srvCtx, s.mgmtMetricsPort, "/metrics") err := s.Metrics().Expose(srvCtx, s.mgmtMetricsPort, "/metrics")
if err != nil { if err != nil {
return fmt.Errorf("failed to expose metrics: %v", err) return fmt.Errorf("failed to expose metrics: %v", err)
@@ -235,6 +245,7 @@ func (s *BaseServer) Stop() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
_ = s.SettingOverrider().Close()
s.IntegratedValidator().Stop(ctx) s.IntegratedValidator().Stop(ctx)
if s.GeoLocationManager() != nil { if s.GeoLocationManager() != nil {
_ = s.GeoLocationManager().Stop() _ = s.GeoLocationManager().Stop()

View File

@@ -0,0 +1,120 @@
package settingoverrider
import (
"context"
"errors"
"fmt"
"time"
"github.com/redis/go-redis/v9"
log "github.com/sirupsen/logrus"
)
const (
DefaultInterval = 5 * time.Minute
)
// ApplyFunc is called with the raw Redis string value whenever it changes.
// The function is responsible for parsing and applying the value.
// Return an error to log a warning without stopping the polling loop.
type ApplyFunc func(value string) error
// Overrider holds a shared Redis connection and allows registering
// individual settings that are polled independently.
type Overrider struct {
client *redis.Client
cancel context.CancelFunc
ctx context.Context
noop bool
}
// New creates an Overrider by connecting to Redis at the given address.
// The address should follow the Redis URL format (e.g. "redis://localhost:6379").
func New(ctx context.Context, redisAddr string) (*Overrider, error) {
if redisAddr == "" {
return nil, fmt.Errorf("redis address is empty")
}
options, err := redis.ParseURL(redisAddr)
if err != nil {
return nil, fmt.Errorf("parsing redis address: %w", err)
}
client := redis.NewClient(options)
pingCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
if _, err := client.Ping(pingCtx).Result(); err != nil {
_ = client.Close()
return nil, fmt.Errorf("connecting to redis: %w", err)
}
oCtx, oCancel := context.WithCancel(ctx)
return &Overrider{client: client, cancel: oCancel, ctx: oCtx}, nil
}
// NewNoop returns an Overrider that does nothing.
// Poll calls are silently ignored and Close is a no-op.
func NewNoop() *Overrider {
return &Overrider{noop: true}
}
// Close stops all polling goroutines and closes the underlying Redis client.
func (o *Overrider) Close() error {
if o.noop {
return nil
}
o.cancel()
return o.client.Close()
}
// Poll starts a background goroutine that polls a single Redis key at the given interval
// and calls apply whenever the value changes. The goroutine stops when the Overrider is closed.
func (o *Overrider) Poll(interval time.Duration, redisKey string, apply ApplyFunc) {
if o.noop {
return
}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
var lastSeen *string
for {
select {
case <-o.ctx.Done():
log.WithContext(o.ctx).Infof("Stopping settings overrider for key %q", redisKey)
return
case <-ticker.C:
getCtx, cancel := context.WithTimeout(o.ctx, 5*time.Second)
val, err := o.client.Get(getCtx, redisKey).Result()
cancel()
if errors.Is(err, redis.Nil) || val == "" {
continue
}
if err != nil {
if o.ctx.Err() != nil {
return
}
log.WithContext(o.ctx).Errorf("Unable to get setting %q from Redis: %v", redisKey, err)
continue
}
if lastSeen != nil && *lastSeen == val {
continue
}
if err := apply(val); err != nil {
log.WithContext(o.ctx).Warnf("Failed to apply setting %q with value %q: %v", redisKey, val, err)
continue
}
lastSeen = &val
}
}
}()
}

View File

@@ -0,0 +1,111 @@
package settingoverrider
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
testcontainersredis "github.com/testcontainers/testcontainers-go/modules/redis"
"github.com/testcontainers/testcontainers-go/wait"
)
func TestPoll_AppliesSettingFromRedis(t *testing.T) {
o, client := setupOverrider(t)
key := "test-setting-key"
require.NoError(t, client.Set(context.Background(), key, "hello", 0).Err())
var applied atomic.Value
o.Poll(100*time.Millisecond, key, func(value string) error {
applied.Store(value)
return nil
})
assert.Eventually(t, func() bool {
v := applied.Load()
return v != nil && v.(string) == "hello"
}, 5*time.Second, 50*time.Millisecond)
}
func TestPoll_IndependentSettings(t *testing.T) {
o, client := setupOverrider(t)
require.NoError(t, client.Set(context.Background(), "key-a", "val-a", 0).Err())
require.NoError(t, client.Set(context.Background(), "key-b", "val-b", 0).Err())
var gotA, gotB atomic.Value
o.Poll(100*time.Millisecond, "key-a", func(v string) error { gotA.Store(v); return nil })
o.Poll(100*time.Millisecond, "key-b", func(v string) error { gotB.Store(v); return nil })
assert.Eventually(t, func() bool {
a, b := gotA.Load(), gotB.Load()
return a != nil && a.(string) == "val-a" && b != nil && b.(string) == "val-b"
}, 5*time.Second, 50*time.Millisecond)
}
func TestPoll_SkipsDuplicateValues(t *testing.T) {
o, client := setupOverrider(t)
key := "test-dedup"
require.NoError(t, client.Set(context.Background(), key, "same", 0).Err())
var count atomic.Int32
o.Poll(100*time.Millisecond, key, func(string) error {
count.Add(1)
return nil
})
// wait for a few ticks
time.Sleep(600 * time.Millisecond)
assert.Equal(t, int32(1), count.Load(), "Apply should be called only once for unchanged value")
}
func setupOverrider(t *testing.T) (*Overrider, *redis.Client) {
t.Helper()
ctx := context.Background()
redisContainer, err := testcontainersredis.RunContainer(ctx,
testcontainers.WithImage("redis:7"),
testcontainers.WithWaitStrategy(
wait.ForListeningPort("6379/tcp"),
),
)
require.NoError(t, err, "Failed to create redis test container")
t.Cleanup(func() {
if err := redisContainer.Terminate(ctx); err != nil {
t.Logf("failed to terminate redis container: %s", err)
}
})
redisURL, err := redisContainer.ConnectionString(ctx)
require.NoError(t, err)
o, err := New(ctx, redisURL)
require.NoError(t, err)
t.Cleanup(func() {
if err := o.Close(); err != nil {
t.Logf("failed to close overrider: %s", err)
}
})
// separate client for test setup (setting keys)
options, err := redis.ParseURL(redisURL)
require.NoError(t, err)
client := redis.NewClient(options)
t.Cleanup(func() {
if err := client.Close(); err != nil {
t.Logf("failed to close redis client: %s", err)
}
})
return o, client
}

View File

@@ -18,7 +18,9 @@ import (
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/h2c" "golang.org/x/net/http2/h2c"
"github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/shared/metrics" "github.com/netbirdio/netbird/shared/metrics"
"github.com/netbirdio/netbird/shared/settingoverrider"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/shared/signal/proto"
@@ -114,7 +116,24 @@ var (
} }
}() }()
srv, err := server.NewServer(cmd.Context(), metricsServer.Meter) overrider := settingoverrider.NewNoop()
if redisAddr := cache.GetAddrFromEnv(); redisAddr != "" {
overrider, err = settingoverrider.New(cmd.Context(), redisAddr)
if err != nil {
return fmt.Errorf("failed to create setting overrider: %w", err)
}
defer func() { _ = overrider.Close() }()
}
overrider.Poll(settingoverrider.DefaultInterval, "signalLogLevel", func(value string) error {
level, err := log.ParseLevel(value)
if err != nil {
return fmt.Errorf("parsing log level %q: %w", value, err)
}
log.SetLevel(level)
return nil
})
srv, err := server.NewServer(cmd.Context(), metricsServer.Meter, overrider)
if err != nil { if err != nil {
return fmt.Errorf("creating signal server: %v", err) return fmt.Errorf("creating signal server: %v", err)
} }

View File

@@ -0,0 +1,135 @@
package server
import (
"context"
"math"
"os"
"strconv"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
)
const (
defaultSendRateLogInterval = 5 * time.Minute
defaultSendRateTopPercent = 0.95
envSendRateLogInterval = "NB_SIGNAL_SEND_RATE_LOG_INTERVAL"
envSendRateTopPercent = "NB_SIGNAL_SEND_RATE_LOG_TOP_PERCENT"
)
// sendRateTracker tracks per-key message counts and logs the busiest peers periodically.
type sendRateTracker struct {
mu sync.Mutex
counts map[string]int64
// atomic so they can be updated by the setting overrider without locking
intervalNs atomic.Int64
// topPercent stored as float64 bits for atomic access
topPercentBits atomic.Uint64
}
func newSendRateTracker() *sendRateTracker {
interval := defaultSendRateLogInterval
if v := os.Getenv(envSendRateLogInterval); v != "" {
if parsed, err := time.ParseDuration(v); err == nil && parsed > 0 {
interval = parsed
}
}
topPercent := defaultSendRateTopPercent
if v := os.Getenv(envSendRateTopPercent); v != "" {
if parsed, err := strconv.ParseFloat(v, 64); err == nil && parsed > 0 && parsed <= 1 {
topPercent = parsed
}
}
log.Debugf("send rate tracker: interval=%s, top_percent=%.2f", interval, topPercent)
t := &sendRateTracker{
counts: make(map[string]int64),
}
t.intervalNs.Store(int64(interval))
t.topPercentBits.Store(math.Float64bits(topPercent))
return t
}
func (t *sendRateTracker) getInterval() time.Duration {
return time.Duration(t.intervalNs.Load())
}
func (t *sendRateTracker) setInterval(d time.Duration) {
t.intervalNs.Store(int64(d))
}
func (t *sendRateTracker) getTopPercent() float64 {
return math.Float64frombits(t.topPercentBits.Load())
}
func (t *sendRateTracker) setTopPercent(p float64) {
t.topPercentBits.Store(math.Float64bits(p))
}
func (t *sendRateTracker) increment(key string) {
t.mu.Lock()
t.counts[key]++
t.mu.Unlock()
}
// resetAndSnapshot atomically returns current counts and resets the tracker.
func (t *sendRateTracker) resetAndSnapshot() map[string]int64 {
t.mu.Lock()
snap := t.counts
t.counts = make(map[string]int64, len(snap))
t.mu.Unlock()
return snap
}
// logSendRates periodically logs peers in the top percentile of the busiest peer.
func (t *sendRateTracker) logSendRates(ctx context.Context) {
currentInterval := t.getInterval()
ticker := time.NewTicker(currentInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if newInterval := t.getInterval(); newInterval != currentInterval {
currentInterval = newInterval
ticker.Reset(currentInterval)
}
snap := t.resetAndSnapshot()
if len(snap) == 0 {
continue
}
var maxCount int64
for _, count := range snap {
if count > maxCount {
maxCount = count
}
}
topPercent := t.getTopPercent()
threshold := int64(float64(maxCount) * topPercent)
intervalMin := currentInterval.Minutes()
log.Debugf("send rate stats: %d unique peers in last %.0fs, max rate %.1f msg/min",
len(snap), currentInterval.Seconds(), float64(maxCount)/intervalMin)
logged := 0
for key, count := range snap {
if count >= threshold {
log.Debugf("peer [%s] %.1f msg/min", key, float64(count)/intervalMin)
logged++
if logged >= 100 {
break
}
}
}
}
}
}

View File

@@ -0,0 +1,56 @@
package server
import (
"sync"
"testing"
)
func TestSendRateTracker_Increment(t *testing.T) {
tracker := newSendRateTracker()
tracker.increment("peer-a")
tracker.increment("peer-a")
tracker.increment("peer-b")
snap := tracker.resetAndSnapshot()
if snap["peer-a"] != 2 {
t.Errorf("expected peer-a count 2, got %d", snap["peer-a"])
}
if snap["peer-b"] != 1 {
t.Errorf("expected peer-b count 1, got %d", snap["peer-b"])
}
}
func TestSendRateTracker_ResetAndSnapshot_Resets(t *testing.T) {
tracker := newSendRateTracker()
tracker.increment("peer-a")
snap1 := tracker.resetAndSnapshot()
if snap1["peer-a"] != 1 {
t.Fatalf("expected 1, got %d", snap1["peer-a"])
}
snap2 := tracker.resetAndSnapshot()
if len(snap2) != 0 {
t.Errorf("expected empty snapshot after reset, got %v", snap2)
}
}
func TestSendRateTracker_ConcurrentIncrement(t *testing.T) {
tracker := newSendRateTracker()
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
tracker.increment("peer-x")
}()
}
wg.Wait()
snap := tracker.resetAndSnapshot()
if snap["peer-x"] != 100 {
t.Errorf("expected 100, got %d", snap["peer-x"])
}
}

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"os" "os"
"strconv"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -17,6 +18,7 @@ import (
"github.com/netbirdio/signal-dispatcher/dispatcher" "github.com/netbirdio/signal-dispatcher/dispatcher"
"github.com/netbirdio/netbird/shared/settingoverrider"
"github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/signal/peer" "github.com/netbirdio/netbird/signal/peer"
@@ -59,10 +61,12 @@ type Server struct {
successHeader metadata.MD successHeader metadata.MD
sendTimeout time.Duration sendTimeout time.Duration
sendTracker *sendRateTracker
} }
// NewServer creates a new Signal server // NewServer creates a new Signal server
func NewServer(ctx context.Context, meter metric.Meter, metricsPrefix ...string) (*Server, error) { func NewServer(ctx context.Context, meter metric.Meter, overrider *settingoverrider.Overrider, metricsPrefix ...string) (*Server, error) {
appMetrics, err := metrics.NewAppMetrics(meter, metricsPrefix...) appMetrics, err := metrics.NewAppMetrics(meter, metricsPrefix...)
if err != nil { if err != nil {
return nil, fmt.Errorf("creating app metrics: %v", err) return nil, fmt.Errorf("creating app metrics: %v", err)
@@ -80,14 +84,36 @@ func NewServer(ctx context.Context, meter metric.Meter, metricsPrefix ...string)
sTimeout = parsed sTimeout = parsed
} }
tracker := newSendRateTracker()
s := &Server{ s := &Server{
dispatcher: d, dispatcher: d,
registry: peer.NewRegistry(appMetrics), registry: peer.NewRegistry(appMetrics),
metrics: appMetrics, metrics: appMetrics,
successHeader: metadata.Pairs(proto.HeaderRegistered, "1"), successHeader: metadata.Pairs(proto.HeaderRegistered, "1"),
sendTimeout: sTimeout, sendTimeout: sTimeout,
sendTracker: tracker,
} }
overrider.Poll(settingoverrider.DefaultInterval, "signalSendRateLogInterval", func(value string) error {
parsed, err := time.ParseDuration(value)
if err != nil || parsed <= 0 {
return fmt.Errorf("invalid send rate log interval %q: %w", value, err)
}
tracker.setInterval(parsed)
return nil
})
overrider.Poll(settingoverrider.DefaultInterval, "signalSendRateTopPercent", func(value string) error {
parsed, err := strconv.ParseFloat(value, 64)
if err != nil || parsed <= 0 || parsed > 1 {
return fmt.Errorf("invalid send rate top percent %q: %w", value, err)
}
tracker.setTopPercent(parsed)
return nil
})
go tracker.logSendRates(ctx)
return s, nil return s, nil
} }
@@ -95,6 +121,8 @@ func NewServer(ctx context.Context, meter metric.Meter, metricsPrefix ...string)
func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
log.Tracef("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey) log.Tracef("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)
s.sendTracker.increment(msg.Key)
if _, found := s.registry.Get(msg.RemoteKey); found { if _, found := s.registry.Get(msg.RemoteKey); found {
s.forwardMessageToPeer(ctx, msg) s.forwardMessageToPeer(ctx, msg)
return &proto.EncryptedMessage{}, nil return &proto.EncryptedMessage{}, nil