mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
[client] Use stdnet with a context to avoid DNS deadlocks (#4781)
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package iface
|
package iface
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -9,13 +10,13 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/pion/transport/v3/stdnet"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
)
|
)
|
||||||
|
|
||||||
// keep darwin compatibility
|
// keep darwin compatibility
|
||||||
@@ -40,7 +41,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
|
|||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||||
addr := "100.64.0.1/8"
|
addr := "100.64.0.1/8"
|
||||||
wgPort := 33100
|
wgPort := 33100
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -123,7 +124,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
|
|||||||
func Test_CreateInterface(t *testing.T) {
|
func Test_CreateInterface(t *testing.T) {
|
||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
|
||||||
wgIP := "10.99.99.1/32"
|
wgIP := "10.99.99.1/32"
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -166,7 +167,7 @@ func Test_Close(t *testing.T) {
|
|||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
||||||
wgIP := "10.99.99.2/32"
|
wgIP := "10.99.99.2/32"
|
||||||
wgPort := 33100
|
wgPort := 33100
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -211,7 +212,7 @@ func TestRecreation(t *testing.T) {
|
|||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
|
||||||
wgIP := "10.99.99.2/32"
|
wgIP := "10.99.99.2/32"
|
||||||
wgPort := 33100
|
wgPort := 33100
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -284,7 +285,7 @@ func Test_ConfigureInterface(t *testing.T) {
|
|||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
|
||||||
wgIP := "10.99.99.5/30"
|
wgIP := "10.99.99.5/30"
|
||||||
wgPort := 33100
|
wgPort := 33100
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -339,7 +340,7 @@ func Test_ConfigureInterface(t *testing.T) {
|
|||||||
func Test_UpdatePeer(t *testing.T) {
|
func Test_UpdatePeer(t *testing.T) {
|
||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||||
wgIP := "10.99.99.9/30"
|
wgIP := "10.99.99.9/30"
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -409,7 +410,7 @@ func Test_UpdatePeer(t *testing.T) {
|
|||||||
func Test_RemovePeer(t *testing.T) {
|
func Test_RemovePeer(t *testing.T) {
|
||||||
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
|
||||||
wgIP := "10.99.99.13/30"
|
wgIP := "10.99.99.13/30"
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -471,7 +472,7 @@ func Test_ConnectPeers(t *testing.T) {
|
|||||||
peer2wgPort := 33200
|
peer2wgPort := 33200
|
||||||
|
|
||||||
keepAlive := 1 * time.Second
|
keepAlive := 1 * time.Second
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -514,7 +515,7 @@ func Test_ConnectPeers(t *testing.T) {
|
|||||||
guid = fmt.Sprintf("{%s}", uuid.New().String())
|
guid = fmt.Sprintf("{%s}", uuid.New().String())
|
||||||
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
device.CustomWindowsGUIDString = strings.ToLower(guid)
|
||||||
|
|
||||||
newNet, err = stdnet.NewNet()
|
newNet, err = stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package udpmux
|
package udpmux
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@@ -12,8 +13,9 @@ import (
|
|||||||
"github.com/pion/logging"
|
"github.com/pion/logging"
|
||||||
"github.com/pion/stun/v3"
|
"github.com/pion/stun/v3"
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
"github.com/pion/transport/v3/stdnet"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
)
|
)
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@@ -199,7 +201,7 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() {
|
|||||||
if len(networks) > 0 {
|
if len(networks) > 0 {
|
||||||
if m.params.Net == nil {
|
if m.params.Net == nil {
|
||||||
var err error
|
var err error
|
||||||
if m.params.Net, err = stdnet.NewNet(); err != nil {
|
if m.params.Net, err = stdnet.NewNet(context.Background(), nil); err != nil {
|
||||||
m.params.Logger.Errorf("failed to get create network: %v", err)
|
m.params.Logger.Errorf("failed to get create network: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -335,7 +335,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
for n, testCase := range testCases {
|
for n, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
privKey, _ := wgtypes.GenerateKey()
|
privKey, _ := wgtypes.GenerateKey()
|
||||||
newNet, err := stdnet.NewNet(nil)
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -434,7 +434,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||||
|
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
newNet, err := stdnet.NewNet([]string{"utun2301"})
|
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create stdnet: %v", err)
|
t.Errorf("create stdnet: %v", err)
|
||||||
return
|
return
|
||||||
@@ -915,7 +915,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
|||||||
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||||
|
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
newNet, err := stdnet.NewNet([]string{"utun2301"})
|
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create stdnet: %v", err)
|
t.Fatalf("create stdnet: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -7,5 +7,5 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||||
return stdnet.NewNet(e.config.IFaceBlackList)
|
return stdnet.NewNet(e.clientCtx, e.config.IFaceBlackList)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,5 +3,5 @@ package internal
|
|||||||
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
|
||||||
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
func (e *Engine) newStdNet() (*stdnet.Net, error) {
|
||||||
return stdnet.NewNetWithDiscover(e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
|
return stdnet.NewNetWithDiscover(e.clientCtx, e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/pion/transport/v3/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -774,7 +774,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -977,7 +977,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
|
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
|
|||||||
func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) {
|
func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) {
|
||||||
log.Debugf("Gathering ICE candidates")
|
log.Debugf("Gathering ICE candidates")
|
||||||
|
|
||||||
agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
|
agent, err := icemaker.NewAgent(ctx, cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("create ICE agent: %w", err)
|
return false, fmt.Errorf("create ICE agent: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package ice
|
package ice
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -49,13 +50,13 @@ func (a *ThreadSafeAgent) Close() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
|
||||||
iceKeepAlive := iceKeepAlive()
|
iceKeepAlive := iceKeepAlive()
|
||||||
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
iceDisconnectedTimeout := iceDisconnectedTimeout()
|
||||||
iceFailedTimeout := iceFailedTimeout()
|
iceFailedTimeout := iceFailedTimeout()
|
||||||
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
|
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
|
||||||
|
|
||||||
transportNet, err := newStdNet(iFaceDiscover, config.InterfaceBlackList)
|
transportNet, err := newStdNet(ctx, iFaceDiscover, config.InterfaceBlackList)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to create pion's stdnet: %s", err)
|
log.Errorf("failed to create pion's stdnet: %s", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,9 +3,11 @@
|
|||||||
package ice
|
package ice
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
func newStdNet(ctx context.Context, _ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
||||||
return stdnet.NewNet(ifaceBlacklist)
|
return stdnet.NewNet(ctx, ifaceBlacklist)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
package ice
|
package ice
|
||||||
|
|
||||||
import "github.com/netbirdio/netbird/client/internal/stdnet"
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist)
|
)
|
||||||
|
|
||||||
|
func newStdNet(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
|
||||||
|
return stdnet.NewNetWithDiscover(ctx, iFaceDiscover, ifaceBlacklist)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -209,7 +209,7 @@ func (w *WorkerICE) Close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
|
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
|
||||||
agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
|
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create agent: %w", err)
|
return nil, fmt.Errorf("create agent: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr stri
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
net, err := stdnet.NewNet(nil)
|
net, err := stdnet.NewNet(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
probeErr = fmt.Errorf("new net: %w", err)
|
probeErr = fmt.Errorf("new net: %w", err)
|
||||||
return
|
return
|
||||||
@@ -286,7 +286,7 @@ func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr stri
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
net, err := stdnet.NewNet(nil)
|
net, err := stdnet.NewNet(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
probeErr = fmt.Errorf("new net: %w", err)
|
probeErr = fmt.Errorf("new net: %w", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pion/transport/v3/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -403,7 +403,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
for n, testCase := range testCases {
|
for n, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pion/transport/v3/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
@@ -436,7 +436,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
|
|||||||
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
|
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
newNet, err := stdnet.NewNet()
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
opts := iface.WGIFaceOpts{
|
opts := iface.WGIFaceOpts{
|
||||||
|
|||||||
@@ -4,17 +4,28 @@
|
|||||||
package stdnet
|
package stdnet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
"github.com/pion/transport/v3/stdnet"
|
"github.com/pion/transport/v3/stdnet"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
)
|
)
|
||||||
|
|
||||||
const updateInterval = 30 * time.Second
|
const (
|
||||||
|
updateInterval = 30 * time.Second
|
||||||
|
dnsResolveTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
var errNoSuitableAddress = errors.New("no suitable address found")
|
||||||
|
|
||||||
// Net is an implementation of the net.Net interface
|
// Net is an implementation of the net.Net interface
|
||||||
// based on functions of the standard net package.
|
// based on functions of the standard net package.
|
||||||
@@ -28,12 +39,19 @@ type Net struct {
|
|||||||
|
|
||||||
// mu is shared between interfaces and lastUpdate
|
// mu is shared between interfaces and lastUpdate
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
|
||||||
|
// ctx is the context for network operations that supports cancellation
|
||||||
|
ctx context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNetWithDiscover creates a new StdNet instance.
|
// NewNetWithDiscover creates a new StdNet instance.
|
||||||
func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
|
func NewNetWithDiscover(ctx context.Context, iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
n := &Net{
|
n := &Net{
|
||||||
interfaceFilter: InterfaceFilter(disallowList),
|
interfaceFilter: InterfaceFilter(disallowList),
|
||||||
|
ctx: ctx,
|
||||||
}
|
}
|
||||||
// current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client
|
// current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client
|
||||||
// so in android cli use pionDiscover
|
// so in android cli use pionDiscover
|
||||||
@@ -46,14 +64,64 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewNet creates a new StdNet instance.
|
// NewNet creates a new StdNet instance.
|
||||||
func NewNet(disallowList []string) (*Net, error) {
|
func NewNet(ctx context.Context, disallowList []string) (*Net, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
n := &Net{
|
n := &Net{
|
||||||
iFaceDiscover: pionDiscover{},
|
iFaceDiscover: pionDiscover{},
|
||||||
interfaceFilter: InterfaceFilter(disallowList),
|
interfaceFilter: InterfaceFilter(disallowList),
|
||||||
|
ctx: ctx,
|
||||||
}
|
}
|
||||||
return n, n.UpdateInterfaces()
|
return n, n.UpdateInterfaces()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// resolveAddr performs DNS resolution with context support and timeout.
|
||||||
|
func (n *Net) resolveAddr(network, address string) (netip.AddrPort, error) {
|
||||||
|
host, portStr, err := net.SplitHostPort(address)
|
||||||
|
if err != nil {
|
||||||
|
return netip.AddrPort{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil {
|
||||||
|
return netip.AddrPort{}, fmt.Errorf("invalid port: %w", err)
|
||||||
|
}
|
||||||
|
if port < 0 || port > 65535 {
|
||||||
|
return netip.AddrPort{}, fmt.Errorf("invalid port: %d", port)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipNet := "ip"
|
||||||
|
switch network {
|
||||||
|
case "tcp4", "udp4":
|
||||||
|
ipNet = "ip4"
|
||||||
|
case "tcp6", "udp6":
|
||||||
|
ipNet = "ip6"
|
||||||
|
}
|
||||||
|
|
||||||
|
if host == "" {
|
||||||
|
addr := netip.IPv4Unspecified()
|
||||||
|
if ipNet == "ip6" {
|
||||||
|
addr = netip.IPv6Unspecified()
|
||||||
|
}
|
||||||
|
return netip.AddrPortFrom(addr, uint16(port)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(n.ctx, dnsResolveTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
addrs, err := net.DefaultResolver.LookupNetIP(ctx, ipNet, host)
|
||||||
|
if err != nil {
|
||||||
|
return netip.AddrPort{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(addrs) == 0 {
|
||||||
|
return netip.AddrPort{}, errNoSuitableAddress
|
||||||
|
}
|
||||||
|
|
||||||
|
return netip.AddrPortFrom(addrs[0], uint16(port)), nil
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateInterfaces updates the internal list of network interfaces
|
// UpdateInterfaces updates the internal list of network interfaces
|
||||||
// and associated addresses filtering them by name.
|
// and associated addresses filtering them by name.
|
||||||
// The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
|
// The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
|
||||||
@@ -137,3 +205,39 @@ func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.I
|
|||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResolveUDPAddr resolves UDP addresses with context support and timeout.
|
||||||
|
func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) {
|
||||||
|
switch network {
|
||||||
|
case "udp", "udp4", "udp6":
|
||||||
|
case "":
|
||||||
|
network = "udp"
|
||||||
|
default:
|
||||||
|
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
|
||||||
|
}
|
||||||
|
|
||||||
|
addrPort, err := n.resolveAddr(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.UDPAddr{IP: nil}, Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return net.UDPAddrFromAddrPort(addrPort), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveTCPAddr resolves TCP addresses with context support and timeout.
|
||||||
|
func (n *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) {
|
||||||
|
switch network {
|
||||||
|
case "tcp", "tcp4", "tcp6":
|
||||||
|
case "":
|
||||||
|
network = "tcp"
|
||||||
|
default:
|
||||||
|
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
|
||||||
|
}
|
||||||
|
|
||||||
|
addrPort, err := n.resolveAddr(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.TCPAddr{IP: nil}, Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
return net.TCPAddrFromAddrPort(addrPort), nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user