mirror of
https://github.com/fosrl/newt.git
synced 2026-02-07 21:46:39 +00:00
1308 lines
34 KiB
Go
1308 lines
34 KiB
Go
/* SPDX-License-Identifier: MIT
|
|
*
|
|
* Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
|
|
*/
|
|
|
|
package netstack2
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/netip"
|
|
"os"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/fosrl/newt/logger"
|
|
"golang.zx2c4.com/wireguard/tun"
|
|
|
|
"golang.org/x/net/dns/dnsmessage"
|
|
"gvisor.dev/gvisor/pkg/buffer"
|
|
"gvisor.dev/gvisor/pkg/tcpip"
|
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
|
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
|
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
|
"gvisor.dev/gvisor/pkg/waiter"
|
|
)
|
|
|
|
type netTun struct {
|
|
ep *channel.Endpoint
|
|
proxyEp *channel.Endpoint // Separate endpoint for promiscuous mode
|
|
stack *stack.Stack
|
|
events chan tun.Event
|
|
notifyHandle *channel.NotificationHandle
|
|
proxyNotifyHandle *channel.NotificationHandle // Notify handle for proxy endpoint
|
|
incomingPacket chan *buffer.View
|
|
mtu int
|
|
dnsServers []netip.Addr
|
|
hasV4, hasV6 bool
|
|
tcpHandler *TCPHandler
|
|
udpHandler *UDPHandler
|
|
}
|
|
|
|
type Net netTun
|
|
|
|
// NetTunOptions contains options for creating a NetTUN device
|
|
type NetTunOptions struct {
|
|
EnableTCPProxy bool
|
|
EnableUDPProxy bool
|
|
}
|
|
|
|
// CreateNetTUN creates a new TUN device with netstack without proxying
|
|
func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
|
|
return CreateNetTUNWithOptions(localAddresses, dnsServers, mtu, NetTunOptions{
|
|
EnableTCPProxy: true,
|
|
EnableUDPProxy: true,
|
|
})
|
|
}
|
|
|
|
// CreateNetTUNWithOptions creates a new TUN device with netstack and optional TCP/UDP proxying
|
|
func CreateNetTUNWithOptions(localAddresses, dnsServers []netip.Addr, mtu int, options NetTunOptions) (tun.Device, *Net, error) {
|
|
stackOpts := stack.Options{
|
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
|
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
|
|
HandleLocal: true,
|
|
}
|
|
dev := &netTun{
|
|
ep: channel.New(1024, uint32(mtu), ""),
|
|
proxyEp: channel.New(1024, uint32(mtu), ""),
|
|
stack: stack.New(stackOpts),
|
|
events: make(chan tun.Event, 10),
|
|
incomingPacket: make(chan *buffer.View),
|
|
dnsServers: dnsServers,
|
|
mtu: mtu,
|
|
}
|
|
|
|
if options.EnableTCPProxy {
|
|
dev.tcpHandler = NewTCPHandler(dev.stack)
|
|
if err := dev.tcpHandler.InstallTCPHandler(); err != nil {
|
|
return nil, nil, fmt.Errorf("failed to install TCP handler: %v", err)
|
|
}
|
|
}
|
|
|
|
if options.EnableUDPProxy {
|
|
dev.udpHandler = NewUDPHandler(dev.stack)
|
|
if err := dev.udpHandler.InstallUDPHandler(); err != nil {
|
|
return nil, nil, fmt.Errorf("failed to install UDP handler: %v", err)
|
|
}
|
|
}
|
|
|
|
sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is enabled by default
|
|
tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
|
|
if tcpipErr != nil {
|
|
return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
|
|
}
|
|
// Create NIC 1 (main interface, no promiscuous mode)
|
|
dev.notifyHandle = dev.ep.AddNotify(dev)
|
|
tcpipErr = dev.stack.CreateNIC(1, dev.ep)
|
|
if tcpipErr != nil {
|
|
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
|
|
}
|
|
|
|
for _, ip := range localAddresses {
|
|
var protoNumber tcpip.NetworkProtocolNumber
|
|
if ip.Is4() {
|
|
protoNumber = ipv4.ProtocolNumber
|
|
} else if ip.Is6() {
|
|
protoNumber = ipv6.ProtocolNumber
|
|
}
|
|
protoAddr := tcpip.ProtocolAddress{
|
|
Protocol: protoNumber,
|
|
AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
|
|
}
|
|
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
|
|
if tcpipErr != nil {
|
|
return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
|
|
}
|
|
if ip.Is4() {
|
|
dev.hasV4 = true
|
|
} else if ip.Is6() {
|
|
dev.hasV6 = true
|
|
}
|
|
}
|
|
if dev.hasV4 {
|
|
// dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
|
|
// add 100.90.129.0/24
|
|
proxySubnet := netip.MustParsePrefix("100.90.129.0/24")
|
|
proxyTcpipSubnet, err := tcpip.NewSubnet(
|
|
tcpip.AddrFromSlice(proxySubnet.Addr().AsSlice()),
|
|
tcpip.MaskFromBytes(proxySubnet.Addr().AsSlice()),
|
|
)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to create proxy subnet: %v", err)
|
|
}
|
|
dev.stack.AddRoute(tcpip.Route{Destination: proxyTcpipSubnet, NIC: 1})
|
|
}
|
|
// if dev.hasV6 {
|
|
// // dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
|
|
// }
|
|
|
|
// Add specific route for proxy network (10.20.20.0/24) to NIC 2
|
|
if options.EnableTCPProxy || options.EnableUDPProxy {
|
|
dev.proxyNotifyHandle = dev.proxyEp.AddNotify(dev)
|
|
tcpipErr = dev.stack.CreateNIC(2, dev.proxyEp)
|
|
if tcpipErr != nil {
|
|
return nil, nil, fmt.Errorf("CreateNIC 2 (proxy): %v", tcpipErr)
|
|
}
|
|
|
|
// Enable promiscuous mode ONLY on NIC 2
|
|
if tcpipErr := dev.stack.SetPromiscuousMode(2, true); tcpipErr != nil {
|
|
return nil, nil, fmt.Errorf("SetPromiscuousMode on NIC 2: %v", tcpipErr)
|
|
}
|
|
|
|
// Enable spoofing ONLY on NIC 2
|
|
if tcpipErr := dev.stack.SetSpoofing(2, true); tcpipErr != nil {
|
|
return nil, nil, fmt.Errorf("SetSpoofing on NIC 2: %v", tcpipErr)
|
|
}
|
|
|
|
// Add the proxy network address (10.20.20.1/24) to NIC 2
|
|
// This allows the stack to accept connections to any IP in this range when in promiscuous mode
|
|
// Similar to how tun2socks adds 10.0.0.1/8 for multicast support
|
|
// The PEB: CanBePrimaryEndpoint is CRITICAL - it allows the stack to build routes
|
|
// and accept connections to any IP in this range when in promiscuous+spoofing mode
|
|
proxyAddr := netip.MustParseAddr("10.20.20.1")
|
|
protoAddr := tcpip.ProtocolAddress{
|
|
Protocol: ipv4.ProtocolNumber,
|
|
AddressWithPrefix: tcpip.AddressWithPrefix{
|
|
Address: tcpip.AddrFromSlice(proxyAddr.AsSlice()),
|
|
PrefixLen: 24, // /24 network
|
|
},
|
|
}
|
|
tcpipErr = dev.stack.AddProtocolAddress(2, protoAddr, stack.AddressProperties{
|
|
PEB: stack.CanBePrimaryEndpoint, // Allow this to be used as primary endpoint
|
|
})
|
|
if tcpipErr != nil {
|
|
return nil, nil, fmt.Errorf("AddProtocolAddress for proxy NIC: %v", tcpipErr)
|
|
}
|
|
|
|
proxySubnet := netip.MustParsePrefix("10.20.20.0/24")
|
|
proxyTcpipSubnet, err := tcpip.NewSubnet(
|
|
tcpip.AddrFromSlice(proxySubnet.Addr().AsSlice()),
|
|
tcpip.MaskFromBytes(net.CIDRMask(24, 32)),
|
|
)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to create proxy subnet: %v", err)
|
|
}
|
|
|
|
dev.stack.AddRoute(tcpip.Route{
|
|
Destination: proxyTcpipSubnet,
|
|
NIC: 2,
|
|
})
|
|
}
|
|
|
|
// print the stack routes table and interfaces for debugging
|
|
logger.Info("Stack configuration:")
|
|
|
|
// Print NICs
|
|
nics := dev.stack.NICInfo()
|
|
for nicID, nicInfo := range nics {
|
|
logger.Info("NIC %d: %s (MTU: %d)", nicID, nicInfo.Name, nicInfo.MTU)
|
|
for _, addr := range nicInfo.ProtocolAddresses {
|
|
logger.Info(" Address: %s", addr.AddressWithPrefix)
|
|
}
|
|
}
|
|
|
|
// Print routing table
|
|
routes := dev.stack.GetRouteTable()
|
|
logger.Info("Routing table (%d routes):", len(routes))
|
|
for i, route := range routes {
|
|
logger.Info(" Route %d: %s -> NIC %d", i, route.Destination, route.NIC)
|
|
}
|
|
|
|
dev.events <- tun.EventUp
|
|
return dev, (*Net)(dev), nil
|
|
}
|
|
|
|
func (tun *netTun) Name() (string, error) {
|
|
return "go", nil
|
|
}
|
|
|
|
func (tun *netTun) File() *os.File {
|
|
return nil
|
|
}
|
|
|
|
func (tun *netTun) Events() <-chan tun.Event {
|
|
return tun.events
|
|
}
|
|
|
|
func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
|
|
view, ok := <-tun.incomingPacket
|
|
if !ok {
|
|
return 0, os.ErrClosed
|
|
}
|
|
|
|
n, err := view.Read(buf[0][offset:])
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
sizes[0] = n
|
|
return 1, nil
|
|
}
|
|
|
|
func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
|
|
for _, buf := range buf {
|
|
packet := buf[offset:]
|
|
if len(packet) == 0 {
|
|
continue
|
|
}
|
|
|
|
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
|
|
|
|
// Determine which NIC to inject the packet into based on destination IP
|
|
targetEp := tun.ep // Default to NIC 1
|
|
|
|
switch packet[0] >> 4 {
|
|
case 4:
|
|
// Parse IPv4 header to check destination
|
|
if len(packet) >= header.IPv4MinimumSize {
|
|
ipv4Header := header.IPv4(packet)
|
|
dstIP := ipv4Header.DestinationAddress()
|
|
|
|
// Check if destination is in the proxy range (10.20.20.0/24)
|
|
// If so, inject into proxyEp (NIC 2) which has promiscuous mode
|
|
if tun.proxyEp != nil {
|
|
dstBytes := dstIP.As4()
|
|
// Check for 10.20.20.x
|
|
if dstBytes[0] == 10 && dstBytes[1] == 20 && dstBytes[2] == 20 {
|
|
targetEp = tun.proxyEp
|
|
// Log what protocol this is
|
|
proto := "unknown"
|
|
if len(packet) > header.IPv4MinimumSize {
|
|
switch ipv4Header.Protocol() {
|
|
case uint8(header.TCPProtocolNumber):
|
|
proto = "TCP"
|
|
case uint8(header.UDPProtocolNumber):
|
|
proto = "UDP"
|
|
case uint8(header.ICMPv4ProtocolNumber):
|
|
proto = "ICMP"
|
|
}
|
|
}
|
|
logger.Info("Routing %s packet to NIC 2 (proxy): dst=%s", proto, dstIP)
|
|
}
|
|
}
|
|
}
|
|
targetEp.InjectInbound(header.IPv4ProtocolNumber, pkb)
|
|
case 6:
|
|
// For IPv6, always use NIC 1 for now
|
|
targetEp.InjectInbound(header.IPv6ProtocolNumber, pkb)
|
|
default:
|
|
return 0, syscall.EAFNOSUPPORT
|
|
}
|
|
}
|
|
return len(buf), nil
|
|
}
|
|
|
|
// logPacketDetails parses and logs packet information
|
|
func logPacketDetails(pkt *stack.PacketBuffer, nicID int) {
|
|
netProto := pkt.NetworkProtocolNumber
|
|
var srcIP, dstIP string
|
|
var protocol string
|
|
var srcPort, dstPort uint16
|
|
|
|
// Parse network layer
|
|
switch netProto {
|
|
case header.IPv4ProtocolNumber:
|
|
if pkt.NetworkHeader().View().Size() >= header.IPv4MinimumSize {
|
|
ipv4 := header.IPv4(pkt.NetworkHeader().Slice())
|
|
srcIP = ipv4.SourceAddress().String()
|
|
dstIP = ipv4.DestinationAddress().String()
|
|
|
|
// Parse transport layer
|
|
switch ipv4.Protocol() {
|
|
case uint8(header.TCPProtocolNumber):
|
|
protocol = "TCP"
|
|
if pkt.TransportHeader().View().Size() >= header.TCPMinimumSize {
|
|
tcp := header.TCP(pkt.TransportHeader().Slice())
|
|
srcPort = tcp.SourcePort()
|
|
dstPort = tcp.DestinationPort()
|
|
}
|
|
case uint8(header.UDPProtocolNumber):
|
|
protocol = "UDP"
|
|
if pkt.TransportHeader().View().Size() >= header.UDPMinimumSize {
|
|
udp := header.UDP(pkt.TransportHeader().Slice())
|
|
srcPort = udp.SourcePort()
|
|
dstPort = udp.DestinationPort()
|
|
}
|
|
case uint8(header.ICMPv4ProtocolNumber):
|
|
protocol = "ICMPv4"
|
|
default:
|
|
protocol = fmt.Sprintf("Proto-%d", ipv4.Protocol())
|
|
}
|
|
}
|
|
case header.IPv6ProtocolNumber:
|
|
if pkt.NetworkHeader().View().Size() >= header.IPv6MinimumSize {
|
|
ipv6 := header.IPv6(pkt.NetworkHeader().Slice())
|
|
srcIP = ipv6.SourceAddress().String()
|
|
dstIP = ipv6.DestinationAddress().String()
|
|
|
|
// Parse transport layer
|
|
switch ipv6.TransportProtocol() {
|
|
case header.TCPProtocolNumber:
|
|
protocol = "TCP"
|
|
if pkt.TransportHeader().View().Size() >= header.TCPMinimumSize {
|
|
tcp := header.TCP(pkt.TransportHeader().Slice())
|
|
srcPort = tcp.SourcePort()
|
|
dstPort = tcp.DestinationPort()
|
|
}
|
|
case header.UDPProtocolNumber:
|
|
protocol = "UDP"
|
|
if pkt.TransportHeader().View().Size() >= header.UDPMinimumSize {
|
|
udp := header.UDP(pkt.TransportHeader().Slice())
|
|
srcPort = udp.SourcePort()
|
|
dstPort = udp.DestinationPort()
|
|
}
|
|
case header.ICMPv6ProtocolNumber:
|
|
protocol = "ICMPv6"
|
|
default:
|
|
protocol = fmt.Sprintf("Proto-%d", ipv6.TransportProtocol())
|
|
}
|
|
}
|
|
default:
|
|
protocol = fmt.Sprintf("Unknown-NetProto-%d", netProto)
|
|
}
|
|
|
|
packetSize := pkt.Size()
|
|
|
|
if srcPort > 0 && dstPort > 0 {
|
|
logger.Info("NIC %d packet: %s %s:%d -> %s:%d (size: %d bytes)",
|
|
nicID, protocol, srcIP, srcPort, dstIP, dstPort, packetSize)
|
|
} else {
|
|
logger.Info("NIC %d packet: %s %s -> %s (size: %d bytes)",
|
|
nicID, protocol, srcIP, dstIP, packetSize)
|
|
}
|
|
}
|
|
|
|
func (tun *netTun) WriteNotify() {
|
|
// Handle notifications from main endpoint (NIC 1)
|
|
pkt := tun.ep.Read()
|
|
if pkt != nil {
|
|
view := pkt.ToView()
|
|
pkt.DecRef()
|
|
tun.incomingPacket <- view
|
|
return
|
|
}
|
|
|
|
// Handle notifications from proxy endpoint (NIC 2) if it exists
|
|
if tun.proxyEp != nil {
|
|
pkt = tun.proxyEp.Read()
|
|
if pkt != nil {
|
|
view := pkt.ToView()
|
|
pkt.DecRef()
|
|
tun.incomingPacket <- view
|
|
}
|
|
}
|
|
}
|
|
|
|
func (tun *netTun) Close() error {
|
|
tun.stack.RemoveNIC(1)
|
|
|
|
// Clean up proxy NIC if it exists
|
|
if tun.proxyEp != nil {
|
|
tun.stack.RemoveNIC(2)
|
|
tun.proxyEp.RemoveNotify(tun.proxyNotifyHandle)
|
|
tun.proxyEp.Close()
|
|
}
|
|
|
|
tun.stack.Close()
|
|
tun.ep.RemoveNotify(tun.notifyHandle)
|
|
tun.ep.Close()
|
|
|
|
if tun.events != nil {
|
|
close(tun.events)
|
|
}
|
|
|
|
if tun.incomingPacket != nil {
|
|
close(tun.incomingPacket)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (tun *netTun) MTU() (int, error) {
|
|
return tun.mtu, nil
|
|
}
|
|
|
|
func (tun *netTun) BatchSize() int {
|
|
return 1
|
|
}
|
|
|
|
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
|
|
var protoNumber tcpip.NetworkProtocolNumber
|
|
if endpoint.Addr().Is4() {
|
|
protoNumber = ipv4.ProtocolNumber
|
|
} else {
|
|
protoNumber = ipv6.ProtocolNumber
|
|
}
|
|
return tcpip.FullAddress{
|
|
NIC: 1,
|
|
Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
|
|
Port: endpoint.Port(),
|
|
}, protoNumber
|
|
}
|
|
|
|
func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
|
|
fa, pn := convertToFullAddr(addr)
|
|
return gonet.DialContextTCP(ctx, net.stack, fa, pn)
|
|
}
|
|
|
|
func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
|
|
if addr == nil {
|
|
return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
|
|
}
|
|
ip, _ := netip.AddrFromSlice(addr.IP)
|
|
return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port)))
|
|
}
|
|
|
|
func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
|
|
fa, pn := convertToFullAddr(addr)
|
|
return gonet.DialTCP(net.stack, fa, pn)
|
|
}
|
|
|
|
func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
|
|
if addr == nil {
|
|
return net.DialTCPAddrPort(netip.AddrPort{})
|
|
}
|
|
ip, _ := netip.AddrFromSlice(addr.IP)
|
|
return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
|
|
}
|
|
|
|
func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
|
|
fa, pn := convertToFullAddr(addr)
|
|
return gonet.ListenTCP(net.stack, fa, pn)
|
|
}
|
|
|
|
func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
|
|
if addr == nil {
|
|
return net.ListenTCPAddrPort(netip.AddrPort{})
|
|
}
|
|
ip, _ := netip.AddrFromSlice(addr.IP)
|
|
return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
|
|
}
|
|
|
|
func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
|
|
var lfa, rfa *tcpip.FullAddress
|
|
var pn tcpip.NetworkProtocolNumber
|
|
if laddr.IsValid() || laddr.Port() > 0 {
|
|
var addr tcpip.FullAddress
|
|
addr, pn = convertToFullAddr(laddr)
|
|
lfa = &addr
|
|
}
|
|
if raddr.IsValid() || raddr.Port() > 0 {
|
|
var addr tcpip.FullAddress
|
|
addr, pn = convertToFullAddr(raddr)
|
|
rfa = &addr
|
|
}
|
|
return gonet.DialUDP(net.stack, lfa, rfa, pn)
|
|
}
|
|
|
|
func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) {
|
|
return net.DialUDPAddrPort(laddr, netip.AddrPort{})
|
|
}
|
|
|
|
func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
|
|
var la, ra netip.AddrPort
|
|
if laddr != nil {
|
|
ip, _ := netip.AddrFromSlice(laddr.IP)
|
|
la = netip.AddrPortFrom(ip, uint16(laddr.Port))
|
|
}
|
|
if raddr != nil {
|
|
ip, _ := netip.AddrFromSlice(raddr.IP)
|
|
ra = netip.AddrPortFrom(ip, uint16(raddr.Port))
|
|
}
|
|
return net.DialUDPAddrPort(la, ra)
|
|
}
|
|
|
|
func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
|
|
return net.DialUDP(laddr, nil)
|
|
}
|
|
|
|
type PingConn struct {
|
|
laddr PingAddr
|
|
raddr PingAddr
|
|
wq waiter.Queue
|
|
ep tcpip.Endpoint
|
|
deadline *time.Timer
|
|
}
|
|
|
|
type PingAddr struct{ addr netip.Addr }
|
|
|
|
func (ia PingAddr) String() string {
|
|
return ia.addr.String()
|
|
}
|
|
|
|
func (ia PingAddr) Network() string {
|
|
if ia.addr.Is4() {
|
|
return "ping4"
|
|
} else if ia.addr.Is6() {
|
|
return "ping6"
|
|
}
|
|
return "ping"
|
|
}
|
|
|
|
func (ia PingAddr) Addr() netip.Addr {
|
|
return ia.addr
|
|
}
|
|
|
|
func PingAddrFromAddr(addr netip.Addr) *PingAddr {
|
|
return &PingAddr{addr}
|
|
}
|
|
|
|
func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
|
|
if !laddr.IsValid() && !raddr.IsValid() {
|
|
return nil, errors.New("ping dial: invalid address")
|
|
}
|
|
v6 := laddr.Is6() || raddr.Is6()
|
|
bind := laddr.IsValid()
|
|
if !bind {
|
|
if v6 {
|
|
laddr = netip.IPv6Unspecified()
|
|
} else {
|
|
laddr = netip.IPv4Unspecified()
|
|
}
|
|
}
|
|
|
|
tn := icmp.ProtocolNumber4
|
|
pn := ipv4.ProtocolNumber
|
|
if v6 {
|
|
tn = icmp.ProtocolNumber6
|
|
pn = ipv6.ProtocolNumber
|
|
}
|
|
|
|
pc := &PingConn{
|
|
laddr: PingAddr{laddr},
|
|
deadline: time.NewTimer(time.Hour << 10),
|
|
}
|
|
pc.deadline.Stop()
|
|
|
|
ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
|
|
if tcpipErr != nil {
|
|
return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr)
|
|
}
|
|
pc.ep = ep
|
|
|
|
if bind {
|
|
fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0))
|
|
if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil {
|
|
return nil, fmt.Errorf("ping bind: %s", tcpipErr)
|
|
}
|
|
}
|
|
|
|
if raddr.IsValid() {
|
|
pc.raddr = PingAddr{raddr}
|
|
fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0))
|
|
if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil {
|
|
return nil, fmt.Errorf("ping connect: %s", tcpipErr)
|
|
}
|
|
}
|
|
|
|
return pc, nil
|
|
}
|
|
|
|
func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) {
|
|
return net.DialPingAddr(laddr, netip.Addr{})
|
|
}
|
|
|
|
func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) {
|
|
var la, ra netip.Addr
|
|
if laddr != nil {
|
|
la = laddr.addr
|
|
}
|
|
if raddr != nil {
|
|
ra = raddr.addr
|
|
}
|
|
return net.DialPingAddr(la, ra)
|
|
}
|
|
|
|
func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) {
|
|
var la netip.Addr
|
|
if laddr != nil {
|
|
la = laddr.addr
|
|
}
|
|
return net.ListenPingAddr(la)
|
|
}
|
|
|
|
func (pc *PingConn) LocalAddr() net.Addr {
|
|
return pc.laddr
|
|
}
|
|
|
|
func (pc *PingConn) RemoteAddr() net.Addr {
|
|
return pc.raddr
|
|
}
|
|
|
|
func (pc *PingConn) Close() error {
|
|
pc.deadline.Reset(0)
|
|
pc.ep.Close()
|
|
return nil
|
|
}
|
|
|
|
func (pc *PingConn) SetWriteDeadline(t time.Time) error {
|
|
return errors.New("not implemented")
|
|
}
|
|
|
|
func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|
var na netip.Addr
|
|
switch v := addr.(type) {
|
|
case *PingAddr:
|
|
na = v.addr
|
|
case *net.IPAddr:
|
|
na, _ = netip.AddrFromSlice(v.IP)
|
|
default:
|
|
return 0, fmt.Errorf("ping write: wrong net.Addr type")
|
|
}
|
|
if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) {
|
|
return 0, fmt.Errorf("ping write: mismatched protocols")
|
|
}
|
|
|
|
buf := bytes.NewReader(p)
|
|
rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0))
|
|
// won't block, no deadlines
|
|
n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{
|
|
To: &rfa,
|
|
})
|
|
if tcpipErr != nil {
|
|
return int(n64), fmt.Errorf("ping write: %s", tcpipErr)
|
|
}
|
|
|
|
return int(n64), nil
|
|
}
|
|
|
|
func (pc *PingConn) Write(p []byte) (n int, err error) {
|
|
return pc.WriteTo(p, &pc.raddr)
|
|
}
|
|
|
|
func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
|
e, notifyCh := waiter.NewChannelEntry(waiter.EventIn)
|
|
pc.wq.EventRegister(&e)
|
|
defer pc.wq.EventUnregister(&e)
|
|
|
|
select {
|
|
case <-pc.deadline.C:
|
|
return 0, nil, os.ErrDeadlineExceeded
|
|
case <-notifyCh:
|
|
}
|
|
|
|
w := tcpip.SliceWriter(p)
|
|
|
|
res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{
|
|
NeedRemoteAddr: true,
|
|
})
|
|
if tcpipErr != nil {
|
|
return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
|
|
}
|
|
|
|
remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice())
|
|
return res.Count, &PingAddr{remoteAddr}, nil
|
|
}
|
|
|
|
func (pc *PingConn) Read(p []byte) (n int, err error) {
|
|
n, _, err = pc.ReadFrom(p)
|
|
return
|
|
}
|
|
|
|
func (pc *PingConn) SetDeadline(t time.Time) error {
|
|
// pc.SetWriteDeadline is unimplemented
|
|
|
|
return pc.SetReadDeadline(t)
|
|
}
|
|
|
|
func (pc *PingConn) SetReadDeadline(t time.Time) error {
|
|
pc.deadline.Reset(time.Until(t))
|
|
return nil
|
|
}
|
|
|
|
var (
|
|
errNoSuchHost = errors.New("no such host")
|
|
errLameReferral = errors.New("lame referral")
|
|
errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message")
|
|
errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message")
|
|
errServerMisbehaving = errors.New("server misbehaving")
|
|
errInvalidDNSResponse = errors.New("invalid DNS response")
|
|
errNoAnswerFromDNSServer = errors.New("no answer from DNS server")
|
|
errServerTemporarilyMisbehaving = errors.New("server misbehaving")
|
|
errCanceled = errors.New("operation was canceled")
|
|
errTimeout = errors.New("i/o timeout")
|
|
errNumericPort = errors.New("port must be numeric")
|
|
errNoSuitableAddress = errors.New("no suitable address found")
|
|
errMissingAddress = errors.New("missing address")
|
|
)
|
|
|
|
func (net *Net) LookupHost(host string) (addrs []string, err error) {
|
|
return net.LookupContextHost(context.Background(), host)
|
|
}
|
|
|
|
func isDomainName(s string) bool {
|
|
l := len(s)
|
|
if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
|
|
return false
|
|
}
|
|
last := byte('.')
|
|
nonNumeric := false
|
|
partlen := 0
|
|
for i := 0; i < len(s); i++ {
|
|
c := s[i]
|
|
switch {
|
|
default:
|
|
return false
|
|
case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
|
|
nonNumeric = true
|
|
partlen++
|
|
case '0' <= c && c <= '9':
|
|
partlen++
|
|
case c == '-':
|
|
if last == '.' {
|
|
return false
|
|
}
|
|
partlen++
|
|
nonNumeric = true
|
|
case c == '.':
|
|
if last == '.' || last == '-' {
|
|
return false
|
|
}
|
|
if partlen > 63 || partlen == 0 {
|
|
return false
|
|
}
|
|
partlen = 0
|
|
}
|
|
last = c
|
|
}
|
|
if last == '-' || partlen > 63 {
|
|
return false
|
|
}
|
|
return nonNumeric
|
|
}
|
|
|
|
func randU16() uint16 {
|
|
var b [2]byte
|
|
_, err := rand.Read(b[:])
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return binary.LittleEndian.Uint16(b[:])
|
|
}
|
|
|
|
func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) {
|
|
id = randU16()
|
|
b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true})
|
|
b.EnableCompression()
|
|
if err := b.StartQuestions(); err != nil {
|
|
return 0, nil, nil, err
|
|
}
|
|
if err := b.Question(q); err != nil {
|
|
return 0, nil, nil, err
|
|
}
|
|
tcpReq, err = b.Finish()
|
|
udpReq = tcpReq[2:]
|
|
l := len(tcpReq) - 2
|
|
tcpReq[0] = byte(l >> 8)
|
|
tcpReq[1] = byte(l)
|
|
return id, udpReq, tcpReq, err
|
|
}
|
|
|
|
func equalASCIIName(x, y dnsmessage.Name) bool {
|
|
if x.Length != y.Length {
|
|
return false
|
|
}
|
|
for i := 0; i < int(x.Length); i++ {
|
|
a := x.Data[i]
|
|
b := y.Data[i]
|
|
if 'A' <= a && a <= 'Z' {
|
|
a += 0x20
|
|
}
|
|
if 'A' <= b && b <= 'Z' {
|
|
b += 0x20
|
|
}
|
|
if a != b {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool {
|
|
if !respHdr.Response {
|
|
return false
|
|
}
|
|
if reqID != respHdr.ID {
|
|
return false
|
|
}
|
|
if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
|
|
if _, err := c.Write(b); err != nil {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, err
|
|
}
|
|
b = make([]byte, 512)
|
|
for {
|
|
n, err := c.Read(b)
|
|
if err != nil {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, err
|
|
}
|
|
var p dnsmessage.Parser
|
|
h, err := p.Start(b[:n])
|
|
if err != nil {
|
|
continue
|
|
}
|
|
q, err := p.Question()
|
|
if err != nil || !checkResponse(id, query, h, q) {
|
|
continue
|
|
}
|
|
return p, h, nil
|
|
}
|
|
}
|
|
|
|
func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
|
|
if _, err := c.Write(b); err != nil {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, err
|
|
}
|
|
b = make([]byte, 1280)
|
|
if _, err := io.ReadFull(c, b[:2]); err != nil {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, err
|
|
}
|
|
l := int(b[0])<<8 | int(b[1])
|
|
if l > len(b) {
|
|
b = make([]byte, l)
|
|
}
|
|
n, err := io.ReadFull(c, b[:l])
|
|
if err != nil {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, err
|
|
}
|
|
var p dnsmessage.Parser
|
|
h, err := p.Start(b[:n])
|
|
if err != nil {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
|
|
}
|
|
q, err := p.Question()
|
|
if err != nil {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
|
|
}
|
|
if !checkResponse(id, query, h, q) {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
|
|
}
|
|
return p, h, nil
|
|
}
|
|
|
|
func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
|
|
q.Class = dnsmessage.ClassINET
|
|
id, udpReq, tcpReq, err := newRequest(q)
|
|
if err != nil {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage
|
|
}
|
|
|
|
for _, useUDP := range []bool{true, false} {
|
|
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
|
|
defer cancel()
|
|
|
|
var c net.Conn
|
|
var err error
|
|
if useUDP {
|
|
c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
|
|
} else {
|
|
c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
|
|
}
|
|
|
|
if err != nil {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, err
|
|
}
|
|
if d, ok := ctx.Deadline(); ok && !d.IsZero() {
|
|
err := c.SetDeadline(d)
|
|
if err != nil {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, err
|
|
}
|
|
}
|
|
var p dnsmessage.Parser
|
|
var h dnsmessage.Header
|
|
if useUDP {
|
|
p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
|
|
} else {
|
|
p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
|
|
}
|
|
c.Close()
|
|
if err != nil {
|
|
if err == context.Canceled {
|
|
err = errCanceled
|
|
} else if err == context.DeadlineExceeded {
|
|
err = errTimeout
|
|
}
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, err
|
|
}
|
|
if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
|
|
}
|
|
if h.Truncated {
|
|
continue
|
|
}
|
|
return p, h, nil
|
|
}
|
|
return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer
|
|
}
|
|
|
|
func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
|
|
if h.RCode == dnsmessage.RCodeNameError {
|
|
return errNoSuchHost
|
|
}
|
|
_, err := p.AnswerHeader()
|
|
if err != nil && err != dnsmessage.ErrSectionDone {
|
|
return errCannotUnmarshalDNSMessage
|
|
}
|
|
if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
|
|
return errLameReferral
|
|
}
|
|
if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError {
|
|
if h.RCode == dnsmessage.RCodeServerFailure {
|
|
return errServerTemporarilyMisbehaving
|
|
}
|
|
return errServerMisbehaving
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error {
|
|
for {
|
|
h, err := p.AnswerHeader()
|
|
if err == dnsmessage.ErrSectionDone {
|
|
return errNoSuchHost
|
|
}
|
|
if err != nil {
|
|
return errCannotUnmarshalDNSMessage
|
|
}
|
|
if h.Type == qtype {
|
|
return nil
|
|
}
|
|
if err := p.SkipAnswer(); err != nil {
|
|
return errCannotUnmarshalDNSMessage
|
|
}
|
|
}
|
|
}
|
|
|
|
func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
|
|
var lastErr error
|
|
|
|
n, err := dnsmessage.NewName(name)
|
|
if err != nil {
|
|
return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage
|
|
}
|
|
q := dnsmessage.Question{
|
|
Name: n,
|
|
Type: qtype,
|
|
Class: dnsmessage.ClassINET,
|
|
}
|
|
|
|
for i := 0; i < 2; i++ {
|
|
for _, server := range tnet.dnsServers {
|
|
p, h, err := tnet.exchange(ctx, server, q, time.Second*5)
|
|
if err != nil {
|
|
dnsErr := &net.DNSError{
|
|
Err: err.Error(),
|
|
Name: name,
|
|
Server: server.String(),
|
|
}
|
|
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
|
|
dnsErr.IsTimeout = true
|
|
}
|
|
if _, ok := err.(*net.OpError); ok {
|
|
dnsErr.IsTemporary = true
|
|
}
|
|
lastErr = dnsErr
|
|
continue
|
|
}
|
|
|
|
if err := checkHeader(&p, h); err != nil {
|
|
dnsErr := &net.DNSError{
|
|
Err: err.Error(),
|
|
Name: name,
|
|
Server: server.String(),
|
|
}
|
|
if err == errServerTemporarilyMisbehaving {
|
|
dnsErr.IsTemporary = true
|
|
}
|
|
if err == errNoSuchHost {
|
|
dnsErr.IsNotFound = true
|
|
return p, server.String(), dnsErr
|
|
}
|
|
lastErr = dnsErr
|
|
continue
|
|
}
|
|
|
|
err = skipToAnswer(&p, qtype)
|
|
if err == nil {
|
|
return p, server.String(), nil
|
|
}
|
|
lastErr = &net.DNSError{
|
|
Err: err.Error(),
|
|
Name: name,
|
|
Server: server.String(),
|
|
}
|
|
if err == errNoSuchHost {
|
|
lastErr.(*net.DNSError).IsNotFound = true
|
|
return p, server.String(), lastErr
|
|
}
|
|
}
|
|
}
|
|
return dnsmessage.Parser{}, "", lastErr
|
|
}
|
|
|
|
func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) {
|
|
if host == "" || (!tnet.hasV6 && !tnet.hasV4) {
|
|
return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
|
|
}
|
|
zlen := len(host)
|
|
if strings.IndexByte(host, ':') != -1 {
|
|
if zidx := strings.LastIndexByte(host, '%'); zidx != -1 {
|
|
zlen = zidx
|
|
}
|
|
}
|
|
if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
|
|
return []string{ip.String()}, nil
|
|
}
|
|
|
|
if !isDomainName(host) {
|
|
return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
|
|
}
|
|
type result struct {
|
|
p dnsmessage.Parser
|
|
server string
|
|
error
|
|
}
|
|
var addrsV4, addrsV6 []netip.Addr
|
|
lanes := 0
|
|
if tnet.hasV4 {
|
|
lanes++
|
|
}
|
|
if tnet.hasV6 {
|
|
lanes++
|
|
}
|
|
lane := make(chan result, lanes)
|
|
var lastErr error
|
|
if tnet.hasV4 {
|
|
go func() {
|
|
p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA)
|
|
lane <- result{p, server, err}
|
|
}()
|
|
}
|
|
if tnet.hasV6 {
|
|
go func() {
|
|
p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA)
|
|
lane <- result{p, server, err}
|
|
}()
|
|
}
|
|
for l := 0; l < lanes; l++ {
|
|
result := <-lane
|
|
if result.error != nil {
|
|
if lastErr == nil {
|
|
lastErr = result.error
|
|
}
|
|
continue
|
|
}
|
|
|
|
loop:
|
|
for {
|
|
h, err := result.p.AnswerHeader()
|
|
if err != nil && err != dnsmessage.ErrSectionDone {
|
|
lastErr = &net.DNSError{
|
|
Err: errCannotMarshalDNSMessage.Error(),
|
|
Name: host,
|
|
Server: result.server,
|
|
}
|
|
}
|
|
if err != nil {
|
|
break
|
|
}
|
|
switch h.Type {
|
|
case dnsmessage.TypeA:
|
|
a, err := result.p.AResource()
|
|
if err != nil {
|
|
lastErr = &net.DNSError{
|
|
Err: errCannotMarshalDNSMessage.Error(),
|
|
Name: host,
|
|
Server: result.server,
|
|
}
|
|
break loop
|
|
}
|
|
addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
|
|
|
|
case dnsmessage.TypeAAAA:
|
|
aaaa, err := result.p.AAAAResource()
|
|
if err != nil {
|
|
lastErr = &net.DNSError{
|
|
Err: errCannotMarshalDNSMessage.Error(),
|
|
Name: host,
|
|
Server: result.server,
|
|
}
|
|
break loop
|
|
}
|
|
addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
|
|
|
|
default:
|
|
if err := result.p.SkipAnswer(); err != nil {
|
|
lastErr = &net.DNSError{
|
|
Err: errCannotMarshalDNSMessage.Error(),
|
|
Name: host,
|
|
Server: result.server,
|
|
}
|
|
break loop
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
// We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled
|
|
var addrs []netip.Addr
|
|
if tnet.hasV6 {
|
|
addrs = append(addrsV6, addrsV4...)
|
|
} else {
|
|
addrs = append(addrsV4, addrsV6...)
|
|
}
|
|
|
|
if len(addrs) == 0 && lastErr != nil {
|
|
return nil, lastErr
|
|
}
|
|
saddrs := make([]string, 0, len(addrs))
|
|
for _, ip := range addrs {
|
|
saddrs = append(saddrs, ip.String())
|
|
}
|
|
return saddrs, nil
|
|
}
|
|
|
|
func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
|
|
if deadline.IsZero() {
|
|
return deadline, nil
|
|
}
|
|
timeRemaining := deadline.Sub(now)
|
|
if timeRemaining <= 0 {
|
|
return time.Time{}, errTimeout
|
|
}
|
|
timeout := timeRemaining / time.Duration(addrsRemaining)
|
|
const saneMinimum = 2 * time.Second
|
|
if timeout < saneMinimum {
|
|
if timeRemaining < saneMinimum {
|
|
timeout = timeRemaining
|
|
} else {
|
|
timeout = saneMinimum
|
|
}
|
|
}
|
|
return now.Add(timeout), nil
|
|
}
|
|
|
|
var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`)
|
|
|
|
func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
|
if ctx == nil {
|
|
panic("nil context")
|
|
}
|
|
var acceptV4, acceptV6 bool
|
|
matches := protoSplitter.FindStringSubmatch(network)
|
|
if matches == nil {
|
|
return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
|
|
} else if len(matches[2]) == 0 {
|
|
acceptV4 = true
|
|
acceptV6 = true
|
|
} else {
|
|
acceptV4 = matches[2][0] == '4'
|
|
acceptV6 = !acceptV4
|
|
}
|
|
var host string
|
|
var port int
|
|
if matches[1] == "ping" {
|
|
host = address
|
|
} else {
|
|
var sport string
|
|
var err error
|
|
host, sport, err = net.SplitHostPort(address)
|
|
if err != nil {
|
|
return nil, &net.OpError{Op: "dial", Err: err}
|
|
}
|
|
port, err = strconv.Atoi(sport)
|
|
if err != nil || port < 0 || port > 65535 {
|
|
return nil, &net.OpError{Op: "dial", Err: errNumericPort}
|
|
}
|
|
}
|
|
allAddr, err := tnet.LookupContextHost(ctx, host)
|
|
if err != nil {
|
|
return nil, &net.OpError{Op: "dial", Err: err}
|
|
}
|
|
var addrs []netip.AddrPort
|
|
for _, addr := range allAddr {
|
|
ip, err := netip.ParseAddr(addr)
|
|
if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
|
|
addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
|
|
}
|
|
}
|
|
if len(addrs) == 0 && len(allAddr) != 0 {
|
|
return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress}
|
|
}
|
|
|
|
var firstErr error
|
|
for i, addr := range addrs {
|
|
select {
|
|
case <-ctx.Done():
|
|
err := ctx.Err()
|
|
if err == context.Canceled {
|
|
err = errCanceled
|
|
} else if err == context.DeadlineExceeded {
|
|
err = errTimeout
|
|
}
|
|
return nil, &net.OpError{Op: "dial", Err: err}
|
|
default:
|
|
}
|
|
|
|
dialCtx := ctx
|
|
if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
|
|
partialDeadline, err := partialDeadline(time.Now(), deadline, len(addrs)-i)
|
|
if err != nil {
|
|
if firstErr == nil {
|
|
firstErr = &net.OpError{Op: "dial", Err: err}
|
|
}
|
|
break
|
|
}
|
|
if partialDeadline.Before(deadline) {
|
|
var cancel context.CancelFunc
|
|
dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
|
|
defer cancel()
|
|
}
|
|
}
|
|
|
|
var c net.Conn
|
|
switch matches[1] {
|
|
case "tcp":
|
|
c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
|
|
case "udp":
|
|
c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
|
|
case "ping":
|
|
c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr())
|
|
}
|
|
if err == nil {
|
|
return c, nil
|
|
}
|
|
if firstErr == nil {
|
|
firstErr = err
|
|
}
|
|
}
|
|
if firstErr == nil {
|
|
firstErr = &net.OpError{Op: "dial", Err: errMissingAddress}
|
|
}
|
|
return nil, firstErr
|
|
}
|
|
|
|
func (tnet *Net) Dial(network, address string) (net.Conn, error) {
|
|
return tnet.DialContext(context.Background(), network, address)
|
|
}
|