mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-24 19:26:39 +00:00
Move wg watcher code to separated file and add tests
This commit is contained in:
123
client/internal/peer/wg_watcher.go
Normal file
123
client/internal/peer/wg_watcher.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
)
|
||||
|
||||
const (
|
||||
wgHandshakePeriod = 3 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
wgHandshakeOvertime = 30 * time.Second
|
||||
wgReadErrorRetry = 5 * time.Second
|
||||
checkPeriod = wgHandshakePeriod + wgHandshakeOvertime
|
||||
)
|
||||
|
||||
type WGInterfaceStater interface {
|
||||
GetStats(key string) (configurer.WGStats, error)
|
||||
}
|
||||
|
||||
type WGWatcher struct {
|
||||
log *log.Entry
|
||||
wgIfaceStater WGInterfaceStater
|
||||
peerKey string
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
ctxLock sync.Mutex
|
||||
}
|
||||
|
||||
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string) *WGWatcher {
|
||||
return &WGWatcher{
|
||||
log: log,
|
||||
wgIfaceStater: wgIfaceStater,
|
||||
peerKey: peerKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
|
||||
w.log.Debugf("enable WireGuard watcher")
|
||||
w.ctxLock.Lock()
|
||||
defer w.ctxLock.Unlock()
|
||||
|
||||
if w.ctx != nil && w.ctx.Err() == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, ctxCancel := context.WithCancel(parentCtx)
|
||||
w.ctx = ctx
|
||||
w.ctxCancel = ctxCancel
|
||||
|
||||
w.wgStateCheck(ctx, ctxCancel, onDisconnectedFn)
|
||||
}
|
||||
|
||||
func (w *WGWatcher) DisableWgWatcher() {
|
||||
w.ctxLock.Lock()
|
||||
defer w.ctxLock.Unlock()
|
||||
|
||||
if w.ctxCancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
w.log.Debugf("disable WireGuard watcher")
|
||||
|
||||
w.ctxCancel()
|
||||
}
|
||||
|
||||
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
||||
func (w *WGWatcher) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func()) {
|
||||
w.log.Debugf("WireGuard watcher started")
|
||||
lastHandshake, err := w.wgState()
|
||||
if err != nil {
|
||||
w.log.Warnf("failed to read wg stats: %v", err)
|
||||
lastHandshake = time.Time{}
|
||||
}
|
||||
|
||||
go func(lastHandshake time.Time) {
|
||||
timer := time.NewTimer(wgHandshakeOvertime)
|
||||
defer timer.Stop()
|
||||
defer ctxCancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
handshake, err := w.wgState()
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to read wg stats: %v", err)
|
||||
timer.Reset(wgReadErrorRetry)
|
||||
continue
|
||||
}
|
||||
|
||||
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
|
||||
|
||||
if handshake.Equal(lastHandshake) {
|
||||
w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake)
|
||||
onDisconnectedFn()
|
||||
return
|
||||
}
|
||||
|
||||
resetTime := time.Until(handshake.Add(checkPeriod))
|
||||
lastHandshake = handshake
|
||||
timer.Reset(resetTime)
|
||||
case <-ctx.Done():
|
||||
w.log.Debugf("WireGuard watcher stopped")
|
||||
return
|
||||
}
|
||||
}
|
||||
}(lastHandshake)
|
||||
}
|
||||
|
||||
func (w *WGWatcher) wgState() (time.Time, error) {
|
||||
wgState, err := w.wgIfaceStater.GetStats(w.peerKey)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
return wgState.LastHandshake, nil
|
||||
}
|
||||
98
client/internal/peer/wg_watcher_test.go
Normal file
98
client/internal/peer/wg_watcher_test.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
)
|
||||
|
||||
type MocWgIface struct {
|
||||
initial bool
|
||||
lastHandshake time.Time
|
||||
stop bool
|
||||
}
|
||||
|
||||
func (m *MocWgIface) GetStats(key string) (configurer.WGStats, error) {
|
||||
if !m.initial {
|
||||
m.initial = true
|
||||
return configurer.WGStats{}, nil
|
||||
}
|
||||
|
||||
if !m.stop {
|
||||
m.lastHandshake = time.Now()
|
||||
}
|
||||
|
||||
stats := configurer.WGStats{
|
||||
LastHandshake: m.lastHandshake,
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (m *MocWgIface) disconnect() {
|
||||
m.stop = true
|
||||
}
|
||||
|
||||
func TestWGWatcher_EnableWgWatcher(t *testing.T) {
|
||||
checkPeriod = 5 * time.Second
|
||||
wgHandshakeOvertime = 1 * time.Second
|
||||
|
||||
mlog := log.WithField("peer", "tet")
|
||||
mocWgIface := &MocWgIface{}
|
||||
watcher := NewWGWatcher(mlog, mocWgIface, "")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
onDisconnected := make(chan struct{}, 1)
|
||||
watcher.EnableWgWatcher(ctx, func() {
|
||||
mlog.Infof("onDisconnectedFn")
|
||||
onDisconnected <- struct{}{}
|
||||
})
|
||||
|
||||
// wait for initial reading
|
||||
time.Sleep(2 * time.Second)
|
||||
mocWgIface.disconnect()
|
||||
|
||||
select {
|
||||
case <-onDisconnected:
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
watcher.DisableWgWatcher()
|
||||
}
|
||||
|
||||
func TestWGWatcher_ReEnable(t *testing.T) {
|
||||
checkPeriod = 5 * time.Second
|
||||
wgHandshakeOvertime = 1 * time.Second
|
||||
|
||||
mlog := log.WithField("peer", "tet")
|
||||
mocWgIface := &MocWgIface{}
|
||||
watcher := NewWGWatcher(mlog, mocWgIface, "")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
onDisconnected := make(chan struct{}, 1)
|
||||
|
||||
watcher.EnableWgWatcher(ctx, func() {})
|
||||
watcher.DisableWgWatcher()
|
||||
|
||||
watcher.EnableWgWatcher(ctx, func() {
|
||||
onDisconnected <- struct{}{}
|
||||
})
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
mocWgIface.disconnect()
|
||||
|
||||
select {
|
||||
case <-onDisconnected:
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
watcher.DisableWgWatcher()
|
||||
}
|
||||
@@ -6,18 +6,12 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
)
|
||||
|
||||
var (
|
||||
wgHandshakePeriod = 3 * time.Minute
|
||||
wgHandshakeOvertime = 30 * time.Second
|
||||
)
|
||||
|
||||
type RelayConnInfo struct {
|
||||
relayedConn net.Conn
|
||||
rosenpassPubKey []byte
|
||||
@@ -36,13 +30,12 @@ type WorkerRelay struct {
|
||||
relayManager relayClient.ManagerService
|
||||
callBacks WorkerRelayCallbacks
|
||||
|
||||
relayedConn net.Conn
|
||||
relayLock sync.Mutex
|
||||
ctxWgWatch context.Context
|
||||
ctxCancelWgWatch context.CancelFunc
|
||||
ctxLock sync.Mutex
|
||||
relayedConn net.Conn
|
||||
relayLock sync.Mutex
|
||||
|
||||
relaySupportedOnRemotePeer atomic.Bool
|
||||
|
||||
wgWatcher *WGWatcher
|
||||
}
|
||||
|
||||
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay {
|
||||
@@ -52,6 +45,7 @@ func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, relayManager r
|
||||
config: config,
|
||||
relayManager: relayManager,
|
||||
callBacks: callbacks,
|
||||
wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key),
|
||||
}
|
||||
return r
|
||||
}
|
||||
@@ -103,32 +97,11 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
|
||||
w.log.Debugf("enable WireGuard watcher")
|
||||
w.ctxLock.Lock()
|
||||
defer w.ctxLock.Unlock()
|
||||
|
||||
if w.ctxWgWatch != nil && w.ctxWgWatch.Err() == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, ctxCancel := context.WithCancel(ctx)
|
||||
w.ctxWgWatch = ctx
|
||||
w.ctxCancelWgWatch = ctxCancel
|
||||
|
||||
w.wgStateCheck(ctx, ctxCancel)
|
||||
w.wgWatcher.EnableWgWatcher(ctx, w.disconnected)
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) DisableWgWatcher() {
|
||||
w.ctxLock.Lock()
|
||||
defer w.ctxLock.Unlock()
|
||||
|
||||
if w.ctxCancelWgWatch == nil {
|
||||
return
|
||||
}
|
||||
|
||||
w.log.Debugf("disable WireGuard watcher")
|
||||
|
||||
w.ctxCancelWgWatch()
|
||||
w.wgWatcher.DisableWgWatcher()
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
|
||||
@@ -150,57 +123,17 @@ func (w *WorkerRelay) CloseConn() {
|
||||
return
|
||||
}
|
||||
|
||||
err := w.relayedConn.Close()
|
||||
if err != nil {
|
||||
if err := w.relayedConn.Close(); err != nil {
|
||||
w.log.Warnf("failed to close relay connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
||||
func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc) {
|
||||
w.log.Debugf("WireGuard watcher started")
|
||||
lastHandshake, err := w.wgState()
|
||||
if err != nil {
|
||||
w.log.Warnf("failed to read wg stats: %v", err)
|
||||
lastHandshake = time.Time{}
|
||||
}
|
||||
|
||||
go func(lastHandshake time.Time) {
|
||||
timer := time.NewTimer(wgHandshakeOvertime)
|
||||
defer timer.Stop()
|
||||
defer ctxCancel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
handshake, err := w.wgState()
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to read wg stats: %v", err)
|
||||
timer.Reset(wgHandshakeOvertime)
|
||||
continue
|
||||
}
|
||||
|
||||
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
|
||||
|
||||
if handshake.Equal(lastHandshake) {
|
||||
w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake)
|
||||
w.relayLock.Lock()
|
||||
_ = w.relayedConn.Close()
|
||||
w.relayLock.Unlock()
|
||||
w.callBacks.OnDisconnected()
|
||||
return
|
||||
}
|
||||
|
||||
resetTime := time.Until(handshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
|
||||
lastHandshake = handshake
|
||||
timer.Reset(resetTime)
|
||||
case <-ctx.Done():
|
||||
w.log.Debugf("WireGuard watcher stopped")
|
||||
return
|
||||
}
|
||||
}
|
||||
}(lastHandshake)
|
||||
func (w *WorkerRelay) disconnected() {
|
||||
w.relayLock.Lock()
|
||||
_ = w.relayedConn.Close()
|
||||
w.relayLock.Unlock()
|
||||
|
||||
w.callBacks.OnDisconnected()
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
|
||||
@@ -217,20 +150,7 @@ func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress st
|
||||
return remoteRelayAddress
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) wgState() (time.Time, error) {
|
||||
wgState, err := w.config.WgConfig.WgInterface.GetStats(w.config.Key)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
return wgState.LastHandshake, nil
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) onRelayMGDisconnected() {
|
||||
w.ctxLock.Lock()
|
||||
defer w.ctxLock.Unlock()
|
||||
|
||||
if w.ctxCancelWgWatch != nil {
|
||||
w.ctxCancelWgWatch()
|
||||
}
|
||||
w.wgWatcher.DisableWgWatcher()
|
||||
go w.callBacks.OnDisconnected()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user