mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 15:56:39 +00:00
update some tests
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
//go:build !race
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
@@ -220,10 +219,8 @@ func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []by
|
||||
})
|
||||
}
|
||||
|
||||
func (t *UDPTracker) getConnections() map[ConnKey]*UDPConnTrack {
|
||||
func (t *UDPTracker) getConnectionsLen() int {
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
copyConn := make(map[ConnKey]*UDPConnTrack, len(t.connections))
|
||||
maps.Copy(copyConn, t.connections)
|
||||
return copyConn
|
||||
return len(t.connections)
|
||||
}
|
||||
|
||||
@@ -202,13 +202,13 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify initial connections
|
||||
assert.Len(t, tracker.getConnections(), 2)
|
||||
assert.Equal(t, 2, tracker.getConnectionsLen())
|
||||
|
||||
// Wait for connection timeout and cleanup interval
|
||||
time.Sleep(timeout + 2*cleanupInterval)
|
||||
|
||||
tracker.mutex.RLock()
|
||||
connCount := len(tracker.getConnections())
|
||||
connCount := tracker.getConnectionsLen()
|
||||
tracker.mutex.RUnlock()
|
||||
|
||||
// Verify connections were cleaned up
|
||||
|
||||
@@ -50,7 +50,7 @@ type upstreamResolverBase struct {
|
||||
upstreamClient upstreamClient
|
||||
upstreamServers []string
|
||||
domain string
|
||||
disabled bool
|
||||
disabled atomic.Bool
|
||||
failsCount atomic.Int32
|
||||
successCount atomic.Int32
|
||||
failsTillDeact int32
|
||||
@@ -176,7 +176,7 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
|
||||
u.mutex.Lock()
|
||||
defer u.mutex.Unlock()
|
||||
|
||||
if u.failsCount.Load() < u.failsTillDeact || u.disabled {
|
||||
if u.failsCount.Load() < u.failsTillDeact || u.disabled.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -305,7 +305,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
||||
u.failsCount.Store(0)
|
||||
u.successCount.Add(1)
|
||||
u.reactivate()
|
||||
u.disabled = false
|
||||
u.disabled.Store(false)
|
||||
}
|
||||
|
||||
// isTimeout returns true if the given error is a network timeout error.
|
||||
@@ -320,14 +320,14 @@ func isTimeout(err error) bool {
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) disable(err error) {
|
||||
if u.disabled {
|
||||
if u.disabled.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
|
||||
u.successCount.Store(0)
|
||||
u.deactivate(err)
|
||||
u.disabled = true
|
||||
u.disabled.Store(true)
|
||||
go u.waitUntilResponse()
|
||||
}
|
||||
|
||||
|
||||
@@ -160,7 +160,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if !resolver.disabled {
|
||||
if !resolver.disabled.Load() {
|
||||
t.Errorf("resolver should be Disabled")
|
||||
return
|
||||
}
|
||||
@@ -180,7 +180,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if resolver.disabled {
|
||||
if resolver.disabled.Load() {
|
||||
t.Errorf("should be enabled")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user