/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package netstack2 import ( "bytes" "context" "crypto/rand" "encoding/binary" "errors" "fmt" "io" "net" "net/netip" "os" "regexp" "strconv" "strings" "syscall" "time" "github.com/fosrl/newt/logger" "golang.zx2c4.com/wireguard/tun" "golang.org/x/net/dns/dnsmessage" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) type netTun struct { ep *channel.Endpoint proxyEp *channel.Endpoint // Separate endpoint for promiscuous mode stack *stack.Stack events chan tun.Event notifyHandle *channel.NotificationHandle proxyNotifyHandle *channel.NotificationHandle // Notify handle for proxy endpoint incomingPacket chan *buffer.View mtu int dnsServers []netip.Addr hasV4, hasV6 bool tcpHandler *TCPHandler udpHandler *UDPHandler } type Net netTun // NetTunOptions contains options for creating a NetTUN device type NetTunOptions struct { EnableTCPProxy bool EnableUDPProxy bool } // CreateNetTUN creates a new TUN device with netstack without proxying func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) { return CreateNetTUNWithOptions(localAddresses, dnsServers, mtu, NetTunOptions{ EnableTCPProxy: true, EnableUDPProxy: true, }) } // CreateNetTUNWithOptions creates a new TUN device with netstack and optional TCP/UDP proxying func CreateNetTUNWithOptions(localAddresses, dnsServers []netip.Addr, mtu int, options NetTunOptions) (tun.Device, *Net, error) { stackOpts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, HandleLocal: true, } dev := &netTun{ ep: channel.New(1024, uint32(mtu), ""), proxyEp: channel.New(1024, uint32(mtu), ""), stack: stack.New(stackOpts), events: make(chan tun.Event, 10), incomingPacket: make(chan *buffer.View), dnsServers: dnsServers, mtu: mtu, } if options.EnableTCPProxy { dev.tcpHandler = NewTCPHandler(dev.stack) if err := dev.tcpHandler.InstallTCPHandler(); err != nil { return nil, nil, fmt.Errorf("failed to install TCP handler: %v", err) } } if options.EnableUDPProxy { dev.udpHandler = NewUDPHandler(dev.stack) if err := dev.udpHandler.InstallUDPHandler(); err != nil { return nil, nil, fmt.Errorf("failed to install UDP handler: %v", err) } } sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is enabled by default tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) if tcpipErr != nil { return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) } // Create NIC 1 (main interface, no promiscuous mode) dev.notifyHandle = dev.ep.AddNotify(dev) tcpipErr = dev.stack.CreateNIC(1, dev.ep) if tcpipErr != nil { return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) } for _, ip := range localAddresses { var protoNumber tcpip.NetworkProtocolNumber if ip.Is4() { protoNumber = ipv4.ProtocolNumber } else if ip.Is6() { protoNumber = ipv6.ProtocolNumber } protoAddr := tcpip.ProtocolAddress{ Protocol: protoNumber, AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), } tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) if tcpipErr != nil { return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) } if ip.Is4() { dev.hasV4 = true } else if ip.Is6() { dev.hasV6 = true } } if dev.hasV4 { // dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1}) // add 100.90.129.0/24 proxySubnet := netip.MustParsePrefix("100.90.129.0/24") proxyTcpipSubnet, err := tcpip.NewSubnet( tcpip.AddrFromSlice(proxySubnet.Addr().AsSlice()), tcpip.MaskFromBytes(proxySubnet.Addr().AsSlice()), ) if err != nil { return nil, nil, fmt.Errorf("failed to create proxy subnet: %v", err) } dev.stack.AddRoute(tcpip.Route{Destination: proxyTcpipSubnet, NIC: 1}) } // if dev.hasV6 { // // dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) // } // Add specific route for proxy network (10.20.20.0/24) to NIC 2 if options.EnableTCPProxy || options.EnableUDPProxy { dev.proxyNotifyHandle = dev.proxyEp.AddNotify(dev) tcpipErr = dev.stack.CreateNIC(2, dev.proxyEp) if tcpipErr != nil { return nil, nil, fmt.Errorf("CreateNIC 2 (proxy): %v", tcpipErr) } // Enable promiscuous mode ONLY on NIC 2 if tcpipErr := dev.stack.SetPromiscuousMode(2, true); tcpipErr != nil { return nil, nil, fmt.Errorf("SetPromiscuousMode on NIC 2: %v", tcpipErr) } // Enable spoofing ONLY on NIC 2 if tcpipErr := dev.stack.SetSpoofing(2, true); tcpipErr != nil { return nil, nil, fmt.Errorf("SetSpoofing on NIC 2: %v", tcpipErr) } // Add the proxy network address (10.20.20.1/24) to NIC 2 // This allows the stack to accept connections to any IP in this range when in promiscuous mode // Similar to how tun2socks adds 10.0.0.1/8 for multicast support // The PEB: CanBePrimaryEndpoint is CRITICAL - it allows the stack to build routes // and accept connections to any IP in this range when in promiscuous+spoofing mode proxyAddr := netip.MustParseAddr("10.20.20.1") protoAddr := tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.AddrFromSlice(proxyAddr.AsSlice()), PrefixLen: 24, // /24 network }, } tcpipErr = dev.stack.AddProtocolAddress(2, protoAddr, stack.AddressProperties{ PEB: stack.CanBePrimaryEndpoint, // Allow this to be used as primary endpoint }) if tcpipErr != nil { return nil, nil, fmt.Errorf("AddProtocolAddress for proxy NIC: %v", tcpipErr) } proxySubnet := netip.MustParsePrefix("10.20.20.0/24") proxyTcpipSubnet, err := tcpip.NewSubnet( tcpip.AddrFromSlice(proxySubnet.Addr().AsSlice()), tcpip.MaskFromBytes(net.CIDRMask(24, 32)), ) if err != nil { return nil, nil, fmt.Errorf("failed to create proxy subnet: %v", err) } dev.stack.AddRoute(tcpip.Route{ Destination: proxyTcpipSubnet, NIC: 2, }) } // print the stack routes table and interfaces for debugging logger.Info("Stack configuration:") // Print NICs nics := dev.stack.NICInfo() for nicID, nicInfo := range nics { logger.Info("NIC %d: %s (MTU: %d)", nicID, nicInfo.Name, nicInfo.MTU) for _, addr := range nicInfo.ProtocolAddresses { logger.Info(" Address: %s", addr.AddressWithPrefix) } } // Print routing table routes := dev.stack.GetRouteTable() logger.Info("Routing table (%d routes):", len(routes)) for i, route := range routes { logger.Info(" Route %d: %s -> NIC %d", i, route.Destination, route.NIC) } dev.events <- tun.EventUp return dev, (*Net)(dev), nil } func (tun *netTun) Name() (string, error) { return "go", nil } func (tun *netTun) File() *os.File { return nil } func (tun *netTun) Events() <-chan tun.Event { return tun.events } func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { view, ok := <-tun.incomingPacket if !ok { return 0, os.ErrClosed } n, err := view.Read(buf[0][offset:]) if err != nil { return 0, err } sizes[0] = n return 1, nil } func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { for _, buf := range buf { packet := buf[offset:] if len(packet) == 0 { continue } pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) // Determine which NIC to inject the packet into based on destination IP targetEp := tun.ep // Default to NIC 1 switch packet[0] >> 4 { case 4: // Parse IPv4 header to check destination if len(packet) >= header.IPv4MinimumSize { ipv4Header := header.IPv4(packet) dstIP := ipv4Header.DestinationAddress() // Check if destination is in the proxy range (10.20.20.0/24) // If so, inject into proxyEp (NIC 2) which has promiscuous mode if tun.proxyEp != nil { dstBytes := dstIP.As4() // Check for 10.20.20.x if dstBytes[0] == 10 && dstBytes[1] == 20 && dstBytes[2] == 20 { targetEp = tun.proxyEp // Log what protocol this is proto := "unknown" if len(packet) > header.IPv4MinimumSize { switch ipv4Header.Protocol() { case uint8(header.TCPProtocolNumber): proto = "TCP" case uint8(header.UDPProtocolNumber): proto = "UDP" case uint8(header.ICMPv4ProtocolNumber): proto = "ICMP" } } logger.Info("Routing %s packet to NIC 2 (proxy): dst=%s", proto, dstIP) } } } targetEp.InjectInbound(header.IPv4ProtocolNumber, pkb) case 6: // For IPv6, always use NIC 1 for now targetEp.InjectInbound(header.IPv6ProtocolNumber, pkb) default: return 0, syscall.EAFNOSUPPORT } } return len(buf), nil } // logPacketDetails parses and logs packet information func logPacketDetails(pkt *stack.PacketBuffer, nicID int) { netProto := pkt.NetworkProtocolNumber var srcIP, dstIP string var protocol string var srcPort, dstPort uint16 // Parse network layer switch netProto { case header.IPv4ProtocolNumber: if pkt.NetworkHeader().View().Size() >= header.IPv4MinimumSize { ipv4 := header.IPv4(pkt.NetworkHeader().Slice()) srcIP = ipv4.SourceAddress().String() dstIP = ipv4.DestinationAddress().String() // Parse transport layer switch ipv4.Protocol() { case uint8(header.TCPProtocolNumber): protocol = "TCP" if pkt.TransportHeader().View().Size() >= header.TCPMinimumSize { tcp := header.TCP(pkt.TransportHeader().Slice()) srcPort = tcp.SourcePort() dstPort = tcp.DestinationPort() } case uint8(header.UDPProtocolNumber): protocol = "UDP" if pkt.TransportHeader().View().Size() >= header.UDPMinimumSize { udp := header.UDP(pkt.TransportHeader().Slice()) srcPort = udp.SourcePort() dstPort = udp.DestinationPort() } case uint8(header.ICMPv4ProtocolNumber): protocol = "ICMPv4" default: protocol = fmt.Sprintf("Proto-%d", ipv4.Protocol()) } } case header.IPv6ProtocolNumber: if pkt.NetworkHeader().View().Size() >= header.IPv6MinimumSize { ipv6 := header.IPv6(pkt.NetworkHeader().Slice()) srcIP = ipv6.SourceAddress().String() dstIP = ipv6.DestinationAddress().String() // Parse transport layer switch ipv6.TransportProtocol() { case header.TCPProtocolNumber: protocol = "TCP" if pkt.TransportHeader().View().Size() >= header.TCPMinimumSize { tcp := header.TCP(pkt.TransportHeader().Slice()) srcPort = tcp.SourcePort() dstPort = tcp.DestinationPort() } case header.UDPProtocolNumber: protocol = "UDP" if pkt.TransportHeader().View().Size() >= header.UDPMinimumSize { udp := header.UDP(pkt.TransportHeader().Slice()) srcPort = udp.SourcePort() dstPort = udp.DestinationPort() } case header.ICMPv6ProtocolNumber: protocol = "ICMPv6" default: protocol = fmt.Sprintf("Proto-%d", ipv6.TransportProtocol()) } } default: protocol = fmt.Sprintf("Unknown-NetProto-%d", netProto) } packetSize := pkt.Size() if srcPort > 0 && dstPort > 0 { logger.Info("NIC %d packet: %s %s:%d -> %s:%d (size: %d bytes)", nicID, protocol, srcIP, srcPort, dstIP, dstPort, packetSize) } else { logger.Info("NIC %d packet: %s %s -> %s (size: %d bytes)", nicID, protocol, srcIP, dstIP, packetSize) } } func (tun *netTun) WriteNotify() { // Handle notifications from main endpoint (NIC 1) pkt := tun.ep.Read() if pkt != nil { view := pkt.ToView() pkt.DecRef() tun.incomingPacket <- view return } // Handle notifications from proxy endpoint (NIC 2) if it exists if tun.proxyEp != nil { pkt = tun.proxyEp.Read() if pkt != nil { view := pkt.ToView() pkt.DecRef() tun.incomingPacket <- view } } } func (tun *netTun) Close() error { tun.stack.RemoveNIC(1) // Clean up proxy NIC if it exists if tun.proxyEp != nil { tun.stack.RemoveNIC(2) tun.proxyEp.RemoveNotify(tun.proxyNotifyHandle) tun.proxyEp.Close() } tun.stack.Close() tun.ep.RemoveNotify(tun.notifyHandle) tun.ep.Close() if tun.events != nil { close(tun.events) } if tun.incomingPacket != nil { close(tun.incomingPacket) } return nil } func (tun *netTun) MTU() (int, error) { return tun.mtu, nil } func (tun *netTun) BatchSize() int { return 1 } func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { var protoNumber tcpip.NetworkProtocolNumber if endpoint.Addr().Is4() { protoNumber = ipv4.ProtocolNumber } else { protoNumber = ipv6.ProtocolNumber } return tcpip.FullAddress{ NIC: 1, Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), Port: endpoint.Port(), }, protoNumber } func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { fa, pn := convertToFullAddr(addr) return gonet.DialContextTCP(ctx, net.stack, fa, pn) } func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { if addr == nil { return net.DialContextTCPAddrPort(ctx, netip.AddrPort{}) } ip, _ := netip.AddrFromSlice(addr.IP) return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port))) } func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) { fa, pn := convertToFullAddr(addr) return gonet.DialTCP(net.stack, fa, pn) } func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { if addr == nil { return net.DialTCPAddrPort(netip.AddrPort{}) } ip, _ := netip.AddrFromSlice(addr.IP) return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) } func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) { fa, pn := convertToFullAddr(addr) return gonet.ListenTCP(net.stack, fa, pn) } func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { if addr == nil { return net.ListenTCPAddrPort(netip.AddrPort{}) } ip, _ := netip.AddrFromSlice(addr.IP) return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) } func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { var lfa, rfa *tcpip.FullAddress var pn tcpip.NetworkProtocolNumber if laddr.IsValid() || laddr.Port() > 0 { var addr tcpip.FullAddress addr, pn = convertToFullAddr(laddr) lfa = &addr } if raddr.IsValid() || raddr.Port() > 0 { var addr tcpip.FullAddress addr, pn = convertToFullAddr(raddr) rfa = &addr } return gonet.DialUDP(net.stack, lfa, rfa, pn) } func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) { return net.DialUDPAddrPort(laddr, netip.AddrPort{}) } func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { var la, ra netip.AddrPort if laddr != nil { ip, _ := netip.AddrFromSlice(laddr.IP) la = netip.AddrPortFrom(ip, uint16(laddr.Port)) } if raddr != nil { ip, _ := netip.AddrFromSlice(raddr.IP) ra = netip.AddrPortFrom(ip, uint16(raddr.Port)) } return net.DialUDPAddrPort(la, ra) } func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { return net.DialUDP(laddr, nil) } type PingConn struct { laddr PingAddr raddr PingAddr wq waiter.Queue ep tcpip.Endpoint deadline *time.Timer } type PingAddr struct{ addr netip.Addr } func (ia PingAddr) String() string { return ia.addr.String() } func (ia PingAddr) Network() string { if ia.addr.Is4() { return "ping4" } else if ia.addr.Is6() { return "ping6" } return "ping" } func (ia PingAddr) Addr() netip.Addr { return ia.addr } func PingAddrFromAddr(addr netip.Addr) *PingAddr { return &PingAddr{addr} } func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) { if !laddr.IsValid() && !raddr.IsValid() { return nil, errors.New("ping dial: invalid address") } v6 := laddr.Is6() || raddr.Is6() bind := laddr.IsValid() if !bind { if v6 { laddr = netip.IPv6Unspecified() } else { laddr = netip.IPv4Unspecified() } } tn := icmp.ProtocolNumber4 pn := ipv4.ProtocolNumber if v6 { tn = icmp.ProtocolNumber6 pn = ipv6.ProtocolNumber } pc := &PingConn{ laddr: PingAddr{laddr}, deadline: time.NewTimer(time.Hour << 10), } pc.deadline.Stop() ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq) if tcpipErr != nil { return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr) } pc.ep = ep if bind { fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0)) if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil { return nil, fmt.Errorf("ping bind: %s", tcpipErr) } } if raddr.IsValid() { pc.raddr = PingAddr{raddr} fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0)) if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil { return nil, fmt.Errorf("ping connect: %s", tcpipErr) } } return pc, nil } func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) { return net.DialPingAddr(laddr, netip.Addr{}) } func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) { var la, ra netip.Addr if laddr != nil { la = laddr.addr } if raddr != nil { ra = raddr.addr } return net.DialPingAddr(la, ra) } func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) { var la netip.Addr if laddr != nil { la = laddr.addr } return net.ListenPingAddr(la) } func (pc *PingConn) LocalAddr() net.Addr { return pc.laddr } func (pc *PingConn) RemoteAddr() net.Addr { return pc.raddr } func (pc *PingConn) Close() error { pc.deadline.Reset(0) pc.ep.Close() return nil } func (pc *PingConn) SetWriteDeadline(t time.Time) error { return errors.New("not implemented") } func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { var na netip.Addr switch v := addr.(type) { case *PingAddr: na = v.addr case *net.IPAddr: na, _ = netip.AddrFromSlice(v.IP) default: return 0, fmt.Errorf("ping write: wrong net.Addr type") } if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) { return 0, fmt.Errorf("ping write: mismatched protocols") } buf := bytes.NewReader(p) rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0)) // won't block, no deadlines n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{ To: &rfa, }) if tcpipErr != nil { return int(n64), fmt.Errorf("ping write: %s", tcpipErr) } return int(n64), nil } func (pc *PingConn) Write(p []byte) (n int, err error) { return pc.WriteTo(p, &pc.raddr) } func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { e, notifyCh := waiter.NewChannelEntry(waiter.EventIn) pc.wq.EventRegister(&e) defer pc.wq.EventUnregister(&e) select { case <-pc.deadline.C: return 0, nil, os.ErrDeadlineExceeded case <-notifyCh: } w := tcpip.SliceWriter(p) res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{ NeedRemoteAddr: true, }) if tcpipErr != nil { return 0, nil, fmt.Errorf("ping read: %s", tcpipErr) } remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice()) return res.Count, &PingAddr{remoteAddr}, nil } func (pc *PingConn) Read(p []byte) (n int, err error) { n, _, err = pc.ReadFrom(p) return } func (pc *PingConn) SetDeadline(t time.Time) error { // pc.SetWriteDeadline is unimplemented return pc.SetReadDeadline(t) } func (pc *PingConn) SetReadDeadline(t time.Time) error { pc.deadline.Reset(time.Until(t)) return nil } var ( errNoSuchHost = errors.New("no such host") errLameReferral = errors.New("lame referral") errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message") errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message") errServerMisbehaving = errors.New("server misbehaving") errInvalidDNSResponse = errors.New("invalid DNS response") errNoAnswerFromDNSServer = errors.New("no answer from DNS server") errServerTemporarilyMisbehaving = errors.New("server misbehaving") errCanceled = errors.New("operation was canceled") errTimeout = errors.New("i/o timeout") errNumericPort = errors.New("port must be numeric") errNoSuitableAddress = errors.New("no suitable address found") errMissingAddress = errors.New("missing address") ) func (net *Net) LookupHost(host string) (addrs []string, err error) { return net.LookupContextHost(context.Background(), host) } func isDomainName(s string) bool { l := len(s) if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { return false } last := byte('.') nonNumeric := false partlen := 0 for i := 0; i < len(s); i++ { c := s[i] switch { default: return false case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': nonNumeric = true partlen++ case '0' <= c && c <= '9': partlen++ case c == '-': if last == '.' { return false } partlen++ nonNumeric = true case c == '.': if last == '.' || last == '-' { return false } if partlen > 63 || partlen == 0 { return false } partlen = 0 } last = c } if last == '-' || partlen > 63 { return false } return nonNumeric } func randU16() uint16 { var b [2]byte _, err := rand.Read(b[:]) if err != nil { panic(err) } return binary.LittleEndian.Uint16(b[:]) } func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) { id = randU16() b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true}) b.EnableCompression() if err := b.StartQuestions(); err != nil { return 0, nil, nil, err } if err := b.Question(q); err != nil { return 0, nil, nil, err } tcpReq, err = b.Finish() udpReq = tcpReq[2:] l := len(tcpReq) - 2 tcpReq[0] = byte(l >> 8) tcpReq[1] = byte(l) return id, udpReq, tcpReq, err } func equalASCIIName(x, y dnsmessage.Name) bool { if x.Length != y.Length { return false } for i := 0; i < int(x.Length); i++ { a := x.Data[i] b := y.Data[i] if 'A' <= a && a <= 'Z' { a += 0x20 } if 'A' <= b && b <= 'Z' { b += 0x20 } if a != b { return false } } return true } func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool { if !respHdr.Response { return false } if reqID != respHdr.ID { return false } if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) { return false } return true } func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { if _, err := c.Write(b); err != nil { return dnsmessage.Parser{}, dnsmessage.Header{}, err } b = make([]byte, 512) for { n, err := c.Read(b) if err != nil { return dnsmessage.Parser{}, dnsmessage.Header{}, err } var p dnsmessage.Parser h, err := p.Start(b[:n]) if err != nil { continue } q, err := p.Question() if err != nil || !checkResponse(id, query, h, q) { continue } return p, h, nil } } func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { if _, err := c.Write(b); err != nil { return dnsmessage.Parser{}, dnsmessage.Header{}, err } b = make([]byte, 1280) if _, err := io.ReadFull(c, b[:2]); err != nil { return dnsmessage.Parser{}, dnsmessage.Header{}, err } l := int(b[0])<<8 | int(b[1]) if l > len(b) { b = make([]byte, l) } n, err := io.ReadFull(c, b[:l]) if err != nil { return dnsmessage.Parser{}, dnsmessage.Header{}, err } var p dnsmessage.Parser h, err := p.Start(b[:n]) if err != nil { return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage } q, err := p.Question() if err != nil { return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage } if !checkResponse(id, query, h, q) { return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse } return p, h, nil } func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { q.Class = dnsmessage.ClassINET id, udpReq, tcpReq, err := newRequest(q) if err != nil { return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage } for _, useUDP := range []bool{true, false} { ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) defer cancel() var c net.Conn var err error if useUDP { c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53)) } else { c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53)) } if err != nil { return dnsmessage.Parser{}, dnsmessage.Header{}, err } if d, ok := ctx.Deadline(); ok && !d.IsZero() { err := c.SetDeadline(d) if err != nil { return dnsmessage.Parser{}, dnsmessage.Header{}, err } } var p dnsmessage.Parser var h dnsmessage.Header if useUDP { p, h, err = dnsPacketRoundTrip(c, id, q, udpReq) } else { p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq) } c.Close() if err != nil { if err == context.Canceled { err = errCanceled } else if err == context.DeadlineExceeded { err = errTimeout } return dnsmessage.Parser{}, dnsmessage.Header{}, err } if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone { return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse } if h.Truncated { continue } return p, h, nil } return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer } func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error { if h.RCode == dnsmessage.RCodeNameError { return errNoSuchHost } _, err := p.AnswerHeader() if err != nil && err != dnsmessage.ErrSectionDone { return errCannotUnmarshalDNSMessage } if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone { return errLameReferral } if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError { if h.RCode == dnsmessage.RCodeServerFailure { return errServerTemporarilyMisbehaving } return errServerMisbehaving } return nil } func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error { for { h, err := p.AnswerHeader() if err == dnsmessage.ErrSectionDone { return errNoSuchHost } if err != nil { return errCannotUnmarshalDNSMessage } if h.Type == qtype { return nil } if err := p.SkipAnswer(); err != nil { return errCannotUnmarshalDNSMessage } } } func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) { var lastErr error n, err := dnsmessage.NewName(name) if err != nil { return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage } q := dnsmessage.Question{ Name: n, Type: qtype, Class: dnsmessage.ClassINET, } for i := 0; i < 2; i++ { for _, server := range tnet.dnsServers { p, h, err := tnet.exchange(ctx, server, q, time.Second*5) if err != nil { dnsErr := &net.DNSError{ Err: err.Error(), Name: name, Server: server.String(), } if nerr, ok := err.(net.Error); ok && nerr.Timeout() { dnsErr.IsTimeout = true } if _, ok := err.(*net.OpError); ok { dnsErr.IsTemporary = true } lastErr = dnsErr continue } if err := checkHeader(&p, h); err != nil { dnsErr := &net.DNSError{ Err: err.Error(), Name: name, Server: server.String(), } if err == errServerTemporarilyMisbehaving { dnsErr.IsTemporary = true } if err == errNoSuchHost { dnsErr.IsNotFound = true return p, server.String(), dnsErr } lastErr = dnsErr continue } err = skipToAnswer(&p, qtype) if err == nil { return p, server.String(), nil } lastErr = &net.DNSError{ Err: err.Error(), Name: name, Server: server.String(), } if err == errNoSuchHost { lastErr.(*net.DNSError).IsNotFound = true return p, server.String(), lastErr } } } return dnsmessage.Parser{}, "", lastErr } func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) { if host == "" || (!tnet.hasV6 && !tnet.hasV4) { return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} } zlen := len(host) if strings.IndexByte(host, ':') != -1 { if zidx := strings.LastIndexByte(host, '%'); zidx != -1 { zlen = zidx } } if ip, err := netip.ParseAddr(host[:zlen]); err == nil { return []string{ip.String()}, nil } if !isDomainName(host) { return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} } type result struct { p dnsmessage.Parser server string error } var addrsV4, addrsV6 []netip.Addr lanes := 0 if tnet.hasV4 { lanes++ } if tnet.hasV6 { lanes++ } lane := make(chan result, lanes) var lastErr error if tnet.hasV4 { go func() { p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA) lane <- result{p, server, err} }() } if tnet.hasV6 { go func() { p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA) lane <- result{p, server, err} }() } for l := 0; l < lanes; l++ { result := <-lane if result.error != nil { if lastErr == nil { lastErr = result.error } continue } loop: for { h, err := result.p.AnswerHeader() if err != nil && err != dnsmessage.ErrSectionDone { lastErr = &net.DNSError{ Err: errCannotMarshalDNSMessage.Error(), Name: host, Server: result.server, } } if err != nil { break } switch h.Type { case dnsmessage.TypeA: a, err := result.p.AResource() if err != nil { lastErr = &net.DNSError{ Err: errCannotMarshalDNSMessage.Error(), Name: host, Server: result.server, } break loop } addrsV4 = append(addrsV4, netip.AddrFrom4(a.A)) case dnsmessage.TypeAAAA: aaaa, err := result.p.AAAAResource() if err != nil { lastErr = &net.DNSError{ Err: errCannotMarshalDNSMessage.Error(), Name: host, Server: result.server, } break loop } addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA)) default: if err := result.p.SkipAnswer(); err != nil { lastErr = &net.DNSError{ Err: errCannotMarshalDNSMessage.Error(), Name: host, Server: result.server, } break loop } continue } } } // We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled var addrs []netip.Addr if tnet.hasV6 { addrs = append(addrsV6, addrsV4...) } else { addrs = append(addrsV4, addrsV6...) } if len(addrs) == 0 && lastErr != nil { return nil, lastErr } saddrs := make([]string, 0, len(addrs)) for _, ip := range addrs { saddrs = append(saddrs, ip.String()) } return saddrs, nil } func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) { if deadline.IsZero() { return deadline, nil } timeRemaining := deadline.Sub(now) if timeRemaining <= 0 { return time.Time{}, errTimeout } timeout := timeRemaining / time.Duration(addrsRemaining) const saneMinimum = 2 * time.Second if timeout < saneMinimum { if timeRemaining < saneMinimum { timeout = timeRemaining } else { timeout = saneMinimum } } return now.Add(timeout), nil } var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`) func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) { if ctx == nil { panic("nil context") } var acceptV4, acceptV6 bool matches := protoSplitter.FindStringSubmatch(network) if matches == nil { return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)} } else if len(matches[2]) == 0 { acceptV4 = true acceptV6 = true } else { acceptV4 = matches[2][0] == '4' acceptV6 = !acceptV4 } var host string var port int if matches[1] == "ping" { host = address } else { var sport string var err error host, sport, err = net.SplitHostPort(address) if err != nil { return nil, &net.OpError{Op: "dial", Err: err} } port, err = strconv.Atoi(sport) if err != nil || port < 0 || port > 65535 { return nil, &net.OpError{Op: "dial", Err: errNumericPort} } } allAddr, err := tnet.LookupContextHost(ctx, host) if err != nil { return nil, &net.OpError{Op: "dial", Err: err} } var addrs []netip.AddrPort for _, addr := range allAddr { ip, err := netip.ParseAddr(addr) if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) { addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port))) } } if len(addrs) == 0 && len(allAddr) != 0 { return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress} } var firstErr error for i, addr := range addrs { select { case <-ctx.Done(): err := ctx.Err() if err == context.Canceled { err = errCanceled } else if err == context.DeadlineExceeded { err = errTimeout } return nil, &net.OpError{Op: "dial", Err: err} default: } dialCtx := ctx if deadline, hasDeadline := ctx.Deadline(); hasDeadline { partialDeadline, err := partialDeadline(time.Now(), deadline, len(addrs)-i) if err != nil { if firstErr == nil { firstErr = &net.OpError{Op: "dial", Err: err} } break } if partialDeadline.Before(deadline) { var cancel context.CancelFunc dialCtx, cancel = context.WithDeadline(ctx, partialDeadline) defer cancel() } } var c net.Conn switch matches[1] { case "tcp": c, err = tnet.DialContextTCPAddrPort(dialCtx, addr) case "udp": c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr) case "ping": c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr()) } if err == nil { return c, nil } if firstErr == nil { firstErr = err } } if firstErr == nil { firstErr = &net.OpError{Op: "dial", Err: errMissingAddress} } return nil, firstErr } func (tnet *Net) Dial(network, address string) (net.Conn, error) { return tnet.DialContext(context.Background(), network, address) }