mirror of
https://github.com/fosrl/olm.git
synced 2026-02-25 06:16:47 +00:00
Make it protocol aware
This commit is contained in:
182
dns/dns_proxy.go
182
dns/dns_proxy.go
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/util"
|
||||
"github.com/fosrl/olm/device"
|
||||
"github.com/miekg/dns"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
@@ -35,11 +36,12 @@ const (
|
||||
|
||||
// DNSProxy implements a DNS proxy using gvisor netstack
|
||||
type DNSProxy struct {
|
||||
stack *stack.Stack
|
||||
ep *channel.Endpoint
|
||||
proxyIP netip.Addr
|
||||
mtu int
|
||||
tunDevice tun.Device // Direct reference to underlying TUN device for responses
|
||||
stack *stack.Stack
|
||||
ep *channel.Endpoint
|
||||
proxyIP netip.Addr
|
||||
mtu int
|
||||
tunDevice tun.Device // Direct reference to underlying TUN device for responses
|
||||
recordStore *DNSRecordStore // Local DNS records
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
@@ -56,11 +58,12 @@ func NewDNSProxy(tunDevice tun.Device, mtu int) (*DNSProxy, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
proxy := &DNSProxy{
|
||||
proxyIP: proxyIP,
|
||||
mtu: mtu,
|
||||
tunDevice: tunDevice,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
proxyIP: proxyIP,
|
||||
mtu: mtu,
|
||||
tunDevice: tunDevice,
|
||||
recordStore: NewDNSRecordStore(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Create gvisor netstack
|
||||
@@ -212,12 +215,112 @@ func (p *DNSProxy) runDNSListener() {
|
||||
copy(query, buf[:n])
|
||||
|
||||
// Handle query in background
|
||||
go p.forwardDNSQuery(udpConn, query, remoteAddr)
|
||||
go p.handleDNSQuery(udpConn, query, remoteAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// forwardDNSQuery forwards a DNS query to upstream DNS server
|
||||
func (p *DNSProxy) forwardDNSQuery(udpConn *gonet.UDPConn, query []byte, clientAddr net.Addr) {
|
||||
// handleDNSQuery processes a DNS query, checking local records first, then forwarding upstream
|
||||
func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clientAddr net.Addr) {
|
||||
// Parse the DNS query
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(queryData); err != nil {
|
||||
logger.Error("Failed to parse DNS query: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(msg.Question) == 0 {
|
||||
logger.Debug("DNS query has no questions")
|
||||
return
|
||||
}
|
||||
|
||||
question := msg.Question[0]
|
||||
logger.Debug("DNS query for %s (type %s)", question.Name, dns.TypeToString[question.Qtype])
|
||||
|
||||
// Check if we have local records for this query
|
||||
var response *dns.Msg
|
||||
if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA {
|
||||
response = p.checkLocalRecords(msg, question)
|
||||
}
|
||||
|
||||
// If no local records, forward to upstream
|
||||
if response == nil {
|
||||
logger.Debug("No local record for %s, forwarding upstream", question.Name)
|
||||
response = p.forwardToUpstream(msg)
|
||||
}
|
||||
|
||||
if response == nil {
|
||||
logger.Error("Failed to get DNS response for %s", question.Name)
|
||||
return
|
||||
}
|
||||
|
||||
// Pack and send response
|
||||
responseData, err := response.Pack()
|
||||
if err != nil {
|
||||
logger.Error("Failed to pack DNS response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err = udpConn.WriteTo(responseData, clientAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to send DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// checkLocalRecords checks if we have local records for the query
|
||||
func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns.Msg {
|
||||
var recordType RecordType
|
||||
if question.Qtype == dns.TypeA {
|
||||
recordType = RecordTypeA
|
||||
} else if question.Qtype == dns.TypeAAAA {
|
||||
recordType = RecordTypeAAAA
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
ips := p.recordStore.GetRecords(question.Name, recordType)
|
||||
if len(ips) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Debug("Found %d local record(s) for %s", len(ips), question.Name)
|
||||
|
||||
// Create response message
|
||||
response := new(dns.Msg)
|
||||
response.SetReply(query)
|
||||
response.Authoritative = true
|
||||
|
||||
// Add answer records
|
||||
for _, ip := range ips {
|
||||
var rr dns.RR
|
||||
if question.Qtype == dns.TypeA {
|
||||
rr = &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300, // 5 minutes
|
||||
},
|
||||
A: ip.To4(),
|
||||
}
|
||||
} else { // TypeAAAA
|
||||
rr = &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300, // 5 minutes
|
||||
},
|
||||
AAAA: ip.To16(),
|
||||
}
|
||||
}
|
||||
response.Answer = append(response.Answer, rr)
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// forwardToUpstream forwards a DNS query to upstream DNS servers
|
||||
func (p *DNSProxy) forwardToUpstream(query *dns.Msg) *dns.Msg {
|
||||
// Try primary DNS server
|
||||
response, err := p.queryUpstream(UpstreamDNS1, query, 2*time.Second)
|
||||
if err != nil {
|
||||
@@ -226,38 +329,24 @@ func (p *DNSProxy) forwardDNSQuery(udpConn *gonet.UDPConn, query []byte, clientA
|
||||
response, err = p.queryUpstream(UpstreamDNS2, query, 2*time.Second)
|
||||
if err != nil {
|
||||
logger.Error("Both DNS servers failed: %v", err)
|
||||
return
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Send response back to client through netstack
|
||||
_, err = udpConn.WriteTo(response, clientAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to send DNS response: %v", err)
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
// queryUpstream sends a DNS query to upstream server
|
||||
func (p *DNSProxy) queryUpstream(server string, query []byte, timeout time.Duration) ([]byte, error) {
|
||||
conn, err := net.DialTimeout("udp", server, timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetDeadline(time.Now().Add(timeout))
|
||||
|
||||
if _, err := conn.Write(query); err != nil {
|
||||
return nil, err
|
||||
// queryUpstream sends a DNS query to upstream server using miekg/dns
|
||||
func (p *DNSProxy) queryUpstream(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) {
|
||||
client := &dns.Client{
|
||||
Timeout: timeout,
|
||||
}
|
||||
|
||||
response := make([]byte, 4096)
|
||||
n, err := conn.Read(response)
|
||||
response, _, err := client.Exchange(query, server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response[:n], nil
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// runPacketSender sends packets from netstack back to TUN
|
||||
@@ -314,3 +403,26 @@ func (p *DNSProxy) runPacketSender() {
|
||||
pkt.DecRef()
|
||||
}
|
||||
}
|
||||
|
||||
// AddDNSRecord adds a DNS record to the local store
|
||||
// domain should be a domain name (e.g., "example.com" or "example.com.")
|
||||
// ip should be a valid IPv4 or IPv6 address
|
||||
func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP) error {
|
||||
return p.recordStore.AddRecord(domain, ip)
|
||||
}
|
||||
|
||||
// RemoveDNSRecord removes a DNS record from the local store
|
||||
// If ip is nil, removes all records for the domain
|
||||
func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) {
|
||||
p.recordStore.RemoveRecord(domain, ip)
|
||||
}
|
||||
|
||||
// GetDNSRecords returns all IP addresses for a domain and record type
|
||||
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP {
|
||||
return p.recordStore.GetRecords(domain, recordType)
|
||||
}
|
||||
|
||||
// ClearDNSRecords removes all DNS records from the local store
|
||||
func (p *DNSProxy) ClearDNSRecords() {
|
||||
p.recordStore.Clear()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user