If can not read WireGuard state then trigger reconnection

This commit is contained in:
Zoltán Papp
2025-01-30 13:37:00 +01:00
parent 759544f2c3
commit ef5e417cb7

View File

@@ -16,7 +16,6 @@ const (
var ( var (
wgHandshakeOvertime = 30 * time.Second wgHandshakeOvertime = 30 * time.Second
wgReadErrorRetry = 5 * time.Second
checkPeriod = wgHandshakePeriod + wgHandshakeOvertime checkPeriod = wgHandshakePeriod + wgHandshakeOvertime
) )
@@ -51,11 +50,14 @@ func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn
return return
} }
ctx, ctxCancel := context.WithCancel(parentCtx) w.ctx, w.ctxCancel = context.WithCancel(parentCtx)
w.ctx = ctx
w.ctxCancel = ctxCancel
w.wgStateCheck(ctx, ctxCancel, onDisconnectedFn) initialHandshake, err := w.wgState()
if err != nil {
w.log.Warnf("failed to read wg stats: %v", err)
}
go w.periodicHandshakeCheck(w.ctx, w.ctxCancel, onDisconnectedFn, initialHandshake)
} }
func (w *WGWatcher) DisableWgWatcher() { func (w *WGWatcher) DisableWgWatcher() {
@@ -72,46 +74,30 @@ func (w *WGWatcher) DisableWgWatcher() {
} }
// wgStateCheck help to check the state of the WireGuard handshake and relay connection // wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WGWatcher) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func()) { func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
w.log.Debugf("WireGuard watcher started") w.log.Debugf("WireGuard watcher started")
lastHandshake, err := w.wgState()
if err != nil {
w.log.Warnf("failed to read wg stats: %v", err)
lastHandshake = time.Time{}
}
go func(lastHandshake time.Time) { timer := time.NewTimer(wgHandshakeOvertime)
timer := time.NewTimer(wgHandshakeOvertime) defer timer.Stop()
defer timer.Stop() defer ctxCancel()
defer ctxCancel()
for { lastHandshake := initialHandshake
select {
case <-timer.C:
handshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
timer.Reset(wgReadErrorRetry)
continue
}
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake) for {
select {
if handshake.Equal(lastHandshake) { case <-timer.C:
w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake) handshake, ok := w.handshakeCheck(lastHandshake)
onDisconnectedFn() if !ok {
return onDisconnectedFn()
}
resetTime := time.Until(handshake.Add(checkPeriod))
lastHandshake = handshake
timer.Reset(resetTime)
case <-ctx.Done():
w.log.Debugf("WireGuard watcher stopped")
return return
} }
timer.Reset(time.Until(handshake.Add(checkPeriod)))
lastHandshake = *handshake
case <-ctx.Done():
w.log.Debugf("WireGuard watcher stopped")
return
} }
}(lastHandshake) }
} }
func (w *WGWatcher) wgState() (time.Time, error) { func (w *WGWatcher) wgState() (time.Time, error) {
@@ -121,3 +107,20 @@ func (w *WGWatcher) wgState() (time.Time, error) {
} }
return wgState.LastHandshake, nil return wgState.LastHandshake, nil
} }
func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
handshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
return nil, false
}
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
if handshake.Equal(lastHandshake) {
w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake)
return nil, false
}
return &handshake, true
}