mirror of
https://github.com/netbirdio/netbird.git
synced 2026-07-02 04:39:55 +00:00
Compare commits
3 Commits
main
...
fix/slow-w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e55b252765 | ||
|
|
b7a2fe6a60 | ||
|
|
f077aa3599 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
.claude
|
||||
.idea
|
||||
.run
|
||||
*.iml
|
||||
|
||||
@@ -17,12 +17,15 @@ import (
|
||||
|
||||
type KernelConfigurer struct {
|
||||
deviceName string
|
||||
statsCache *statsCache
|
||||
}
|
||||
|
||||
func NewKernelConfigurer(deviceName string) *KernelConfigurer {
|
||||
return &KernelConfigurer{
|
||||
c := &KernelConfigurer{
|
||||
deviceName: deviceName,
|
||||
}
|
||||
c.statsCache = newStatsCache(statsCacheTTL, c.fetchStats)
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error {
|
||||
@@ -246,12 +249,6 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
|
||||
}
|
||||
}()
|
||||
|
||||
// validate if device with name exists
|
||||
_, err = wg.Device(c.deviceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return wg.ConfigureDevice(c.deviceName, config)
|
||||
}
|
||||
|
||||
@@ -300,6 +297,14 @@ func (c *KernelConfigurer) FullStats() (*Stats, error) {
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
|
||||
return c.statsCache.get()
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) LastActivities() map[string]monotime.Time {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) fetchStats() (map[string]WGStats, error) {
|
||||
stats := make(map[string]WGStats)
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
@@ -326,7 +331,3 @@ func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) LastActivities() map[string]monotime.Time {
|
||||
return nil
|
||||
}
|
||||
|
||||
52
client/iface/configurer/stats_cache.go
Normal file
52
client/iface/configurer/stats_cache.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package configurer
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
const statsCacheTTL = 1 * time.Second
|
||||
|
||||
type statsCache struct {
|
||||
ttl time.Duration
|
||||
fetch func() (map[string]WGStats, error)
|
||||
|
||||
mu sync.RWMutex
|
||||
value map[string]WGStats
|
||||
expireAt time.Time
|
||||
|
||||
sf singleflight.Group
|
||||
}
|
||||
|
||||
func newStatsCache(ttl time.Duration, fetch func() (map[string]WGStats, error)) *statsCache {
|
||||
return &statsCache{ttl: ttl, fetch: fetch}
|
||||
}
|
||||
|
||||
func (c *statsCache) get() (map[string]WGStats, error) {
|
||||
c.mu.RLock()
|
||||
if c.value != nil && time.Now().Before(c.expireAt) {
|
||||
value := c.value
|
||||
c.mu.RUnlock()
|
||||
return value, nil
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
value, err, _ := c.sf.Do("stats", func() (interface{}, error) {
|
||||
res, err := c.fetch()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.value = res
|
||||
c.expireAt = time.Now().Add(c.ttl)
|
||||
c.mu.Unlock()
|
||||
return res, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return value.(map[string]WGStats), nil
|
||||
}
|
||||
70
client/iface/configurer/stats_cache_test.go
Normal file
70
client/iface/configurer/stats_cache_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package configurer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStatsCache_CachesWithinTTL(t *testing.T) {
|
||||
var calls atomic.Int64
|
||||
c := newStatsCache(50*time.Millisecond, func() (map[string]WGStats, error) {
|
||||
calls.Add(1)
|
||||
return map[string]WGStats{"p": {}}, nil
|
||||
})
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err := c.get()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, int64(1), calls.Load(), "within TTL only one underlying fetch")
|
||||
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
_, err := c.get()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), calls.Load(), "after TTL expiry a fresh fetch happens")
|
||||
}
|
||||
|
||||
func TestStatsCache_SingleFlight(t *testing.T) {
|
||||
var calls atomic.Int64
|
||||
release := make(chan struct{})
|
||||
c := newStatsCache(time.Minute, func() (map[string]WGStats, error) {
|
||||
calls.Add(1)
|
||||
<-release
|
||||
return map[string]WGStats{}, nil
|
||||
})
|
||||
|
||||
const n = 50
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = c.get()
|
||||
}()
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
close(release)
|
||||
wg.Wait()
|
||||
|
||||
require.Equal(t, int64(1), calls.Load(), "concurrent misses collapse into one fetch")
|
||||
}
|
||||
|
||||
func TestStatsCache_ErrorNotCached(t *testing.T) {
|
||||
var calls atomic.Int64
|
||||
wantErr := errors.New("dump failed")
|
||||
c := newStatsCache(time.Minute, func() (map[string]WGStats, error) {
|
||||
calls.Add(1)
|
||||
return nil, wantErr
|
||||
})
|
||||
|
||||
_, err := c.get()
|
||||
require.ErrorIs(t, err, wantErr)
|
||||
_, err = c.get()
|
||||
require.ErrorIs(t, err, wantErr)
|
||||
require.Equal(t, int64(2), calls.Load(), "errors are not cached; each call retries")
|
||||
}
|
||||
@@ -40,6 +40,7 @@ type WGUSPConfigurer struct {
|
||||
device *device.Device
|
||||
deviceName string
|
||||
activityRecorder *bind.ActivityRecorder
|
||||
statsCache *statsCache
|
||||
|
||||
uapiListener net.Listener
|
||||
}
|
||||
@@ -50,16 +51,19 @@ func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder
|
||||
deviceName: deviceName,
|
||||
activityRecorder: activityRecorder,
|
||||
}
|
||||
wgCfg.statsCache = newStatsCache(statsCacheTTL, wgCfg.fetchStats)
|
||||
wgCfg.startUAPI()
|
||||
return wgCfg
|
||||
}
|
||||
|
||||
func NewUSPConfigurerNoUAPI(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
|
||||
return &WGUSPConfigurer{
|
||||
wgCfg := &WGUSPConfigurer{
|
||||
device: device,
|
||||
deviceName: deviceName,
|
||||
activityRecorder: activityRecorder,
|
||||
}
|
||||
wgCfg.statsCache = newStatsCache(statsCacheTTL, wgCfg.fetchStats)
|
||||
return wgCfg
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
|
||||
@@ -348,6 +352,10 @@ func (t *WGUSPConfigurer) Close() {
|
||||
}
|
||||
|
||||
func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) {
|
||||
return t.statsCache.get()
|
||||
}
|
||||
|
||||
func (t *WGUSPConfigurer) fetchStats() (map[string]WGStats, error) {
|
||||
ipc, err := t.device.IpcGet()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ipc get: %w", err)
|
||||
|
||||
@@ -803,17 +803,15 @@ func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
|
||||
}
|
||||
|
||||
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
|
||||
if !conn.wgWatcher.PrepareInitialHandshake() {
|
||||
return
|
||||
if !conn.wgWatcher.IsEnabled() {
|
||||
wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx)
|
||||
conn.wgWatcherCancel = wgWatcherCancel
|
||||
conn.wgWatcherWg.Add(1)
|
||||
go func() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, enabledTime, conn.onWGDisconnected, conn.onWGHandshakeSuccess)
|
||||
}()
|
||||
}
|
||||
|
||||
wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx)
|
||||
conn.wgWatcherCancel = wgWatcherCancel
|
||||
conn.wgWatcherWg.Add(1)
|
||||
go func() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, enabledTime, conn.onWGDisconnected, conn.onWGHandshakeSuccess)
|
||||
}()
|
||||
}
|
||||
|
||||
func (conn *Conn) disableWgWatcherIfNeeded() {
|
||||
|
||||
@@ -31,9 +31,7 @@ type WGWatcher struct {
|
||||
stateDump *stateDump
|
||||
|
||||
enabled bool
|
||||
muEnabled sync.Mutex
|
||||
// initialHandshake is not thread-safe; never call PrepareInitialHandshake and EnableWgWatcher concurrently.
|
||||
initialHandshake time.Time
|
||||
muEnabled sync.RWMutex
|
||||
|
||||
resetCh chan struct{}
|
||||
}
|
||||
@@ -48,38 +46,38 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
|
||||
}
|
||||
}
|
||||
|
||||
// PrepareInitialHandshake reserves the watcher and reads the peer's current WireGuard
|
||||
// handshake time. It must be called before the peer is (re)configured on the WireGuard
|
||||
// interface, so the captured baseline reflects the state prior to this connection attempt
|
||||
// instead of racing with that configuration. Returns ok=false if the watcher is already
|
||||
// running, in which case EnableWgWatcher must not be called.
|
||||
func (w *WGWatcher) PrepareInitialHandshake() (ok bool) {
|
||||
// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing.
|
||||
// The watcher runs until ctx is cancelled. Caller is responsible for context lifecycle management.
|
||||
func (w *WGWatcher) EnableWgWatcher(ctx context.Context, enabledTime time.Time, onDisconnectedFn func(), onHandshakeSuccessFn func(when time.Time)) {
|
||||
w.muEnabled.Lock()
|
||||
if w.enabled {
|
||||
w.muEnabled.Unlock()
|
||||
return false
|
||||
return
|
||||
}
|
||||
|
||||
w.log.Debugf("enable WireGuard watcher")
|
||||
w.enabled = true
|
||||
w.muEnabled.Unlock()
|
||||
|
||||
handshake, _ := w.wgState()
|
||||
w.initialHandshake = handshake
|
||||
return true
|
||||
}
|
||||
initialHandshake, err := w.wgState()
|
||||
if err != nil {
|
||||
w.log.Warnf("failed to read initial wg stats: %v", err)
|
||||
}
|
||||
|
||||
// EnableWgWatcher runs the WireGuard watcher loop using the handshake baseline captured by
|
||||
// PrepareInitialHandshake. The watcher runs until ctx is cancelled. Caller is responsible
|
||||
// for context lifecycle management.
|
||||
func (w *WGWatcher) EnableWgWatcher(ctx context.Context, enabledTime time.Time, onDisconnectedFn func(), onHandshakeSuccessFn func(when time.Time)) {
|
||||
w.periodicHandshakeCheck(ctx, onDisconnectedFn, onHandshakeSuccessFn, enabledTime, w.initialHandshake)
|
||||
w.periodicHandshakeCheck(ctx, onDisconnectedFn, onHandshakeSuccessFn, enabledTime, initialHandshake)
|
||||
|
||||
w.muEnabled.Lock()
|
||||
w.enabled = false
|
||||
w.muEnabled.Unlock()
|
||||
}
|
||||
|
||||
// IsEnabled returns true if the WireGuard watcher is currently enabled
|
||||
func (w *WGWatcher) IsEnabled() bool {
|
||||
w.muEnabled.RLock()
|
||||
defer w.muEnabled.RUnlock()
|
||||
return w.enabled
|
||||
}
|
||||
|
||||
// Reset signals the watcher that the WireGuard peer has been reset and a new
|
||||
// handshake is expected. This restarts the handshake timeout from scratch.
|
||||
func (w *WGWatcher) Reset() {
|
||||
@@ -103,16 +101,13 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn
|
||||
case <-timer.C:
|
||||
handshake, ok := w.handshakeCheck(lastHandshake)
|
||||
if !ok {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
onDisconnectedFn()
|
||||
return
|
||||
}
|
||||
if lastHandshake.IsZero() {
|
||||
elapsed := calcElapsed(enabledTime, *handshake)
|
||||
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
|
||||
if onHandshakeSuccessFn != nil && ctx.Err() == nil {
|
||||
if onHandshakeSuccessFn != nil {
|
||||
onHandshakeSuccessFn(*handshake)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
)
|
||||
@@ -35,9 +34,6 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ok := watcher.PrepareInitialHandshake()
|
||||
require.True(t, ok, "watcher should not be enabled yet")
|
||||
|
||||
onDisconnected := make(chan struct{}, 1)
|
||||
go watcher.EnableWgWatcher(ctx, time.Now(), func() {
|
||||
mlog.Infof("onDisconnectedFn")
|
||||
@@ -66,9 +62,6 @@ func TestWGWatcher_ReEnable(t *testing.T) {
|
||||
watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump("peer", mlog, &Status{}))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ok := watcher.PrepareInitialHandshake()
|
||||
require.True(t, ok, "watcher should not be enabled yet")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@@ -83,9 +76,6 @@ func TestWGWatcher_ReEnable(t *testing.T) {
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ok = watcher.PrepareInitialHandshake()
|
||||
require.True(t, ok, "watcher should be re-enabled after the previous run stopped")
|
||||
|
||||
onDisconnected := make(chan struct{}, 1)
|
||||
go watcher.EnableWgWatcher(ctx, time.Now(), func() {
|
||||
onDisconnected <- struct{}{}
|
||||
|
||||
@@ -2057,7 +2057,6 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain, email, nam
|
||||
Extra: &types.ExtraSettings{
|
||||
UserApprovalRequired: true,
|
||||
},
|
||||
LazyConnectionEnabled: true,
|
||||
},
|
||||
Onboarding: types.AccountOnboarding{
|
||||
OnboardingFlowPending: true,
|
||||
|
||||
Reference in New Issue
Block a user