mirror of
https://github.com/fosrl/olm.git
synced 2026-03-07 03:06:44 +00:00
Replace trie-based domain lookup with simple map for O(1) lookups. Add exists boolean to GetRecords for proper NODATA vs NXDOMAIN responses.
760 lines
20 KiB
Go
760 lines
20 KiB
Go
package dns
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/fosrl/newt/logger"
|
|
"github.com/fosrl/newt/util"
|
|
"github.com/fosrl/olm/device"
|
|
"github.com/miekg/dns"
|
|
"gvisor.dev/gvisor/pkg/buffer"
|
|
"gvisor.dev/gvisor/pkg/tcpip"
|
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
|
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
|
)
|
|
|
|
const (
|
|
DNSPort = 53
|
|
)
|
|
|
|
// DNSProxy implements a DNS proxy using gvisor netstack
|
|
type DNSProxy struct {
|
|
stack *stack.Stack
|
|
ep *channel.Endpoint
|
|
proxyIP netip.Addr
|
|
upstreamDNS []string
|
|
tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally
|
|
mtu int
|
|
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering and TUN writes
|
|
recordStore *DNSRecordStore // Local DNS records
|
|
|
|
// Tunnel DNS fields - for sending queries over WireGuard
|
|
tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries)
|
|
tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries
|
|
tunnelEp *channel.Endpoint
|
|
tunnelActivePorts map[uint16]bool
|
|
tunnelPortsLock sync.Mutex
|
|
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
wg sync.WaitGroup
|
|
}
|
|
|
|
// NewDNSProxy creates a new DNS proxy
|
|
func NewDNSProxy(middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) {
|
|
proxyIP, err := PickIPFromSubnet(utilitySubnet)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err)
|
|
}
|
|
|
|
if len(upstreamDns) == 0 {
|
|
return nil, fmt.Errorf("at least one upstream DNS server must be specified")
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
proxy := &DNSProxy{
|
|
proxyIP: proxyIP,
|
|
mtu: mtu,
|
|
middleDevice: middleDevice,
|
|
upstreamDNS: upstreamDns,
|
|
tunnelDNS: tunnelDns,
|
|
recordStore: NewDNSRecordStore(),
|
|
tunnelActivePorts: make(map[uint16]bool),
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
|
|
// Parse tunnel IP if provided (needed for tunneled DNS)
|
|
if tunnelIP != "" {
|
|
addr, err := netip.ParseAddr(tunnelIP)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse tunnel IP: %v", err)
|
|
}
|
|
proxy.tunnelIP = addr
|
|
}
|
|
|
|
// Create gvisor netstack for receiving DNS queries
|
|
stackOpts := stack.Options{
|
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
|
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
|
|
HandleLocal: true,
|
|
}
|
|
|
|
proxy.ep = channel.New(256, uint32(mtu), "")
|
|
proxy.stack = stack.New(stackOpts)
|
|
|
|
// Create NIC
|
|
if err := proxy.stack.CreateNIC(1, proxy.ep); err != nil {
|
|
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
|
}
|
|
|
|
// Add IP address
|
|
// Parse the proxy IP to get the octets
|
|
ipBytes := proxyIP.As4()
|
|
protoAddr := tcpip.ProtocolAddress{
|
|
Protocol: ipv4.ProtocolNumber,
|
|
AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(),
|
|
}
|
|
|
|
if err := proxy.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
|
|
return nil, fmt.Errorf("failed to add protocol address: %v", err)
|
|
}
|
|
|
|
// Add default route
|
|
proxy.stack.AddRoute(tcpip.Route{
|
|
Destination: header.IPv4EmptySubnet,
|
|
NIC: 1,
|
|
})
|
|
|
|
// Initialize tunnel netstack if tunnel DNS is enabled
|
|
if tunnelDns {
|
|
if !proxy.tunnelIP.IsValid() {
|
|
return nil, fmt.Errorf("tunnel IP is required when tunnelDNS is enabled")
|
|
}
|
|
|
|
// TODO: DO WE NEED TO ESTABLISH ANOTHER NETSTACK HERE OR CAN WE COMBINE WITH WGTESTER?
|
|
if err := proxy.initTunnelNetstack(); err != nil {
|
|
return nil, fmt.Errorf("failed to initialize tunnel netstack: %v", err)
|
|
}
|
|
}
|
|
|
|
return proxy, nil
|
|
}
|
|
|
|
// initTunnelNetstack creates a separate netstack for outbound DNS queries through the tunnel
|
|
func (p *DNSProxy) initTunnelNetstack() error {
|
|
// Create gvisor netstack for outbound tunnel queries
|
|
stackOpts := stack.Options{
|
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
|
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
|
|
HandleLocal: true,
|
|
}
|
|
|
|
p.tunnelEp = channel.New(256, uint32(p.mtu), "")
|
|
p.tunnelStack = stack.New(stackOpts)
|
|
|
|
// Create NIC
|
|
if err := p.tunnelStack.CreateNIC(1, p.tunnelEp); err != nil {
|
|
return fmt.Errorf("failed to create tunnel NIC: %v", err)
|
|
}
|
|
|
|
// Add tunnel IP address (WireGuard interface IP)
|
|
ipBytes := p.tunnelIP.As4()
|
|
protoAddr := tcpip.ProtocolAddress{
|
|
Protocol: ipv4.ProtocolNumber,
|
|
AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(),
|
|
}
|
|
|
|
if err := p.tunnelStack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
|
|
return fmt.Errorf("failed to add tunnel protocol address: %v", err)
|
|
}
|
|
|
|
// Add default route
|
|
p.tunnelStack.AddRoute(tcpip.Route{
|
|
Destination: header.IPv4EmptySubnet,
|
|
NIC: 1,
|
|
})
|
|
|
|
// Register filter rule on MiddleDevice to intercept responses
|
|
p.middleDevice.AddRule(p.tunnelIP, p.handleTunnelResponse)
|
|
|
|
return nil
|
|
}
|
|
|
|
// handleTunnelResponse handles packets coming back from the tunnel destined for the tunnel IP
|
|
func (p *DNSProxy) handleTunnelResponse(packet []byte) bool {
|
|
// Check if it's UDP
|
|
proto, ok := util.GetProtocol(packet)
|
|
if !ok || proto != 17 { // UDP
|
|
return false
|
|
}
|
|
|
|
// Check destination port - should be one of our active outbound ports
|
|
port, ok := util.GetDestPort(packet)
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
// Check if we are expecting a response on this port
|
|
p.tunnelPortsLock.Lock()
|
|
active := p.tunnelActivePorts[uint16(port)]
|
|
p.tunnelPortsLock.Unlock()
|
|
|
|
if !active {
|
|
return false
|
|
}
|
|
|
|
// Inject into tunnel netstack
|
|
version := packet[0] >> 4
|
|
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
|
Payload: buffer.MakeWithData(packet),
|
|
})
|
|
|
|
switch version {
|
|
case 4:
|
|
p.tunnelEp.InjectInbound(ipv4.ProtocolNumber, pkb)
|
|
case 6:
|
|
p.tunnelEp.InjectInbound(ipv6.ProtocolNumber, pkb)
|
|
default:
|
|
pkb.DecRef()
|
|
return false
|
|
}
|
|
|
|
pkb.DecRef()
|
|
return true // Handled
|
|
}
|
|
|
|
// Start starts the DNS proxy and registers with the filter
|
|
func (p *DNSProxy) Start() error {
|
|
// Install packet filter rule
|
|
p.middleDevice.AddRule(p.proxyIP, p.handlePacket)
|
|
|
|
// Start DNS listener
|
|
p.wg.Add(2)
|
|
go p.runDNSListener()
|
|
go p.runPacketSender()
|
|
|
|
// Start tunnel packet sender if tunnel DNS is enabled
|
|
if p.tunnelDNS {
|
|
p.wg.Add(1)
|
|
go p.runTunnelPacketSender()
|
|
}
|
|
|
|
logger.Info("DNS proxy started on %s:%d (tunnelDNS=%v)", p.proxyIP.String(), DNSPort, p.tunnelDNS)
|
|
return nil
|
|
}
|
|
|
|
// Stop stops the DNS proxy
|
|
func (p *DNSProxy) Stop() {
|
|
if p.middleDevice != nil {
|
|
p.middleDevice.RemoveRule(p.proxyIP)
|
|
if p.tunnelDNS && p.tunnelIP.IsValid() {
|
|
p.middleDevice.RemoveRule(p.tunnelIP)
|
|
}
|
|
}
|
|
p.cancel()
|
|
|
|
// Close the endpoint first to unblock any pending Read() calls in runPacketSender
|
|
if p.ep != nil {
|
|
p.ep.Close()
|
|
}
|
|
|
|
// Close tunnel endpoint if it exists
|
|
if p.tunnelEp != nil {
|
|
p.tunnelEp.Close()
|
|
}
|
|
|
|
p.wg.Wait()
|
|
|
|
if p.stack != nil {
|
|
p.stack.Close()
|
|
}
|
|
|
|
if p.tunnelStack != nil {
|
|
p.tunnelStack.Close()
|
|
}
|
|
|
|
logger.Info("DNS proxy stopped")
|
|
}
|
|
|
|
func (p *DNSProxy) GetProxyIP() netip.Addr {
|
|
return p.proxyIP
|
|
}
|
|
|
|
// handlePacket is called by the filter for packets destined to DNS proxy IP
|
|
func (p *DNSProxy) handlePacket(packet []byte) bool {
|
|
if len(packet) < 20 {
|
|
return false // Don't drop, malformed
|
|
}
|
|
|
|
// Quick check for UDP port 53
|
|
proto, ok := util.GetProtocol(packet)
|
|
if !ok || proto != 17 { // 17 = UDP
|
|
return false // Not UDP, don't handle
|
|
}
|
|
|
|
port, ok := util.GetDestPort(packet)
|
|
if !ok || port != DNSPort {
|
|
return false // Not DNS port
|
|
}
|
|
|
|
// Inject packet into our netstack
|
|
version := packet[0] >> 4
|
|
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
|
Payload: buffer.MakeWithData(packet),
|
|
})
|
|
|
|
switch version {
|
|
case 4:
|
|
p.ep.InjectInbound(ipv4.ProtocolNumber, pkb)
|
|
case 6:
|
|
p.ep.InjectInbound(ipv6.ProtocolNumber, pkb)
|
|
default:
|
|
pkb.DecRef()
|
|
return false
|
|
}
|
|
|
|
pkb.DecRef()
|
|
return true // Drop packet from normal path
|
|
}
|
|
|
|
// runDNSListener listens for DNS queries on the netstack
|
|
func (p *DNSProxy) runDNSListener() {
|
|
defer p.wg.Done()
|
|
|
|
// Create UDP listener using gonet
|
|
// Parse the proxy IP to get the octets
|
|
ipBytes := p.proxyIP.As4()
|
|
laddr := &tcpip.FullAddress{
|
|
NIC: 1,
|
|
Addr: tcpip.AddrFrom4(ipBytes),
|
|
Port: DNSPort,
|
|
}
|
|
|
|
udpConn, err := gonet.DialUDP(p.stack, laddr, nil, ipv4.ProtocolNumber)
|
|
if err != nil {
|
|
logger.Error("Failed to create DNS listener: %v", err)
|
|
return
|
|
}
|
|
defer udpConn.Close()
|
|
|
|
logger.Debug("DNS proxy listening on netstack")
|
|
|
|
// Handle DNS queries
|
|
buf := make([]byte, 4096)
|
|
for {
|
|
select {
|
|
case <-p.ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
udpConn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
|
n, remoteAddr, err := udpConn.ReadFrom(buf)
|
|
if err != nil {
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
|
continue
|
|
}
|
|
if p.ctx.Err() != nil {
|
|
return
|
|
}
|
|
logger.Error("DNS read error: %v", err)
|
|
continue
|
|
}
|
|
|
|
query := make([]byte, n)
|
|
copy(query, buf[:n])
|
|
|
|
// Handle query in background
|
|
go p.handleDNSQuery(udpConn, query, remoteAddr)
|
|
}
|
|
}
|
|
|
|
// handleDNSQuery processes a DNS query, checking local records first, then forwarding upstream
|
|
func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clientAddr net.Addr) {
|
|
// Parse the DNS query
|
|
msg := new(dns.Msg)
|
|
if err := msg.Unpack(queryData); err != nil {
|
|
logger.Error("Failed to parse DNS query: %v", err)
|
|
return
|
|
}
|
|
|
|
if len(msg.Question) == 0 {
|
|
logger.Debug("DNS query has no questions")
|
|
return
|
|
}
|
|
|
|
question := msg.Question[0]
|
|
logger.Debug("DNS query for %s (type %s)", question.Name, dns.TypeToString[question.Qtype])
|
|
|
|
// Check if we have local records for this query
|
|
var response *dns.Msg
|
|
if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA || question.Qtype == dns.TypePTR {
|
|
response = p.checkLocalRecords(msg, question)
|
|
}
|
|
|
|
// If no local records, forward to upstream
|
|
if response == nil {
|
|
logger.Debug("No local record for %s, forwarding upstream to %v", question.Name, p.upstreamDNS)
|
|
response = p.forwardToUpstream(msg)
|
|
}
|
|
|
|
if response == nil {
|
|
logger.Error("Failed to get DNS response for %s", question.Name)
|
|
return
|
|
}
|
|
|
|
// Pack and send response
|
|
responseData, err := response.Pack()
|
|
if err != nil {
|
|
logger.Error("Failed to pack DNS response: %v", err)
|
|
return
|
|
}
|
|
|
|
_, err = udpConn.WriteTo(responseData, clientAddr)
|
|
if err != nil {
|
|
logger.Error("Failed to send DNS response: %v", err)
|
|
}
|
|
}
|
|
|
|
// checkLocalRecords checks if we have local records for the query
|
|
func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns.Msg {
|
|
// Handle PTR queries
|
|
if question.Qtype == dns.TypePTR {
|
|
if ptrDomain, ok := p.recordStore.GetPTRRecord(question.Name); ok {
|
|
logger.Debug("Found local PTR record for %s -> %s", question.Name, ptrDomain)
|
|
|
|
// Create response message
|
|
response := new(dns.Msg)
|
|
response.SetReply(query)
|
|
response.Authoritative = true
|
|
|
|
// Add PTR answer record
|
|
rr := &dns.PTR{
|
|
Hdr: dns.RR_Header{
|
|
Name: question.Name,
|
|
Rrtype: dns.TypePTR,
|
|
Class: dns.ClassINET,
|
|
Ttl: 300, // 5 minutes
|
|
},
|
|
Ptr: ptrDomain,
|
|
}
|
|
response.Answer = append(response.Answer, rr)
|
|
|
|
return response
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Handle A and AAAA queries
|
|
var recordType RecordType
|
|
if question.Qtype == dns.TypeA {
|
|
recordType = RecordTypeA
|
|
} else if question.Qtype == dns.TypeAAAA {
|
|
recordType = RecordTypeAAAA
|
|
} else {
|
|
return nil
|
|
}
|
|
|
|
ips, exists := p.recordStore.GetRecords(question.Name, recordType)
|
|
if !exists {
|
|
// Domain not found in local records, forward to upstream
|
|
return nil
|
|
}
|
|
|
|
logger.Debug("Found %d local record(s) for %s", len(ips), question.Name)
|
|
|
|
// Create response message (NODATA if no records, otherwise with answers)
|
|
response := new(dns.Msg)
|
|
response.SetReply(query)
|
|
response.Authoritative = true
|
|
|
|
// Add answer records (loop is a no-op if ips is empty)
|
|
for _, ip := range ips {
|
|
var rr dns.RR
|
|
if question.Qtype == dns.TypeA {
|
|
rr = &dns.A{
|
|
Hdr: dns.RR_Header{
|
|
Name: question.Name,
|
|
Rrtype: dns.TypeA,
|
|
Class: dns.ClassINET,
|
|
Ttl: 300, // 5 minutes
|
|
},
|
|
A: ip.To4(),
|
|
}
|
|
} else { // TypeAAAA
|
|
rr = &dns.AAAA{
|
|
Hdr: dns.RR_Header{
|
|
Name: question.Name,
|
|
Rrtype: dns.TypeAAAA,
|
|
Class: dns.ClassINET,
|
|
Ttl: 300, // 5 minutes
|
|
},
|
|
AAAA: ip.To16(),
|
|
}
|
|
}
|
|
response.Answer = append(response.Answer, rr)
|
|
}
|
|
|
|
return response
|
|
}
|
|
|
|
// forwardToUpstream forwards a DNS query to upstream DNS servers
|
|
func (p *DNSProxy) forwardToUpstream(query *dns.Msg) *dns.Msg {
|
|
// Try primary DNS server
|
|
response, err := p.queryUpstream(p.upstreamDNS[0], query, 2*time.Second)
|
|
if err != nil && len(p.upstreamDNS) > 1 {
|
|
// Try secondary DNS server
|
|
logger.Debug("Primary DNS failed, trying secondary: %v", err)
|
|
response, err = p.queryUpstream(p.upstreamDNS[1], query, 2*time.Second)
|
|
if err != nil {
|
|
logger.Error("Both DNS servers failed: %v", err)
|
|
return nil
|
|
}
|
|
}
|
|
return response
|
|
}
|
|
|
|
// queryUpstream sends a DNS query to upstream server
|
|
func (p *DNSProxy) queryUpstream(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
|
|
if p.tunnelDNS {
|
|
return p.queryUpstreamTunnel(server, query, timeout)
|
|
}
|
|
return p.queryUpstreamDirect(server, query, timeout)
|
|
}
|
|
|
|
// queryUpstreamDirect sends a DNS query to upstream server using miekg/dns directly (host networking)
|
|
func (p *DNSProxy) queryUpstreamDirect(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
|
|
client := &dns.Client{
|
|
Timeout: timeout,
|
|
}
|
|
|
|
response, _, err := client.Exchange(query, server)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// queryUpstreamTunnel sends a DNS query through the WireGuard tunnel
|
|
func (p *DNSProxy) queryUpstreamTunnel(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
|
|
// Dial through the tunnel netstack
|
|
conn, port, err := p.dialTunnel("udp", server)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to dial tunnel: %v", err)
|
|
}
|
|
defer func() {
|
|
conn.Close()
|
|
p.removeTunnelPort(port)
|
|
}()
|
|
|
|
// Pack the query
|
|
queryData, err := query.Pack()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to pack query: %v", err)
|
|
}
|
|
|
|
// Set deadline
|
|
conn.SetDeadline(time.Now().Add(timeout))
|
|
|
|
// Send the query
|
|
_, err = conn.Write(queryData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to send query: %v", err)
|
|
}
|
|
|
|
// Read the response
|
|
buf := make([]byte, 4096)
|
|
n, err := conn.Read(buf)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read response: %v", err)
|
|
}
|
|
|
|
// Parse the response
|
|
response := new(dns.Msg)
|
|
if err := response.Unpack(buf[:n]); err != nil {
|
|
return nil, fmt.Errorf("failed to unpack response: %v", err)
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
// dialTunnel creates a UDP connection through the tunnel netstack
|
|
func (p *DNSProxy) dialTunnel(network, addr string) (net.Conn, uint16, error) {
|
|
if p.tunnelStack == nil {
|
|
return nil, 0, fmt.Errorf("tunnel netstack not initialized")
|
|
}
|
|
|
|
// Parse remote address
|
|
raddr, err := net.ResolveUDPAddr("udp", addr)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
// Use tunnel IP as source
|
|
ipBytes := p.tunnelIP.As4()
|
|
|
|
// Create UDP connection with ephemeral port
|
|
laddr := &tcpip.FullAddress{
|
|
NIC: 1,
|
|
Addr: tcpip.AddrFrom4(ipBytes),
|
|
Port: 0,
|
|
}
|
|
|
|
raddrTcpip := &tcpip.FullAddress{
|
|
NIC: 1,
|
|
Addr: tcpip.AddrFrom4([4]byte(raddr.IP.To4())),
|
|
Port: uint16(raddr.Port),
|
|
}
|
|
|
|
conn, err := gonet.DialUDP(p.tunnelStack, laddr, raddrTcpip, ipv4.ProtocolNumber)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
// Get local port
|
|
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
|
port := uint16(localAddr.Port)
|
|
|
|
// Register port so we can receive responses
|
|
p.tunnelPortsLock.Lock()
|
|
p.tunnelActivePorts[port] = true
|
|
p.tunnelPortsLock.Unlock()
|
|
|
|
return conn, port, nil
|
|
}
|
|
|
|
// removeTunnelPort removes a port from the active ports map
|
|
func (p *DNSProxy) removeTunnelPort(port uint16) {
|
|
p.tunnelPortsLock.Lock()
|
|
delete(p.tunnelActivePorts, port)
|
|
p.tunnelPortsLock.Unlock()
|
|
}
|
|
|
|
// runTunnelPacketSender reads packets from tunnel netstack and injects them into WireGuard
|
|
func (p *DNSProxy) runTunnelPacketSender() {
|
|
defer p.wg.Done()
|
|
logger.Debug("DNS tunnel packet sender goroutine started")
|
|
|
|
for {
|
|
// Use blocking ReadContext instead of polling - much more CPU efficient
|
|
// This will block until a packet is available or context is cancelled
|
|
pkt := p.tunnelEp.ReadContext(p.ctx)
|
|
if pkt == nil {
|
|
// Context was cancelled or endpoint closed
|
|
logger.Debug("DNS tunnel packet sender exiting")
|
|
// Drain any remaining packets
|
|
for {
|
|
pkt := p.tunnelEp.Read()
|
|
if pkt == nil {
|
|
break
|
|
}
|
|
pkt.DecRef()
|
|
}
|
|
return
|
|
}
|
|
|
|
// Extract packet data
|
|
slices := pkt.AsSlices()
|
|
if len(slices) > 0 {
|
|
var totalSize int
|
|
for _, slice := range slices {
|
|
totalSize += len(slice)
|
|
}
|
|
|
|
buf := make([]byte, totalSize)
|
|
pos := 0
|
|
for _, slice := range slices {
|
|
copy(buf[pos:], slice)
|
|
pos += len(slice)
|
|
}
|
|
|
|
// Inject into MiddleDevice (outbound to WG)
|
|
p.middleDevice.InjectOutbound(buf)
|
|
}
|
|
|
|
pkt.DecRef()
|
|
}
|
|
}
|
|
|
|
// runPacketSender sends packets from netstack back to TUN
|
|
func (p *DNSProxy) runPacketSender() {
|
|
defer p.wg.Done()
|
|
|
|
// MessageTransportHeaderSize is the offset used by WireGuard device
|
|
// for reading/writing packets to the TUN interface
|
|
const offset = 16
|
|
|
|
for {
|
|
// Use blocking ReadContext instead of polling - much more CPU efficient
|
|
// This will block until a packet is available or context is cancelled
|
|
pkt := p.ep.ReadContext(p.ctx)
|
|
if pkt == nil {
|
|
// Context was cancelled or endpoint closed
|
|
return
|
|
}
|
|
|
|
// Extract packet data as slices
|
|
slices := pkt.AsSlices()
|
|
if len(slices) > 0 {
|
|
// Flatten all slices into a single packet buffer
|
|
var totalSize int
|
|
for _, slice := range slices {
|
|
totalSize += len(slice)
|
|
}
|
|
|
|
// Allocate buffer with offset space for WireGuard transport header
|
|
// The first 'offset' bytes are reserved for the transport header
|
|
buf := make([]byte, offset+totalSize)
|
|
|
|
// Copy packet data after the offset
|
|
pos := offset
|
|
for _, slice := range slices {
|
|
copy(buf[pos:], slice)
|
|
pos += len(slice)
|
|
}
|
|
|
|
// Write packet to TUN device via MiddleDevice
|
|
// offset=16 indicates packet data starts at position 16 in the buffer
|
|
_, err := p.middleDevice.WriteToTun([][]byte{buf}, offset)
|
|
if err != nil {
|
|
logger.Error("Failed to write DNS response to TUN: %v", err)
|
|
}
|
|
}
|
|
|
|
pkt.DecRef()
|
|
}
|
|
}
|
|
|
|
// AddDNSRecord adds a DNS record to the local store
|
|
// domain should be a domain name (e.g., "example.com" or "example.com.")
|
|
// ip should be a valid IPv4 or IPv6 address
|
|
func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP) error {
|
|
return p.recordStore.AddRecord(domain, ip)
|
|
}
|
|
|
|
// RemoveDNSRecord removes a DNS record from the local store
|
|
// If ip is nil, removes all records for the domain
|
|
func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) {
|
|
p.recordStore.RemoveRecord(domain, ip)
|
|
}
|
|
|
|
// GetDNSRecords returns all IP addresses for a domain and record type.
|
|
// The second return value indicates whether the domain exists.
|
|
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) ([]net.IP, bool) {
|
|
return p.recordStore.GetRecords(domain, recordType)
|
|
}
|
|
|
|
// ClearDNSRecords removes all DNS records from the local store
|
|
func (p *DNSProxy) ClearDNSRecords() {
|
|
p.recordStore.Clear()
|
|
}
|
|
|
|
func PickIPFromSubnet(subnet string) (netip.Addr, error) {
|
|
// given a subnet in CIDR notation, pick the first usable IP
|
|
prefix, err := netip.ParsePrefix(subnet)
|
|
if err != nil {
|
|
return netip.Addr{}, fmt.Errorf("invalid subnet: %w", err)
|
|
}
|
|
|
|
// Pick the first usable IP address from the subnet
|
|
ip := prefix.Addr().Next()
|
|
if !ip.IsValid() {
|
|
return netip.Addr{}, fmt.Errorf("no valid IP address found in subnet: %s", subnet)
|
|
}
|
|
|
|
return ip, nil
|
|
}
|