diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index 6ae7488..4734b2c 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -11,6 +11,7 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/util" "github.com/fosrl/olm/device" + "github.com/miekg/dns" "golang.zx2c4.com/wireguard/tun" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" @@ -35,11 +36,12 @@ const ( // DNSProxy implements a DNS proxy using gvisor netstack type DNSProxy struct { - stack *stack.Stack - ep *channel.Endpoint - proxyIP netip.Addr - mtu int - tunDevice tun.Device // Direct reference to underlying TUN device for responses + stack *stack.Stack + ep *channel.Endpoint + proxyIP netip.Addr + mtu int + tunDevice tun.Device // Direct reference to underlying TUN device for responses + recordStore *DNSRecordStore // Local DNS records ctx context.Context cancel context.CancelFunc @@ -56,11 +58,12 @@ func NewDNSProxy(tunDevice tun.Device, mtu int) (*DNSProxy, error) { ctx, cancel := context.WithCancel(context.Background()) proxy := &DNSProxy{ - proxyIP: proxyIP, - mtu: mtu, - tunDevice: tunDevice, - ctx: ctx, - cancel: cancel, + proxyIP: proxyIP, + mtu: mtu, + tunDevice: tunDevice, + recordStore: NewDNSRecordStore(), + ctx: ctx, + cancel: cancel, } // Create gvisor netstack @@ -212,12 +215,112 @@ func (p *DNSProxy) runDNSListener() { copy(query, buf[:n]) // Handle query in background - go p.forwardDNSQuery(udpConn, query, remoteAddr) + go p.handleDNSQuery(udpConn, query, remoteAddr) } } -// forwardDNSQuery forwards a DNS query to upstream DNS server -func (p *DNSProxy) forwardDNSQuery(udpConn *gonet.UDPConn, query []byte, clientAddr net.Addr) { +// handleDNSQuery processes a DNS query, checking local records first, then forwarding upstream +func (p *DNSProxy) handleDNSQuery(udpConn *gonet.UDPConn, queryData []byte, clientAddr net.Addr) { + // Parse the DNS query + msg := new(dns.Msg) + if err := msg.Unpack(queryData); err != nil { + logger.Error("Failed to parse DNS query: %v", err) + return + } + + if len(msg.Question) == 0 { + logger.Debug("DNS query has no questions") + return + } + + question := msg.Question[0] + logger.Debug("DNS query for %s (type %s)", question.Name, dns.TypeToString[question.Qtype]) + + // Check if we have local records for this query + var response *dns.Msg + if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA { + response = p.checkLocalRecords(msg, question) + } + + // If no local records, forward to upstream + if response == nil { + logger.Debug("No local record for %s, forwarding upstream", question.Name) + response = p.forwardToUpstream(msg) + } + + if response == nil { + logger.Error("Failed to get DNS response for %s", question.Name) + return + } + + // Pack and send response + responseData, err := response.Pack() + if err != nil { + logger.Error("Failed to pack DNS response: %v", err) + return + } + + _, err = udpConn.WriteTo(responseData, clientAddr) + if err != nil { + logger.Error("Failed to send DNS response: %v", err) + } +} + +// checkLocalRecords checks if we have local records for the query +func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns.Msg { + var recordType RecordType + if question.Qtype == dns.TypeA { + recordType = RecordTypeA + } else if question.Qtype == dns.TypeAAAA { + recordType = RecordTypeAAAA + } else { + return nil + } + + ips := p.recordStore.GetRecords(question.Name, recordType) + if len(ips) == 0 { + return nil + } + + logger.Debug("Found %d local record(s) for %s", len(ips), question.Name) + + // Create response message + response := new(dns.Msg) + response.SetReply(query) + response.Authoritative = true + + // Add answer records + for _, ip := range ips { + var rr dns.RR + if question.Qtype == dns.TypeA { + rr = &dns.A{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, // 5 minutes + }, + A: ip.To4(), + } + } else { // TypeAAAA + rr = &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: question.Name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 300, // 5 minutes + }, + AAAA: ip.To16(), + } + } + response.Answer = append(response.Answer, rr) + } + + return response +} + +// forwardToUpstream forwards a DNS query to upstream DNS servers +func (p *DNSProxy) forwardToUpstream(query *dns.Msg) *dns.Msg { // Try primary DNS server response, err := p.queryUpstream(UpstreamDNS1, query, 2*time.Second) if err != nil { @@ -226,38 +329,24 @@ func (p *DNSProxy) forwardDNSQuery(udpConn *gonet.UDPConn, query []byte, clientA response, err = p.queryUpstream(UpstreamDNS2, query, 2*time.Second) if err != nil { logger.Error("Both DNS servers failed: %v", err) - return + return nil } } - - // Send response back to client through netstack - _, err = udpConn.WriteTo(response, clientAddr) - if err != nil { - logger.Error("Failed to send DNS response: %v", err) - } + return response } -// queryUpstream sends a DNS query to upstream server -func (p *DNSProxy) queryUpstream(server string, query []byte, timeout time.Duration) ([]byte, error) { - conn, err := net.DialTimeout("udp", server, timeout) - if err != nil { - return nil, err - } - defer conn.Close() - - conn.SetDeadline(time.Now().Add(timeout)) - - if _, err := conn.Write(query); err != nil { - return nil, err +// queryUpstream sends a DNS query to upstream server using miekg/dns +func (p *DNSProxy) queryUpstream(server string, query *dns.Msg, timeout time.Duration) (*dns.Msg, error) { + client := &dns.Client{ + Timeout: timeout, } - response := make([]byte, 4096) - n, err := conn.Read(response) + response, _, err := client.Exchange(query, server) if err != nil { return nil, err } - return response[:n], nil + return response, nil } // runPacketSender sends packets from netstack back to TUN @@ -314,3 +403,26 @@ func (p *DNSProxy) runPacketSender() { pkt.DecRef() } } + +// AddDNSRecord adds a DNS record to the local store +// domain should be a domain name (e.g., "example.com" or "example.com.") +// ip should be a valid IPv4 or IPv6 address +func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP) error { + return p.recordStore.AddRecord(domain, ip) +} + +// RemoveDNSRecord removes a DNS record from the local store +// If ip is nil, removes all records for the domain +func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) { + p.recordStore.RemoveRecord(domain, ip) +} + +// GetDNSRecords returns all IP addresses for a domain and record type +func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP { + return p.recordStore.GetRecords(domain, recordType) +} + +// ClearDNSRecords removes all DNS records from the local store +func (p *DNSProxy) ClearDNSRecords() { + p.recordStore.Clear() +} diff --git a/dns/dns_records.go b/dns/dns_records.go new file mode 100644 index 0000000..8d57d68 --- /dev/null +++ b/dns/dns_records.go @@ -0,0 +1,166 @@ +package dns + +import ( + "net" + "sync" + + "github.com/miekg/dns" +) + +// RecordType represents the type of DNS record +type RecordType uint16 + +const ( + RecordTypeA RecordType = RecordType(dns.TypeA) + RecordTypeAAAA RecordType = RecordType(dns.TypeAAAA) +) + +// DNSRecordStore manages local DNS records for A and AAAA queries +type DNSRecordStore struct { + mu sync.RWMutex + aRecords map[string][]net.IP // domain -> list of IPv4 addresses + aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses +} + +// NewDNSRecordStore creates a new DNS record store +func NewDNSRecordStore() *DNSRecordStore { + return &DNSRecordStore{ + aRecords: make(map[string][]net.IP), + aaaaRecords: make(map[string][]net.IP), + } +} + +// AddRecord adds a DNS record mapping (A or AAAA) +// domain should be in FQDN format (e.g., "example.com.") +// ip should be a valid IPv4 or IPv6 address +func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Ensure domain ends with a dot (FQDN format) + if len(domain) == 0 || domain[len(domain)-1] != '.' { + domain = domain + "." + } + + // Normalize domain to lowercase + domain = dns.Fqdn(domain) + + if ip.To4() != nil { + // IPv4 address + s.aRecords[domain] = append(s.aRecords[domain], ip) + } else if ip.To16() != nil { + // IPv6 address + s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip) + } else { + return &net.ParseError{Type: "IP address", Text: ip.String()} + } + + return nil +} + +// RemoveRecord removes a specific DNS record mapping +// If ip is nil, removes all records for the domain +func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) { + s.mu.Lock() + defer s.mu.Unlock() + + // Ensure domain ends with a dot (FQDN format) + if len(domain) == 0 || domain[len(domain)-1] != '.' { + domain = domain + "." + } + + // Normalize domain to lowercase + domain = dns.Fqdn(domain) + + if ip == nil { + // Remove all records for this domain + delete(s.aRecords, domain) + delete(s.aaaaRecords, domain) + return + } + + if ip.To4() != nil { + // Remove specific IPv4 address + if ips, ok := s.aRecords[domain]; ok { + s.aRecords[domain] = removeIP(ips, ip) + if len(s.aRecords[domain]) == 0 { + delete(s.aRecords, domain) + } + } + } else if ip.To16() != nil { + // Remove specific IPv6 address + if ips, ok := s.aaaaRecords[domain]; ok { + s.aaaaRecords[domain] = removeIP(ips, ip) + if len(s.aaaaRecords[domain]) == 0 { + delete(s.aaaaRecords, domain) + } + } + } +} + +// GetRecords returns all IP addresses for a domain and record type +func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP { + s.mu.RLock() + defer s.mu.RUnlock() + + // Normalize domain to lowercase FQDN + domain = dns.Fqdn(domain) + + var records []net.IP + switch recordType { + case RecordTypeA: + if ips, ok := s.aRecords[domain]; ok { + // Return a copy to prevent external modifications + records = make([]net.IP, len(ips)) + copy(records, ips) + } + case RecordTypeAAAA: + if ips, ok := s.aaaaRecords[domain]; ok { + // Return a copy to prevent external modifications + records = make([]net.IP, len(ips)) + copy(records, ips) + } + } + + return records +} + +// HasRecord checks if a domain has any records of the specified type +func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool { + s.mu.RLock() + defer s.mu.RUnlock() + + // Normalize domain to lowercase FQDN + domain = dns.Fqdn(domain) + + switch recordType { + case RecordTypeA: + _, ok := s.aRecords[domain] + return ok + case RecordTypeAAAA: + _, ok := s.aaaaRecords[domain] + return ok + } + + return false +} + +// Clear removes all records from the store +func (s *DNSRecordStore) Clear() { + s.mu.Lock() + defer s.mu.Unlock() + + s.aRecords = make(map[string][]net.IP) + s.aaaaRecords = make(map[string][]net.IP) +} + +// removeIP is a helper function to remove a specific IP from a slice +func removeIP(ips []net.IP, toRemove net.IP) []net.IP { + result := make([]net.IP, 0, len(ips)) + for _, ip := range ips { + if !ip.Equal(toRemove) { + result = append(result, ip) + } + } + return result +} diff --git a/dns/example_usage.go b/dns/example_usage.go new file mode 100644 index 0000000..0a38b97 --- /dev/null +++ b/dns/example_usage.go @@ -0,0 +1,53 @@ +package dns + +// Example usage of DNS record management (not compiled, just for reference) +/* + +import ( + "net" + "github.com/fosrl/olm/dns" +) + +func exampleUsage() { + // Assuming you have a DNSProxy instance + var proxy *dns.DNSProxy + + // Add an A record for example.com pointing to 192.168.1.100 + ip := net.ParseIP("192.168.1.100") + err := proxy.AddDNSRecord("example.com", ip) + if err != nil { + // Handle error + } + + // Add multiple A records for the same domain (round-robin) + proxy.AddDNSRecord("example.com", net.ParseIP("192.168.1.101")) + proxy.AddDNSRecord("example.com", net.ParseIP("192.168.1.102")) + + // Add an AAAA record (IPv6) + ipv6 := net.ParseIP("2001:db8::1") + proxy.AddDNSRecord("example.com", ipv6) + + // Query records + aRecords := proxy.GetDNSRecords("example.com", dns.RecordTypeA) + // Returns: [192.168.1.100, 192.168.1.101, 192.168.1.102] + + aaaaRecords := proxy.GetDNSRecords("example.com", dns.RecordTypeAAAA) + // Returns: [2001:db8::1] + + // Remove a specific record + proxy.RemoveDNSRecord("example.com", net.ParseIP("192.168.1.101")) + + // Remove all records for a domain + proxy.RemoveDNSRecord("example.com", nil) + + // Clear all DNS records + proxy.ClearDNSRecords() +} + +// How it works: +// 1. When a DNS query arrives, the proxy first checks its local record store +// 2. If a matching A or AAAA record exists locally, it returns that immediately +// 3. If no local record exists, it forwards the query to upstream DNS (8.8.8.8 or 8.8.4.4) +// 4. All other DNS record types (MX, CNAME, TXT, etc.) are always forwarded upstream + +*/ diff --git a/go.mod b/go.mod index e32b1d2..a5fc99c 100644 --- a/go.mod +++ b/go.mod @@ -16,11 +16,15 @@ require ( require ( github.com/google/btree v1.1.3 // indirect + github.com/miekg/dns v1.1.68 // indirect github.com/vishvananda/netns v0.0.5 // indirect golang.org/x/crypto v0.44.0 // indirect golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect + golang.org/x/mod v0.30.0 // indirect golang.org/x/net v0.47.0 // indirect + golang.org/x/sync v0.18.0 // indirect golang.org/x/time v0.12.0 // indirect + golang.org/x/tools v0.39.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect ) diff --git a/go.sum b/go.sum index 46054fa..c439800 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA= +github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps= github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= @@ -14,14 +16,20 @@ golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU= golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc= golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0= golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0= +golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= +golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= +golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= +golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+ZbWg+4sHnLp52d5yiIPUxMBSt4X9A= diff --git a/olm/example_extension.go.template b/olm/example_extension.go.template deleted file mode 100644 index 44604f7..0000000 --- a/olm/example_extension.go.template +++ /dev/null @@ -1,111 +0,0 @@ -package olm - -// This file demonstrates how to add additional virtual services using the FilteredDevice infrastructure -// Copy and modify this template to add new services - -import ( - "context" - "net/netip" - "sync" - - "github.com/fosrl/newt/logger" - "golang.zx2c4.com/wireguard/tun" -) - -// Example: Simple echo server on 10.30.30.50:7777 - -const ( - EchoProxyIP = "10.30.30.50" - EchoProxyPort = 7777 -) - -// EchoProxy implements a simple echo server -type EchoProxy struct { - proxyIP netip.Addr - tunDevice tun.Device - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup -} - -// NewEchoProxy creates a new echo proxy instance -func NewEchoProxy(tunDevice tun.Device) (*EchoProxy, error) { - proxyIP := netip.MustParseAddr(EchoProxyIP) - ctx, cancel := context.WithCancel(context.Background()) - - return &EchoProxy{ - proxyIP: proxyIP, - tunDevice: tunDevice, - ctx: ctx, - cancel: cancel, - }, nil -} - -// Start registers the proxy with the filter -func (e *EchoProxy) Start(filter *FilteredDevice) error { - filter.AddRule(e.proxyIP, e.handlePacket) - logger.Info("Echo proxy started on %s:%d", EchoProxyIP, EchoProxyPort) - return nil -} - -// Stop unregisters the proxy -func (e *EchoProxy) Stop(filter *FilteredDevice) { - if filter != nil { - filter.RemoveRule(e.proxyIP) - } - e.cancel() - e.wg.Wait() - logger.Info("Echo proxy stopped") -} - -// handlePacket processes packets destined for the echo server -func (e *EchoProxy) handlePacket(packet []byte) bool { - // Quick validation - if len(packet) < 20 { - return false - } - - // Check protocol (UDP) - proto, ok := GetProtocol(packet) - if !ok || proto != 17 { - return false - } - - // Check port - port, ok := GetDestPort(packet) - if !ok || port != EchoProxyPort { - return false - } - - // For a real implementation, you would: - // 1. Parse the UDP packet - // 2. Extract the payload - // 3. Create a response packet with swapped src/dest - // 4. Write response back to TUN device - - logger.Debug("Echo proxy received packet (would echo back)") - - // Return true to drop packet from normal WireGuard path - return true -} - -// Example integration in olm.go: -// -// var echoProxy *EchoProxy -// -// // During tunnel setup (after creating filteredDev): -// echoProxy, err = NewEchoProxy(tdev) -// if err != nil { -// logger.Error("Failed to create echo proxy: %v", err) -// return -// } -// if err := echoProxy.Start(filteredDev); err != nil { -// logger.Error("Failed to start echo proxy: %v", err) -// return -// } -// -// // During tunnel teardown: -// if echoProxy != nil { -// echoProxy.Stop(filteredDev) -// echoProxy = nil -// } diff --git a/olm/olm.go b/olm/olm.go index bc6f828..ac28a7b 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -435,11 +435,13 @@ func StartTunnel(config TunnelConfig) { dnsProxy, err = dns.NewDNSProxy(tdev, config.MTU) if err != nil { logger.Error("Failed to create DNS proxy: %v", err) - return } if err := dnsProxy.Start(middleDev); err != nil { logger.Error("Failed to start DNS proxy: %v", err) - return + } + ip := net.ParseIP("192.168.1.100") + if dnsProxy.AddDNSRecord("example.com", ip); err != nil { + logger.Error("Failed to add DNS record: %v", err) } // fileUAPI, err := func() (*os.File, error) {