Make it protocol aware

This commit is contained in:
Owen
2025-11-21 17:11:03 -05:00
parent 5505c1d2c7
commit 511f303559
7 changed files with 382 additions and 148 deletions

View File

@@ -11,6 +11,7 @@ import (
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util"
"github.com/fosrl/olm/device"
"github.com/miekg/dns"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -35,11 +36,12 @@ const (
// DNSProxy implements a DNS proxy using gvisor netstack
type DNSProxy struct {
stack *stack.Stack
ep *channel.Endpoint
proxyIP netip.Addr
mtu int
tunDevice tun.Device // Direct reference to underlying TUN device for responses
stack *stack.Stack
ep *channel.Endpoint
proxyIP netip.Addr
mtu int
tunDevice tun.Device // Direct reference to underlying TUN device for responses
recordStore *DNSRecordStore // Local DNS records
ctx context.Context
cancel context.CancelFunc
@@ -56,11 +58,12 @@ func NewDNSProxy(tunDevice tun.Device, mtu int) (*DNSProxy, error) {
ctx, cancel := context.WithCancel(context.Background())
proxy := &DNSProxy{
proxyIP: proxyIP,
mtu: mtu,
tunDevice: tunDevice,
ctx: ctx,
cancel: cancel,
proxyIP: proxyIP,
mtu: mtu,
tunDevice: tunDevice,
recordStore: NewDNSRecordStore(),
ctx: ctx,
cancel: cancel,
}
// Create gvisor netstack
@@ -212,12 +215,112 @@ func (p *DNSProxy) runDNSListener() {
copy(query, buf[:n])
// Handle query in background
go p.forwardDNSQuery(udpConn, query, remoteAddr)
go p.handleDNSQuery(udpConn, query, remoteAddr)
}
}
// forwardDNSQuery forwards a DNS query to upstream DNS server
func (p *DNSProxy) forwardDNSQuery(udpConn *gonet.UDPConn, query []byte, clientAddr net.Addr) {
// handleDNSQuery processes a DNS query, checking local records first, then forwarding upstream
func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clientAddr net.Addr) {
// Parse the DNS query
msg := new(dns.Msg)
if err := msg.Unpack(queryData); err != nil {
logger.Error("Failed to parse DNS query: %v", err)
return
}
if len(msg.Question) == 0 {
logger.Debug("DNS query has no questions")
return
}
question := msg.Question[0]
logger.Debug("DNS query for %s (type %s)", question.Name, dns.TypeToString[question.Qtype])
// Check if we have local records for this query
var response *dns.Msg
if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA {
response = p.checkLocalRecords(msg, question)
}
// If no local records, forward to upstream
if response == nil {
logger.Debug("No local record for %s, forwarding upstream", question.Name)
response = p.forwardToUpstream(msg)
}
if response == nil {
logger.Error("Failed to get DNS response for %s", question.Name)
return
}
// Pack and send response
responseData, err := response.Pack()
if err != nil {
logger.Error("Failed to pack DNS response: %v", err)
return
}
_, err = udpConn.WriteTo(responseData, clientAddr)
if err != nil {
logger.Error("Failed to send DNS response: %v", err)
}
}
// checkLocalRecords checks if we have local records for the query
func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns.Msg {
var recordType RecordType
if question.Qtype == dns.TypeA {
recordType = RecordTypeA
} else if question.Qtype == dns.TypeAAAA {
recordType = RecordTypeAAAA
} else {
return nil
}
ips := p.recordStore.GetRecords(question.Name, recordType)
if len(ips) == 0 {
return nil
}
logger.Debug("Found %d local record(s) for %s", len(ips), question.Name)
// Create response message
response := new(dns.Msg)
response.SetReply(query)
response.Authoritative = true
// Add answer records
for _, ip := range ips {
var rr dns.RR
if question.Qtype == dns.TypeA {
rr = &dns.A{
Hdr: dns.RR_Header{
Name: question.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300, // 5 minutes
},
A: ip.To4(),
}
} else { // TypeAAAA
rr = &dns.AAAA{
Hdr: dns.RR_Header{
Name: question.Name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300, // 5 minutes
},
AAAA: ip.To16(),
}
}
response.Answer = append(response.Answer, rr)
}
return response
}
// forwardToUpstream forwards a DNS query to upstream DNS servers
func (p *DNSProxy) forwardToUpstream(query *dns.Msg) *dns.Msg {
// Try primary DNS server
response, err := p.queryUpstream(UpstreamDNS1, query, 2*time.Second)
if err != nil {
@@ -226,38 +329,24 @@ func (p *DNSProxy) forwardDNSQuery(udpConn *gonet.UDPConn, query []byte, clientA
response, err = p.queryUpstream(UpstreamDNS2, query, 2*time.Second)
if err != nil {
logger.Error("Both DNS servers failed: %v", err)
return
return nil
}
}
// Send response back to client through netstack
_, err = udpConn.WriteTo(response, clientAddr)
if err != nil {
logger.Error("Failed to send DNS response: %v", err)
}
return response
}
// queryUpstream sends a DNS query to upstream server
func (p *DNSProxy) queryUpstream(server string, query []byte, timeout time.Duration) ([]byte, error) {
conn, err := net.DialTimeout("udp", server, timeout)
if err != nil {
return nil, err
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(timeout))
if _, err := conn.Write(query); err != nil {
return nil, err
// queryUpstream sends a DNS query to upstream server using miekg/dns
func (p *DNSProxy) queryUpstream(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
client := &dns.Client{
Timeout: timeout,
}
response := make([]byte, 4096)
n, err := conn.Read(response)
response, _, err := client.Exchange(query, server)
if err != nil {
return nil, err
}
return response[:n], nil
return response, nil
}
// runPacketSender sends packets from netstack back to TUN
@@ -314,3 +403,26 @@ func (p *DNSProxy) runPacketSender() {
pkt.DecRef()
}
}
// AddDNSRecord adds a DNS record to the local store
// domain should be a domain name (e.g., "example.com" or "example.com.")
// ip should be a valid IPv4 or IPv6 address
func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP) error {
return p.recordStore.AddRecord(domain, ip)
}
// RemoveDNSRecord removes a DNS record from the local store
// If ip is nil, removes all records for the domain
func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) {
p.recordStore.RemoveRecord(domain, ip)
}
// GetDNSRecords returns all IP addresses for a domain and record type
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP {
return p.recordStore.GetRecords(domain, recordType)
}
// ClearDNSRecords removes all DNS records from the local store
func (p *DNSProxy) ClearDNSRecords() {
p.recordStore.Clear()
}