Make it protocol aware

Former-commit-id: 511f303559
This commit is contained in:
Owen
2025-11-21 17:11:03 -05:00
parent d7cd746cc9
commit c230c7be28
7 changed files with 382 additions and 148 deletions

View File

@@ -11,6 +11,7 @@ import (
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util"
"github.com/fosrl/olm/device"
"github.com/miekg/dns"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -40,6 +41,7 @@ type DNSProxy struct {
proxyIP netip.Addr
mtu int
tunDevice tun.Device // Direct reference to underlying TUN device for responses
recordStore *DNSRecordStore // Local DNS records
ctx context.Context
cancel context.CancelFunc
@@ -59,6 +61,7 @@ func NewDNSProxy(tunDevice tun.Device, mtu int) (*DNSProxy, error) {
proxyIP: proxyIP,
mtu: mtu,
tunDevice: tunDevice,
recordStore: NewDNSRecordStore(),
ctx: ctx,
cancel: cancel,
}
@@ -212,12 +215,112 @@ func (p *DNSProxy) runDNSListener() {
copy(query, buf[:n])
// Handle query in background
go p.forwardDNSQuery(udpConn, query, remoteAddr)
go p.handleDNSQuery(udpConn, query, remoteAddr)
}
}
// forwardDNSQuery forwards a DNS query to upstream DNS server
func (p *DNSProxy) forwardDNSQuery(udpConn *gonet.UDPConn, query []byte, clientAddr net.Addr) {
// handleDNSQuery processes a DNS query, checking local records first, then forwarding upstream
func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clientAddr net.Addr) {
// Parse the DNS query
msg := new(dns.Msg)
if err := msg.Unpack(queryData); err != nil {
logger.Error("Failed to parse DNS query: %v", err)
return
}
if len(msg.Question) == 0 {
logger.Debug("DNS query has no questions")
return
}
question := msg.Question[0]
logger.Debug("DNS query for %s (type %s)", question.Name, dns.TypeToString[question.Qtype])
// Check if we have local records for this query
var response *dns.Msg
if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA {
response = p.checkLocalRecords(msg, question)
}
// If no local records, forward to upstream
if response == nil {
logger.Debug("No local record for %s, forwarding upstream", question.Name)
response = p.forwardToUpstream(msg)
}
if response == nil {
logger.Error("Failed to get DNS response for %s", question.Name)
return
}
// Pack and send response
responseData, err := response.Pack()
if err != nil {
logger.Error("Failed to pack DNS response: %v", err)
return
}
_, err = udpConn.WriteTo(responseData, clientAddr)
if err != nil {
logger.Error("Failed to send DNS response: %v", err)
}
}
// checkLocalRecords checks if we have local records for the query
func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns.Msg {
var recordType RecordType
if question.Qtype == dns.TypeA {
recordType = RecordTypeA
} else if question.Qtype == dns.TypeAAAA {
recordType = RecordTypeAAAA
} else {
return nil
}
ips := p.recordStore.GetRecords(question.Name, recordType)
if len(ips) == 0 {
return nil
}
logger.Debug("Found %d local record(s) for %s", len(ips), question.Name)
// Create response message
response := new(dns.Msg)
response.SetReply(query)
response.Authoritative = true
// Add answer records
for _, ip := range ips {
var rr dns.RR
if question.Qtype == dns.TypeA {
rr = &dns.A{
Hdr: dns.RR_Header{
Name: question.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300, // 5 minutes
},
A: ip.To4(),
}
} else { // TypeAAAA
rr = &dns.AAAA{
Hdr: dns.RR_Header{
Name: question.Name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300, // 5 minutes
},
AAAA: ip.To16(),
}
}
response.Answer = append(response.Answer, rr)
}
return response
}
// forwardToUpstream forwards a DNS query to upstream DNS servers
func (p *DNSProxy) forwardToUpstream(query *dns.Msg) *dns.Msg {
// Try primary DNS server
response, err := p.queryUpstream(UpstreamDNS1, query, 2*time.Second)
if err != nil {
@@ -226,38 +329,24 @@ func (p *DNSProxy) forwardDNSQuery(udpConn *gonet.UDPConn, query []byte, clientA
response, err = p.queryUpstream(UpstreamDNS2, query, 2*time.Second)
if err != nil {
logger.Error("Both DNS servers failed: %v", err)
return
return nil
}
}
return response
}
// Send response back to client through netstack
_, err = udpConn.WriteTo(response, clientAddr)
if err != nil {
logger.Error("Failed to send DNS response: %v", err)
}
// queryUpstream sends a DNS query to upstream server using miekg/dns
func (p *DNSProxy) queryUpstream(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
client := &dns.Client{
Timeout: timeout,
}
// queryUpstream sends a DNS query to upstream server
func (p *DNSProxy) queryUpstream(server string, query []byte, timeout time.Duration) ([]byte, error) {
conn, err := net.DialTimeout("udp", server, timeout)
if err != nil {
return nil, err
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(timeout))
if _, err := conn.Write(query); err != nil {
return nil, err
}
response := make([]byte, 4096)
n, err := conn.Read(response)
response, _, err := client.Exchange(query, server)
if err != nil {
return nil, err
}
return response[:n], nil
return response, nil
}
// runPacketSender sends packets from netstack back to TUN
@@ -314,3 +403,26 @@ func (p *DNSProxy) runPacketSender() {
pkt.DecRef()
}
}
// AddDNSRecord adds a DNS record to the local store
// domain should be a domain name (e.g., "example.com" or "example.com.")
// ip should be a valid IPv4 or IPv6 address
func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP) error {
return p.recordStore.AddRecord(domain, ip)
}
// RemoveDNSRecord removes a DNS record from the local store
// If ip is nil, removes all records for the domain
func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) {
p.recordStore.RemoveRecord(domain, ip)
}
// GetDNSRecords returns all IP addresses for a domain and record type
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP {
return p.recordStore.GetRecords(domain, recordType)
}
// ClearDNSRecords removes all DNS records from the local store
func (p *DNSProxy) ClearDNSRecords() {
p.recordStore.Clear()
}

166
dns/dns_records.go Normal file
View File

@@ -0,0 +1,166 @@
package dns
import (
"net"
"sync"
"github.com/miekg/dns"
)
// RecordType represents the type of DNS record
type RecordType uint16
const (
RecordTypeA RecordType = RecordType(dns.TypeA)
RecordTypeAAAA RecordType = RecordType(dns.TypeAAAA)
)
// DNSRecordStore manages local DNS records for A and AAAA queries
type DNSRecordStore struct {
mu sync.RWMutex
aRecords map[string][]net.IP // domain -> list of IPv4 addresses
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
}
// NewDNSRecordStore creates a new DNS record store
func NewDNSRecordStore() *DNSRecordStore {
return &DNSRecordStore{
aRecords: make(map[string][]net.IP),
aaaaRecords: make(map[string][]net.IP),
}
}
// AddRecord adds a DNS record mapping (A or AAAA)
// domain should be in FQDN format (e.g., "example.com.")
// ip should be a valid IPv4 or IPv6 address
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
s.mu.Lock()
defer s.mu.Unlock()
// Ensure domain ends with a dot (FQDN format)
if len(domain) == 0 || domain[len(domain)-1] != '.' {
domain = domain + "."
}
// Normalize domain to lowercase
domain = dns.Fqdn(domain)
if ip.To4() != nil {
// IPv4 address
s.aRecords[domain] = append(s.aRecords[domain], ip)
} else if ip.To16() != nil {
// IPv6 address
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
} else {
return &net.ParseError{Type: "IP address", Text: ip.String()}
}
return nil
}
// RemoveRecord removes a specific DNS record mapping
// If ip is nil, removes all records for the domain
func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
s.mu.Lock()
defer s.mu.Unlock()
// Ensure domain ends with a dot (FQDN format)
if len(domain) == 0 || domain[len(domain)-1] != '.' {
domain = domain + "."
}
// Normalize domain to lowercase
domain = dns.Fqdn(domain)
if ip == nil {
// Remove all records for this domain
delete(s.aRecords, domain)
delete(s.aaaaRecords, domain)
return
}
if ip.To4() != nil {
// Remove specific IPv4 address
if ips, ok := s.aRecords[domain]; ok {
s.aRecords[domain] = removeIP(ips, ip)
if len(s.aRecords[domain]) == 0 {
delete(s.aRecords, domain)
}
}
} else if ip.To16() != nil {
// Remove specific IPv6 address
if ips, ok := s.aaaaRecords[domain]; ok {
s.aaaaRecords[domain] = removeIP(ips, ip)
if len(s.aaaaRecords[domain]) == 0 {
delete(s.aaaaRecords, domain)
}
}
}
}
// GetRecords returns all IP addresses for a domain and record type
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
s.mu.RLock()
defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN
domain = dns.Fqdn(domain)
var records []net.IP
switch recordType {
case RecordTypeA:
if ips, ok := s.aRecords[domain]; ok {
// Return a copy to prevent external modifications
records = make([]net.IP, len(ips))
copy(records, ips)
}
case RecordTypeAAAA:
if ips, ok := s.aaaaRecords[domain]; ok {
// Return a copy to prevent external modifications
records = make([]net.IP, len(ips))
copy(records, ips)
}
}
return records
}
// HasRecord checks if a domain has any records of the specified type
func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
s.mu.RLock()
defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN
domain = dns.Fqdn(domain)
switch recordType {
case RecordTypeA:
_, ok := s.aRecords[domain]
return ok
case RecordTypeAAAA:
_, ok := s.aaaaRecords[domain]
return ok
}
return false
}
// Clear removes all records from the store
func (s *DNSRecordStore) Clear() {
s.mu.Lock()
defer s.mu.Unlock()
s.aRecords = make(map[string][]net.IP)
s.aaaaRecords = make(map[string][]net.IP)
}
// removeIP is a helper function to remove a specific IP from a slice
func removeIP(ips []net.IP, toRemove net.IP) []net.IP {
result := make([]net.IP, 0, len(ips))
for _, ip := range ips {
if !ip.Equal(toRemove) {
result = append(result, ip)
}
}
return result
}

53
dns/example_usage.go Normal file
View File

@@ -0,0 +1,53 @@
package dns
// Example usage of DNS record management (not compiled, just for reference)
/*
import (
"net"
"github.com/fosrl/olm/dns"
)
func exampleUsage() {
// Assuming you have a DNSProxy instance
var proxy *dns.DNSProxy
// Add an A record for example.com pointing to 192.168.1.100
ip := net.ParseIP("192.168.1.100")
err := proxy.AddDNSRecord("example.com", ip)
if err != nil {
// Handle error
}
// Add multiple A records for the same domain (round-robin)
proxy.AddDNSRecord("example.com", net.ParseIP("192.168.1.101"))
proxy.AddDNSRecord("example.com", net.ParseIP("192.168.1.102"))
// Add an AAAA record (IPv6)
ipv6 := net.ParseIP("2001:db8::1")
proxy.AddDNSRecord("example.com", ipv6)
// Query records
aRecords := proxy.GetDNSRecords("example.com", dns.RecordTypeA)
// Returns: [192.168.1.100, 192.168.1.101, 192.168.1.102]
aaaaRecords := proxy.GetDNSRecords("example.com", dns.RecordTypeAAAA)
// Returns: [2001:db8::1]
// Remove a specific record
proxy.RemoveDNSRecord("example.com", net.ParseIP("192.168.1.101"))
// Remove all records for a domain
proxy.RemoveDNSRecord("example.com", nil)
// Clear all DNS records
proxy.ClearDNSRecords()
}
// How it works:
// 1. When a DNS query arrives, the proxy first checks its local record store
// 2. If a matching A or AAAA record exists locally, it returns that immediately
// 3. If no local record exists, it forwards the query to upstream DNS (8.8.8.8 or 8.8.4.4)
// 4. All other DNS record types (MX, CNAME, TXT, etc.) are always forwarded upstream
*/

4
go.mod
View File

@@ -16,11 +16,15 @@ require (
require (
github.com/google/btree v1.1.3 // indirect
github.com/miekg/dns v1.1.68 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
golang.org/x/crypto v0.44.0 // indirect
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
golang.org/x/mod v0.30.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/time v0.12.0 // indirect
golang.org/x/tools v0.39.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
)

8
go.sum
View File

@@ -6,6 +6,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0=
github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
@@ -14,14 +16,20 @@ golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A=

View File

@@ -1,111 +0,0 @@
package olm
// This file demonstrates how to add additional virtual services using the FilteredDevice infrastructure
// Copy and modify this template to add new services
import (
"context"
"net/netip"
"sync"
"github.com/fosrl/newt/logger"
"golang.zx2c4.com/wireguard/tun"
)
// Example: Simple echo server on 10.30.30.50:7777
const (
EchoProxyIP = "10.30.30.50"
EchoProxyPort = 7777
)
// EchoProxy implements a simple echo server
type EchoProxy struct {
proxyIP netip.Addr
tunDevice tun.Device
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// NewEchoProxy creates a new echo proxy instance
func NewEchoProxy(tunDevice tun.Device) (*EchoProxy, error) {
proxyIP := netip.MustParseAddr(EchoProxyIP)
ctx, cancel := context.WithCancel(context.Background())
return &EchoProxy{
proxyIP: proxyIP,
tunDevice: tunDevice,
ctx: ctx,
cancel: cancel,
}, nil
}
// Start registers the proxy with the filter
func (e *EchoProxy) Start(filter *FilteredDevice) error {
filter.AddRule(e.proxyIP, e.handlePacket)
logger.Info("Echo proxy started on %s:%d", EchoProxyIP, EchoProxyPort)
return nil
}
// Stop unregisters the proxy
func (e *EchoProxy) Stop(filter *FilteredDevice) {
if filter != nil {
filter.RemoveRule(e.proxyIP)
}
e.cancel()
e.wg.Wait()
logger.Info("Echo proxy stopped")
}
// handlePacket processes packets destined for the echo server
func (e *EchoProxy) handlePacket(packet []byte) bool {
// Quick validation
if len(packet) < 20 {
return false
}
// Check protocol (UDP)
proto, ok := GetProtocol(packet)
if !ok || proto != 17 {
return false
}
// Check port
port, ok := GetDestPort(packet)
if !ok || port != EchoProxyPort {
return false
}
// For a real implementation, you would:
// 1. Parse the UDP packet
// 2. Extract the payload
// 3. Create a response packet with swapped src/dest
// 4. Write response back to TUN device
logger.Debug("Echo proxy received packet (would echo back)")
// Return true to drop packet from normal WireGuard path
return true
}
// Example integration in olm.go:
//
// var echoProxy *EchoProxy
//
// // During tunnel setup (after creating filteredDev):
// echoProxy, err = NewEchoProxy(tdev)
// if err != nil {
// logger.Error("Failed to create echo proxy: %v", err)
// return
// }
// if err := echoProxy.Start(filteredDev); err != nil {
// logger.Error("Failed to start echo proxy: %v", err)
// return
// }
//
// // During tunnel teardown:
// if echoProxy != nil {
// echoProxy.Stop(filteredDev)
// echoProxy = nil
// }

View File

@@ -435,11 +435,13 @@ func StartTunnel(config TunnelConfig) {
dnsProxy, err = dns.NewDNSProxy(tdev, config.MTU)
if err != nil {
logger.Error("Failed to create DNS proxy: %v", err)
return
}
if err := dnsProxy.Start(middleDev); err != nil {
logger.Error("Failed to start DNS proxy: %v", err)
return
}
ip := net.ParseIP("192.168.1.100")
if dnsProxy.AddDNSRecord("example.com", ip); err != nil {
logger.Error("Failed to add DNS record: %v", err)
}
// fileUAPI, err := func() (*os.File, error) {