diff --git a/config.go b/config.go index 4b1c824..2e13d6a 100644 --- a/config.go +++ b/config.go @@ -43,6 +43,7 @@ type OlmConfig struct { DisableHolepunch bool `json:"disableHolepunch"` TlsClientCert string `json:"tlsClientCert"` OverrideDNS bool `json:"overrideDNS"` + TunnelDNS bool `json:"tunnelDNS"` DisableRelay bool `json:"disableRelay"` // DoNotCreateNewClient bool `json:"doNotCreateNewClient"` @@ -88,6 +89,7 @@ func DefaultConfig() *OlmConfig { PingInterval: "3s", PingTimeout: "5s", DisableHolepunch: false, + TunnelDNS: false, // DoNotCreateNewClient: false, sources: make(map[string]string), } @@ -105,6 +107,7 @@ func DefaultConfig() *OlmConfig { config.sources["pingTimeout"] = string(SourceDefault) config.sources["disableHolepunch"] = string(SourceDefault) config.sources["overrideDNS"] = string(SourceDefault) + config.sources["tunnelDNS"] = string(SourceDefault) config.sources["disableRelay"] = string(SourceDefault) // config.sources["doNotCreateNewClient"] = string(SourceDefault) @@ -265,6 +268,10 @@ func loadConfigFromEnv(config *OlmConfig) { config.DisableRelay = true config.sources["disableRelay"] = string(SourceEnv) } + if val := os.Getenv("TUNNEL_DNS"); val == "true" { + config.TunnelDNS = true + config.sources["tunnelDNS"] = string(SourceEnv) + } // if val := os.Getenv("DO_NOT_CREATE_NEW_CLIENT"); val == "true" { // config.DoNotCreateNewClient = true // config.sources["doNotCreateNewClient"] = string(SourceEnv) @@ -295,6 +302,7 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { "disableHolepunch": config.DisableHolepunch, "overrideDNS": config.OverrideDNS, "disableRelay": config.DisableRelay, + "tunnelDNS": config.TunnelDNS, // "doNotCreateNewClient": config.DoNotCreateNewClient, } @@ -318,6 +326,7 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { serviceFlags.BoolVar(&config.DisableHolepunch, "disable-holepunch", config.DisableHolepunch, "Disable hole punching") serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "Override system DNS settings") serviceFlags.BoolVar(&config.DisableRelay, "disable-relay", config.DisableRelay, "Disable relay connections") + serviceFlags.BoolVar(&config.TunnelDNS, "tunnel-dns", config.TunnelDNS, "Use tunnel for DNS traffic") // serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client") version := serviceFlags.Bool("version", false, "Print the version") @@ -393,6 +402,9 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { if config.DisableRelay != origValues["disableRelay"].(bool) { config.sources["disableRelay"] = string(SourceCLI) } + if config.TunnelDNS != origValues["tunnelDNS"].(bool) { + config.sources["tunnelDNS"] = string(SourceCLI) + } // if config.DoNotCreateNewClient != origValues["doNotCreateNewClient"].(bool) { // config.sources["doNotCreateNewClient"] = string(SourceCLI) // } @@ -606,6 +618,7 @@ func (c *OlmConfig) ShowConfig() { fmt.Println("\nAdvanced:") fmt.Printf(" disable-holepunch = %v [%s]\n", c.DisableHolepunch, getSource("disableHolepunch")) fmt.Printf(" override-dns = %v [%s]\n", c.OverrideDNS, getSource("overrideDNS")) + fmt.Printf(" tunnel-dns = %v [%s]\n", c.TunnelDNS, getSource("tunnelDNS")) fmt.Printf(" disable-relay = %v [%s]\n", c.DisableRelay, getSource("disableRelay")) // fmt.Printf(" do-not-create-new-client = %v [%s]\n", c.DoNotCreateNewClient, getSource("doNotCreateNewClient")) if c.TlsClientCert != "" { diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index d0ed7b3..6d56379 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -34,18 +34,26 @@ type DNSProxy struct { ep *channel.Endpoint proxyIP netip.Addr upstreamDNS []string + tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally mtu int tunDevice tun.Device // Direct reference to underlying TUN device for responses middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering recordStore *DNSRecordStore // Local DNS records + // Tunnel DNS fields - for sending queries over WireGuard + tunnelIP netip.Addr // WireGuard interface IP (source for tunneled queries) + tunnelStack *stack.Stack // Separate netstack for outbound tunnel queries + tunnelEp *channel.Endpoint + tunnelActivePorts map[uint16]bool + tunnelPortsLock sync.Mutex + ctx context.Context cancel context.CancelFunc wg sync.WaitGroup } // NewDNSProxy creates a new DNS proxy -func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string) (*DNSProxy, error) { +func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) { proxyIP, err := PickIPFromSubnet(utilitySubnet) if err != nil { return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err) @@ -58,17 +66,28 @@ func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu in ctx, cancel := context.WithCancel(context.Background()) proxy := &DNSProxy{ - proxyIP: proxyIP, - mtu: mtu, - tunDevice: tunDevice, - middleDevice: middleDevice, - upstreamDNS: upstreamDns, - recordStore: NewDNSRecordStore(), - ctx: ctx, - cancel: cancel, + proxyIP: proxyIP, + mtu: mtu, + tunDevice: tunDevice, + middleDevice: middleDevice, + upstreamDNS: upstreamDns, + tunnelDNS: tunnelDns, + recordStore: NewDNSRecordStore(), + tunnelActivePorts: make(map[uint16]bool), + ctx: ctx, + cancel: cancel, } - // Create gvisor netstack + // Parse tunnel IP if provided (needed for tunneled DNS) + if tunnelIP != "" { + addr, err := netip.ParseAddr(tunnelIP) + if err != nil { + return nil, fmt.Errorf("failed to parse tunnel IP: %v", err) + } + proxy.tunnelIP = addr + } + + // Create gvisor netstack for receiving DNS queries stackOpts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, @@ -101,9 +120,104 @@ func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu in NIC: 1, }) + // Initialize tunnel netstack if tunnel DNS is enabled + if tunnelDns { + if !proxy.tunnelIP.IsValid() { + return nil, fmt.Errorf("tunnel IP is required when tunnelDNS is enabled") + } + + // TODO: DO WE NEED TO ESTABLISH ANOTHER NETSTACK HERE OR CAN WE COMBINE WITH WGTESTER? + if err := proxy.initTunnelNetstack(); err != nil { + return nil, fmt.Errorf("failed to initialize tunnel netstack: %v", err) + } + } + return proxy, nil } +// initTunnelNetstack creates a separate netstack for outbound DNS queries through the tunnel +func (p *DNSProxy) initTunnelNetstack() error { + // Create gvisor netstack for outbound tunnel queries + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + HandleLocal: true, + } + + p.tunnelEp = channel.New(256, uint32(p.mtu), "") + p.tunnelStack = stack.New(stackOpts) + + // Create NIC + if err := p.tunnelStack.CreateNIC(1, p.tunnelEp); err != nil { + return fmt.Errorf("failed to create tunnel NIC: %v", err) + } + + // Add tunnel IP address (WireGuard interface IP) + ipBytes := p.tunnelIP.As4() + protoAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(), + } + + if err := p.tunnelStack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil { + return fmt.Errorf("failed to add tunnel protocol address: %v", err) + } + + // Add default route + p.tunnelStack.AddRoute(tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: 1, + }) + + // Register filter rule on MiddleDevice to intercept responses + p.middleDevice.AddRule(p.tunnelIP, p.handleTunnelResponse) + + return nil +} + +// handleTunnelResponse handles packets coming back from the tunnel destined for the tunnel IP +func (p *DNSProxy) handleTunnelResponse(packet []byte) bool { + // Check if it's UDP + proto, ok := util.GetProtocol(packet) + if !ok || proto != 17 { // UDP + return false + } + + // Check destination port - should be one of our active outbound ports + port, ok := util.GetDestPort(packet) + if !ok { + return false + } + + // Check if we are expecting a response on this port + p.tunnelPortsLock.Lock() + active := p.tunnelActivePorts[uint16(port)] + p.tunnelPortsLock.Unlock() + + if !active { + return false + } + + // Inject into tunnel netstack + version := packet[0] >> 4 + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(packet), + }) + + switch version { + case 4: + p.tunnelEp.InjectInbound(ipv4.ProtocolNumber, pkb) + case 6: + p.tunnelEp.InjectInbound(ipv6.ProtocolNumber, pkb) + default: + pkb.DecRef() + return false + } + + pkb.DecRef() + return true // Handled +} + // Start starts the DNS proxy and registers with the filter func (p *DNSProxy) Start() error { // Install packet filter rule @@ -114,7 +228,13 @@ func (p *DNSProxy) Start() error { go p.runDNSListener() go p.runPacketSender() - logger.Info("DNS proxy started on %s:%d", p.proxyIP.String(), DNSPort) + // Start tunnel packet sender if tunnel DNS is enabled + if p.tunnelDNS { + p.wg.Add(1) + go p.runTunnelPacketSender() + } + + logger.Info("DNS proxy started on %s:%d (tunnelDNS=%v)", p.proxyIP.String(), DNSPort, p.tunnelDNS) return nil } @@ -122,6 +242,9 @@ func (p *DNSProxy) Start() error { func (p *DNSProxy) Stop() { if p.middleDevice != nil { p.middleDevice.RemoveRule(p.proxyIP) + if p.tunnelDNS && p.tunnelIP.IsValid() { + p.middleDevice.RemoveRule(p.tunnelIP) + } } p.cancel() @@ -130,12 +253,21 @@ func (p *DNSProxy) Stop() { p.ep.Close() } + // Close tunnel endpoint if it exists + if p.tunnelEp != nil { + p.tunnelEp.Close() + } + p.wg.Wait() if p.stack != nil { p.stack.Close() } + if p.tunnelStack != nil { + p.tunnelStack.Close() + } + logger.Info("DNS proxy stopped") } @@ -348,8 +480,16 @@ func (p *DNSProxy) forwardToUpstream(query *dns.Msg) *dns.Msg { return response } -// queryUpstream sends a DNS query to upstream server using miekg/dns +// queryUpstream sends a DNS query to upstream server func (p *DNSProxy) queryUpstream(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) { + if p.tunnelDNS { + return p.queryUpstreamTunnel(server, query, timeout) + } + return p.queryUpstreamDirect(server, query, timeout) +} + +// queryUpstreamDirect sends a DNS query to upstream server using miekg/dns directly (host networking) +func (p *DNSProxy) queryUpstreamDirect(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) { client := &dns.Client{ Timeout: timeout, } @@ -362,6 +502,155 @@ func (p *DNSProxy) queryUpstream(server string, query *dns.Msg, timeout time.Dur return response, nil } +// queryUpstreamTunnel sends a DNS query through the WireGuard tunnel +func (p *DNSProxy) queryUpstreamTunnel(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) { + // Dial through the tunnel netstack + conn, port, err := p.dialTunnel("udp", server) + if err != nil { + return nil, fmt.Errorf("failed to dial tunnel: %v", err) + } + defer func() { + conn.Close() + p.removeTunnelPort(port) + }() + + // Pack the query + queryData, err := query.Pack() + if err != nil { + return nil, fmt.Errorf("failed to pack query: %v", err) + } + + // Set deadline + conn.SetDeadline(time.Now().Add(timeout)) + + // Send the query + _, err = conn.Write(queryData) + if err != nil { + return nil, fmt.Errorf("failed to send query: %v", err) + } + + // Read the response + buf := make([]byte, 4096) + n, err := conn.Read(buf) + if err != nil { + return nil, fmt.Errorf("failed to read response: %v", err) + } + + // Parse the response + response := new(dns.Msg) + if err := response.Unpack(buf[:n]); err != nil { + return nil, fmt.Errorf("failed to unpack response: %v", err) + } + + return response, nil +} + +// dialTunnel creates a UDP connection through the tunnel netstack +func (p *DNSProxy) dialTunnel(network, addr string) (net.Conn, uint16, error) { + if p.tunnelStack == nil { + return nil, 0, fmt.Errorf("tunnel netstack not initialized") + } + + // Parse remote address + raddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, 0, err + } + + // Use tunnel IP as source + ipBytes := p.tunnelIP.As4() + + // Create UDP connection with ephemeral port + laddr := &tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.AddrFrom4(ipBytes), + Port: 0, + } + + raddrTcpip := &tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.AddrFrom4([4]byte(raddr.IP.To4())), + Port: uint16(raddr.Port), + } + + conn, err := gonet.DialUDP(p.tunnelStack, laddr, raddrTcpip, ipv4.ProtocolNumber) + if err != nil { + return nil, 0, err + } + + // Get local port + localAddr := conn.LocalAddr().(*net.UDPAddr) + port := uint16(localAddr.Port) + + // Register port so we can receive responses + p.tunnelPortsLock.Lock() + p.tunnelActivePorts[port] = true + p.tunnelPortsLock.Unlock() + + return conn, port, nil +} + +// removeTunnelPort removes a port from the active ports map +func (p *DNSProxy) removeTunnelPort(port uint16) { + p.tunnelPortsLock.Lock() + delete(p.tunnelActivePorts, port) + p.tunnelPortsLock.Unlock() +} + +// runTunnelPacketSender reads packets from tunnel netstack and injects them into WireGuard +func (p *DNSProxy) runTunnelPacketSender() { + defer p.wg.Done() + logger.Debug("DNS tunnel packet sender goroutine started") + + ticker := time.NewTicker(1 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-p.ctx.Done(): + logger.Debug("DNS tunnel packet sender exiting") + // Drain any remaining packets + for { + pkt := p.tunnelEp.Read() + if pkt == nil { + break + } + pkt.DecRef() + } + return + case <-ticker.C: + // Try to read packets + for i := 0; i < 10; i++ { + pkt := p.tunnelEp.Read() + if pkt == nil { + break + } + + // Extract packet data + slices := pkt.AsSlices() + if len(slices) > 0 { + var totalSize int + for _, slice := range slices { + totalSize += len(slice) + } + + buf := make([]byte, totalSize) + pos := 0 + for _, slice := range slices { + copy(buf[pos:], slice) + pos += len(slice) + } + + // Inject into MiddleDevice (outbound to WG) + p.middleDevice.InjectOutbound(buf) + } + + pkt.DecRef() + } + } + } +} + // runPacketSender sends packets from netstack back to TUN func (p *DNSProxy) runPacketSender() { defer p.wg.Done() diff --git a/olm/olm.go b/olm/olm.go index a85b4c0..f84ee4f 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -374,8 +374,14 @@ func StartTunnel(config TunnelConfig) { logger.Error("Failed to bring up WireGuard device: %v", err) } + // Extract interface IP (strip CIDR notation if present) + interfaceIP := wgData.TunnelIP + if strings.Contains(interfaceIP, "/") { + interfaceIP = strings.Split(interfaceIP, "/")[0] + } + // Create and start DNS proxy - dnsProxy, err = dns.NewDNSProxy(tdev, middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS) + dnsProxy, err = dns.NewDNSProxy(tdev, middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS, config.TunnelDNS, interfaceIP) if err != nil { logger.Error("Failed to create DNS proxy: %v", err) } @@ -388,12 +394,6 @@ func StartTunnel(config TunnelConfig) { logger.Error("Failed to add route for utility subnet: %v", err) } - // TODO: seperate adding the callback to this so we can init it above with the interface - interfaceIP := wgData.TunnelIP - if strings.Contains(interfaceIP, "/") { - interfaceIP = strings.Split(interfaceIP, "/")[0] - } - // Create peer manager with integrated peer monitoring peerManager = peers.NewPeerManager(peers.PeerManagerConfig{ Device: dev, diff --git a/olm/types.go b/olm/types.go index 993bb56..b7153af 100644 --- a/olm/types.go +++ b/olm/types.go @@ -61,6 +61,7 @@ type TunnelConfig struct { EnableUAPI bool OverrideDNS bool + TunnelDNS bool DisableRelay bool }