diff --git a/client/iface/iface_test.go b/client/iface/iface_test.go index e890b30f3..6bbfeaa63 100644 --- a/client/iface/iface_test.go +++ b/client/iface/iface_test.go @@ -1,6 +1,7 @@ package iface import ( + "context" "fmt" "net" "net/netip" @@ -9,13 +10,13 @@ import ( "time" "github.com/google/uuid" - "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/internal/stdnet" ) // keep darwin compatibility @@ -40,7 +41,7 @@ func TestWGIface_UpdateAddr(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) addr := "100.64.0.1/8" wgPort := 33100 - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -123,7 +124,7 @@ func getIfaceAddrs(ifaceName string) ([]net.Addr, error) { func Test_CreateInterface(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+1) wgIP := "10.99.99.1/32" - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -166,7 +167,7 @@ func Test_Close(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2) wgIP := "10.99.99.2/32" wgPort := 33100 - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -211,7 +212,7 @@ func TestRecreation(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2) wgIP := "10.99.99.2/32" wgPort := 33100 - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -284,7 +285,7 @@ func Test_ConfigureInterface(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3) wgIP := "10.99.99.5/30" wgPort := 33100 - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -339,7 +340,7 @@ func Test_ConfigureInterface(t *testing.T) { func Test_UpdatePeer(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) wgIP := "10.99.99.9/30" - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -409,7 +410,7 @@ func Test_UpdatePeer(t *testing.T) { func Test_RemovePeer(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+4) wgIP := "10.99.99.13/30" - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -471,7 +472,7 @@ func Test_ConnectPeers(t *testing.T) { peer2wgPort := 33200 keepAlive := 1 * time.Second - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -514,7 +515,7 @@ func Test_ConnectPeers(t *testing.T) { guid = fmt.Sprintf("{%s}", uuid.New().String()) device.CustomWindowsGUIDString = strings.ToLower(guid) - newNet, err = stdnet.NewNet() + newNet, err = stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } diff --git a/client/iface/udpmux/mux.go b/client/iface/udpmux/mux.go index 319724926..c5d2de4a5 100644 --- a/client/iface/udpmux/mux.go +++ b/client/iface/udpmux/mux.go @@ -1,6 +1,7 @@ package udpmux import ( + "context" "fmt" "io" "net" @@ -12,8 +13,9 @@ import ( "github.com/pion/logging" "github.com/pion/stun/v3" "github.com/pion/transport/v3" - "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/stdnet" ) /* @@ -199,7 +201,7 @@ func (m *SingleSocketUDPMux) updateLocalAddresses() { if len(networks) > 0 { if m.params.Net == nil { var err error - if m.params.Net, err = stdnet.NewNet(); err != nil { + if m.params.Net, err = stdnet.NewNet(context.Background(), nil); err != nil { m.params.Logger.Errorf("failed to get create network: %v", err) } } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 451b83f92..d12070128 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -335,7 +335,7 @@ func TestUpdateDNSServer(t *testing.T) { for n, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { privKey, _ := wgtypes.GenerateKey() - newNet, err := stdnet.NewNet(nil) + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -434,7 +434,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { defer t.Setenv("NB_WG_KERNEL_DISABLED", ov) t.Setenv("NB_WG_KERNEL_DISABLED", "true") - newNet, err := stdnet.NewNet([]string{"utun2301"}) + newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"}) if err != nil { t.Errorf("create stdnet: %v", err) return @@ -915,7 +915,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { defer t.Setenv("NB_WG_KERNEL_DISABLED", ov) t.Setenv("NB_WG_KERNEL_DISABLED", "true") - newNet, err := stdnet.NewNet([]string{"utun2301"}) + newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"}) if err != nil { t.Fatalf("create stdnet: %v", err) return nil, err diff --git a/client/internal/engine_stdnet.go b/client/internal/engine_stdnet.go index 9e171b0b2..1ebb5779c 100644 --- a/client/internal/engine_stdnet.go +++ b/client/internal/engine_stdnet.go @@ -7,5 +7,5 @@ import ( ) func (e *Engine) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNet(e.config.IFaceBlackList) + return stdnet.NewNet(e.clientCtx, e.config.IFaceBlackList) } diff --git a/client/internal/engine_stdnet_android.go b/client/internal/engine_stdnet_android.go index 68a0ae719..de3c80bcf 100644 --- a/client/internal/engine_stdnet_android.go +++ b/client/internal/engine_stdnet_android.go @@ -3,5 +3,5 @@ package internal import "github.com/netbirdio/netbird/client/internal/stdnet" func (e *Engine) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNetWithDiscover(e.mobileDep.IFaceDiscover, e.config.IFaceBlackList) + return stdnet.NewNetWithDiscover(e.clientCtx, e.mobileDep.IFaceDiscover, e.config.IFaceBlackList) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 15ac0a947..d15a07f9d 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -14,7 +14,7 @@ import ( "github.com/golang/mock/gomock" "github.com/google/uuid" - "github.com/pion/transport/v3/stdnet" + "github.com/netbirdio/netbird/client/internal/stdnet" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -774,7 +774,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { MTU: iface.DefaultMTU, }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) engine.ctx = ctx - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -977,7 +977,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) engine.ctx = ctx - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } diff --git a/client/internal/peer/guard/ice_monitor.go b/client/internal/peer/guard/ice_monitor.go index 0f22ee7b0..a201dd095 100644 --- a/client/internal/peer/guard/ice_monitor.go +++ b/client/internal/peer/guard/ice_monitor.go @@ -78,7 +78,7 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) { func (cm *ICEMonitor) handleCandidateTick(ctx context.Context, ufrag string, pwd string) (bool, error) { log.Debugf("Gathering ICE candidates") - agent, err := icemaker.NewAgent(cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd) + agent, err := icemaker.NewAgent(ctx, cm.iFaceDiscover, cm.iceConfig, candidateTypesP2P(), ufrag, pwd) if err != nil { return false, fmt.Errorf("create ICE agent: %w", err) } diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go index 7b929c29d..79f68d279 100644 --- a/client/internal/peer/ice/agent.go +++ b/client/internal/peer/ice/agent.go @@ -1,6 +1,7 @@ package ice import ( + "context" "sync" "time" @@ -49,13 +50,13 @@ func (a *ThreadSafeAgent) Close() error { return err } -func NewAgent(iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) { +func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) { iceKeepAlive := iceKeepAlive() iceDisconnectedTimeout := iceDisconnectedTimeout() iceFailedTimeout := iceFailedTimeout() iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() - transportNet, err := newStdNet(iFaceDiscover, config.InterfaceBlackList) + transportNet, err := newStdNet(ctx, iFaceDiscover, config.InterfaceBlackList) if err != nil { log.Errorf("failed to create pion's stdnet: %s", err) } diff --git a/client/internal/peer/ice/stdnet.go b/client/internal/peer/ice/stdnet.go index 3ce83727e..685ed0363 100644 --- a/client/internal/peer/ice/stdnet.go +++ b/client/internal/peer/ice/stdnet.go @@ -3,9 +3,11 @@ package ice import ( + "context" + "github.com/netbirdio/netbird/client/internal/stdnet" ) -func newStdNet(_ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { - return stdnet.NewNet(ifaceBlacklist) +func newStdNet(ctx context.Context, _ stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { + return stdnet.NewNet(ctx, ifaceBlacklist) } diff --git a/client/internal/peer/ice/stdnet_android.go b/client/internal/peer/ice/stdnet_android.go index 84c665e6f..5033ec1b9 100644 --- a/client/internal/peer/ice/stdnet_android.go +++ b/client/internal/peer/ice/stdnet_android.go @@ -1,7 +1,11 @@ package ice -import "github.com/netbirdio/netbird/client/internal/stdnet" +import ( + "context" -func newStdNet(iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { - return stdnet.NewNetWithDiscover(iFaceDiscover, ifaceBlacklist) + "github.com/netbirdio/netbird/client/internal/stdnet" +) + +func newStdNet(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, ifaceBlacklist []string) (*stdnet.Net, error) { + return stdnet.NewNetWithDiscover(ctx, iFaceDiscover, ifaceBlacklist) } diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 5d8ebfe45..840fc9241 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -209,7 +209,7 @@ func (w *WorkerICE) Close() { } func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) { - agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd) + agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd) if err != nil { return nil, fmt.Errorf("create agent: %w", err) } diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index 693ea1f31..59be5b0a7 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -197,7 +197,7 @@ func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr stri } }() - net, err := stdnet.NewNet(nil) + net, err := stdnet.NewNet(ctx, nil) if err != nil { probeErr = fmt.Errorf("new net: %w", err) return @@ -286,7 +286,7 @@ func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr stri } }() - net, err := stdnet.NewNet(nil) + net, err := stdnet.NewNet(ctx, nil) if err != nil { probeErr = fmt.Errorf("new net: %w", err) return diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index d2f02526c..3697545ae 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -6,7 +6,7 @@ import ( "net/netip" "testing" - "github.com/pion/transport/v3/stdnet" + "github.com/netbirdio/netbird/client/internal/stdnet" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/stretchr/testify/require" @@ -403,7 +403,7 @@ func TestManagerUpdateRoutes(t *testing.T) { for n, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { peerPrivateKey, _ := wgtypes.GeneratePrivateKey() - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) if err != nil { t.Fatal(err) } diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index d9b109beb..01916fbe3 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -15,7 +15,7 @@ import ( "syscall" "testing" - "github.com/pion/transport/v3/stdnet" + "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -436,7 +436,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen peerPrivateKey, err := wgtypes.GeneratePrivateKey() require.NoError(t, err) - newNet, err := stdnet.NewNet() + newNet, err := stdnet.NewNet(context.Background(), nil) require.NoError(t, err) opts := iface.WGIFaceOpts{ diff --git a/client/internal/stdnet/stdnet.go b/client/internal/stdnet/stdnet.go index 4b031c05c..381886ac6 100644 --- a/client/internal/stdnet/stdnet.go +++ b/client/internal/stdnet/stdnet.go @@ -4,17 +4,28 @@ package stdnet import ( + "context" + "errors" "fmt" + "net" + "net/netip" "slices" + "strconv" "sync" "time" - "github.com/netbirdio/netbird/client/iface/netstack" "github.com/pion/transport/v3" "github.com/pion/transport/v3/stdnet" + + "github.com/netbirdio/netbird/client/iface/netstack" ) -const updateInterval = 30 * time.Second +const ( + updateInterval = 30 * time.Second + dnsResolveTimeout = 30 * time.Second +) + +var errNoSuitableAddress = errors.New("no suitable address found") // Net is an implementation of the net.Net interface // based on functions of the standard net package. @@ -28,12 +39,19 @@ type Net struct { // mu is shared between interfaces and lastUpdate mu sync.Mutex + + // ctx is the context for network operations that supports cancellation + ctx context.Context } // NewNetWithDiscover creates a new StdNet instance. -func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) { +func NewNetWithDiscover(ctx context.Context, iFaceDiscover ExternalIFaceDiscover, disallowList []string) (*Net, error) { + if ctx == nil { + ctx = context.Background() + } n := &Net{ interfaceFilter: InterfaceFilter(disallowList), + ctx: ctx, } // current ExternalIFaceDiscover implement in android-client https://github.dev/netbirdio/android-client // so in android cli use pionDiscover @@ -46,14 +64,64 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri } // NewNet creates a new StdNet instance. -func NewNet(disallowList []string) (*Net, error) { +func NewNet(ctx context.Context, disallowList []string) (*Net, error) { + if ctx == nil { + ctx = context.Background() + } n := &Net{ iFaceDiscover: pionDiscover{}, interfaceFilter: InterfaceFilter(disallowList), + ctx: ctx, } return n, n.UpdateInterfaces() } +// resolveAddr performs DNS resolution with context support and timeout. +func (n *Net) resolveAddr(network, address string) (netip.AddrPort, error) { + host, portStr, err := net.SplitHostPort(address) + if err != nil { + return netip.AddrPort{}, err + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return netip.AddrPort{}, fmt.Errorf("invalid port: %w", err) + } + if port < 0 || port > 65535 { + return netip.AddrPort{}, fmt.Errorf("invalid port: %d", port) + } + + ipNet := "ip" + switch network { + case "tcp4", "udp4": + ipNet = "ip4" + case "tcp6", "udp6": + ipNet = "ip6" + } + + if host == "" { + addr := netip.IPv4Unspecified() + if ipNet == "ip6" { + addr = netip.IPv6Unspecified() + } + return netip.AddrPortFrom(addr, uint16(port)), nil + } + + ctx, cancel := context.WithTimeout(n.ctx, dnsResolveTimeout) + defer cancel() + + addrs, err := net.DefaultResolver.LookupNetIP(ctx, ipNet, host) + if err != nil { + return netip.AddrPort{}, err + } + + if len(addrs) == 0 { + return netip.AddrPort{}, errNoSuitableAddress + } + + return netip.AddrPortFrom(addrs[0], uint16(port)), nil +} + // UpdateInterfaces updates the internal list of network interfaces // and associated addresses filtering them by name. // The interfaces are discovered by an external iFaceDiscover function or by a default discoverer if the external one @@ -137,3 +205,39 @@ func (n *Net) filterInterfaces(interfaces []*transport.Interface) []*transport.I } return result } + +// ResolveUDPAddr resolves UDP addresses with context support and timeout. +func (n *Net) ResolveUDPAddr(network, address string) (*net.UDPAddr, error) { + switch network { + case "udp", "udp4", "udp6": + case "": + network = "udp" + default: + return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)} + } + + addrPort, err := n.resolveAddr(network, address) + if err != nil { + return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.UDPAddr{IP: nil}, Err: err} + } + + return net.UDPAddrFromAddrPort(addrPort), nil +} + +// ResolveTCPAddr resolves TCP addresses with context support and timeout. +func (n *Net) ResolveTCPAddr(network, address string) (*net.TCPAddr, error) { + switch network { + case "tcp", "tcp4", "tcp6": + case "": + network = "tcp" + default: + return nil, &net.OpError{Op: "resolve", Net: network, Err: net.UnknownNetworkError(network)} + } + + addrPort, err := n.resolveAddr(network, address) + if err != nil { + return nil, &net.OpError{Op: "resolve", Net: network, Addr: &net.TCPAddr{IP: nil}, Err: err} + } + + return net.TCPAddrFromAddrPort(addrPort), nil +}