Compare commits

...

14 Commits

Author SHA1 Message Date
Maycon Santos
c2180e4bb2 run windows iface tests without race flag 2025-07-21 00:06:50 +02:00
Maycon Santos
fc735d1337 fix test timer 2025-07-21 00:02:50 +02:00
Maycon Santos
33f4c3bd3f use getPeerListener 2025-07-20 23:38:24 +02:00
Maycon Santos
b6cef0cd26 skip race flag on 386 2025-07-20 23:13:35 +02:00
Maycon Santos
cc78a3c65f introduce ConnPriorityStore 2025-07-20 23:09:42 +02:00
Maycon Santos
cc1c77f6dc use set/get for reconnectingTimeout 2025-07-20 22:16:35 +02:00
Maycon Santos
d7d57a4ec4 fix invalid log format 2025-07-20 22:13:03 +02:00
Maycon Santos
09d0fea5ca use setters and getters for healthCheckInterval and healthCheckTimeout 2025-07-20 22:09:03 +02:00
Maycon Santos
4e737a482b rename and use getHeartBeatTimeout and setHeartBeatTimeout 2025-07-20 21:57:24 +02:00
Maycon Santos
62deb64f5f use getHealthCheckInterval and setHealthCheckInterval 2025-07-20 21:37:49 +02:00
Maycon Santos
bdb38dfa57 use relay manager.getClientLen() 2025-07-20 21:31:59 +02:00
Maycon Santos
84988b4d53 update checkChangeFn test usage 2025-07-20 21:24:51 +02:00
Maycon Santos
7b4cc63054 update some tests 2025-07-18 19:30:24 +02:00
Maycon Santos
e66412da1b add race flag to client tests
using for now a temp fixed for ice
2025-07-18 19:12:44 +02:00
25 changed files with 251 additions and 86 deletions

View File

@@ -42,5 +42,5 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management) run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -race -timeout 5m -p 1 $(go list ./... | grep -v /management)

View File

@@ -103,7 +103,11 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] include:
- arch: "386"
raceFlag: ""
- arch: "amd64"
raceFlag: "-race"
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Install Go - name: Install Go
@@ -144,7 +148,7 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay) run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' ${{ matrix.raceFlag }} -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
test_client_on_docker: test_client_on_docker:
name: "Client (Docker) / Unit" name: "Client (Docker) / Unit"

View File

@@ -63,10 +63,16 @@ jobs:
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }} - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/client/iface' } )" >> $env:GITHUB_ENV
- name: test - name: test without iface
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -race -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
- name: test output - name: test output without iface
if: ${{ always() }}
run: Get-Content test-out.txt
# todo: remove this once iface tests are stable with race flag
- name: test iface
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ./client/iface/... > test-out.txt 2>&1"
- name: test output iface
if: ${{ always() }} if: ${{ always() }}
run: Get-Content test-out.txt run: Get-Content test-out.txt

View File

@@ -1,3 +1,5 @@
//go:build !race
package cmd package cmd
import ( import (

View File

@@ -218,3 +218,9 @@ func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []by
TxBytes: conn.BytesTx.Load(), TxBytes: conn.BytesTx.Load(),
}) })
} }
func (t *UDPTracker) getConnectionsLen() int {
t.mutex.RLock()
defer t.mutex.RUnlock()
return len(t.connections)
}

View File

@@ -202,13 +202,13 @@ func TestUDPTracker_Cleanup(t *testing.T) {
} }
// Verify initial connections // Verify initial connections
assert.Len(t, tracker.connections, 2) assert.Equal(t, 2, tracker.getConnectionsLen())
// Wait for connection timeout and cleanup interval // Wait for connection timeout and cleanup interval
time.Sleep(timeout + 2*cleanupInterval) time.Sleep(timeout + 2*cleanupInterval)
tracker.mutex.RLock() tracker.mutex.RLock()
connCount := len(tracker.connections) connCount := tracker.getConnectionsLen()
tracker.mutex.RUnlock() tracker.mutex.RUnlock()
// Verify connections were cleaned up // Verify connections were cleaned up

View File

@@ -50,7 +50,7 @@ type upstreamResolverBase struct {
upstreamClient upstreamClient upstreamClient upstreamClient
upstreamServers []string upstreamServers []string
domain string domain string
disabled bool disabled atomic.Bool
failsCount atomic.Int32 failsCount atomic.Int32
successCount atomic.Int32 successCount atomic.Int32
failsTillDeact int32 failsTillDeact int32
@@ -176,7 +176,7 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
u.mutex.Lock() u.mutex.Lock()
defer u.mutex.Unlock() defer u.mutex.Unlock()
if u.failsCount.Load() < u.failsTillDeact || u.disabled { if u.failsCount.Load() < u.failsTillDeact || u.disabled.Load() {
return return
} }
@@ -305,7 +305,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
u.failsCount.Store(0) u.failsCount.Store(0)
u.successCount.Add(1) u.successCount.Add(1)
u.reactivate() u.reactivate()
u.disabled = false u.disabled.Store(false)
} }
// isTimeout returns true if the given error is a network timeout error. // isTimeout returns true if the given error is a network timeout error.
@@ -320,14 +320,14 @@ func isTimeout(err error) bool {
} }
func (u *upstreamResolverBase) disable(err error) { func (u *upstreamResolverBase) disable(err error) {
if u.disabled { if u.disabled.Load() {
return return
} }
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod) log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
u.successCount.Store(0) u.successCount.Store(0)
u.deactivate(err) u.deactivate(err)
u.disabled = true u.disabled.Store(true)
go u.waitUntilResponse() go u.waitUntilResponse()
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"net/netip" "net/netip"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@@ -135,32 +136,41 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
responseWriter := &test.MockResponseWriter{ responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error { return nil }, WriteMsgFunc: func(m *dns.Msg) error { return nil },
} }
lmux := sync.Mutex{}
failed := false failed := false
resolver.deactivate = func(error) { resolver.deactivate = func(error) {
lmux.Lock()
failed = true failed = true
lmux.Unlock()
} }
reactivated := false reactivated := false
resolver.reactivate = func() { resolver.reactivate = func() {
lmux.Lock()
reactivated = true reactivated = true
lmux.Unlock()
} }
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA)) resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA))
lmux.Lock()
if !failed { failedCheck := failed
lmux.Unlock()
if !failedCheck {
t.Errorf("expected that resolving was deactivated") t.Errorf("expected that resolving was deactivated")
return return
} }
if !resolver.disabled { if !resolver.disabled.Load() {
t.Errorf("resolver should be Disabled") t.Errorf("resolver should be Disabled")
return return
} }
time.Sleep(time.Millisecond * 200) time.Sleep(time.Millisecond * 200)
if !reactivated { lmux.Lock()
checkReactivated := reactivated
lmux.Unlock()
if !checkReactivated {
t.Errorf("expected that resolving was reactivated") t.Errorf("expected that resolving was reactivated")
return return
} }
@@ -170,7 +180,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
return return
} }
if resolver.disabled { if resolver.disabled.Load() {
t.Errorf("should be enabled") t.Errorf("should be enabled")
} }
} }

View File

@@ -836,7 +836,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
} }
go func() { go func() {
// blocking // blocking
err = e.sshServer.Start() e.syncMsgMux.Lock()
sshServer := e.sshServer
e.syncMsgMux.Unlock()
err = sshServer.Start()
if err != nil { if err != nil {
// will throw error when we stop it even if it is a graceful stop // will throw error when we stop it even if it is a graceful stop
log.Debugf("stopped SSH server with error %v", err) log.Debugf("stopped SSH server with error %v", err)
@@ -851,6 +854,8 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
} }
} else if !isNil(e.sshServer) { } else if !isNil(e.sshServer) {
// Disable SSH server request, so stop it if it was running // Disable SSH server request, so stop it if it was running
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
err := e.sshServer.Stop() err := e.sshServer.Stop()
if err != nil { if err != nil {
log.Warnf("failed to stop SSH server %v", err) log.Warnf("failed to stop SSH server %v", err)

View File

@@ -102,3 +102,11 @@ func (m *Manager) notify(peerConnID peerid.ConnID) {
case m.OnActivityChan <- peerConnID: case m.OnActivityChan <- peerConnID:
} }
} }
func (m *Manager) getPeerListener(peerConnID peerid.ConnID) (*Listener, bool) {
m.mu.Lock()
defer m.mu.Unlock()
listener, ok := m.peers[peerConnID]
return listener, ok
}

View File

@@ -50,8 +50,11 @@ func TestManager_MonitorPeerActivity(t *testing.T) {
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil { if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
t.Fatalf("failed to monitor peer activity: %v", err) t.Fatalf("failed to monitor peer activity: %v", err)
} }
listener, ok := mgr.getPeerListener(peerCfg1.PeerConnID)
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil { if !ok {
t.Fatalf("failed to get peer listener: %s", peerCfg1.PublicKey)
}
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err) t.Fatalf("failed to trigger activity: %v", err)
} }
@@ -83,7 +86,12 @@ func TestManager_RemovePeerActivity(t *testing.T) {
t.Fatalf("failed to monitor peer activity: %v", err) t.Fatalf("failed to monitor peer activity: %v", err)
} }
addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String() peer1Listener, ok := mgr.getPeerListener(peerCfg1.PeerConnID)
if !ok {
t.Fatalf("failed to get peer listener: %s", peerCfg1.PublicKey)
}
addr := peer1Listener.conn.LocalAddr().String()
mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID) mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID)
@@ -128,11 +136,20 @@ func TestManager_MultiPeerActivity(t *testing.T) {
t.Fatalf("failed to monitor peer activity: %v", err) t.Fatalf("failed to monitor peer activity: %v", err)
} }
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil { peer1Listener, ok := mgr.getPeerListener(peerCfg1.PeerConnID)
if !ok {
t.Fatalf("failed to get peer listener: %s", peerCfg1.PublicKey)
}
if err := trigger(peer1Listener.conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err) t.Fatalf("failed to trigger activity: %v", err)
} }
if err := trigger(mgr.peers[peerCfg2.PeerConnID].conn.LocalAddr().String()); err != nil { peer2Listener, ok := mgr.getPeerListener(peerCfg2.PeerConnID)
if !ok {
t.Fatalf("failed to get peer listener: %s", peerCfg2.PublicKey)
}
if err := trigger(peer2Listener.conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err) t.Fatalf("failed to trigger activity: %v", err)
} }

View File

@@ -22,6 +22,19 @@ const (
) )
var checkChangeFn = checkChange var checkChangeFn = checkChange
var mux sync.Mutex
func getCheckChangeFn() func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
mux.Lock()
defer mux.Unlock()
return checkChangeFn
}
func setCheckChangeFn(fn func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error) {
mux.Lock()
defer mux.Unlock()
checkChangeFn = fn
}
// NetworkMonitor watches for changes in network configuration. // NetworkMonitor watches for changes in network configuration.
type NetworkMonitor struct { type NetworkMonitor struct {
@@ -120,7 +133,8 @@ func (nw *NetworkMonitor) Stop() {
func (nw *NetworkMonitor) checkChanges(ctx context.Context, event chan struct{}, nexthop4 systemops.Nexthop, nexthop6 systemops.Nexthop) { func (nw *NetworkMonitor) checkChanges(ctx context.Context, event chan struct{}, nexthop4 systemops.Nexthop, nexthop6 systemops.Nexthop) {
defer close(event) defer close(event)
for { for {
if err := checkChangeFn(ctx, nexthop4, nexthop6); err != nil { checkFn := getCheckChangeFn()
if err := checkFn(ctx, nexthop4, nexthop6); err != nil {
if !errors.Is(err, context.Canceled) { if !errors.Is(err, context.Canceled) {
log.Errorf("Network monitor: failed to check for changes: %v", err) log.Errorf("Network monitor: failed to check for changes: %v", err)
} }

View File

@@ -25,10 +25,10 @@ func (m *MocMultiEvent) checkChange(ctx context.Context, nexthopv4, nexthopv6 sy
} }
func TestNetworkMonitor_Close(t *testing.T) { func TestNetworkMonitor_Close(t *testing.T) {
checkChangeFn = func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { setCheckChangeFn(func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
<-ctx.Done() <-ctx.Done()
return ctx.Err() return ctx.Err()
} })
nw := New() nw := New()
var resErr error var resErr error
@@ -48,7 +48,7 @@ func TestNetworkMonitor_Close(t *testing.T) {
} }
func TestNetworkMonitor_Event(t *testing.T) { func TestNetworkMonitor_Event(t *testing.T) {
checkChangeFn = func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { setCheckChangeFn(func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
timeout, cancel := context.WithTimeout(ctx, 3*time.Second) timeout, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel() defer cancel()
select { select {
@@ -57,7 +57,7 @@ func TestNetworkMonitor_Event(t *testing.T) {
case <-timeout.Done(): case <-timeout.Done():
return nil return nil
} }
} })
nw := New() nw := New()
defer nw.Stop() defer nw.Stop()
@@ -77,7 +77,7 @@ func TestNetworkMonitor_Event(t *testing.T) {
func TestNetworkMonitor_MultiEvent(t *testing.T) { func TestNetworkMonitor_MultiEvent(t *testing.T) {
eventsRepeated := 3 eventsRepeated := 3
me := &MocMultiEvent{counter: eventsRepeated} me := &MocMultiEvent{counter: eventsRepeated}
checkChangeFn = me.checkChange setCheckChangeFn(me.checkChange)
nw := New() nw := New()
defer nw.Stop() defer nw.Stop()

View File

@@ -99,7 +99,7 @@ type Conn struct {
statusRelay *worker.AtomicWorkerStatus statusRelay *worker.AtomicWorkerStatus
statusICE *worker.AtomicWorkerStatus statusICE *worker.AtomicWorkerStatus
currentConnPriority conntype.ConnPriority currentConnPriority conntype.ConnPriorityStore
opened bool // this flag is used to prevent close in case of not opened connection opened bool // this flag is used to prevent close in case of not opened connection
workerICE *WorkerICE workerICE *WorkerICE
@@ -283,7 +283,7 @@ func (conn *Conn) Close(signalToRemote bool) {
// doesn't block, discards the message if connection wasn't ready // doesn't block, discards the message if connection wasn't ready
func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool { func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
conn.dumpState.RemoteAnswer() conn.dumpState.RemoteAnswer()
conn.Log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay) conn.Log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority.Get(), conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteAnswer(answer) return conn.handshaker.OnRemoteAnswer(answer)
} }
@@ -353,8 +353,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
// this never should happen, because Relay is the lower priority and ICE always close the deprecated connection before upgrade // this never should happen, because Relay is the lower priority and ICE always close the deprecated connection before upgrade
// todo consider to remove this check // todo consider to remove this check
if conn.currentConnPriority > priority { if conn.currentConnPriority.Get() > priority {
conn.Log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority) conn.Log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority.Get(), priority)
conn.statusICE.SetConnected() conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo) conn.updateIceState(iceConnInfo)
return return
@@ -408,7 +408,7 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
} }
wgConfigWorkaround() wgConfigWorkaround()
conn.currentConnPriority = priority conn.currentConnPriority.Set(priority)
conn.statusICE.SetConnected() conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo) conn.updateIceState(iceConnInfo)
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
@@ -445,10 +445,10 @@ func (conn *Conn) onICEStateDisconnected() {
defer conn.wgWatcherWg.Done() defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx) conn.workerRelay.EnableWgWatcher(conn.ctx)
}() }()
conn.currentConnPriority = conntype.Relay conn.currentConnPriority.Set(conntype.Relay)
} else { } else {
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String()) conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
conn.currentConnPriority = conntype.None conn.currentConnPriority.Set(conntype.None)
} }
changed := conn.statusICE.Get() != worker.StatusDisconnected changed := conn.statusICE.Get() != worker.StatusDisconnected
@@ -496,7 +496,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
if conn.isICEActive() { if conn.isICEActive() {
conn.Log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String()) conn.Log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.Get().String())
conn.setRelayedProxy(wgProxy) conn.setRelayedProxy(wgProxy)
conn.statusRelay.SetConnected() conn.statusRelay.SetConnected()
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
@@ -524,7 +524,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
wgConfigWorkaround() wgConfigWorkaround()
conn.rosenpassRemoteKey = rci.rosenpassPubKey conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = conntype.Relay conn.currentConnPriority.Set(conntype.Relay)
conn.statusRelay.SetConnected() conn.statusRelay.SetConnected()
conn.setRelayedProxy(wgProxy) conn.setRelayedProxy(wgProxy)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
@@ -542,9 +542,9 @@ func (conn *Conn) onRelayDisconnected() {
conn.Log.Debugf("relay connection is disconnected") conn.Log.Debugf("relay connection is disconnected")
if conn.currentConnPriority == conntype.Relay { if conn.currentConnPriority.Get() == conntype.Relay {
conn.Log.Debugf("clean up WireGuard config") conn.Log.Debugf("clean up WireGuard config")
conn.currentConnPriority = conntype.None conn.currentConnPriority.Set(conntype.None)
} }
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
@@ -626,7 +626,7 @@ func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
func (conn *Conn) setStatusToDisconnected() { func (conn *Conn) setStatusToDisconnected() {
conn.statusRelay.SetDisconnected() conn.statusRelay.SetDisconnected()
conn.statusICE.SetDisconnected() conn.statusICE.SetDisconnected()
conn.currentConnPriority = conntype.None conn.currentConnPriority.Set(conntype.None)
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
@@ -669,7 +669,7 @@ func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
} }
func (conn *Conn) isRelayed() bool { func (conn *Conn) isRelayed() bool {
switch conn.currentConnPriority { switch conn.currentConnPriority.Get() {
case conntype.Relay, conntype.ICETurn: case conntype.Relay, conntype.ICETurn:
return true return true
default: default:
@@ -753,11 +753,11 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
} }
func (conn *Conn) isReadyToUpgrade() bool { func (conn *Conn) isReadyToUpgrade() bool {
return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay return conn.wgProxyRelay != nil && conn.currentConnPriority.Get() != conntype.Relay
} }
func (conn *Conn) isICEActive() bool { func (conn *Conn) isICEActive() bool {
return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected return (conn.currentConnPriority.Get() == conntype.ICEP2P || conn.currentConnPriority.Get() == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected
} }
func (conn *Conn) removeWgPeer() error { func (conn *Conn) removeWgPeer() error {

View File

@@ -2,6 +2,7 @@ package conntype
import ( import (
"fmt" "fmt"
"sync/atomic"
) )
const ( const (
@@ -11,7 +12,7 @@ const (
ICEP2P ConnPriority = 3 ICEP2P ConnPriority = 3
) )
type ConnPriority int type ConnPriority int32
func (cp ConnPriority) String() string { func (cp ConnPriority) String() string {
switch cp { switch cp {
@@ -27,3 +28,15 @@ func (cp ConnPriority) String() string {
return fmt.Sprintf("ConnPriority(%d)", cp) return fmt.Sprintf("ConnPriority(%d)", cp)
} }
} }
type ConnPriorityStore struct {
store atomic.Int32
}
func (cps *ConnPriorityStore) Get() ConnPriority {
return ConnPriority(cps.store.Load())
}
func (cps *ConnPriorityStore) Set(cp ConnPriority) {
cps.store.Store(int32(cp))
}

View File

@@ -9,30 +9,54 @@ type mocListener struct {
lastState int lastState int
wg sync.WaitGroup wg sync.WaitGroup
peers int peers int
mux sync.Mutex
} }
func (l *mocListener) OnConnected() { func (l *mocListener) OnConnected() {
l.mux.Lock()
defer l.mux.Unlock()
l.lastState = stateConnected l.lastState = stateConnected
l.wg.Done() l.wg.Done()
} }
func (l *mocListener) OnDisconnected() { func (l *mocListener) OnDisconnected() {
l.mux.Lock()
defer l.mux.Unlock()
l.lastState = stateDisconnected l.lastState = stateDisconnected
l.wg.Done() l.wg.Done()
} }
func (l *mocListener) OnConnecting() { func (l *mocListener) OnConnecting() {
l.mux.Lock()
defer l.mux.Unlock()
l.lastState = stateConnecting l.lastState = stateConnecting
l.wg.Done() l.wg.Done()
} }
func (l *mocListener) OnDisconnecting() { func (l *mocListener) OnDisconnecting() {
l.mux.Lock()
defer l.mux.Unlock()
l.lastState = stateDisconnecting l.lastState = stateDisconnecting
l.wg.Done() l.wg.Done()
}
func (l *mocListener) getLastState() int {
l.mux.Lock()
defer l.mux.Unlock()
return l.lastState
} }
func (l *mocListener) OnAddressChanged(host, addr string) { func (l *mocListener) OnAddressChanged(host, addr string) {
} }
func (l *mocListener) OnPeersListChanged(size int) { func (l *mocListener) OnPeersListChanged(size int) {
l.mux.Lock()
l.peers = size l.peers = size
l.mux.Unlock()
}
func (l *mocListener) getPeers() int {
l.mux.Lock()
defer l.mux.Unlock()
return l.peers
} }
func (l *mocListener) setWaiter() { func (l *mocListener) setWaiter() {
@@ -77,7 +101,7 @@ func Test_notifier_SetListener(t *testing.T) {
n.lastNotification = stateConnecting n.lastNotification = stateConnecting
n.setListener(listener) n.setListener(listener)
listener.wait() listener.wait()
if listener.lastState != n.lastNotification { if listener.getLastState() != n.lastNotification {
t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification) t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification)
} }
} }
@@ -91,7 +115,7 @@ func Test_notifier_RemoveListener(t *testing.T) {
n.removeListener() n.removeListener()
n.peerListChanged(1) n.peerListChanged(1)
if listener.peers != 0 { if listener.getPeers() != 0 {
t.Errorf("invalid state: %d", listener.peers) t.Errorf("invalid state: %d", listener.peers)
} }
} }

2
go.mod
View File

@@ -257,6 +257,6 @@ replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-2
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
replace github.com/pion/ice/v3 => github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e replace github.com/pion/ice/v3 => github.com/netbirdio/ice/v3 v3.0.0-20250718163601-725c8ac53a31
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944

4
go.sum
View File

@@ -501,8 +501,8 @@ github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJE
github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20250718163601-725c8ac53a31 h1:lr/CnQ9NnlHr4yjDaCqy3V1FW+y9DDpzqxu1+YXzXtc=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= github.com/netbirdio/ice/v3 v3.0.0-20250718163601-725c8ac53a31/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5 h1:Zfn8d83OVyELCdxgprcyXR3D8uqoxHtXE9PUxVXDx/w= github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5 h1:Zfn8d83OVyELCdxgprcyXR3D8uqoxHtXE9PUxVXDx/w=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q= github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=

View File

@@ -2,6 +2,7 @@ package client
import ( import (
"context" "context"
"sync"
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
@@ -10,8 +11,21 @@ import (
var ( var (
reconnectingTimeout = 60 * time.Second reconnectingTimeout = 60 * time.Second
mux sync.Mutex
) )
func getReconnectingTimeout() time.Duration {
mux.Lock()
defer mux.Unlock()
return reconnectingTimeout
}
func setReconnectingTimeout(timeout time.Duration) {
mux.Lock()
defer mux.Unlock()
reconnectingTimeout = timeout
}
// Guard manage the reconnection tries to the Relay server in case of disconnection event. // Guard manage the reconnection tries to the Relay server in case of disconnection event.
type Guard struct { type Guard struct {
// OnNewRelayClient is a channel that is used to notify the relay manager about a new relay client instance. // OnNewRelayClient is a channel that is used to notify the relay manager about a new relay client instance.
@@ -128,7 +142,7 @@ func exponentTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{ bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 2 * time.Second, InitialInterval: 2 * time.Second,
Multiplier: 2, Multiplier: 2,
MaxInterval: reconnectingTimeout, MaxInterval: getReconnectingTimeout(),
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
}, ctx) }, ctx)

View File

@@ -350,6 +350,12 @@ func (m *Manager) startCleanupLoop() {
} }
} }
func (m *Manager) getClientLen() int {
m.relayClientsMutex.Lock()
defer m.relayClientsMutex.Unlock()
return len(m.relayClients)
}
func (m *Manager) cleanUpUnusedRelays() { func (m *Manager) cleanUpUnusedRelays() {
m.relayClientsMutex.Lock() m.relayClientsMutex.Lock()
defer m.relayClientsMutex.Unlock() defer m.relayClientsMutex.Unlock()

View File

@@ -292,8 +292,8 @@ func TestForeginAutoClose(t *testing.T) {
timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second
t.Logf("waiting for relay cleanup: %s", timeout) t.Logf("waiting for relay cleanup: %s", timeout)
time.Sleep(timeout) time.Sleep(timeout)
if len(mgr.relayClients) != 0 { if mgr.getClientLen() != 0 {
t.Errorf("expected 0, got %d", len(mgr.relayClients)) t.Errorf("expected 0, got %d", mgr.getClientLen())
} }
t.Logf("closing manager") t.Logf("closing manager")
@@ -301,7 +301,7 @@ func TestForeginAutoClose(t *testing.T) {
func TestAutoReconnect(t *testing.T) { func TestAutoReconnect(t *testing.T) {
ctx := context.Background() ctx := context.Background()
reconnectingTimeout = 2 * time.Second setReconnectingTimeout(2 * time.Second)
srvCfg := server.ListenerConfig{ srvCfg := server.ListenerConfig{
Address: "localhost:1234", Address: "localhost:1234",
@@ -362,7 +362,7 @@ func TestAutoReconnect(t *testing.T) {
} }
log.Infof("waiting for reconnection") log.Infof("waiting for reconnection")
time.Sleep(reconnectingTimeout + 1*time.Second) time.Sleep(getReconnectingTimeout() + 1*time.Second)
log.Infof("reopent the connection") log.Infof("reopent the connection")
_, err = clientAlice.OpenConn(ctx, ra, "bob") _, err = clientAlice.OpenConn(ctx, ra, "bob")

View File

@@ -2,14 +2,26 @@ package healthcheck
import ( import (
"context" "context"
"sync"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var ( var heartbeatTimeout = getHealthCheckInterval() + 10*time.Second
heartbeatTimeout = healthCheckInterval + 10*time.Second var mux sync.Mutex
)
func getHeartBeatTimeout() time.Duration {
mux.Lock()
defer mux.Unlock()
return heartbeatTimeout
}
func setHeartBeatTimeout(interval time.Duration) {
mux.Lock()
defer mux.Unlock()
heartbeatTimeout = interval
}
// Receiver is a healthcheck receiver // Receiver is a healthcheck receiver
// It will listen for heartbeat and check if the heartbeat is not received in a certain time // It will listen for heartbeat and check if the heartbeat is not received in a certain time
@@ -56,7 +68,7 @@ func (r *Receiver) Stop() {
} }
func (r *Receiver) waitForHealthcheck() { func (r *Receiver) waitForHealthcheck() {
ticker := time.NewTicker(heartbeatTimeout) ticker := time.NewTicker(getHeartBeatTimeout())
defer ticker.Stop() defer ticker.Stop()
defer r.ctxCancel() defer r.ctxCancel()
defer close(r.OnTimeout) defer close(r.OnTimeout)

View File

@@ -11,7 +11,7 @@ import (
) )
func TestNewReceiver(t *testing.T) { func TestNewReceiver(t *testing.T) {
heartbeatTimeout = 5 * time.Second setHeartBeatTimeout(5 * time.Second)
r := NewReceiver(log.WithContext(context.Background())) r := NewReceiver(log.WithContext(context.Background()))
select { select {
@@ -23,7 +23,7 @@ func TestNewReceiver(t *testing.T) {
} }
func TestNewReceiverNotReceive(t *testing.T) { func TestNewReceiverNotReceive(t *testing.T) {
heartbeatTimeout = 1 * time.Second setHeartBeatTimeout(1 * time.Second)
r := NewReceiver(log.WithContext(context.Background())) r := NewReceiver(log.WithContext(context.Background()))
select { select {
@@ -34,7 +34,7 @@ func TestNewReceiverNotReceive(t *testing.T) {
} }
func TestNewReceiverAck(t *testing.T) { func TestNewReceiverAck(t *testing.T) {
heartbeatTimeout = 2 * time.Second setHeartBeatTimeout(2 * time.Second)
r := NewReceiver(log.WithContext(context.Background())) r := NewReceiver(log.WithContext(context.Background()))
r.Heartbeat() r.Heartbeat()
@@ -59,13 +59,13 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
for _, tc := range testsCases { for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
originalInterval := healthCheckInterval originalInterval := getHealthCheckInterval()
originalTimeout := heartbeatTimeout originalTimeout := getHeartBeatTimeout()
healthCheckInterval = 1 * time.Second setHealthCheckInterval(1 * time.Second)
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond setHeartBeatTimeout(getHealthCheckInterval() + 500*time.Millisecond)
defer func() { defer func() {
healthCheckInterval = originalInterval setHealthCheckInterval(originalInterval)
heartbeatTimeout = originalTimeout setHeartBeatTimeout(originalTimeout)
}() }()
//nolint:tenv //nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
@@ -73,7 +73,7 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
receiver := NewReceiver(log.WithField("test_name", tc.name)) receiver := NewReceiver(log.WithField("test_name", tc.name))
testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval testTimeout := getHeartBeatTimeout()*time.Duration(tc.threshold) + getHealthCheckInterval()
if tc.resetCounterOnce { if tc.resetCounterOnce {
receiver.Heartbeat() receiver.Heartbeat()

View File

@@ -19,6 +19,30 @@ var (
healthCheckTimeout = 20 * time.Second healthCheckTimeout = 20 * time.Second
) )
func getHealthCheckInterval() time.Duration {
mux.Lock()
defer mux.Unlock()
return healthCheckInterval
}
func setHealthCheckInterval(interval time.Duration) {
mux.Lock()
defer mux.Unlock()
healthCheckInterval = interval
}
func getHealthCheckTimeout() time.Duration {
mux.Lock()
defer mux.Unlock()
return healthCheckTimeout
}
func setHealthCheckTimeout(timeout time.Duration) {
mux.Lock()
defer mux.Unlock()
healthCheckTimeout = timeout
}
// Sender is a healthcheck sender // Sender is a healthcheck sender
// It will send healthcheck signal to the receiver // It will send healthcheck signal to the receiver
// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work // If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
@@ -57,7 +81,7 @@ func (hc *Sender) OnHCResponse() {
} }
func (hc *Sender) StartHealthCheck(ctx context.Context) { func (hc *Sender) StartHealthCheck(ctx context.Context) {
ticker := time.NewTicker(healthCheckInterval) ticker := time.NewTicker(getHealthCheckInterval())
defer ticker.Stop() defer ticker.Stop()
timeoutTicker := time.NewTicker(hc.getTimeoutTime()) timeoutTicker := time.NewTicker(hc.getTimeoutTime())
@@ -94,7 +118,7 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) {
} }
func (hc *Sender) getTimeoutTime() time.Duration { func (hc *Sender) getTimeoutTime() time.Duration {
return healthCheckInterval + healthCheckTimeout return getHealthCheckInterval() + getHealthCheckTimeout()
} }
func getAttemptThresholdFromEnv() int { func getAttemptThresholdFromEnv() int {

View File

@@ -12,8 +12,8 @@ import (
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
// override the health check interval to speed up the test // override the health check interval to speed up the test
healthCheckInterval = 2 * time.Second setHealthCheckInterval(2 * time.Second)
healthCheckTimeout = 100 * time.Millisecond setHealthCheckTimeout(100 * time.Millisecond)
code := m.Run() code := m.Run()
os.Exit(code) os.Exit(code)
} }
@@ -32,7 +32,7 @@ func TestNewHealthPeriod(t *testing.T) {
hc.OnHCResponse() hc.OnHCResponse()
case <-hc.Timeout: case <-hc.Timeout:
t.Fatalf("health check is timed out") t.Fatalf("health check is timed out")
case <-time.After(healthCheckInterval + 100*time.Millisecond): case <-time.After(getHealthCheckInterval() + 100*time.Millisecond):
t.Fatalf("health check not received") t.Fatalf("health check not received")
} }
} }
@@ -46,7 +46,7 @@ func TestNewHealthFailed(t *testing.T) {
select { select {
case <-hc.Timeout: case <-hc.Timeout:
case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond): case <-time.After(getHealthCheckInterval() + getHealthCheckTimeout() + 100*time.Millisecond):
t.Fatalf("health check is not timed out") t.Fatalf("health check is not timed out")
} }
} }
@@ -89,7 +89,7 @@ func TestTimeoutReset(t *testing.T) {
hc.OnHCResponse() hc.OnHCResponse()
case <-hc.Timeout: case <-hc.Timeout:
t.Fatalf("health check is timed out") t.Fatalf("health check is timed out")
case <-time.After(healthCheckInterval + 100*time.Millisecond): case <-time.After(getHealthCheckInterval() + 100*time.Millisecond):
t.Fatalf("health check not received") t.Fatalf("health check not received")
} }
} }
@@ -118,13 +118,13 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
for _, tc := range testsCases { for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
originalInterval := healthCheckInterval originalInterval := getHealthCheckInterval()
originalTimeout := healthCheckTimeout originalTimeout := getHealthCheckTimeout()
healthCheckInterval = 1 * time.Second setHealthCheckInterval(1 * time.Second)
healthCheckTimeout = 500 * time.Millisecond setHealthCheckTimeout(500 * time.Millisecond)
defer func() { defer func() {
healthCheckInterval = originalInterval setHealthCheckInterval(originalInterval)
healthCheckTimeout = originalTimeout setHealthCheckTimeout(originalTimeout)
}() }()
//nolint:tenv //nolint:tenv
@@ -155,7 +155,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
} }
}() }()
testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + getHealthCheckInterval()
select { select {
case <-sender.Timeout: case <-sender.Timeout: