[client] Use stdnet with a context to avoid DNS deadlocks (#4781)

This commit is contained in:
Viktor Liu
2025-11-13 20:16:45 +01:00
committed by GitHub
parent 3176b53968
commit 9cc9462cd5
15 changed files with 153 additions and 39 deletions

View File

@@ -1,6 +1,7 @@
package iface
import (
"context"
"fmt"
"net"
"net/netip"
@@ -9,13 +10,13 @@ import (
"time"
"github.com/google/uuid"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
// keep darwin compatibility
@@ -40,7 +41,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
addr := "100.64.0.1/8"
wgPort := 33100
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -123,7 +124,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) {
func Test_CreateInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1)
wgIP := "10.99.99.1/32"
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -166,7 +167,7 @@ func Test_Close(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
wgIP := "10.99.99.2/32"
wgPort := 33100
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -211,7 +212,7 @@ func TestRecreation(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
wgIP := "10.99.99.2/32"
wgPort := 33100
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -284,7 +285,7 @@ func Test_ConfigureInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
wgIP := "10.99.99.5/30"
wgPort := 33100
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -339,7 +340,7 @@ func Test_ConfigureInterface(t *testing.T) {
func Test_UpdatePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.9/30"
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -409,7 +410,7 @@ func Test_UpdatePeer(t *testing.T) {
func Test_RemovePeer(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4)
wgIP := "10.99.99.13/30"
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -471,7 +472,7 @@ func Test_ConnectPeers(t *testing.T) {
peer2wgPort := 33200
keepAlive := 1 * time.Second
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -514,7 +515,7 @@ func Test_ConnectPeers(t *testing.T) {
guid = fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid)
newNet, err = stdnet.NewNet()
newNet, err = stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -1,6 +1,7 @@
package udpmux
import (
"context"
"fmt"
"io"
"net"
@@ -12,8 +13,9 @@ import (
"github.com/pion/logging"
"github.com/pion/stun/v3"
"github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
/*
@@ -199,7 +201,7 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() {
if len(networks) > 0 {
if m.params.Net == nil {
var err error
if m.params.Net, err = stdnet.NewNet(); err != nil {
if m.params.Net, err = stdnet.NewNet(context.Background(), nil); err != nil {
m.params.Logger.Errorf("failed to get create network: %v", err)
}
}

View File

@@ -335,7 +335,7 @@ func TestUpdateDNSServer(t *testing.T) {
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
privKey, _ := wgtypes.GenerateKey()
newNet, err := stdnet.NewNet(nil)
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -434,7 +434,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet([]string{"utun2301"})
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
if err != nil {
t.Errorf("create stdnet: %v", err)
return
@@ -915,7 +915,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet([]string{"utun2301"})
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
if err != nil {
t.Fatalf("create stdnet: %v", err)
return nil, err

View File

@@ -7,5 +7,5 @@ import (
)
func (e *Engine) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNet(e.config.IFaceBlackList)
return stdnet.NewNet(e.clientCtx, e.config.IFaceBlackList)
}

View File

@@ -3,5 +3,5 @@ package internal
import "github.com/netbirdio/netbird/client/internal/stdnet"
func (e *Engine) newStdNet() (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
return stdnet.NewNetWithDiscover(e.clientCtx, e.mobileDep.IFaceDiscover, e.config.IFaceBlackList)
}

View File

@@ -14,7 +14,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/pion/transport/v3/stdnet"
"github.com/netbirdio/netbird/client/internal/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -774,7 +774,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -977,7 +977,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -78,7 +78,7 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) {
log.Debugf("Gathering ICE candidates")
agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
agent, err := icemaker.NewAgent(ctx, cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd)
if err != nil {
return false, fmt.Errorf("create ICE agent: %w", err)
}

View File

@@ -1,6 +1,7 @@
package ice
import (
"context"
"sync"
"time"
@@ -49,13 +50,13 @@ func (a *ThreadSafeAgent) Close() error {
return err
}
func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
iceFailedTimeout := iceFailedTimeout()
iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
transportNet, err := newStdNet(iFaceDiscover, config.InterfaceBlackList)
transportNet, err := newStdNet(ctx, iFaceDiscover, config.InterfaceBlackList)
if err != nil {
log.Errorf("failed to create pion's stdnet: %s", err)
}

View File

@@ -3,9 +3,11 @@
package ice
import (
"context"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
return stdnet.NewNet(ifaceBlacklist)
func newStdNet(ctx context.Context, _ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
return stdnet.NewNet(ctx, ifaceBlacklist)
}

View File

@@ -1,7 +1,11 @@
package ice
import "github.com/netbirdio/netbird/client/internal/stdnet"
import (
"context"
func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist)
"github.com/netbirdio/netbird/client/internal/stdnet"
)
func newStdNet(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) {
return stdnet.NewNetWithDiscover(ctx, iFaceDiscover, ifaceBlacklist)
}

View File

@@ -209,7 +209,7 @@ func (w *WorkerICE) Close() {
}
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil {
return nil, fmt.Errorf("create agent: %w", err)
}

View File

@@ -197,7 +197,7 @@ func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr stri
}
}()
net, err := stdnet.NewNet(nil)
net, err := stdnet.NewNet(ctx, nil)
if err != nil {
probeErr = fmt.Errorf("new net: %w", err)
return
@@ -286,7 +286,7 @@ func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr stri
}
}()
net, err := stdnet.NewNet(nil)
net, err := stdnet.NewNet(ctx, nil)
if err != nil {
probeErr = fmt.Errorf("new net: %w", err)
return

View File

@@ -6,7 +6,7 @@ import (
"net/netip"
"testing"
"github.com/pion/transport/v3/stdnet"
"github.com/netbirdio/netbird/client/internal/stdnet"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/stretchr/testify/require"
@@ -403,7 +403,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
peerPrivateKey, _ := wgtypes.GeneratePrivateKey()
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -15,7 +15,7 @@ import (
"syscall"
"testing"
"github.com/pion/transport/v3/stdnet"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -436,7 +436,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
newNet, err := stdnet.NewNet()
newNet, err := stdnet.NewNet(context.Background(), nil)
require.NoError(t, err)
opts := iface.WGIFaceOpts{

View File

@@ -4,17 +4,28 @@
package stdnet
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"slices"
"strconv"
"sync"
"time"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet"
"github.com/netbirdio/netbird/client/iface/netstack"
)
const updateInterval = 30 * time.Second
const (
updateInterval = 30 * time.Second
dnsResolveTimeout = 30 * time.Second
)
var errNoSuitableAddress = errors.New("no suitable address found")
// Net is an implementation of the net.Net interface
// based on functions of the standard net package.
@@ -28,12 +39,19 @@ type Net struct {
// mu is shared between interfaces and lastUpdate
mu sync.Mutex
// ctx is the context for network operations that supports cancellation
ctx context.Context
}
// NewNetWithDiscover creates a new StdNet instance.
func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
func NewNetWithDiscover(ctx context.Context, iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
if ctx == nil {
ctx = context.Background()
}
n := &Net{
interfaceFilter: InterfaceFilter(disallowList),
ctx: ctx,
}
// current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client
// so in android cli use pionDiscover
@@ -46,14 +64,64 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri
}
// NewNet creates a new StdNet instance.
func NewNet(disallowList []string) (*Net, error) {
func NewNet(ctx context.Context, disallowList []string) (*Net, error) {
if ctx == nil {
ctx = context.Background()
}
n := &Net{
iFaceDiscover: pionDiscover{},
interfaceFilter: InterfaceFilter(disallowList),
ctx: ctx,
}
return n, n.UpdateInterfaces()
}
// resolveAddr performs DNS resolution with context support and timeout.
func (n *Net) resolveAddr(network, address string) (netip.AddrPort, error) {
host, portStr, err := net.SplitHostPort(address)
if err != nil {
return netip.AddrPort{}, err
}
port, err := strconv.Atoi(portStr)
if err != nil {
return netip.AddrPort{}, fmt.Errorf("invalid port: %w", err)
}
if port < 0 || port > 65535 {
return netip.AddrPort{}, fmt.Errorf("invalid port: %d", port)
}
ipNet := "ip"
switch network {
case "tcp4", "udp4":
ipNet = "ip4"
case "tcp6", "udp6":
ipNet = "ip6"
}
if host == "" {
addr := netip.IPv4Unspecified()
if ipNet == "ip6" {
addr = netip.IPv6Unspecified()
}
return netip.AddrPortFrom(addr, uint16(port)), nil
}
ctx, cancel := context.WithTimeout(n.ctx, dnsResolveTimeout)
defer cancel()
addrs, err := net.DefaultResolver.LookupNetIP(ctx, ipNet, host)
if err != nil {
return netip.AddrPort{}, err
}
if len(addrs) == 0 {
return netip.AddrPort{}, errNoSuitableAddress
}
return netip.AddrPortFrom(addrs[0], uint16(port)), nil
}
// UpdateInterfaces updates the internal list of network interfaces
// and associated addresses filtering them by name.
// The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
@@ -137,3 +205,39 @@ func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.I
}
return result
}
// ResolveUDPAddr resolves UDP addresses with context support and timeout.
func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) {
switch network {
case "udp", "udp4", "udp6":
case "":
network = "udp"
default:
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
}
addrPort, err := n.resolveAddr(network, address)
if err != nil {
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.UDPAddr{IP: nil}, Err: err}
}
return net.UDPAddrFromAddrPort(addrPort), nil
}
// ResolveTCPAddr resolves TCP addresses with context support and timeout.
func (n *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) {
switch network {
case "tcp", "tcp4", "tcp6":
case "":
network = "tcp"
default:
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
}
addrPort, err := n.resolveAddr(network, address)
if err != nil {
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.TCPAddr{IP: nil}, Err: err}
}
return net.TCPAddrFromAddrPort(addrPort), nil
}