mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-03 14:39:55 +00:00
Compare commits
14 Commits
nb-interfa
...
add/race-t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c2180e4bb2 | ||
|
|
fc735d1337 | ||
|
|
33f4c3bd3f | ||
|
|
b6cef0cd26 | ||
|
|
cc78a3c65f | ||
|
|
cc1c77f6dc | ||
|
|
d7d57a4ec4 | ||
|
|
09d0fea5ca | ||
|
|
4e737a482b | ||
|
|
62deb64f5f | ||
|
|
bdb38dfa57 | ||
|
|
84988b4d53 | ||
|
|
7b4cc63054 | ||
|
|
e66412da1b |
2
.github/workflows/golang-test-darwin.yml
vendored
2
.github/workflows/golang-test-darwin.yml
vendored
@@ -42,5 +42,5 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -race -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
||||
|
||||
|
||||
8
.github/workflows/golang-test-linux.yml
vendored
8
.github/workflows/golang-test-linux.yml
vendored
@@ -103,7 +103,11 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
include:
|
||||
- arch: "386"
|
||||
raceFlag: ""
|
||||
- arch: "amd64"
|
||||
raceFlag: "-race"
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Install Go
|
||||
@@ -144,7 +148,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' ${{ matrix.raceFlag }} -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
|
||||
14
.github/workflows/golang-test-windows.yml
vendored
14
.github/workflows/golang-test-windows.yml
vendored
@@ -63,10 +63,16 @@ jobs:
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
|
||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/client/iface' } )" >> $env:GITHUB_ENV
|
||||
|
||||
- name: test
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||
- name: test output
|
||||
- name: test without iface
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -race -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||
- name: test output without iface
|
||||
if: ${{ always() }}
|
||||
run: Get-Content test-out.txt
|
||||
# todo: remove this once iface tests are stable with race flag
|
||||
- name: test iface
|
||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ./client/iface/... > test-out.txt 2>&1"
|
||||
- name: test output iface
|
||||
if: ${{ always() }}
|
||||
run: Get-Content test-out.txt
|
||||
|
||||
@@ -307,7 +307,7 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||
} else {
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, ""),
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil),
|
||||
)
|
||||
}
|
||||
return statusOutputString
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !race
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
|
||||
@@ -26,7 +26,6 @@ var (
|
||||
statusFilter string
|
||||
ipsFilterMap map[string]struct{}
|
||||
prefixNamesFilterMap map[string]struct{}
|
||||
connectionTypeFilter string
|
||||
)
|
||||
|
||||
var statusCmd = &cobra.Command{
|
||||
@@ -46,7 +45,6 @@ func init() {
|
||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
||||
}
|
||||
|
||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -91,7 +89,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter)
|
||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap)
|
||||
var statusOutputString string
|
||||
switch {
|
||||
case detailFlag:
|
||||
@@ -158,15 +156,6 @@ func parseFilters() error {
|
||||
enableDetailFlagWhenFilterFlag()
|
||||
}
|
||||
|
||||
switch strings.ToLower(connectionTypeFilter) {
|
||||
case "", "p2p", "relayed":
|
||||
if strings.ToLower(connectionTypeFilter) != "" {
|
||||
enableDetailFlagWhenFilterFlag()
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("wrong connection-type filter, should be one of P2P|Relayed, got: %s", connectionTypeFilter)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -218,3 +218,9 @@ func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []by
|
||||
TxBytes: conn.BytesTx.Load(),
|
||||
})
|
||||
}
|
||||
|
||||
func (t *UDPTracker) getConnectionsLen() int {
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
return len(t.connections)
|
||||
}
|
||||
|
||||
@@ -202,13 +202,13 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify initial connections
|
||||
assert.Len(t, tracker.connections, 2)
|
||||
assert.Equal(t, 2, tracker.getConnectionsLen())
|
||||
|
||||
// Wait for connection timeout and cleanup interval
|
||||
time.Sleep(timeout + 2*cleanupInterval)
|
||||
|
||||
tracker.mutex.RLock()
|
||||
connCount := len(tracker.connections)
|
||||
connCount := tracker.getConnectionsLen()
|
||||
tracker.mutex.RUnlock()
|
||||
|
||||
// Verify connections were cleaned up
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
|
||||
func init() {
|
||||
listener := nbnet.NewListener()
|
||||
if listener.ListenConfig.Control != nil {
|
||||
*wireguard.ControlFns = append(*wireguard.ControlFns, listener.ListenConfig.Control)
|
||||
}
|
||||
}
|
||||
12
client/iface/bind/control_android.go
Normal file
12
client/iface/bind/control_android.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// ControlFns is not thread safe and should only be modified during init.
|
||||
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
|
||||
}
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type RecvMessage struct {
|
||||
@@ -154,7 +153,7 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
||||
|
||||
s.udpMux = NewUniversalUDPMuxDefault(
|
||||
UniversalUDPMuxParams{
|
||||
UDPConn: nbnet.WrapUDPConn(conn),
|
||||
UDPConn: conn,
|
||||
Net: s.transportNet,
|
||||
FilterFn: s.filterFn,
|
||||
WGAddress: s.address,
|
||||
|
||||
@@ -296,20 +296,14 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
||||
return
|
||||
}
|
||||
|
||||
var allAddresses []string
|
||||
m.addressMapMu.Lock()
|
||||
defer m.addressMapMu.Unlock()
|
||||
|
||||
for _, c := range removedConns {
|
||||
addresses := c.getAddresses()
|
||||
allAddresses = append(allAddresses, addresses...)
|
||||
}
|
||||
|
||||
m.addressMapMu.Lock()
|
||||
for _, addr := range allAddresses {
|
||||
delete(m.addressMap, addr)
|
||||
}
|
||||
m.addressMapMu.Unlock()
|
||||
|
||||
for _, addr := range allAddresses {
|
||||
m.notifyAddressRemoval(addr)
|
||||
for _, addr := range addresses {
|
||||
delete(m.addressMap, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -357,13 +351,14 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
|
||||
}
|
||||
|
||||
m.addressMapMu.Lock()
|
||||
defer m.addressMapMu.Unlock()
|
||||
|
||||
existing, ok := m.addressMap[addr]
|
||||
if !ok {
|
||||
existing = []*udpMuxedConn{}
|
||||
}
|
||||
existing = append(existing, conn)
|
||||
m.addressMap[addr] = existing
|
||||
m.addressMapMu.Unlock()
|
||||
|
||||
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
|
||||
}
|
||||
@@ -391,12 +386,12 @@ func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) erro
|
||||
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
|
||||
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
|
||||
// We will then forward STUN packets to each of these connections.
|
||||
m.addressMapMu.RLock()
|
||||
m.addressMapMu.Lock()
|
||||
var destinationConnList []*udpMuxedConn
|
||||
if storedConns, ok := m.addressMap[addr.String()]; ok {
|
||||
destinationConnList = append(destinationConnList, storedConns...)
|
||||
}
|
||||
m.addressMapMu.RUnlock()
|
||||
m.addressMapMu.Unlock()
|
||||
|
||||
var isIPv6 bool
|
||||
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
//go:build !ios
|
||||
|
||||
package bind
|
||||
|
||||
import (
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
|
||||
wrapped, ok := m.params.UDPConn.(*UDPConn)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
nbnetConn, ok := wrapped.GetPacketConn().(*nbnet.UDPConn)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
nbnetConn.RemoveAddress(addr)
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build ios
|
||||
|
||||
package bind
|
||||
|
||||
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
|
||||
// iOS doesn't support nbnet hooks, so this is a no-op
|
||||
}
|
||||
@@ -62,7 +62,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
||||
|
||||
// wrap UDP connection, process server reflexive messages
|
||||
// before they are passed to the UDPMux connection handler (connWorker)
|
||||
m.params.UDPConn = &UDPConn{
|
||||
m.params.UDPConn = &udpConn{
|
||||
PacketConn: params.UDPConn,
|
||||
mux: m,
|
||||
logger: params.Logger,
|
||||
@@ -70,6 +70,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
||||
address: params.WGAddress,
|
||||
}
|
||||
|
||||
// embed UDPMux
|
||||
udpMuxParams := UDPMuxParams{
|
||||
Logger: params.Logger,
|
||||
UDPConn: m.params.UDPConn,
|
||||
@@ -113,8 +114,8 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
|
||||
type UDPConn struct {
|
||||
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
|
||||
type udpConn struct {
|
||||
net.PacketConn
|
||||
mux *UniversalUDPMuxDefault
|
||||
logger logging.LeveledLogger
|
||||
@@ -124,12 +125,7 @@ type UDPConn struct {
|
||||
address wgaddr.Address
|
||||
}
|
||||
|
||||
// GetPacketConn returns the underlying PacketConn
|
||||
func (u *UDPConn) GetPacketConn() net.PacketConn {
|
||||
return u.PacketConn
|
||||
}
|
||||
|
||||
func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
if u.filterFn == nil {
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
@@ -141,21 +137,21 @@ func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||
return u.handleUncachedAddress(b, addr)
|
||||
}
|
||||
|
||||
func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
|
||||
func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
|
||||
if isRouted {
|
||||
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
|
||||
}
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
|
||||
func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
|
||||
if err := u.performFilterCheck(addr); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return u.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
||||
func (u *udpConn) performFilterCheck(addr net.Addr) error {
|
||||
host, err := getHostFromAddr(addr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get host from address %s: %v", addr, err)
|
||||
|
||||
@@ -3,4 +3,4 @@
|
||||
package configurer
|
||||
|
||||
// WgInterfaceDefault is a default interface name of Netbird
|
||||
const WgInterfaceDefault = "nb0"
|
||||
const WgInterfaceDefault = "wt0"
|
||||
|
||||
@@ -39,7 +39,7 @@ const (
|
||||
)
|
||||
|
||||
var defaultInterfaceBlacklist = []string{
|
||||
iface.WgInterfaceDefault, "nb", "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
||||
iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
|
||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
|
||||
}
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ type upstreamResolverBase struct {
|
||||
upstreamClient upstreamClient
|
||||
upstreamServers []string
|
||||
domain string
|
||||
disabled bool
|
||||
disabled atomic.Bool
|
||||
failsCount atomic.Int32
|
||||
successCount atomic.Int32
|
||||
failsTillDeact int32
|
||||
@@ -176,7 +176,7 @@ func (u *upstreamResolverBase) checkUpstreamFails(err error) {
|
||||
u.mutex.Lock()
|
||||
defer u.mutex.Unlock()
|
||||
|
||||
if u.failsCount.Load() < u.failsTillDeact || u.disabled {
|
||||
if u.failsCount.Load() < u.failsTillDeact || u.disabled.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -305,7 +305,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
||||
u.failsCount.Store(0)
|
||||
u.successCount.Add(1)
|
||||
u.reactivate()
|
||||
u.disabled = false
|
||||
u.disabled.Store(false)
|
||||
}
|
||||
|
||||
// isTimeout returns true if the given error is a network timeout error.
|
||||
@@ -320,14 +320,14 @@ func isTimeout(err error) bool {
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) disable(err error) {
|
||||
if u.disabled {
|
||||
if u.disabled.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
|
||||
u.successCount.Store(0)
|
||||
u.deactivate(err)
|
||||
u.disabled = true
|
||||
u.disabled.Store(true)
|
||||
go u.waitUntilResponse()
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -135,32 +136,41 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error { return nil },
|
||||
}
|
||||
|
||||
lmux := sync.Mutex{}
|
||||
failed := false
|
||||
resolver.deactivate = func(error) {
|
||||
lmux.Lock()
|
||||
failed = true
|
||||
lmux.Unlock()
|
||||
}
|
||||
|
||||
reactivated := false
|
||||
resolver.reactivate = func() {
|
||||
lmux.Lock()
|
||||
reactivated = true
|
||||
lmux.Unlock()
|
||||
}
|
||||
|
||||
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA))
|
||||
|
||||
if !failed {
|
||||
lmux.Lock()
|
||||
failedCheck := failed
|
||||
lmux.Unlock()
|
||||
if !failedCheck {
|
||||
t.Errorf("expected that resolving was deactivated")
|
||||
return
|
||||
}
|
||||
|
||||
if !resolver.disabled {
|
||||
if !resolver.disabled.Load() {
|
||||
t.Errorf("resolver should be Disabled")
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond * 200)
|
||||
|
||||
if !reactivated {
|
||||
lmux.Lock()
|
||||
checkReactivated := reactivated
|
||||
lmux.Unlock()
|
||||
if !checkReactivated {
|
||||
t.Errorf("expected that resolving was reactivated")
|
||||
return
|
||||
}
|
||||
@@ -170,7 +180,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if resolver.disabled {
|
||||
if resolver.disabled.Load() {
|
||||
t.Errorf("should be enabled")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,6 +61,7 @@ import (
|
||||
signal "github.com/netbirdio/netbird/signal/client"
|
||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
||||
@@ -137,6 +138,9 @@ type Engine struct {
|
||||
|
||||
connMgr *ConnMgr
|
||||
|
||||
beforePeerHook nbnet.AddHookFunc
|
||||
afterPeerHook nbnet.RemoveHookFunc
|
||||
|
||||
// rpManager is a Rosenpass manager
|
||||
rpManager *rosenpass.Manager
|
||||
|
||||
@@ -405,8 +409,12 @@ func (e *Engine) Start() error {
|
||||
DisableClientRoutes: e.config.DisableClientRoutes,
|
||||
DisableServerRoutes: e.config.DisableServerRoutes,
|
||||
})
|
||||
if err := e.routeManager.Init(); err != nil {
|
||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to initialize route manager: %s", err)
|
||||
} else {
|
||||
e.beforePeerHook = beforePeerHook
|
||||
e.afterPeerHook = afterPeerHook
|
||||
}
|
||||
|
||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||
@@ -828,7 +836,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
}
|
||||
go func() {
|
||||
// blocking
|
||||
err = e.sshServer.Start()
|
||||
e.syncMsgMux.Lock()
|
||||
sshServer := e.sshServer
|
||||
e.syncMsgMux.Unlock()
|
||||
err = sshServer.Start()
|
||||
if err != nil {
|
||||
// will throw error when we stop it even if it is a graceful stop
|
||||
log.Debugf("stopped SSH server with error %v", err)
|
||||
@@ -843,6 +854,8 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
}
|
||||
} else if !isNil(e.sshServer) {
|
||||
// Disable SSH server request, so stop it if it was running
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
err := e.sshServer.Stop()
|
||||
if err != nil {
|
||||
log.Warnf("failed to stop SSH server %v", err)
|
||||
@@ -1253,6 +1266,10 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
||||
return fmt.Errorf("peer already exists: %s", peerKey)
|
||||
}
|
||||
|
||||
if e.beforePeerHook != nil && e.afterPeerHook != nil {
|
||||
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
||||
conn.AddAfterRemovePeerHook(e.afterPeerHook)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -400,7 +400,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
StatusRecorder: engine.statusRecorder,
|
||||
RelayManager: relayMgr,
|
||||
})
|
||||
err = engine.routeManager.Init()
|
||||
_, _, err = engine.routeManager.Init()
|
||||
require.NoError(t, err)
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
@@ -1393,7 +1393,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
if runtime.GOOS == "darwin" {
|
||||
ifaceName = fmt.Sprintf("utun1%d", i)
|
||||
} else {
|
||||
ifaceName = fmt.Sprintf("nb%d", i)
|
||||
ifaceName = fmt.Sprintf("wt%d", i)
|
||||
}
|
||||
|
||||
wgPort := 33100 + i
|
||||
|
||||
@@ -102,3 +102,11 @@ func (m *Manager) notify(peerConnID peerid.ConnID) {
|
||||
case m.OnActivityChan <- peerConnID:
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) getPeerListener(peerConnID peerid.ConnID) (*Listener, bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
listener, ok := m.peers[peerConnID]
|
||||
return listener, ok
|
||||
}
|
||||
|
||||
@@ -50,8 +50,11 @@ func TestManager_MonitorPeerActivity(t *testing.T) {
|
||||
if err := mgr.MonitorPeerActivity(peerCfg1); err != nil {
|
||||
t.Fatalf("failed to monitor peer activity: %v", err)
|
||||
}
|
||||
|
||||
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil {
|
||||
listener, ok := mgr.getPeerListener(peerCfg1.PeerConnID)
|
||||
if !ok {
|
||||
t.Fatalf("failed to get peer listener: %s", peerCfg1.PublicKey)
|
||||
}
|
||||
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
|
||||
t.Fatalf("failed to trigger activity: %v", err)
|
||||
}
|
||||
|
||||
@@ -83,7 +86,12 @@ func TestManager_RemovePeerActivity(t *testing.T) {
|
||||
t.Fatalf("failed to monitor peer activity: %v", err)
|
||||
}
|
||||
|
||||
addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()
|
||||
peer1Listener, ok := mgr.getPeerListener(peerCfg1.PeerConnID)
|
||||
if !ok {
|
||||
t.Fatalf("failed to get peer listener: %s", peerCfg1.PublicKey)
|
||||
}
|
||||
|
||||
addr := peer1Listener.conn.LocalAddr().String()
|
||||
|
||||
mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID)
|
||||
|
||||
@@ -128,11 +136,20 @@ func TestManager_MultiPeerActivity(t *testing.T) {
|
||||
t.Fatalf("failed to monitor peer activity: %v", err)
|
||||
}
|
||||
|
||||
if err := trigger(mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()); err != nil {
|
||||
peer1Listener, ok := mgr.getPeerListener(peerCfg1.PeerConnID)
|
||||
if !ok {
|
||||
t.Fatalf("failed to get peer listener: %s", peerCfg1.PublicKey)
|
||||
}
|
||||
if err := trigger(peer1Listener.conn.LocalAddr().String()); err != nil {
|
||||
t.Fatalf("failed to trigger activity: %v", err)
|
||||
}
|
||||
|
||||
if err := trigger(mgr.peers[peerCfg2.PeerConnID].conn.LocalAddr().String()); err != nil {
|
||||
peer2Listener, ok := mgr.getPeerListener(peerCfg2.PeerConnID)
|
||||
if !ok {
|
||||
t.Fatalf("failed to get peer listener: %s", peerCfg2.PublicKey)
|
||||
}
|
||||
|
||||
if err := trigger(peer2Listener.conn.LocalAddr().String()); err != nil {
|
||||
t.Fatalf("failed to trigger activity: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ type mockIFaceMapper struct {
|
||||
}
|
||||
|
||||
func (m *mockIFaceMapper) Name() string {
|
||||
return "nb0"
|
||||
return "wt0"
|
||||
}
|
||||
|
||||
func (m *mockIFaceMapper) Address() wgaddr.Address {
|
||||
|
||||
@@ -22,6 +22,19 @@ const (
|
||||
)
|
||||
|
||||
var checkChangeFn = checkChange
|
||||
var mux sync.Mutex
|
||||
|
||||
func getCheckChangeFn() func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
return checkChangeFn
|
||||
}
|
||||
|
||||
func setCheckChangeFn(fn func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error) {
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
checkChangeFn = fn
|
||||
}
|
||||
|
||||
// NetworkMonitor watches for changes in network configuration.
|
||||
type NetworkMonitor struct {
|
||||
@@ -120,7 +133,8 @@ func (nw *NetworkMonitor) Stop() {
|
||||
func (nw *NetworkMonitor) checkChanges(ctx context.Context, event chan struct{}, nexthop4 systemops.Nexthop, nexthop6 systemops.Nexthop) {
|
||||
defer close(event)
|
||||
for {
|
||||
if err := checkChangeFn(ctx, nexthop4, nexthop6); err != nil {
|
||||
checkFn := getCheckChangeFn()
|
||||
if err := checkFn(ctx, nexthop4, nexthop6); err != nil {
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
log.Errorf("Network monitor: failed to check for changes: %v", err)
|
||||
}
|
||||
|
||||
@@ -25,10 +25,10 @@ func (m *MocMultiEvent) checkChange(ctx context.Context, nexthopv4, nexthopv6 sy
|
||||
}
|
||||
|
||||
func TestNetworkMonitor_Close(t *testing.T) {
|
||||
checkChangeFn = func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
setCheckChangeFn(func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
}
|
||||
})
|
||||
nw := New()
|
||||
|
||||
var resErr error
|
||||
@@ -48,7 +48,7 @@ func TestNetworkMonitor_Close(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNetworkMonitor_Event(t *testing.T) {
|
||||
checkChangeFn = func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
setCheckChangeFn(func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
timeout, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
select {
|
||||
@@ -57,7 +57,7 @@ func TestNetworkMonitor_Event(t *testing.T) {
|
||||
case <-timeout.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
})
|
||||
nw := New()
|
||||
defer nw.Stop()
|
||||
|
||||
@@ -77,7 +77,7 @@ func TestNetworkMonitor_Event(t *testing.T) {
|
||||
func TestNetworkMonitor_MultiEvent(t *testing.T) {
|
||||
eventsRepeated := 3
|
||||
me := &MocMultiEvent{counter: eventsRepeated}
|
||||
checkChangeFn = me.checkChange
|
||||
setCheckChangeFn(me.checkChange)
|
||||
|
||||
nw := New()
|
||||
defer nw.Stop()
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
)
|
||||
|
||||
@@ -98,13 +99,17 @@ type Conn struct {
|
||||
|
||||
statusRelay *worker.AtomicWorkerStatus
|
||||
statusICE *worker.AtomicWorkerStatus
|
||||
currentConnPriority conntype.ConnPriority
|
||||
currentConnPriority conntype.ConnPriorityStore
|
||||
opened bool // this flag is used to prevent close in case of not opened connection
|
||||
|
||||
workerICE *WorkerICE
|
||||
workerRelay *WorkerRelay
|
||||
wgWatcherWg sync.WaitGroup
|
||||
|
||||
connIDRelay nbnet.ConnectionID
|
||||
connIDICE nbnet.ConnectionID
|
||||
beforeAddPeerHooks []nbnet.AddHookFunc
|
||||
afterRemovePeerHooks []nbnet.RemoveHookFunc
|
||||
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
|
||||
rosenpassRemoteKey []byte
|
||||
|
||||
@@ -262,6 +267,8 @@ func (conn *Conn) Close(signalToRemote bool) {
|
||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
|
||||
conn.freeUpConnID()
|
||||
|
||||
if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil {
|
||||
conn.onDisconnected(conn.config.WgConfig.RemoteKey)
|
||||
}
|
||||
@@ -276,7 +283,7 @@ func (conn *Conn) Close(signalToRemote bool) {
|
||||
// doesn't block, discards the message if connection wasn't ready
|
||||
func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
|
||||
conn.dumpState.RemoteAnswer()
|
||||
conn.Log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay)
|
||||
conn.Log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority.Get(), conn.statusICE, conn.statusRelay)
|
||||
return conn.handshaker.OnRemoteAnswer(answer)
|
||||
}
|
||||
|
||||
@@ -286,6 +293,13 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa
|
||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||
}
|
||||
|
||||
func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) {
|
||||
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
|
||||
}
|
||||
func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) {
|
||||
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
|
||||
}
|
||||
|
||||
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
|
||||
func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) {
|
||||
conn.onConnected = handler
|
||||
@@ -339,8 +353,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
|
||||
// this never should happen, because Relay is the lower priority and ICE always close the deprecated connection before upgrade
|
||||
// todo consider to remove this check
|
||||
if conn.currentConnPriority > priority {
|
||||
conn.Log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority, priority)
|
||||
if conn.currentConnPriority.Get() > priority {
|
||||
conn.Log.Infof("current connection priority (%s) is higher than the new one (%s), do not upgrade connection", conn.currentConnPriority.Get(), priority)
|
||||
conn.statusICE.SetConnected()
|
||||
conn.updateIceState(iceConnInfo)
|
||||
return
|
||||
@@ -373,6 +387,10 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
ep = directEp
|
||||
}
|
||||
|
||||
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
|
||||
conn.Log.Errorf("Before add peer hook failed: %v", err)
|
||||
}
|
||||
|
||||
conn.workerRelay.DisableWgWatcher()
|
||||
// todo consider to run conn.wgWatcherWg.Wait() here
|
||||
|
||||
@@ -390,7 +408,7 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
}
|
||||
wgConfigWorkaround()
|
||||
|
||||
conn.currentConnPriority = priority
|
||||
conn.currentConnPriority.Set(priority)
|
||||
conn.statusICE.SetConnected()
|
||||
conn.updateIceState(iceConnInfo)
|
||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||
@@ -427,10 +445,10 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
}()
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
conn.currentConnPriority.Set(conntype.Relay)
|
||||
} else {
|
||||
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
||||
conn.currentConnPriority = conntype.None
|
||||
conn.currentConnPriority.Set(conntype.None)
|
||||
}
|
||||
|
||||
changed := conn.statusICE.Get() != worker.StatusDisconnected
|
||||
@@ -478,13 +496,17 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
||||
|
||||
if conn.isICEActive() {
|
||||
conn.Log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
|
||||
conn.Log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.Get().String())
|
||||
conn.setRelayedProxy(wgProxy)
|
||||
conn.statusRelay.SetConnected()
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
|
||||
conn.Log.Errorf("Before add peer hook failed: %v", err)
|
||||
}
|
||||
|
||||
wgProxy.Work()
|
||||
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
@@ -502,7 +524,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
|
||||
wgConfigWorkaround()
|
||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
conn.currentConnPriority.Set(conntype.Relay)
|
||||
conn.statusRelay.SetConnected()
|
||||
conn.setRelayedProxy(wgProxy)
|
||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||
@@ -520,9 +542,9 @@ func (conn *Conn) onRelayDisconnected() {
|
||||
|
||||
conn.Log.Debugf("relay connection is disconnected")
|
||||
|
||||
if conn.currentConnPriority == conntype.Relay {
|
||||
if conn.currentConnPriority.Get() == conntype.Relay {
|
||||
conn.Log.Debugf("clean up WireGuard config")
|
||||
conn.currentConnPriority = conntype.None
|
||||
conn.currentConnPriority.Set(conntype.None)
|
||||
}
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
@@ -604,7 +626,7 @@ func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
|
||||
func (conn *Conn) setStatusToDisconnected() {
|
||||
conn.statusRelay.SetDisconnected()
|
||||
conn.statusICE.SetDisconnected()
|
||||
conn.currentConnPriority = conntype.None
|
||||
conn.currentConnPriority.Set(conntype.None)
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
@@ -647,7 +669,7 @@ func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (conn *Conn) isRelayed() bool {
|
||||
switch conn.currentConnPriority {
|
||||
switch conn.currentConnPriority.Get() {
|
||||
case conntype.Relay, conntype.ICETurn:
|
||||
return true
|
||||
default:
|
||||
@@ -685,6 +707,36 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
return true
|
||||
}
|
||||
|
||||
func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
|
||||
conn.connIDICE = nbnet.GenerateConnID()
|
||||
for _, hook := range conn.beforeAddPeerHooks {
|
||||
if err := hook(conn.connIDICE, ip); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *Conn) freeUpConnID() {
|
||||
if conn.connIDRelay != "" {
|
||||
for _, hook := range conn.afterRemovePeerHooks {
|
||||
if err := hook(conn.connIDRelay); err != nil {
|
||||
conn.Log.Errorf("After remove peer hook failed: %v", err)
|
||||
}
|
||||
}
|
||||
conn.connIDRelay = ""
|
||||
}
|
||||
|
||||
if conn.connIDICE != "" {
|
||||
for _, hook := range conn.afterRemovePeerHooks {
|
||||
if err := hook(conn.connIDICE); err != nil {
|
||||
conn.Log.Errorf("After remove peer hook failed: %v", err)
|
||||
}
|
||||
}
|
||||
conn.connIDICE = ""
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||
conn.Log.Debugf("setup proxied WireGuard connection")
|
||||
udpAddr := &net.UDPAddr{
|
||||
@@ -701,11 +753,11 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||
}
|
||||
|
||||
func (conn *Conn) isReadyToUpgrade() bool {
|
||||
return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
|
||||
return conn.wgProxyRelay != nil && conn.currentConnPriority.Get() != conntype.Relay
|
||||
}
|
||||
|
||||
func (conn *Conn) isICEActive() bool {
|
||||
return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected
|
||||
return (conn.currentConnPriority.Get() == conntype.ICEP2P || conn.currentConnPriority.Get() == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected
|
||||
}
|
||||
|
||||
func (conn *Conn) removeWgPeer() error {
|
||||
|
||||
@@ -2,6 +2,7 @@ package conntype
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -11,7 +12,7 @@ const (
|
||||
ICEP2P ConnPriority = 3
|
||||
)
|
||||
|
||||
type ConnPriority int
|
||||
type ConnPriority int32
|
||||
|
||||
func (cp ConnPriority) String() string {
|
||||
switch cp {
|
||||
@@ -27,3 +28,15 @@ func (cp ConnPriority) String() string {
|
||||
return fmt.Sprintf("ConnPriority(%d)", cp)
|
||||
}
|
||||
}
|
||||
|
||||
type ConnPriorityStore struct {
|
||||
store atomic.Int32
|
||||
}
|
||||
|
||||
func (cps *ConnPriorityStore) Get() ConnPriority {
|
||||
return ConnPriority(cps.store.Load())
|
||||
}
|
||||
|
||||
func (cps *ConnPriorityStore) Set(cp ConnPriority) {
|
||||
cps.store.Store(int32(cp))
|
||||
}
|
||||
|
||||
@@ -9,30 +9,54 @@ type mocListener struct {
|
||||
lastState int
|
||||
wg sync.WaitGroup
|
||||
peers int
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func (l *mocListener) OnConnected() {
|
||||
l.mux.Lock()
|
||||
defer l.mux.Unlock()
|
||||
l.lastState = stateConnected
|
||||
l.wg.Done()
|
||||
}
|
||||
func (l *mocListener) OnDisconnected() {
|
||||
l.mux.Lock()
|
||||
defer l.mux.Unlock()
|
||||
l.lastState = stateDisconnected
|
||||
l.wg.Done()
|
||||
}
|
||||
func (l *mocListener) OnConnecting() {
|
||||
l.mux.Lock()
|
||||
defer l.mux.Unlock()
|
||||
l.lastState = stateConnecting
|
||||
l.wg.Done()
|
||||
}
|
||||
func (l *mocListener) OnDisconnecting() {
|
||||
l.mux.Lock()
|
||||
defer l.mux.Unlock()
|
||||
l.lastState = stateDisconnecting
|
||||
l.wg.Done()
|
||||
|
||||
}
|
||||
|
||||
func (l *mocListener) getLastState() int {
|
||||
l.mux.Lock()
|
||||
defer l.mux.Unlock()
|
||||
return l.lastState
|
||||
}
|
||||
|
||||
func (l *mocListener) OnAddressChanged(host, addr string) {
|
||||
|
||||
}
|
||||
func (l *mocListener) OnPeersListChanged(size int) {
|
||||
l.mux.Lock()
|
||||
l.peers = size
|
||||
l.mux.Unlock()
|
||||
}
|
||||
|
||||
func (l *mocListener) getPeers() int {
|
||||
l.mux.Lock()
|
||||
defer l.mux.Unlock()
|
||||
return l.peers
|
||||
}
|
||||
|
||||
func (l *mocListener) setWaiter() {
|
||||
@@ -77,7 +101,7 @@ func Test_notifier_SetListener(t *testing.T) {
|
||||
n.lastNotification = stateConnecting
|
||||
n.setListener(listener)
|
||||
listener.wait()
|
||||
if listener.lastState != n.lastNotification {
|
||||
if listener.getLastState() != n.lastNotification {
|
||||
t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification)
|
||||
}
|
||||
}
|
||||
@@ -91,7 +115,7 @@ func Test_notifier_RemoveListener(t *testing.T) {
|
||||
n.removeListener()
|
||||
n.peerListChanged(1)
|
||||
|
||||
if listener.peers != 0 {
|
||||
if listener.getPeers() != 0 {
|
||||
t.Errorf("invalid state: %d", listener.peers)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -812,7 +812,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
||||
}
|
||||
|
||||
params := common.HandlerParams{
|
||||
Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")},
|
||||
Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")},
|
||||
}
|
||||
// create new clientNetwork
|
||||
client := &Watcher{
|
||||
|
||||
@@ -44,7 +44,7 @@ import (
|
||||
|
||||
// Manager is a route manager interface
|
||||
type Manager interface {
|
||||
Init() error
|
||||
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
||||
UpdateRoutes(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
|
||||
ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
|
||||
TriggerSelection(route.HAMap)
|
||||
@@ -201,11 +201,11 @@ func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
||||
}
|
||||
|
||||
// Init sets up the routing
|
||||
func (m *DefaultManager) Init() error {
|
||||
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
m.routeSelector = m.initSelector()
|
||||
|
||||
if nbnet.CustomRoutingDisabled() || m.disableClientRoutes {
|
||||
return nil
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
if err := m.sysOps.CleanupRouting(nil); err != nil {
|
||||
@@ -219,12 +219,13 @@ func (m *DefaultManager) Init() error {
|
||||
|
||||
ips := resolveURLsToIPs(initialAddresses)
|
||||
|
||||
if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil {
|
||||
return fmt.Errorf("setup routing: %w", err)
|
||||
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("setup routing: %w", err)
|
||||
}
|
||||
|
||||
log.Info("Routing setup complete")
|
||||
return nil
|
||||
return beforePeerHook, afterPeerHook, nil
|
||||
}
|
||||
|
||||
func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
|
||||
|
||||
@@ -430,7 +430,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
StatusRecorder: statusRecorder,
|
||||
})
|
||||
|
||||
err = routeManager.Init()
|
||||
_, _, err = routeManager.Init()
|
||||
|
||||
require.NoError(t, err, "should init route manager")
|
||||
defer routeManager.Stop(nil)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
// MockManager is the mock instance of a route manager
|
||||
@@ -22,8 +23,8 @@ type MockManager struct {
|
||||
StopFunc func(manager *statemanager.Manager)
|
||||
}
|
||||
|
||||
func (m *MockManager) Init() error {
|
||||
return nil
|
||||
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface
|
||||
|
||||
@@ -33,4 +33,4 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||
|
||||
func (n *Notifier) GetInitialRouteRanges() []string {
|
||||
return []string{}
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||
@@ -57,10 +56,6 @@ type SysOps struct {
|
||||
// seq is an atomic counter for generating unique sequence numbers for route messages
|
||||
//nolint:unused // only used on BSD systems
|
||||
seq atomic.Uint32
|
||||
|
||||
localSubnetsCache []*net.IPNet
|
||||
localSubnetsCacheMu sync.RWMutex
|
||||
localSubnetsCacheTime time.Time
|
||||
}
|
||||
|
||||
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
||||
|
||||
@@ -10,10 +10,11 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
|
||||
return nil
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/libp2p/go-netroute"
|
||||
@@ -25,8 +24,6 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const localSubnetsCacheTTL = 15 * time.Minute
|
||||
|
||||
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
|
||||
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
||||
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
||||
@@ -34,7 +31,7 @@ var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
||||
|
||||
var ErrRoutingIsSeparate = errors.New("routing is separate")
|
||||
|
||||
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
stateManager.RegisterState(&ShutdownState{})
|
||||
|
||||
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||
@@ -78,10 +75,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
|
||||
|
||||
r.refCounter = refCounter
|
||||
|
||||
if err := r.setupHooks(initAddresses, stateManager); err != nil {
|
||||
return fmt.Errorf("setup hooks: %w", err)
|
||||
}
|
||||
return nil
|
||||
return r.setupHooks(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
// updateState updates state on every change so it will be persisted regularly
|
||||
@@ -134,14 +128,18 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
|
||||
return Nexthop{}, fmt.Errorf("get next hop: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.Intf)
|
||||
exitNextHop := nexthop
|
||||
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP)
|
||||
exitNextHop := Nexthop{
|
||||
IP: nexthop.IP,
|
||||
Intf: nexthop.Intf,
|
||||
}
|
||||
|
||||
vpnAddr := vpnIntf.Address().IP
|
||||
|
||||
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
||||
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
|
||||
log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop)
|
||||
|
||||
exitNextHop = initialNextHop
|
||||
}
|
||||
|
||||
@@ -154,37 +152,12 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
|
||||
}
|
||||
|
||||
func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) {
|
||||
r.localSubnetsCacheMu.RLock()
|
||||
cacheAge := time.Since(r.localSubnetsCacheTime)
|
||||
subnets := r.localSubnetsCache
|
||||
r.localSubnetsCacheMu.RUnlock()
|
||||
|
||||
if cacheAge > localSubnetsCacheTTL || subnets == nil {
|
||||
r.localSubnetsCacheMu.Lock()
|
||||
if time.Since(r.localSubnetsCacheTime) > localSubnetsCacheTTL || r.localSubnetsCache == nil {
|
||||
r.refreshLocalSubnetsCache()
|
||||
}
|
||||
subnets = r.localSubnetsCache
|
||||
r.localSubnetsCacheMu.Unlock()
|
||||
}
|
||||
|
||||
for _, subnet := range subnets {
|
||||
if subnet.Contains(prefix.Addr().AsSlice()) {
|
||||
return true, subnet
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) refreshLocalSubnetsCache() {
|
||||
localInterfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get local interfaces: %v", err)
|
||||
return
|
||||
return false, nil
|
||||
}
|
||||
|
||||
var newSubnets []*net.IPNet
|
||||
for _, intf := range localInterfaces {
|
||||
addrs, err := intf.Addrs()
|
||||
if err != nil {
|
||||
@@ -198,12 +171,14 @@ func (r *SysOps) refreshLocalSubnetsCache() {
|
||||
log.Errorf("Failed to convert address to IPNet: %v", addr)
|
||||
continue
|
||||
}
|
||||
newSubnets = append(newSubnets, ipnet)
|
||||
|
||||
if ipnet.Contains(prefix.Addr().AsSlice()) {
|
||||
return true, ipnet
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r.localSubnetsCache = newSubnets
|
||||
r.localSubnetsCacheTime = time.Now()
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
||||
@@ -289,7 +264,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
|
||||
return r.removeFromRouteTable(prefix, nextHop)
|
||||
}
|
||||
|
||||
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||
prefix, err := util.GetPrefixFromIP(ip)
|
||||
if err != nil {
|
||||
@@ -314,11 +289,9 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
||||
return nil
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, ip := range initAddresses {
|
||||
if err := beforeHook("init", ip); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err))
|
||||
log.Errorf("Failed to add route reference: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -327,11 +300,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
var result *multierror.Error
|
||||
for _, ip := range resolvedIPs {
|
||||
merr = multierror.Append(merr, beforeHook(connID, ip.IP))
|
||||
result = multierror.Append(result, beforeHook(connID, ip.IP))
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
return nberrors.FormatErrorOrNil(result)
|
||||
})
|
||||
|
||||
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
||||
@@ -346,16 +319,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
||||
return afterHook(connID)
|
||||
})
|
||||
|
||||
nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error {
|
||||
if _, err := r.refCounter.Decrement(prefix); err != nil {
|
||||
return fmt.Errorf("remove route reference: %w", err)
|
||||
}
|
||||
|
||||
r.updateState(stateManager)
|
||||
return nil
|
||||
})
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
return beforeHook, afterHook, nil
|
||||
}
|
||||
|
||||
func GetNextHop(ip netip.Addr) (Nexthop, error) {
|
||||
|
||||
@@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {
|
||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
err := r.SetupRouting(nil, nil)
|
||||
_, _, err := r.SetupRouting(nil, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting(nil))
|
||||
@@ -341,7 +341,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
|
||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
err := r.SetupRouting(nil, nil)
|
||||
_, _, err := r.SetupRouting(nil, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting(nil))
|
||||
@@ -484,7 +484,7 @@ func setupTestEnv(t *testing.T) {
|
||||
})
|
||||
|
||||
r := NewSysOps(wgInterface, nil)
|
||||
err := r.SetupRouting(nil, nil)
|
||||
_, _, err := r.SetupRouting(nil, nil)
|
||||
require.NoError(t, err, "setupRouting should not return err")
|
||||
t.Cleanup(func() {
|
||||
assert.NoError(t, r.CleanupRouting(nil))
|
||||
|
||||
@@ -10,13 +10,14 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
|
||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.prefixes = make(map[netip.Prefix]struct{})
|
||||
return nil
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
||||
|
||||
@@ -72,7 +72,7 @@ func getSetupRules() []ruleParams {
|
||||
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
||||
// This table is where a default route or other specific routes received from the management server are configured,
|
||||
// enabling VPN connectivity.
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) {
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) {
|
||||
if !nbnet.AdvancedRouting() {
|
||||
log.Infof("Using legacy routing setup")
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
@@ -89,7 +89,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
||||
rules := getSetupRules()
|
||||
for _, rule := range rules {
|
||||
if err := addRule(rule); err != nil {
|
||||
return fmt.Errorf("%s: %w", rule.description, err)
|
||||
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
||||
}
|
||||
originalSysctl = originalValues
|
||||
|
||||
return nil
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
||||
|
||||
@@ -252,7 +252,7 @@ func TestSysOps_validateRoute_InvalidPrefix(t *testing.T) {
|
||||
IP: wgNetwork.Addr(),
|
||||
Network: wgNetwork,
|
||||
},
|
||||
name: "nb0",
|
||||
name: "wt0",
|
||||
}
|
||||
|
||||
sysOps := &SysOps{
|
||||
|
||||
@@ -18,9 +18,10 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
const InfiniteLifetime = 0xffffffff
|
||||
@@ -136,7 +137,7 @@ const (
|
||||
RouteDeleted
|
||||
)
|
||||
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
|
||||
@@ -1330,13 +1330,6 @@ func (x *PeerState) GetRelayAddress() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *PeerState) GetConnectionType() string {
|
||||
if x.Relayed {
|
||||
return "Relayed"
|
||||
}
|
||||
return "P2P"
|
||||
}
|
||||
|
||||
// LocalPeerState contains the latest state of the local peer
|
||||
type LocalPeerState struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
|
||||
@@ -100,7 +100,7 @@ type OutputOverview struct {
|
||||
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
|
||||
}
|
||||
|
||||
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string) OutputOverview {
|
||||
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) OutputOverview {
|
||||
pbFullStatus := resp.GetFullStatus()
|
||||
|
||||
managementState := pbFullStatus.GetManagementState()
|
||||
@@ -118,7 +118,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status
|
||||
}
|
||||
|
||||
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
||||
peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
|
||||
peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter)
|
||||
|
||||
overview := OutputOverview{
|
||||
Peers: peersOverview,
|
||||
@@ -193,7 +193,6 @@ func mapPeers(
|
||||
prefixNamesFilter []string,
|
||||
prefixNamesFilterMap map[string]struct{},
|
||||
ipsFilter map[string]struct{},
|
||||
connectionTypeFilter string,
|
||||
) PeersStateOutput {
|
||||
var peersStateDetail []PeerStateDetailOutput
|
||||
peersConnected := 0
|
||||
@@ -209,7 +208,7 @@ func mapPeers(
|
||||
transferSent := int64(0)
|
||||
|
||||
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
||||
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter) {
|
||||
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter) {
|
||||
continue
|
||||
}
|
||||
if isPeerConnected {
|
||||
@@ -219,7 +218,10 @@ func mapPeers(
|
||||
remoteICE = pbPeerState.GetRemoteIceCandidateType()
|
||||
localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
|
||||
remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
|
||||
connType = pbPeerState.GetConnectionType()
|
||||
connType = "P2P"
|
||||
if pbPeerState.Relayed {
|
||||
connType = "Relayed"
|
||||
}
|
||||
relayServerAddress = pbPeerState.GetRelayAddress()
|
||||
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
|
||||
transferReceived = pbPeerState.GetBytesRx()
|
||||
@@ -540,11 +542,10 @@ func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bo
|
||||
return peersString
|
||||
}
|
||||
|
||||
func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string) bool {
|
||||
func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}) bool {
|
||||
statusEval := false
|
||||
ipEval := false
|
||||
nameEval := true
|
||||
connectionTypeEval := false
|
||||
|
||||
if statusFilter != "" {
|
||||
if !strings.EqualFold(peerStatus, statusFilter) {
|
||||
@@ -569,11 +570,8 @@ func skipDetailByFilters(peerState *proto.PeerState, peerStatus string, statusFi
|
||||
} else {
|
||||
nameEval = false
|
||||
}
|
||||
if connectionTypeFilter != "" && !strings.EqualFold(peerState.GetConnectionType(), connectionTypeFilter) {
|
||||
connectionTypeEval = true
|
||||
}
|
||||
|
||||
return statusEval || ipEval || nameEval || connectionTypeEval
|
||||
return statusEval || ipEval || nameEval
|
||||
}
|
||||
|
||||
func toIEC(b int64) string {
|
||||
|
||||
@@ -234,7 +234,7 @@ var overview = OutputOverview{
|
||||
}
|
||||
|
||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
||||
convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil, "")
|
||||
convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil)
|
||||
|
||||
assert.Equal(t, overview, convertedResult)
|
||||
}
|
||||
|
||||
@@ -433,7 +433,7 @@ func (s *serviceClient) collectDebugData(
|
||||
|
||||
var postUpStatusOutput string
|
||||
if postUpStatus != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "")
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil)
|
||||
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
}
|
||||
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||
@@ -450,7 +450,7 @@ func (s *serviceClient) collectDebugData(
|
||||
|
||||
var preDownStatusOutput string
|
||||
if preDownStatus != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "")
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil)
|
||||
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
}
|
||||
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
|
||||
@@ -581,7 +581,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
|
||||
|
||||
var statusOutput string
|
||||
if statusResp != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "")
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil)
|
||||
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
}
|
||||
|
||||
|
||||
2
go.mod
2
go.mod
@@ -257,6 +257,6 @@ replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-2
|
||||
|
||||
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
|
||||
|
||||
replace github.com/pion/ice/v3 => github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e
|
||||
replace github.com/pion/ice/v3 => github.com/netbirdio/ice/v3 v3.0.0-20250718163601-725c8ac53a31
|
||||
|
||||
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944
|
||||
|
||||
4
go.sum
4
go.sum
@@ -501,8 +501,8 @@ github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJE
|
||||
github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20250718163601-725c8ac53a31 h1:lr/CnQ9NnlHr4yjDaCqy3V1FW+y9DDpzqxu1+YXzXtc=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20250718163601-725c8ac53a31/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5 h1:Zfn8d83OVyELCdxgprcyXR3D8uqoxHtXE9PUxVXDx/w=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250718071730-f4d133556ff5/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
|
||||
@@ -424,10 +424,9 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
|
||||
}
|
||||
if group, ok := groupsMap[gid]; ok {
|
||||
minimum := api.GroupMinimum{
|
||||
Id: group.ID,
|
||||
Name: group.Name,
|
||||
PeersCount: len(group.Peers),
|
||||
ResourcesCount: len(group.Resources),
|
||||
Id: group.ID,
|
||||
Name: group.Name,
|
||||
PeersCount: len(group.Peers),
|
||||
}
|
||||
destinations = append(destinations, minimum)
|
||||
cache[gid] = minimum
|
||||
|
||||
@@ -2,6 +2,7 @@ package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
@@ -10,8 +11,21 @@ import (
|
||||
|
||||
var (
|
||||
reconnectingTimeout = 60 * time.Second
|
||||
mux sync.Mutex
|
||||
)
|
||||
|
||||
func getReconnectingTimeout() time.Duration {
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
return reconnectingTimeout
|
||||
}
|
||||
|
||||
func setReconnectingTimeout(timeout time.Duration) {
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
reconnectingTimeout = timeout
|
||||
}
|
||||
|
||||
// Guard manage the reconnection tries to the Relay server in case of disconnection event.
|
||||
type Guard struct {
|
||||
// OnNewRelayClient is a channel that is used to notify the relay manager about a new relay client instance.
|
||||
@@ -128,7 +142,7 @@ func exponentTicker(ctx context.Context) *backoff.Ticker {
|
||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 2 * time.Second,
|
||||
Multiplier: 2,
|
||||
MaxInterval: reconnectingTimeout,
|
||||
MaxInterval: getReconnectingTimeout(),
|
||||
Clock: backoff.SystemClock,
|
||||
}, ctx)
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ type Manager struct {
|
||||
|
||||
relayClient *Client
|
||||
// the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable
|
||||
relayClientMu sync.RWMutex
|
||||
relayClientMu sync.Mutex
|
||||
reconnectGuard *Guard
|
||||
|
||||
relayClients map[string]*RelayTrack
|
||||
@@ -124,8 +124,8 @@ func (m *Manager) Serve() error {
|
||||
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
|
||||
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
|
||||
func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
|
||||
m.relayClientMu.RLock()
|
||||
defer m.relayClientMu.RUnlock()
|
||||
m.relayClientMu.Lock()
|
||||
defer m.relayClientMu.Unlock()
|
||||
|
||||
if m.relayClient == nil {
|
||||
return nil, ErrRelayClientNotConnected
|
||||
@@ -155,8 +155,8 @@ func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (
|
||||
|
||||
// Ready returns true if the home Relay client is connected to the relay server.
|
||||
func (m *Manager) Ready() bool {
|
||||
m.relayClientMu.RLock()
|
||||
defer m.relayClientMu.RUnlock()
|
||||
m.relayClientMu.Lock()
|
||||
defer m.relayClientMu.Unlock()
|
||||
|
||||
if m.relayClient == nil {
|
||||
return false
|
||||
@@ -174,8 +174,8 @@ func (m *Manager) SetOnReconnectedListener(f func()) {
|
||||
// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
|
||||
// closed.
|
||||
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
|
||||
m.relayClientMu.RLock()
|
||||
defer m.relayClientMu.RUnlock()
|
||||
m.relayClientMu.Lock()
|
||||
defer m.relayClientMu.Unlock()
|
||||
|
||||
if m.relayClient == nil {
|
||||
return ErrRelayClientNotConnected
|
||||
@@ -199,8 +199,8 @@ func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServ
|
||||
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is
|
||||
// lost. This address will be sent to the target peer to choose the common relay server for the communication.
|
||||
func (m *Manager) RelayInstanceAddress() (string, error) {
|
||||
m.relayClientMu.RLock()
|
||||
defer m.relayClientMu.RUnlock()
|
||||
m.relayClientMu.Lock()
|
||||
defer m.relayClientMu.Unlock()
|
||||
|
||||
if m.relayClient == nil {
|
||||
return "", ErrRelayClientNotConnected
|
||||
@@ -300,9 +300,7 @@ func (m *Manager) onServerConnected() {
|
||||
func (m *Manager) onServerDisconnected(serverAddress string) {
|
||||
m.relayClientMu.Lock()
|
||||
if serverAddress == m.relayClient.connectionURL {
|
||||
go func(client *Client) {
|
||||
m.reconnectGuard.StartReconnectTrys(m.ctx, client)
|
||||
}(m.relayClient)
|
||||
go m.reconnectGuard.StartReconnectTrys(m.ctx, m.relayClient)
|
||||
}
|
||||
m.relayClientMu.Unlock()
|
||||
|
||||
@@ -352,6 +350,12 @@ func (m *Manager) startCleanupLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) getClientLen() int {
|
||||
m.relayClientsMutex.Lock()
|
||||
defer m.relayClientsMutex.Unlock()
|
||||
return len(m.relayClients)
|
||||
}
|
||||
|
||||
func (m *Manager) cleanUpUnusedRelays() {
|
||||
m.relayClientsMutex.Lock()
|
||||
defer m.relayClientsMutex.Unlock()
|
||||
|
||||
@@ -292,8 +292,8 @@ func TestForeginAutoClose(t *testing.T) {
|
||||
timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second
|
||||
t.Logf("waiting for relay cleanup: %s", timeout)
|
||||
time.Sleep(timeout)
|
||||
if len(mgr.relayClients) != 0 {
|
||||
t.Errorf("expected 0, got %d", len(mgr.relayClients))
|
||||
if mgr.getClientLen() != 0 {
|
||||
t.Errorf("expected 0, got %d", mgr.getClientLen())
|
||||
}
|
||||
|
||||
t.Logf("closing manager")
|
||||
@@ -301,7 +301,7 @@ func TestForeginAutoClose(t *testing.T) {
|
||||
|
||||
func TestAutoReconnect(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
reconnectingTimeout = 2 * time.Second
|
||||
setReconnectingTimeout(2 * time.Second)
|
||||
|
||||
srvCfg := server.ListenerConfig{
|
||||
Address: "localhost:1234",
|
||||
@@ -362,7 +362,7 @@ func TestAutoReconnect(t *testing.T) {
|
||||
}
|
||||
|
||||
log.Infof("waiting for reconnection")
|
||||
time.Sleep(reconnectingTimeout + 1*time.Second)
|
||||
time.Sleep(getReconnectingTimeout() + 1*time.Second)
|
||||
|
||||
log.Infof("reopent the connection")
|
||||
_, err = clientAlice.OpenConn(ctx, ra, "bob")
|
||||
|
||||
@@ -2,14 +2,26 @@ package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
heartbeatTimeout = healthCheckInterval + 10*time.Second
|
||||
)
|
||||
var heartbeatTimeout = getHealthCheckInterval() + 10*time.Second
|
||||
var mux sync.Mutex
|
||||
|
||||
func getHeartBeatTimeout() time.Duration {
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
return heartbeatTimeout
|
||||
}
|
||||
|
||||
func setHeartBeatTimeout(interval time.Duration) {
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
heartbeatTimeout = interval
|
||||
}
|
||||
|
||||
// Receiver is a healthcheck receiver
|
||||
// It will listen for heartbeat and check if the heartbeat is not received in a certain time
|
||||
@@ -56,7 +68,7 @@ func (r *Receiver) Stop() {
|
||||
}
|
||||
|
||||
func (r *Receiver) waitForHealthcheck() {
|
||||
ticker := time.NewTicker(heartbeatTimeout)
|
||||
ticker := time.NewTicker(getHeartBeatTimeout())
|
||||
defer ticker.Stop()
|
||||
defer r.ctxCancel()
|
||||
defer close(r.OnTimeout)
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
func TestNewReceiver(t *testing.T) {
|
||||
heartbeatTimeout = 5 * time.Second
|
||||
setHeartBeatTimeout(5 * time.Second)
|
||||
r := NewReceiver(log.WithContext(context.Background()))
|
||||
|
||||
select {
|
||||
@@ -23,7 +23,7 @@ func TestNewReceiver(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewReceiverNotReceive(t *testing.T) {
|
||||
heartbeatTimeout = 1 * time.Second
|
||||
setHeartBeatTimeout(1 * time.Second)
|
||||
r := NewReceiver(log.WithContext(context.Background()))
|
||||
|
||||
select {
|
||||
@@ -34,7 +34,7 @@ func TestNewReceiverNotReceive(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewReceiverAck(t *testing.T) {
|
||||
heartbeatTimeout = 2 * time.Second
|
||||
setHeartBeatTimeout(2 * time.Second)
|
||||
r := NewReceiver(log.WithContext(context.Background()))
|
||||
|
||||
r.Heartbeat()
|
||||
@@ -59,13 +59,13 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
|
||||
|
||||
for _, tc := range testsCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
originalInterval := healthCheckInterval
|
||||
originalTimeout := heartbeatTimeout
|
||||
healthCheckInterval = 1 * time.Second
|
||||
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
|
||||
originalInterval := getHealthCheckInterval()
|
||||
originalTimeout := getHeartBeatTimeout()
|
||||
setHealthCheckInterval(1 * time.Second)
|
||||
setHeartBeatTimeout(getHealthCheckInterval() + 500*time.Millisecond)
|
||||
defer func() {
|
||||
healthCheckInterval = originalInterval
|
||||
heartbeatTimeout = originalTimeout
|
||||
setHealthCheckInterval(originalInterval)
|
||||
setHeartBeatTimeout(originalTimeout)
|
||||
}()
|
||||
//nolint:tenv
|
||||
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
|
||||
@@ -73,7 +73,7 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
|
||||
|
||||
receiver := NewReceiver(log.WithField("test_name", tc.name))
|
||||
|
||||
testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
|
||||
testTimeout := getHeartBeatTimeout()*time.Duration(tc.threshold) + getHealthCheckInterval()
|
||||
|
||||
if tc.resetCounterOnce {
|
||||
receiver.Heartbeat()
|
||||
|
||||
@@ -19,6 +19,30 @@ var (
|
||||
healthCheckTimeout = 20 * time.Second
|
||||
)
|
||||
|
||||
func getHealthCheckInterval() time.Duration {
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
return healthCheckInterval
|
||||
}
|
||||
|
||||
func setHealthCheckInterval(interval time.Duration) {
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
healthCheckInterval = interval
|
||||
}
|
||||
|
||||
func getHealthCheckTimeout() time.Duration {
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
return healthCheckTimeout
|
||||
}
|
||||
|
||||
func setHealthCheckTimeout(timeout time.Duration) {
|
||||
mux.Lock()
|
||||
defer mux.Unlock()
|
||||
healthCheckTimeout = timeout
|
||||
}
|
||||
|
||||
// Sender is a healthcheck sender
|
||||
// It will send healthcheck signal to the receiver
|
||||
// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
|
||||
@@ -57,7 +81,7 @@ func (hc *Sender) OnHCResponse() {
|
||||
}
|
||||
|
||||
func (hc *Sender) StartHealthCheck(ctx context.Context) {
|
||||
ticker := time.NewTicker(healthCheckInterval)
|
||||
ticker := time.NewTicker(getHealthCheckInterval())
|
||||
defer ticker.Stop()
|
||||
|
||||
timeoutTicker := time.NewTicker(hc.getTimeoutTime())
|
||||
@@ -94,7 +118,7 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (hc *Sender) getTimeoutTime() time.Duration {
|
||||
return healthCheckInterval + healthCheckTimeout
|
||||
return getHealthCheckInterval() + getHealthCheckTimeout()
|
||||
}
|
||||
|
||||
func getAttemptThresholdFromEnv() int {
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// override the health check interval to speed up the test
|
||||
healthCheckInterval = 2 * time.Second
|
||||
healthCheckTimeout = 100 * time.Millisecond
|
||||
setHealthCheckInterval(2 * time.Second)
|
||||
setHealthCheckTimeout(100 * time.Millisecond)
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
@@ -32,7 +32,7 @@ func TestNewHealthPeriod(t *testing.T) {
|
||||
hc.OnHCResponse()
|
||||
case <-hc.Timeout:
|
||||
t.Fatalf("health check is timed out")
|
||||
case <-time.After(healthCheckInterval + 100*time.Millisecond):
|
||||
case <-time.After(getHealthCheckInterval() + 100*time.Millisecond):
|
||||
t.Fatalf("health check not received")
|
||||
}
|
||||
}
|
||||
@@ -46,7 +46,7 @@ func TestNewHealthFailed(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-hc.Timeout:
|
||||
case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond):
|
||||
case <-time.After(getHealthCheckInterval() + getHealthCheckTimeout() + 100*time.Millisecond):
|
||||
t.Fatalf("health check is not timed out")
|
||||
}
|
||||
}
|
||||
@@ -89,7 +89,7 @@ func TestTimeoutReset(t *testing.T) {
|
||||
hc.OnHCResponse()
|
||||
case <-hc.Timeout:
|
||||
t.Fatalf("health check is timed out")
|
||||
case <-time.After(healthCheckInterval + 100*time.Millisecond):
|
||||
case <-time.After(getHealthCheckInterval() + 100*time.Millisecond):
|
||||
t.Fatalf("health check not received")
|
||||
}
|
||||
}
|
||||
@@ -118,13 +118,13 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
|
||||
|
||||
for _, tc := range testsCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
originalInterval := healthCheckInterval
|
||||
originalTimeout := healthCheckTimeout
|
||||
healthCheckInterval = 1 * time.Second
|
||||
healthCheckTimeout = 500 * time.Millisecond
|
||||
originalInterval := getHealthCheckInterval()
|
||||
originalTimeout := getHealthCheckTimeout()
|
||||
setHealthCheckInterval(1 * time.Second)
|
||||
setHealthCheckTimeout(500 * time.Millisecond)
|
||||
defer func() {
|
||||
healthCheckInterval = originalInterval
|
||||
healthCheckTimeout = originalTimeout
|
||||
setHealthCheckInterval(originalInterval)
|
||||
setHealthCheckTimeout(originalTimeout)
|
||||
}()
|
||||
|
||||
//nolint:tenv
|
||||
@@ -155,7 +155,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval
|
||||
testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + getHealthCheckInterval()
|
||||
|
||||
select {
|
||||
case <-sender.Timeout:
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -18,16 +17,11 @@ type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte
|
||||
// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn.
|
||||
type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error
|
||||
|
||||
// ListenerAddressRemoveHookFunc defines the function signature for hooks called when addresses are removed.
|
||||
type ListenerAddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error
|
||||
|
||||
var (
|
||||
listenerWriteHooksMutex sync.RWMutex
|
||||
listenerWriteHooks []ListenerWriteHookFunc
|
||||
listenerCloseHooksMutex sync.RWMutex
|
||||
listenerCloseHooks []ListenerCloseHookFunc
|
||||
listenerAddressRemoveHooksMutex sync.RWMutex
|
||||
listenerAddressRemoveHooks []ListenerAddressRemoveHookFunc
|
||||
listenerWriteHooksMutex sync.RWMutex
|
||||
listenerWriteHooks []ListenerWriteHookFunc
|
||||
listenerCloseHooksMutex sync.RWMutex
|
||||
listenerCloseHooks []ListenerCloseHookFunc
|
||||
)
|
||||
|
||||
// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent.
|
||||
@@ -44,14 +38,7 @@ func AddListenerCloseHook(hook ListenerCloseHookFunc) {
|
||||
listenerCloseHooks = append(listenerCloseHooks, hook)
|
||||
}
|
||||
|
||||
// AddListenerAddressRemoveHook allows adding a new hook to be executed when an address is removed.
|
||||
func AddListenerAddressRemoveHook(hook ListenerAddressRemoveHookFunc) {
|
||||
listenerAddressRemoveHooksMutex.Lock()
|
||||
defer listenerAddressRemoveHooksMutex.Unlock()
|
||||
listenerAddressRemoveHooks = append(listenerAddressRemoveHooks, hook)
|
||||
}
|
||||
|
||||
// RemoveListenerHooks removes all listener hooks.
|
||||
// RemoveListenerHooks removes all dialer hooks.
|
||||
func RemoveListenerHooks() {
|
||||
listenerWriteHooksMutex.Lock()
|
||||
defer listenerWriteHooksMutex.Unlock()
|
||||
@@ -60,10 +47,6 @@ func RemoveListenerHooks() {
|
||||
listenerCloseHooksMutex.Lock()
|
||||
defer listenerCloseHooksMutex.Unlock()
|
||||
listenerCloseHooks = nil
|
||||
|
||||
listenerAddressRemoveHooksMutex.Lock()
|
||||
defer listenerAddressRemoveHooksMutex.Unlock()
|
||||
listenerAddressRemoveHooks = nil
|
||||
}
|
||||
|
||||
// ListenPacket listens on the network address and returns a PacketConn
|
||||
@@ -78,7 +61,6 @@ func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address stri
|
||||
return nil, fmt.Errorf("listen packet: %w", err)
|
||||
}
|
||||
connID := GenerateConnID()
|
||||
|
||||
return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil
|
||||
}
|
||||
|
||||
@@ -120,45 +102,6 @@ func (c *UDPConn) Close() error {
|
||||
return closeConn(c.ID, c.UDPConn)
|
||||
}
|
||||
|
||||
// WrapUDPConn wraps an existing *net.UDPConn with nbnet functionality
|
||||
func WrapUDPConn(conn *net.UDPConn) *UDPConn {
|
||||
return &UDPConn{
|
||||
UDPConn: conn,
|
||||
ID: GenerateConnID(),
|
||||
seenAddrs: &sync.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveAddress removes an address from the seen cache and triggers removal hooks.
|
||||
func (c *UDPConn) RemoveAddress(addr string) {
|
||||
if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists {
|
||||
return
|
||||
}
|
||||
|
||||
ipStr, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
log.Errorf("Error splitting IP address and port: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ipAddr, err := netip.ParseAddr(ipStr)
|
||||
if err != nil {
|
||||
log.Errorf("Error parsing IP address %s: %v", ipStr, err)
|
||||
return
|
||||
}
|
||||
|
||||
prefix := netip.PrefixFrom(ipAddr, ipAddr.BitLen())
|
||||
|
||||
listenerAddressRemoveHooksMutex.RLock()
|
||||
defer listenerAddressRemoveHooksMutex.RUnlock()
|
||||
|
||||
for _, hook := range listenerAddressRemoveHooks {
|
||||
if err := hook(c.ID, prefix); err != nil {
|
||||
log.Errorf("Error executing listener address remove hook: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) {
|
||||
// Lookup the address in the seenAddrs map to avoid calling the hooks for every write
|
||||
if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded {
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// WrapUDPConn on iOS just returns the original connection since iOS handles its own networking
|
||||
func WrapUDPConn(conn *net.UDPConn) *net.UDPConn {
|
||||
return conn
|
||||
}
|
||||
Reference in New Issue
Block a user