[client] Use stdnet with a context to avoid DNS deadlocks (#4781)

This commit is contained in:
Viktor Liu
2025-11-13 20:16:45 +01:00
committed by GitHub
parent 3176b53968
commit 9cc9462cd5
15 changed files with 153 additions and 39 deletions

View File

@@ -4,17 +4,28 @@
package stdnet
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"slices"
"strconv"
"sync"
"time"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/pion/transport/v3"
"github.com/pion/transport/v3/stdnet"
"github.com/netbirdio/netbird/client/iface/netstack"
)
const updateInterval = 30 * time.Second
const (
updateInterval = 30 * time.Second
dnsResolveTimeout = 30 * time.Second
)
var errNoSuitableAddress = errors.New("no suitable address found")
// Net is an implementation of the net.Net interface
// based on functions of the standard net package.
@@ -28,12 +39,19 @@ type Net struct {
// mu is shared between interfaces and lastUpdate
mu sync.Mutex
// ctx is the context for network operations that supports cancellation
ctx context.Context
}
// NewNetWithDiscover creates a new StdNet instance.
func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
func NewNetWithDiscover(ctx context.Context, iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) {
if ctx == nil {
ctx = context.Background()
}
n := &Net{
interfaceFilter: InterfaceFilter(disallowList),
ctx: ctx,
}
// current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client
// so in android cli use pionDiscover
@@ -46,14 +64,64 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri
}
// NewNet creates a new StdNet instance.
func NewNet(disallowList []string) (*Net, error) {
func NewNet(ctx context.Context, disallowList []string) (*Net, error) {
if ctx == nil {
ctx = context.Background()
}
n := &Net{
iFaceDiscover: pionDiscover{},
interfaceFilter: InterfaceFilter(disallowList),
ctx: ctx,
}
return n, n.UpdateInterfaces()
}
// resolveAddr performs DNS resolution with context support and timeout.
func (n *Net) resolveAddr(network, address string) (netip.AddrPort, error) {
host, portStr, err := net.SplitHostPort(address)
if err != nil {
return netip.AddrPort{}, err
}
port, err := strconv.Atoi(portStr)
if err != nil {
return netip.AddrPort{}, fmt.Errorf("invalid port: %w", err)
}
if port < 0 || port > 65535 {
return netip.AddrPort{}, fmt.Errorf("invalid port: %d", port)
}
ipNet := "ip"
switch network {
case "tcp4", "udp4":
ipNet = "ip4"
case "tcp6", "udp6":
ipNet = "ip6"
}
if host == "" {
addr := netip.IPv4Unspecified()
if ipNet == "ip6" {
addr = netip.IPv6Unspecified()
}
return netip.AddrPortFrom(addr, uint16(port)), nil
}
ctx, cancel := context.WithTimeout(n.ctx, dnsResolveTimeout)
defer cancel()
addrs, err := net.DefaultResolver.LookupNetIP(ctx, ipNet, host)
if err != nil {
return netip.AddrPort{}, err
}
if len(addrs) == 0 {
return netip.AddrPort{}, errNoSuitableAddress
}
return netip.AddrPortFrom(addrs[0], uint16(port)), nil
}
// UpdateInterfaces updates the internal list of network interfaces
// and associated addresses filtering them by name.
// The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one
@@ -137,3 +205,39 @@ func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.I
}
return result
}
// ResolveUDPAddr resolves UDP addresses with context support and timeout.
func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) {
switch network {
case "udp", "udp4", "udp6":
case "":
network = "udp"
default:
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
}
addrPort, err := n.resolveAddr(network, address)
if err != nil {
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.UDPAddr{IP: nil}, Err: err}
}
return net.UDPAddrFromAddrPort(addrPort), nil
}
// ResolveTCPAddr resolves TCP addresses with context support and timeout.
func (n *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) {
switch network {
case "tcp", "tcp4", "tcp6":
case "":
network = "tcp"
default:
return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)}
}
addrPort, err := n.resolveAddr(network, address)
if err != nil {
return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.TCPAddr{IP: nil}, Err: err}
}
return net.TCPAddrFromAddrPort(addrPort), nil
}