If can not read WireGuard state then trigger reconnection

This commit is contained in:
Zoltán Papp
2025-01-30 13:37:00 +01:00
parent 759544f2c3
commit ef5e417cb7

View File

@@ -16,7 +16,6 @@ const (
var (
wgHandshakeOvertime = 30 * time.Second
wgReadErrorRetry = 5 * time.Second
checkPeriod = wgHandshakePeriod + wgHandshakeOvertime
)
@@ -51,11 +50,14 @@ func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn
return
}
ctx, ctxCancel := context.WithCancel(parentCtx)
w.ctx = ctx
w.ctxCancel = ctxCancel
w.ctx, w.ctxCancel = context.WithCancel(parentCtx)
w.wgStateCheck(ctx, ctxCancel, onDisconnectedFn)
initialHandshake, err := w.wgState()
if err != nil {
w.log.Warnf("failed to read wg stats: %v", err)
}
go w.periodicHandshakeCheck(w.ctx, w.ctxCancel, onDisconnectedFn, initialHandshake)
}
func (w *WGWatcher) DisableWgWatcher() {
@@ -72,46 +74,30 @@ func (w *WGWatcher) DisableWgWatcher() {
}
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WGWatcher) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func()) {
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
w.log.Debugf("WireGuard watcher started")
lastHandshake, err := w.wgState()
if err != nil {
w.log.Warnf("failed to read wg stats: %v", err)
lastHandshake = time.Time{}
}
go func(lastHandshake time.Time) {
timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()
defer ctxCancel()
timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()
defer ctxCancel()
for {
select {
case <-timer.C:
handshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
timer.Reset(wgReadErrorRetry)
continue
}
lastHandshake := initialHandshake
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
if handshake.Equal(lastHandshake) {
w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake)
onDisconnectedFn()
return
}
resetTime := time.Until(handshake.Add(checkPeriod))
lastHandshake = handshake
timer.Reset(resetTime)
case <-ctx.Done():
w.log.Debugf("WireGuard watcher stopped")
for {
select {
case <-timer.C:
handshake, ok := w.handshakeCheck(lastHandshake)
if !ok {
onDisconnectedFn()
return
}
timer.Reset(time.Until(handshake.Add(checkPeriod)))
lastHandshake = *handshake
case <-ctx.Done():
w.log.Debugf("WireGuard watcher stopped")
return
}
}(lastHandshake)
}
}
func (w *WGWatcher) wgState() (time.Time, error) {
@@ -121,3 +107,20 @@ func (w *WGWatcher) wgState() (time.Time, error) {
}
return wgState.LastHandshake, nil
}
func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
handshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
return nil, false
}
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
if handshake.Equal(lastHandshake) {
w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake)
return nil, false
}
return &handshake, true
}