Compare commits

...

4 Commits

Author SHA1 Message Date
riccardom
5740dd22e6 Remove verbose comments 2026-07-03 17:26:20 +02:00
riccardom
ec98c930cb [Recheck watcher ctx cancellation under conn.mu in onWGDisconnected
onWGDisconnected only checked conn.ctx (the engine-scoped context), never
the watcher's own context. disableWgWatcherIfNeeded cancels the wgWatcherCtx,
not conn.ctx, so a disabled watcher's timeout callback did not see the
cancellation.

handshakeCheck runs lock-free, so between the ctx check in periodicHandshakeCheck
and acquiring conn.mu a fast disconnect/reconnect can slip in: the stale watcher
then acquires the lock and tears down the *new*, healthy connection based on the
old timeout, forcing the guard into an unnecessary reconnect (flap).

Recheck watcherCtx.Err() under conn.mu so a superseded watcher exits without
touching the connection that replaced it.
2026-07-03 12:15:24 +02:00
riccardom
60104e000b Discriminate not updated from timeout handshakes 2026-07-03 12:02:50 +02:00
riccardom
d5a212349f Stick new watcher creation to actual existence of af the conn
and its removal to the removal of such same conn.
Avoid debouncing and cross lock dead locking
2026-07-03 11:37:41 +02:00
3 changed files with 38 additions and 41 deletions

View File

@@ -195,7 +195,6 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
statusICE: worker.NewAtomicStatus(),
dumpState: dumpState,
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
metricsRecorder: services.MetricsRecorder,
}
@@ -663,11 +662,12 @@ func (conn *Conn) onGuardEvent() {
}
}
func (conn *Conn) onWGDisconnected() {
func (conn *Conn) onWGDisconnected(watcherCtx context.Context) {
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.ctx.Err() != nil {
// watcherCtx guards against a stale watcher tearing down a connection that already superseded it.
if conn.ctx.Err() != nil || watcherCtx.Err() != nil {
return
}
@@ -802,25 +802,39 @@ func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
})
}
// enableWgWatcherIfNeeded starts a fresh watcher instance per connection attempt, so its
// lifecycle stays bound to conn.mu and enable/disable can't race an old goroutine's shutdown.
// Caller must hold conn.mu.
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
if !conn.wgWatcher.PrepareInitialHandshake() {
if conn.wgWatcher != nil {
return
}
watcher := NewWGWatcher(conn.Log, conn.config.WgConfig.WgInterface, conn.config.Key, conn.dumpState)
watcher.PrepareInitialHandshake()
wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx)
conn.wgWatcher = watcher
conn.wgWatcherCancel = wgWatcherCancel
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, enabledTime, conn.onWGDisconnected, conn.onWGHandshakeSuccess)
onDisconnected := func() { conn.onWGDisconnected(wgWatcherCtx) }
watcher.EnableWgWatcher(wgWatcherCtx, enabledTime, onDisconnected, conn.onWGHandshakeSuccess)
}()
}
// disableWgWatcherIfNeeded cancels and drops the watcher once no transport is active. It never
// waits for the goroutine: the timeout path reentrantly calls back here under conn.mu, so
// blocking would deadlock. Caller must hold conn.mu.
func (conn *Conn) disableWgWatcherIfNeeded() {
if conn.currentConnPriority == conntype.None && conn.wgWatcherCancel != nil {
conn.wgWatcherCancel()
conn.wgWatcherCancel = nil
if conn.currentConnPriority != conntype.None || conn.wgWatcher == nil {
return
}
conn.wgWatcherCancel()
conn.wgWatcher = nil
conn.wgWatcherCancel = nil
}
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
@@ -843,7 +857,9 @@ func (conn *Conn) resetEndpoint() {
return
}
conn.Log.Infof("reset wg endpoint")
conn.wgWatcher.Reset()
if conn.wgWatcher != nil {
conn.wgWatcher.Reset()
}
if err := conn.endpointUpdater.RemoveEndpointAddress(); err != nil {
conn.Log.Warnf("failed to remove endpoint address before update: %v", err)
}

View File

@@ -3,7 +3,6 @@ package peer
import (
"context"
"fmt"
"sync"
"time"
log "github.com/sirupsen/logrus"
@@ -24,14 +23,14 @@ type WGInterfaceStater interface {
GetStats() (map[string]configurer.WGStats, error)
}
// WGWatcher is single-shot: one instance per connection attempt, run once, then discarded.
// Lifecycle is owned by Conn under conn.mu, so it keeps no "enabled" state to go stale.
type WGWatcher struct {
log *log.Entry
wgIfaceStater WGInterfaceStater
peerKey string
stateDump *stateDump
enabled bool
muEnabled sync.Mutex
// initialHandshake is not thread-safe; never call PrepareInitialHandshake and EnableWgWatcher concurrently.
initialHandshake time.Time
@@ -48,25 +47,14 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
}
}
// PrepareInitialHandshake reserves the watcher and reads the peer's current WireGuard
// handshake time. It must be called before the peer is (re)configured on the WireGuard
// interface, so the captured baseline reflects the state prior to this connection attempt
// instead of racing with that configuration. Returns ok=false if the watcher is already
// running, in which case EnableWgWatcher must not be called.
func (w *WGWatcher) PrepareInitialHandshake() (ok bool) {
w.muEnabled.Lock()
if w.enabled {
w.muEnabled.Unlock()
return false
}
// PrepareInitialHandshake reads the peer's current WireGuard handshake time. It must be
// called before the peer is (re)configured on the WireGuard interface, so the captured
// baseline reflects the state prior to this connection attempt instead of racing with
// that configuration.
func (w *WGWatcher) PrepareInitialHandshake() {
w.log.Debugf("enable WireGuard watcher")
w.enabled = true
w.muEnabled.Unlock()
handshake, _ := w.wgState()
w.initialHandshake = handshake
return true
}
// EnableWgWatcher runs the WireGuard watcher loop using the handshake baseline captured by
@@ -74,10 +62,6 @@ func (w *WGWatcher) PrepareInitialHandshake() (ok bool) {
// for context lifecycle management.
func (w *WGWatcher) EnableWgWatcher(ctx context.Context, enabledTime time.Time, onDisconnectedFn func(), onHandshakeSuccessFn func(when time.Time)) {
w.periodicHandshakeCheck(ctx, onDisconnectedFn, onHandshakeSuccessFn, enabledTime, w.initialHandshake)
w.muEnabled.Lock()
w.enabled = false
w.muEnabled.Unlock()
}
// Reset signals the watcher that the WireGuard peer has been reset and a new
@@ -103,6 +87,7 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn
case <-timer.C:
handshake, ok := w.handshakeCheck(lastHandshake)
if !ok {
// early ctx cancel check return
if ctx.Err() != nil {
return
}
@@ -147,9 +132,9 @@ func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
// the current know handshake did not change
// the current known handshake did not change
if handshake.Equal(lastHandshake) {
w.log.Warnf("WireGuard handshake timed out: %v", handshake)
w.log.Warnf("WireGuard handshake not updated: %v", handshake)
return nil, false
}

View File

@@ -7,7 +7,6 @@ import (
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/configurer"
)
@@ -35,8 +34,7 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ok := watcher.PrepareInitialHandshake()
require.True(t, ok, "watcher should not be enabled yet")
watcher.PrepareInitialHandshake()
onDisconnected := make(chan struct{}, 1)
go watcher.EnableWgWatcher(ctx, time.Now(), func() {
@@ -66,8 +64,7 @@ func TestWGWatcher_ReEnable(t *testing.T) {
watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump("peer", mlog, &Status{}))
ctx, cancel := context.WithCancel(context.Background())
ok := watcher.PrepareInitialHandshake()
require.True(t, ok, "watcher should not be enabled yet")
watcher.PrepareInitialHandshake()
wg := &sync.WaitGroup{}
wg.Add(1)
@@ -83,8 +80,7 @@ func TestWGWatcher_ReEnable(t *testing.T) {
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
ok = watcher.PrepareInitialHandshake()
require.True(t, ok, "watcher should be re-enabled after the previous run stopped")
watcher.PrepareInitialHandshake()
onDisconnected := make(chan struct{}, 1)
go watcher.EnableWgWatcher(ctx, time.Now(), func() {