Make it protocol aware

Former-commit-id: 511f303559
This commit is contained in:
Owen
2025-11-21 17:11:03 -05:00
parent d7cd746cc9
commit c230c7be28
7 changed files with 382 additions and 148 deletions

View File

@@ -11,6 +11,7 @@ import (
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util"
"github.com/fosrl/olm/device"
"github.com/miekg/dns"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -35,11 +36,12 @@ const (
// DNSProxy implements a DNS proxy using gvisor netstack
type DNSProxy struct {
stack *stack.Stack
ep *channel.Endpoint
proxyIP netip.Addr
mtu int
tunDevice tun.Device // Direct reference to underlying TUN device for responses
stack *stack.Stack
ep *channel.Endpoint
proxyIP netip.Addr
mtu int
tunDevice tun.Device // Direct reference to underlying TUN device for responses
recordStore *DNSRecordStore // Local DNS records
ctx context.Context
cancel context.CancelFunc
@@ -56,11 +58,12 @@ func NewDNSProxy(tunDevice tun.Device, mtu int) (*DNSProxy, error) {
ctx, cancel := context.WithCancel(context.Background())
proxy := &DNSProxy{
proxyIP: proxyIP,
mtu: mtu,
tunDevice: tunDevice,
ctx: ctx,
cancel: cancel,
proxyIP: proxyIP,
mtu: mtu,
tunDevice: tunDevice,
recordStore: NewDNSRecordStore(),
ctx: ctx,
cancel: cancel,
}
// Create gvisor netstack
@@ -212,12 +215,112 @@ func (p *DNSProxy) runDNSListener() {
copy(query, buf[:n])
// Handle query in background
go p.forwardDNSQuery(udpConn, query, remoteAddr)
go p.handleDNSQuery(udpConn, query, remoteAddr)
}
}
// forwardDNSQuery forwards a DNS query to upstream DNS server
func (p *DNSProxy) forwardDNSQuery(udpConn *gonet.UDPConn, query []byte, clientAddr net.Addr) {
// handleDNSQuery processes a DNS query, checking local records first, then forwarding upstream
func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clientAddr net.Addr) {
// Parse the DNS query
msg := new(dns.Msg)
if err := msg.Unpack(queryData); err != nil {
logger.Error("Failed to parse DNS query: %v", err)
return
}
if len(msg.Question) == 0 {
logger.Debug("DNS query has no questions")
return
}
question := msg.Question[0]
logger.Debug("DNS query for %s (type %s)", question.Name, dns.TypeToString[question.Qtype])
// Check if we have local records for this query
var response *dns.Msg
if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA {
response = p.checkLocalRecords(msg, question)
}
// If no local records, forward to upstream
if response == nil {
logger.Debug("No local record for %s, forwarding upstream", question.Name)
response = p.forwardToUpstream(msg)
}
if response == nil {
logger.Error("Failed to get DNS response for %s", question.Name)
return
}
// Pack and send response
responseData, err := response.Pack()
if err != nil {
logger.Error("Failed to pack DNS response: %v", err)
return
}
_, err = udpConn.WriteTo(responseData, clientAddr)
if err != nil {
logger.Error("Failed to send DNS response: %v", err)
}
}
// checkLocalRecords checks if we have local records for the query
func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns.Msg {
var recordType RecordType
if question.Qtype == dns.TypeA {
recordType = RecordTypeA
} else if question.Qtype == dns.TypeAAAA {
recordType = RecordTypeAAAA
} else {
return nil
}
ips := p.recordStore.GetRecords(question.Name, recordType)
if len(ips) == 0 {
return nil
}
logger.Debug("Found %d local record(s) for %s", len(ips), question.Name)
// Create response message
response := new(dns.Msg)
response.SetReply(query)
response.Authoritative = true
// Add answer records
for _, ip := range ips {
var rr dns.RR
if question.Qtype == dns.TypeA {
rr = &dns.A{
Hdr: dns.RR_Header{
Name: question.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300, // 5 minutes
},
A: ip.To4(),
}
} else { // TypeAAAA
rr = &dns.AAAA{
Hdr: dns.RR_Header{
Name: question.Name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300, // 5 minutes
},
AAAA: ip.To16(),
}
}
response.Answer = append(response.Answer, rr)
}
return response
}
// forwardToUpstream forwards a DNS query to upstream DNS servers
func (p *DNSProxy) forwardToUpstream(query *dns.Msg) *dns.Msg {
// Try primary DNS server
response, err := p.queryUpstream(UpstreamDNS1, query, 2*time.Second)
if err != nil {
@@ -226,38 +329,24 @@ func (p *DNSProxy) forwardDNSQuery(udpConn *gonet.UDPConn, query []byte, clientA
response, err = p.queryUpstream(UpstreamDNS2, query, 2*time.Second)
if err != nil {
logger.Error("Both DNS servers failed: %v", err)
return
return nil
}
}
// Send response back to client through netstack
_, err = udpConn.WriteTo(response, clientAddr)
if err != nil {
logger.Error("Failed to send DNS response: %v", err)
}
return response
}
// queryUpstream sends a DNS query to upstream server
func (p *DNSProxy) queryUpstream(server string, query []byte, timeout time.Duration) ([]byte, error) {
conn, err := net.DialTimeout("udp", server, timeout)
if err != nil {
return nil, err
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(timeout))
if _, err := conn.Write(query); err != nil {
return nil, err
// queryUpstream sends a DNS query to upstream server using miekg/dns
func (p *DNSProxy) queryUpstream(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
client := &dns.Client{
Timeout: timeout,
}
response := make([]byte, 4096)
n, err := conn.Read(response)
response, _, err := client.Exchange(query, server)
if err != nil {
return nil, err
}
return response[:n], nil
return response, nil
}
// runPacketSender sends packets from netstack back to TUN
@@ -314,3 +403,26 @@ func (p *DNSProxy) runPacketSender() {
pkt.DecRef()
}
}
// AddDNSRecord adds a DNS record to the local store
// domain should be a domain name (e.g., "example.com" or "example.com.")
// ip should be a valid IPv4 or IPv6 address
func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP) error {
return p.recordStore.AddRecord(domain, ip)
}
// RemoveDNSRecord removes a DNS record from the local store
// If ip is nil, removes all records for the domain
func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) {
p.recordStore.RemoveRecord(domain, ip)
}
// GetDNSRecords returns all IP addresses for a domain and record type
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP {
return p.recordStore.GetRecords(domain, recordType)
}
// ClearDNSRecords removes all DNS records from the local store
func (p *DNSProxy) ClearDNSRecords() {
p.recordStore.Clear()
}

166
dns/dns_records.go Normal file
View File

@@ -0,0 +1,166 @@
package dns
import (
"net"
"sync"
"github.com/miekg/dns"
)
// RecordType represents the type of DNS record
type RecordType uint16
const (
RecordTypeA RecordType = RecordType(dns.TypeA)
RecordTypeAAAA RecordType = RecordType(dns.TypeAAAA)
)
// DNSRecordStore manages local DNS records for A and AAAA queries
type DNSRecordStore struct {
mu sync.RWMutex
aRecords map[string][]net.IP // domain -> list of IPv4 addresses
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
}
// NewDNSRecordStore creates a new DNS record store
func NewDNSRecordStore() *DNSRecordStore {
return &DNSRecordStore{
aRecords: make(map[string][]net.IP),
aaaaRecords: make(map[string][]net.IP),
}
}
// AddRecord adds a DNS record mapping (A or AAAA)
// domain should be in FQDN format (e.g., "example.com.")
// ip should be a valid IPv4 or IPv6 address
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
s.mu.Lock()
defer s.mu.Unlock()
// Ensure domain ends with a dot (FQDN format)
if len(domain) == 0 || domain[len(domain)-1] != '.' {
domain = domain + "."
}
// Normalize domain to lowercase
domain = dns.Fqdn(domain)
if ip.To4() != nil {
// IPv4 address
s.aRecords[domain] = append(s.aRecords[domain], ip)
} else if ip.To16() != nil {
// IPv6 address
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
} else {
return &net.ParseError{Type: "IP address", Text: ip.String()}
}
return nil
}
// RemoveRecord removes a specific DNS record mapping
// If ip is nil, removes all records for the domain
func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
s.mu.Lock()
defer s.mu.Unlock()
// Ensure domain ends with a dot (FQDN format)
if len(domain) == 0 || domain[len(domain)-1] != '.' {
domain = domain + "."
}
// Normalize domain to lowercase
domain = dns.Fqdn(domain)
if ip == nil {
// Remove all records for this domain
delete(s.aRecords, domain)
delete(s.aaaaRecords, domain)
return
}
if ip.To4() != nil {
// Remove specific IPv4 address
if ips, ok := s.aRecords[domain]; ok {
s.aRecords[domain] = removeIP(ips, ip)
if len(s.aRecords[domain]) == 0 {
delete(s.aRecords, domain)
}
}
} else if ip.To16() != nil {
// Remove specific IPv6 address
if ips, ok := s.aaaaRecords[domain]; ok {
s.aaaaRecords[domain] = removeIP(ips, ip)
if len(s.aaaaRecords[domain]) == 0 {
delete(s.aaaaRecords, domain)
}
}
}
}
// GetRecords returns all IP addresses for a domain and record type
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
s.mu.RLock()
defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN
domain = dns.Fqdn(domain)
var records []net.IP
switch recordType {
case RecordTypeA:
if ips, ok := s.aRecords[domain]; ok {
// Return a copy to prevent external modifications
records = make([]net.IP, len(ips))
copy(records, ips)
}
case RecordTypeAAAA:
if ips, ok := s.aaaaRecords[domain]; ok {
// Return a copy to prevent external modifications
records = make([]net.IP, len(ips))
copy(records, ips)
}
}
return records
}
// HasRecord checks if a domain has any records of the specified type
func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
s.mu.RLock()
defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN
domain = dns.Fqdn(domain)
switch recordType {
case RecordTypeA:
_, ok := s.aRecords[domain]
return ok
case RecordTypeAAAA:
_, ok := s.aaaaRecords[domain]
return ok
}
return false
}
// Clear removes all records from the store
func (s *DNSRecordStore) Clear() {
s.mu.Lock()
defer s.mu.Unlock()
s.aRecords = make(map[string][]net.IP)
s.aaaaRecords = make(map[string][]net.IP)
}
// removeIP is a helper function to remove a specific IP from a slice
func removeIP(ips []net.IP, toRemove net.IP) []net.IP {
result := make([]net.IP, 0, len(ips))
for _, ip := range ips {
if !ip.Equal(toRemove) {
result = append(result, ip)
}
}
return result
}

53
dns/example_usage.go Normal file
View File

@@ -0,0 +1,53 @@
package dns
// Example usage of DNS record management (not compiled, just for reference)
/*
import (
"net"
"github.com/fosrl/olm/dns"
)
func exampleUsage() {
// Assuming you have a DNSProxy instance
var proxy *dns.DNSProxy
// Add an A record for example.com pointing to 192.168.1.100
ip := net.ParseIP("192.168.1.100")
err := proxy.AddDNSRecord("example.com", ip)
if err != nil {
// Handle error
}
// Add multiple A records for the same domain (round-robin)
proxy.AddDNSRecord("example.com", net.ParseIP("192.168.1.101"))
proxy.AddDNSRecord("example.com", net.ParseIP("192.168.1.102"))
// Add an AAAA record (IPv6)
ipv6 := net.ParseIP("2001:db8::1")
proxy.AddDNSRecord("example.com", ipv6)
// Query records
aRecords := proxy.GetDNSRecords("example.com", dns.RecordTypeA)
// Returns: [192.168.1.100, 192.168.1.101, 192.168.1.102]
aaaaRecords := proxy.GetDNSRecords("example.com", dns.RecordTypeAAAA)
// Returns: [2001:db8::1]
// Remove a specific record
proxy.RemoveDNSRecord("example.com", net.ParseIP("192.168.1.101"))
// Remove all records for a domain
proxy.RemoveDNSRecord("example.com", nil)
// Clear all DNS records
proxy.ClearDNSRecords()
}
// How it works:
// 1. When a DNS query arrives, the proxy first checks its local record store
// 2. If a matching A or AAAA record exists locally, it returns that immediately
// 3. If no local record exists, it forwards the query to upstream DNS (8.8.8.8 or 8.8.4.4)
// 4. All other DNS record types (MX, CNAME, TXT, etc.) are always forwarded upstream
*/