diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index ee45a6ba0..7d1065c28 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -173,17 +173,21 @@ func (conn *Conn) reCreateAgent() error { if err != nil { log.Warnf("failed to create pion's stdnet: %s", err) } + hostWait := 500 * time.Millisecond + srflxWait := 1000 * time.Millisecond agentConfig := &ice.AgentConfig{ MulticastDNSMode: ice.MulticastDNSModeDisabled, NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, - Urls: conn.config.StunTurn, - CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}, - FailedTimeout: &failedTimeout, - InterfaceFilter: interfaceFilter(conn.config.InterfaceBlackList), - UDPMux: conn.config.UDPMux, - UDPMuxSrflx: conn.config.UDPMuxSrflx, - NAT1To1IPs: conn.config.NATExternalIPs, - Net: transportNet, + // Urls: conn.config.StunTurn, + CandidateTypes: []ice.CandidateType{ice.CandidateTypeHost}, + FailedTimeout: &failedTimeout, + InterfaceFilter: interfaceFilter(conn.config.InterfaceBlackList), + UDPMux: conn.config.UDPMux, + UDPMuxSrflx: conn.config.UDPMuxSrflx, + NAT1To1IPs: conn.config.NATExternalIPs, + Net: transportNet, + HostAcceptanceMinWait: &hostWait, + SrflxAcceptanceMinWait: &srflxWait, } if conn.config.DisableIPv6Discovery { @@ -423,7 +427,7 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error { } func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy { - + return proxy.NewWireguardProxy(conn.config.ProxyConfig) useProxy := shouldUseProxy(pair) localDirectMode := !useProxy remoteDirectMode := localDirectMode diff --git a/go.mod b/go.mod index 8b7c266be..043b7a6d6 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,7 @@ require ( github.com/vishvananda/netlink v1.1.0 golang.org/x/crypto v0.7.0 golang.org/x/sys v0.6.0 - golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434 + golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de golang.zx2c4.com/wireguard/windows v0.5.1 google.golang.org/grpc v1.52.3 diff --git a/go.sum b/go.sum index e50c4a033..79b5cbd44 100644 --- a/go.sum +++ b/go.sum @@ -888,6 +888,8 @@ golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+D golang.zx2c4.com/wireguard v0.0.0-20211129173154-2dd424e2d808/go.mod h1:TjUWrnD5ATh7bFvmm/ALEJZQ4ivKbETb6pmyj1vUoNI= golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434 h1:3zl8RkJNQ8wfPRomwv/6DBbH2Ut6dgMaWTxM0ZunWnE= golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434/go.mod h1:TjUWrnD5ATh7bFvmm/ALEJZQ4ivKbETb6pmyj1vUoNI= +golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675 h1:/J/RVnr7ng4fWPRH3xa4WtBJ1Jp+Auu4YNLmGiPv5QU= +golang.zx2c4.com/wireguard v0.0.0-20230223181233-21636207a675/go.mod h1:whfbyDBt09xhCYQWtO2+3UVjlaq6/9hDZrjg2ZE6SyA= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de h1:qDZ+lyO5jC9RNJ7ANJA0GWXk3pSn0Fu5SlcAIlgw+6w= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de/go.mod h1:Q2XNgour4QSkFj0BWCkVlW0HWJwQgNMsMahpSlI0Eno= golang.zx2c4.com/wireguard/windows v0.5.1 h1:OnYw96PF+CsIMrqWo5QP3Q59q5hY1rFErk/yN3cS+JQ= diff --git a/iface/bind/bind.go b/iface/bind/bind.go index 458f70f22..442e6ad12 100644 --- a/iface/bind/bind.go +++ b/iface/bind/bind.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "github.com/pion/stun" + "github.com/pion/transport/v2/stdnet" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/conn" "net" @@ -43,7 +44,11 @@ func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { return nil, 0, err } b.sharedConn = ipv4Conn - b.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.sharedConn}) + newNet, err := stdnet.NewNet() + if err != nil { + return nil, 0, err + } + b.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.sharedConn, Net: newNet}) portAddr1, err := netip.ParseAddrPort(ipv4Conn.LocalAddr().String()) if err != nil { @@ -99,11 +104,7 @@ func (b *ICEBind) makeReceiveIPv4(c net.PacketConn) conn.ReceiveFunc { } if !stun.IsMessage(buff[:20]) { // WireGuard traffic - return n, (*conn.StdNetEndpoint)(&net.UDPAddr{ - IP: e.Addr().AsSlice(), - Port: int(e.Port()), - Zone: e.Addr().Zone(), - }), nil + return n, (conn.StdNetEndpoint)(netip.AddrPortFrom(e.Addr(), e.Port())), nil } msg, err := parseSTUNMessage(buff[:n]) @@ -155,20 +156,43 @@ func (b *ICEBind) SetMark(mark uint32) error { } func (b *ICEBind) Send(buff []byte, endpoint conn.Endpoint) error { - nend, ok := endpoint.(*conn.StdNetEndpoint) + + nend, ok := endpoint.(conn.StdNetEndpoint) if !ok { return conn.ErrWrongEndpointType } - _, err := b.sharedConn.WriteTo(buff, (*net.UDPAddr)(nend)) + addrPort := netip.AddrPort(nend) + _, err := b.sharedConn.WriteTo(buff, &net.UDPAddr{ + IP: addrPort.Addr().AsSlice(), + Port: int(addrPort.Port()), + Zone: addrPort.Addr().Zone(), + }) return err } // ParseEndpoint creates a new endpoint from a string. func (b *ICEBind) ParseEndpoint(s string) (ep conn.Endpoint, err error) { e, err := netip.ParseAddrPort(s) - return (*conn.StdNetEndpoint)(&net.UDPAddr{ - IP: e.Addr().AsSlice(), - Port: int(e.Port()), - Zone: e.Addr().Zone(), - }), err + return asEndpoint(e), err +} + +// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint. +// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates, +// but Endpoints are immutable, so we can re-use them. +var endpointPool = sync.Pool{ + New: func() any { + return make(map[netip.AddrPort]conn.Endpoint) + }, +} + +// asEndpoint returns an Endpoint containing ap. +func asEndpoint(ap netip.AddrPort) conn.Endpoint { + m := endpointPool.Get().(map[netip.AddrPort]conn.Endpoint) + defer endpointPool.Put(m) + e, ok := m[ap] + if !ok { + e = conn.Endpoint(conn.StdNetEndpoint(ap)) + m[ap] = e + } + return e } diff --git a/iface/bind/udp_mux.go b/iface/bind/udp_mux.go index 5b46c637f..980f7eba5 100644 --- a/iface/bind/udp_mux.go +++ b/iface/bind/udp_mux.go @@ -2,7 +2,9 @@ package bind import ( "fmt" + "github.com/pion/ice/v2" "github.com/pion/stun" + "github.com/pion/transport/v2/stdnet" log "github.com/sirupsen/logrus" "io" "net" @@ -35,6 +37,8 @@ type UDPMuxDefault struct { // for UDP connection listen at unspecified address localAddrsForUnspecified []net.Addr + + used bool } const maxAddrSize = 512 @@ -50,12 +54,139 @@ type UDPMuxParams struct { Net transport.Net } +func localInterfaces(n transport.Net, interfaceFilter func(string) bool, ipFilter func(net.IP) bool, networkTypes []ice.NetworkType, includeLoopback bool) ([]net.IP, error) { //nolint:gocognit + ips := []net.IP{} + ifaces, err := n.Interfaces() + if err != nil { + return ips, err + } + + var IPv4Requested, IPv6Requested bool + for _, typ := range networkTypes { + if typ.IsIPv4() { + IPv4Requested = true + } + + if typ.IsIPv6() { + IPv6Requested = true + } + } + + for _, iface := range ifaces { + if iface.Flags&net.FlagUp == 0 { + continue // interface down + } + if (iface.Flags&net.FlagLoopback != 0) && !includeLoopback { + continue // loopback interface + } + + if interfaceFilter != nil && !interfaceFilter(iface.Name) { + continue + } + + addrs, err := iface.Addrs() + if err != nil { + continue + } + + for _, addr := range addrs { + var ip net.IP + switch addr := addr.(type) { + case *net.IPNet: + ip = addr.IP + case *net.IPAddr: + ip = addr.IP + } + if ip == nil || (ip.IsLoopback() && !includeLoopback) { + continue + } + + if ipv4 := ip.To4(); ipv4 == nil { + if !IPv6Requested { + continue + } else if !isSupportedIPv6(ip) { + continue + } + } else if !IPv4Requested { + continue + } + + if ipFilter != nil && !ipFilter(ip) { + continue + } + + ips = append(ips, ip) + } + } + return ips, nil +} + +// The conditions of invalidation written below are defined in +// https://tools.ietf.org/html/rfc8445#section-5.1.1.1 +func isSupportedIPv6(ip net.IP) bool { + if len(ip) != net.IPv6len || + isZeros(ip[0:12]) || // !(IPv4-compatible IPv6) + ip[0] == 0xfe && ip[1]&0xc0 == 0xc0 || // !(IPv6 site-local unicast) + ip.IsLinkLocalUnicast() || + ip.IsLinkLocalMulticast() { + return false + } + return true +} + +func isZeros(ip net.IP) bool { + for i := 0; i < len(ip); i++ { + if ip[i] != 0 { + return false + } + } + return true +} + // NewUDPMuxDefault creates an implementation of UDPMux func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { if params.Logger == nil { params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") } + var localAddrsForUnspecified []net.Addr + if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { + params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr()) + } else if ok && addr.IP.IsUnspecified() { + // For unspecified addresses, the correct behavior is to return errListenUnspecified, but + // it will break the applications that are already using unspecified UDP connection + // with UDPMuxDefault, so print a warn log and create a local address list for mux. + params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") + var networks []ice.NetworkType + switch { + case addr.IP.To4() != nil: + networks = []ice.NetworkType{ice.NetworkTypeUDP4} + + case addr.IP.To16() != nil: + networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6} + + default: + params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr()) + } + if len(networks) > 0 { + if params.Net == nil { + var err error + if params.Net, err = stdnet.NewNet(); err != nil { + params.Logger.Errorf("failed to get create network: %v", err) + } + } + + ips, err := localInterfaces(params.Net, nil, nil, networks, true) + if err == nil { + for _, ip := range ips { + localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port}) + } + } else { + params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err) + } + } + } + return &UDPMuxDefault{ addressMap: map[string][]*udpMuxedConn{}, params: params, @@ -68,7 +199,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { return newBufferHolder(receiveMTU + maxAddrSize) }, }, - localAddrsForUnspecified: []net.Addr{}, + localAddrsForUnspecified: localAddrsForUnspecified, } }