package dns import ( "context" "fmt" "net" "net/netip" "sync" "time" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/util" "github.com/fosrl/olm/device" "golang.zx2c4.com/wireguard/tun" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) const ( // DNS proxy listening address DNSProxyIP = "10.30.30.30" DNSPort = 53 // Upstream DNS servers UpstreamDNS1 = "8.8.8.8:53" UpstreamDNS2 = "8.8.4.4:53" ) // DNSProxy implements a DNS proxy using gvisor netstack type DNSProxy struct { stack *stack.Stack ep *channel.Endpoint proxyIP netip.Addr mtu int tunDevice tun.Device // Direct reference to underlying TUN device for responses ctx context.Context cancel context.CancelFunc wg sync.WaitGroup } // NewDNSProxy creates a new DNS proxy func NewDNSProxy(tunDevice tun.Device, mtu int) (*DNSProxy, error) { proxyIP, err := netip.ParseAddr(DNSProxyIP) if err != nil { return nil, fmt.Errorf("invalid proxy IP: %w", err) } ctx, cancel := context.WithCancel(context.Background()) proxy := &DNSProxy{ proxyIP: proxyIP, mtu: mtu, tunDevice: tunDevice, ctx: ctx, cancel: cancel, } // Create gvisor netstack stackOpts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, HandleLocal: true, } proxy.ep = channel.New(256, uint32(mtu), "") proxy.stack = stack.New(stackOpts) // Create NIC if err := proxy.stack.CreateNIC(1, proxy.ep); err != nil { return nil, fmt.Errorf("failed to create NIC: %v", err) } // Add IP address protoAddr := tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddrFrom4([4]byte{10, 30, 30, 30}).WithPrefix(), } if err := proxy.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil { return nil, fmt.Errorf("failed to add protocol address: %v", err) } // Add default route proxy.stack.AddRoute(tcpip.Route{ Destination: header.IPv4EmptySubnet, NIC: 1, }) return proxy, nil } // Start starts the DNS proxy and registers with the filter func (p *DNSProxy) Start(device *device.MiddleDevice) error { // Install packet filter rule device.AddRule(p.proxyIP, p.handlePacket) // Start DNS listener p.wg.Add(2) go p.runDNSListener() go p.runPacketSender() logger.Info("DNS proxy started on %s:%d", DNSProxyIP, DNSPort) return nil } // Stop stops the DNS proxy func (p *DNSProxy) Stop(device *device.MiddleDevice) { if device != nil { device.RemoveRule(p.proxyIP) } p.cancel() p.wg.Wait() if p.stack != nil { p.stack.Close() } if p.ep != nil { p.ep.Close() } logger.Info("DNS proxy stopped") } // handlePacket is called by the filter for packets destined to DNS proxy IP func (p *DNSProxy) handlePacket(packet []byte) bool { if len(packet) < 20 { return false // Don't drop, malformed } // Quick check for UDP port 53 proto, ok := util.GetProtocol(packet) if !ok || proto != 17 { // 17 = UDP return false // Not UDP, don't handle } port, ok := util.GetDestPort(packet) if !ok || port != DNSPort { return false // Not DNS port } // Inject packet into our netstack version := packet[0] >> 4 pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{ Payload: buffer.MakeWithData(packet), }) switch version { case 4: p.ep.InjectInbound(ipv4.ProtocolNumber, pkb) case 6: p.ep.InjectInbound(ipv6.ProtocolNumber, pkb) default: pkb.DecRef() return false } pkb.DecRef() return true // Drop packet from normal path } // runDNSListener listens for DNS queries on the netstack func (p *DNSProxy) runDNSListener() { defer p.wg.Done() // Create UDP listener using gonet laddr := &tcpip.FullAddress{ NIC: 1, Addr: tcpip.AddrFrom4([4]byte{10, 30, 30, 30}), Port: DNSPort, } udpConn, err := gonet.DialUDP(p.stack, laddr, nil, ipv4.ProtocolNumber) if err != nil { logger.Error("Failed to create DNS listener: %v", err) return } defer udpConn.Close() logger.Debug("DNS proxy listening on netstack") // Handle DNS queries buf := make([]byte, 4096) for { select { case <-p.ctx.Done(): return default: } udpConn.SetReadDeadline(time.Now().Add(1 * time.Second)) n, remoteAddr, err := udpConn.ReadFrom(buf) if err != nil { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { continue } if p.ctx.Err() != nil { return } logger.Error("DNS read error: %v", err) continue } query := make([]byte, n) copy(query, buf[:n]) // Handle query in background go p.forwardDNSQuery(udpConn, query, remoteAddr) } } // forwardDNSQuery forwards a DNS query to upstream DNS server func (p *DNSProxy) forwardDNSQuery(udpConn *gonet.UDPConn, query []byte, clientAddr net.Addr) { // Try primary DNS server response, err := p.queryUpstream(UpstreamDNS1, query, 2*time.Second) if err != nil { // Try secondary DNS server logger.Debug("Primary DNS failed, trying secondary: %v", err) response, err = p.queryUpstream(UpstreamDNS2, query, 2*time.Second) if err != nil { logger.Error("Both DNS servers failed: %v", err) return } } // Send response back to client through netstack _, err = udpConn.WriteTo(response, clientAddr) if err != nil { logger.Error("Failed to send DNS response: %v", err) } } // queryUpstream sends a DNS query to upstream server func (p *DNSProxy) queryUpstream(server string, query []byte, timeout time.Duration) ([]byte, error) { conn, err := net.DialTimeout("udp", server, timeout) if err != nil { return nil, err } defer conn.Close() conn.SetDeadline(time.Now().Add(timeout)) if _, err := conn.Write(query); err != nil { return nil, err } response := make([]byte, 4096) n, err := conn.Read(response) if err != nil { return nil, err } return response[:n], nil } // runPacketSender sends packets from netstack back to TUN func (p *DNSProxy) runPacketSender() { defer p.wg.Done() // MessageTransportHeaderSize is the offset used by WireGuard device // for reading/writing packets to the TUN interface const offset = 16 for { select { case <-p.ctx.Done(): return default: } // Read packets from netstack endpoint pkt := p.ep.Read() if pkt == nil { // No packet available, small sleep to avoid busy loop time.Sleep(1 * time.Millisecond) continue } // Extract packet data as slices slices := pkt.AsSlices() if len(slices) > 0 { // Flatten all slices into a single packet buffer var totalSize int for _, slice := range slices { totalSize += len(slice) } // Allocate buffer with offset space for WireGuard transport header // The first 'offset' bytes are reserved for the transport header buf := make([]byte, offset+totalSize) // Copy packet data after the offset pos := offset for _, slice := range slices { copy(buf[pos:], slice) pos += len(slice) } // Write packet to TUN device // offset=16 indicates packet data starts at position 16 in the buffer _, err := p.tunDevice.Write([][]byte{buf}, offset) if err != nil { logger.Error("Failed to write DNS response to TUN: %v", err) } } pkt.DecRef() } }