From 92e53d63193932c991e423659719296aa1f17398 Mon Sep 17 00:00:00 2001 From: pascal Date: Tue, 21 Apr 2026 18:31:53 +0200 Subject: [PATCH] generic settings overrider --- management/internals/server/boot.go | 19 ++++ management/internals/server/server.go | 11 ++ shared/settingoverrider/overrider.go | 120 ++++++++++++++++++++++ shared/settingoverrider/overrider_test.go | 111 ++++++++++++++++++++ signal/cmd/run.go | 21 +++- signal/server/send_tracker.go | 53 +++++++--- signal/server/signal.go | 27 ++++- 7 files changed, 346 insertions(+), 16 deletions(-) create mode 100644 shared/settingoverrider/overrider.go create mode 100644 shared/settingoverrider/overrider_test.go diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 24dfb641b..9d00433ea 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -19,6 +19,7 @@ import ( "google.golang.org/grpc/keepalive" cachestore "github.com/eko/gocache/lib/v4/store" + "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/encryption" @@ -33,6 +34,7 @@ import ( "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" mgmtProto "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/settingoverrider" "github.com/netbirdio/netbird/util/crypt" ) @@ -72,6 +74,23 @@ func (s *BaseServer) CacheStore() cachestore.StoreInterface { }) } +// SettingOverrider returns a shared setting overrider backed by Redis. +// Returns a no-op overrider if no Redis address is configured. +func (s *BaseServer) SettingOverrider() *settingoverrider.Overrider { + return Create(s, func() *settingoverrider.Overrider { + redisAddr := nbcache.GetAddrFromEnv() + if redisAddr == "" { + return settingoverrider.NewNoop() + } + + o, err := settingoverrider.New(context.Background(), redisAddr) + if err != nil { + log.Fatalf("failed to create setting overrider: %v", err) + } + return o + }) +} + func (s *BaseServer) Store() store.Store { return Create(s, func() store.Store { store, err := store.NewStore(context.Background(), s.Config.StoreConfig.Engine, s.Config.Datadir, s.Metrics(), false) diff --git a/management/internals/server/server.go b/management/internals/server/server.go index 9b8716da1..74712f3b7 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -23,6 +23,7 @@ import ( "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/metrics" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/settingoverrider" "github.com/netbirdio/netbird/util/wsproxy" wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server" "github.com/netbirdio/netbird/version" @@ -123,6 +124,15 @@ func (s *BaseServer) Start(ctx context.Context) error { s.PeersManager() s.GeoLocationManager() + s.SettingOverrider().Poll(settingoverrider.DefaultInterval, "managementLogLevel", func(value string) error { + level, err := log.ParseLevel(value) + if err != nil { + return fmt.Errorf("parsing log level %q: %w", value, err) + } + log.SetLevel(level) + return nil + }) + err := s.Metrics().Expose(srvCtx, s.mgmtMetricsPort, "/metrics") if err != nil { return fmt.Errorf("failed to expose metrics: %v", err) @@ -235,6 +245,7 @@ func (s *BaseServer) Stop() error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() + _ = s.SettingOverrider().Close() s.IntegratedValidator().Stop(ctx) if s.GeoLocationManager() != nil { _ = s.GeoLocationManager().Stop() diff --git a/shared/settingoverrider/overrider.go b/shared/settingoverrider/overrider.go new file mode 100644 index 000000000..411a25036 --- /dev/null +++ b/shared/settingoverrider/overrider.go @@ -0,0 +1,120 @@ +package settingoverrider + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/redis/go-redis/v9" + log "github.com/sirupsen/logrus" +) + +const ( + DefaultInterval = 5 * time.Minute +) + +// ApplyFunc is called with the raw Redis string value whenever it changes. +// The function is responsible for parsing and applying the value. +// Return an error to log a warning without stopping the polling loop. +type ApplyFunc func(value string) error + +// Overrider holds a shared Redis connection and allows registering +// individual settings that are polled independently. +type Overrider struct { + client *redis.Client + cancel context.CancelFunc + ctx context.Context + noop bool +} + +// New creates an Overrider by connecting to Redis at the given address. +// The address should follow the Redis URL format (e.g. "redis://localhost:6379"). +func New(ctx context.Context, redisAddr string) (*Overrider, error) { + if redisAddr == "" { + return nil, fmt.Errorf("redis address is empty") + } + + options, err := redis.ParseURL(redisAddr) + if err != nil { + return nil, fmt.Errorf("parsing redis address: %w", err) + } + + client := redis.NewClient(options) + + pingCtx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + + if _, err := client.Ping(pingCtx).Result(); err != nil { + _ = client.Close() + return nil, fmt.Errorf("connecting to redis: %w", err) + } + + oCtx, oCancel := context.WithCancel(ctx) + + return &Overrider{client: client, cancel: oCancel, ctx: oCtx}, nil +} + +// NewNoop returns an Overrider that does nothing. +// Poll calls are silently ignored and Close is a no-op. +func NewNoop() *Overrider { + return &Overrider{noop: true} +} + +// Close stops all polling goroutines and closes the underlying Redis client. +func (o *Overrider) Close() error { + if o.noop { + return nil + } + o.cancel() + return o.client.Close() +} + +// Poll starts a background goroutine that polls a single Redis key at the given interval +// and calls apply whenever the value changes. The goroutine stops when the Overrider is closed. +func (o *Overrider) Poll(interval time.Duration, redisKey string, apply ApplyFunc) { + if o.noop { + return + } + + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + var lastSeen *string + + for { + select { + case <-o.ctx.Done(): + log.WithContext(o.ctx).Infof("Stopping settings overrider for key %q", redisKey) + return + case <-ticker.C: + getCtx, cancel := context.WithTimeout(o.ctx, 5*time.Second) + val, err := o.client.Get(getCtx, redisKey).Result() + cancel() + + if errors.Is(err, redis.Nil) || val == "" { + continue + } + if err != nil { + if o.ctx.Err() != nil { + return + } + log.WithContext(o.ctx).Errorf("Unable to get setting %q from Redis: %v", redisKey, err) + continue + } + + if lastSeen != nil && *lastSeen == val { + continue + } + + if err := apply(val); err != nil { + log.WithContext(o.ctx).Warnf("Failed to apply setting %q with value %q: %v", redisKey, val, err) + continue + } + + lastSeen = &val + } + } + }() +} diff --git a/shared/settingoverrider/overrider_test.go b/shared/settingoverrider/overrider_test.go new file mode 100644 index 000000000..044f82c47 --- /dev/null +++ b/shared/settingoverrider/overrider_test.go @@ -0,0 +1,111 @@ +package settingoverrider + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + testcontainersredis "github.com/testcontainers/testcontainers-go/modules/redis" + "github.com/testcontainers/testcontainers-go/wait" +) + +func TestPoll_AppliesSettingFromRedis(t *testing.T) { + o, client := setupOverrider(t) + + key := "test-setting-key" + require.NoError(t, client.Set(context.Background(), key, "hello", 0).Err()) + + var applied atomic.Value + + o.Poll(100*time.Millisecond, key, func(value string) error { + applied.Store(value) + return nil + }) + + assert.Eventually(t, func() bool { + v := applied.Load() + return v != nil && v.(string) == "hello" + }, 5*time.Second, 50*time.Millisecond) +} + +func TestPoll_IndependentSettings(t *testing.T) { + o, client := setupOverrider(t) + + require.NoError(t, client.Set(context.Background(), "key-a", "val-a", 0).Err()) + require.NoError(t, client.Set(context.Background(), "key-b", "val-b", 0).Err()) + + var gotA, gotB atomic.Value + + o.Poll(100*time.Millisecond, "key-a", func(v string) error { gotA.Store(v); return nil }) + o.Poll(100*time.Millisecond, "key-b", func(v string) error { gotB.Store(v); return nil }) + + assert.Eventually(t, func() bool { + a, b := gotA.Load(), gotB.Load() + return a != nil && a.(string) == "val-a" && b != nil && b.(string) == "val-b" + }, 5*time.Second, 50*time.Millisecond) +} + +func TestPoll_SkipsDuplicateValues(t *testing.T) { + o, client := setupOverrider(t) + + key := "test-dedup" + require.NoError(t, client.Set(context.Background(), key, "same", 0).Err()) + + var count atomic.Int32 + + o.Poll(100*time.Millisecond, key, func(string) error { + count.Add(1) + return nil + }) + + // wait for a few ticks + time.Sleep(600 * time.Millisecond) + assert.Equal(t, int32(1), count.Load(), "Apply should be called only once for unchanged value") +} + +func setupOverrider(t *testing.T) (*Overrider, *redis.Client) { + t.Helper() + + ctx := context.Background() + redisContainer, err := testcontainersredis.RunContainer(ctx, + testcontainers.WithImage("redis:7"), + testcontainers.WithWaitStrategy( + wait.ForListeningPort("6379/tcp"), + ), + ) + require.NoError(t, err, "Failed to create redis test container") + + t.Cleanup(func() { + if err := redisContainer.Terminate(ctx); err != nil { + t.Logf("failed to terminate redis container: %s", err) + } + }) + + redisURL, err := redisContainer.ConnectionString(ctx) + require.NoError(t, err) + + o, err := New(ctx, redisURL) + require.NoError(t, err) + t.Cleanup(func() { + if err := o.Close(); err != nil { + t.Logf("failed to close overrider: %s", err) + } + }) + + // separate client for test setup (setting keys) + options, err := redis.ParseURL(redisURL) + require.NoError(t, err) + client := redis.NewClient(options) + t.Cleanup(func() { + if err := client.Close(); err != nil { + t.Logf("failed to close redis client: %s", err) + } + }) + + return o, client +} diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 681222403..c2efc5cc1 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -18,7 +18,9 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" + "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/shared/metrics" + "github.com/netbirdio/netbird/shared/settingoverrider" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/shared/signal/proto" @@ -114,7 +116,24 @@ var ( } }() - srv, err := server.NewServer(cmd.Context(), metricsServer.Meter) + overrider := settingoverrider.NewNoop() + if redisAddr := cache.GetAddrFromEnv(); redisAddr != "" { + overrider, err = settingoverrider.New(cmd.Context(), redisAddr) + if err != nil { + return fmt.Errorf("failed to create setting overrider: %w", err) + } + defer func() { _ = overrider.Close() }() + } + overrider.Poll(settingoverrider.DefaultInterval, "signalLogLevel", func(value string) error { + level, err := log.ParseLevel(value) + if err != nil { + return fmt.Errorf("parsing log level %q: %w", value, err) + } + log.SetLevel(level) + return nil + }) + + srv, err := server.NewServer(cmd.Context(), metricsServer.Meter, overrider) if err != nil { return fmt.Errorf("creating signal server: %v", err) } diff --git a/signal/server/send_tracker.go b/signal/server/send_tracker.go index 2dc56ab7b..c31c52a8c 100644 --- a/signal/server/send_tracker.go +++ b/signal/server/send_tracker.go @@ -2,9 +2,11 @@ package server import ( "context" + "math" "os" "strconv" "sync" + "sync/atomic" "time" log "github.com/sirupsen/logrus" @@ -19,10 +21,13 @@ const ( // sendRateTracker tracks per-key message counts and logs the busiest peers periodically. type sendRateTracker struct { - mu sync.Mutex - counts map[string]int64 - interval time.Duration - topPercent float64 + mu sync.Mutex + counts map[string]int64 + + // atomic so they can be updated by the setting overrider without locking + intervalNs atomic.Int64 + // topPercent stored as float64 bits for atomic access + topPercentBits atomic.Uint64 } func newSendRateTracker() *sendRateTracker { @@ -42,11 +47,28 @@ func newSendRateTracker() *sendRateTracker { log.Debugf("send rate tracker: interval=%s, top_percent=%.2f", interval, topPercent) - return &sendRateTracker{ - counts: make(map[string]int64), - interval: interval, - topPercent: topPercent, + t := &sendRateTracker{ + counts: make(map[string]int64), } + t.intervalNs.Store(int64(interval)) + t.topPercentBits.Store(math.Float64bits(topPercent)) + return t +} + +func (t *sendRateTracker) getInterval() time.Duration { + return time.Duration(t.intervalNs.Load()) +} + +func (t *sendRateTracker) setInterval(d time.Duration) { + t.intervalNs.Store(int64(d)) +} + +func (t *sendRateTracker) getTopPercent() float64 { + return math.Float64frombits(t.topPercentBits.Load()) +} + +func (t *sendRateTracker) setTopPercent(p float64) { + t.topPercentBits.Store(math.Float64bits(p)) } func (t *sendRateTracker) increment(key string) { @@ -66,7 +88,8 @@ func (t *sendRateTracker) resetAndSnapshot() map[string]int64 { // logSendRates periodically logs peers in the top percentile of the busiest peer. func (t *sendRateTracker) logSendRates(ctx context.Context) { - ticker := time.NewTicker(t.interval) + currentInterval := t.getInterval() + ticker := time.NewTicker(currentInterval) defer ticker.Stop() for { @@ -74,6 +97,11 @@ func (t *sendRateTracker) logSendRates(ctx context.Context) { case <-ctx.Done(): return case <-ticker.C: + if newInterval := t.getInterval(); newInterval != currentInterval { + currentInterval = newInterval + ticker.Reset(currentInterval) + } + snap := t.resetAndSnapshot() if len(snap) == 0 { continue @@ -86,11 +114,12 @@ func (t *sendRateTracker) logSendRates(ctx context.Context) { } } - threshold := int64(float64(maxCount) * t.topPercent) - intervalMin := t.interval.Minutes() + topPercent := t.getTopPercent() + threshold := int64(float64(maxCount) * topPercent) + intervalMin := currentInterval.Minutes() log.Debugf("send rate stats: %d unique peers in last %.0fs, max rate %.1f msg/min", - len(snap), t.interval.Seconds(), float64(maxCount)/intervalMin) + len(snap), currentInterval.Seconds(), float64(maxCount)/intervalMin) logged := 0 for key, count := range snap { if count >= threshold { diff --git a/signal/server/signal.go b/signal/server/signal.go index 6fc980c1f..d9663f1c4 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "os" + "strconv" "time" log "github.com/sirupsen/logrus" @@ -17,6 +18,7 @@ import ( "github.com/netbirdio/signal-dispatcher/dispatcher" + "github.com/netbirdio/netbird/shared/settingoverrider" "github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/signal/peer" @@ -64,7 +66,7 @@ type Server struct { } // NewServer creates a new Signal server -func NewServer(ctx context.Context, meter metric.Meter, metricsPrefix ...string) (*Server, error) { +func NewServer(ctx context.Context, meter metric.Meter, overrider *settingoverrider.Overrider, metricsPrefix ...string) (*Server, error) { appMetrics, err := metrics.NewAppMetrics(meter, metricsPrefix...) if err != nil { return nil, fmt.Errorf("creating app metrics: %v", err) @@ -82,16 +84,35 @@ func NewServer(ctx context.Context, meter metric.Meter, metricsPrefix ...string) sTimeout = parsed } + tracker := newSendRateTracker() + s := &Server{ dispatcher: d, registry: peer.NewRegistry(appMetrics), metrics: appMetrics, successHeader: metadata.Pairs(proto.HeaderRegistered, "1"), sendTimeout: sTimeout, - sendTracker: newSendRateTracker(), + sendTracker: tracker, } - go s.sendTracker.logSendRates(ctx) + overrider.Poll(settingoverrider.DefaultInterval, "signalSendRateLogInterval", func(value string) error { + parsed, err := time.ParseDuration(value) + if err != nil || parsed <= 0 { + return fmt.Errorf("invalid send rate log interval %q: %w", value, err) + } + tracker.setInterval(parsed) + return nil + }) + overrider.Poll(settingoverrider.DefaultInterval, "signalSendRateTopPercent", func(value string) error { + parsed, err := strconv.ParseFloat(value, 64) + if err != nil || parsed <= 0 || parsed > 1 { + return fmt.Errorf("invalid send rate top percent %q: %w", value, err) + } + tracker.setTopPercent(parsed) + return nil + }) + + go tracker.logSendRates(ctx) return s, nil }