diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index ae31ffac6..bc153479c 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -4,6 +4,8 @@ import ( "context" "errors" "net" + "net/netip" + "time" "github.com/miekg/dns" log "github.com/sirupsen/logrus" @@ -12,6 +14,7 @@ import ( ) const errResolveFailed = "failed to resolve query for domain=%s: %v" +const upstreamTimeout = 15 * time.Second type DNSForwarder struct { listenAddress string @@ -79,41 +82,72 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { domain := question.Name resp := query.SetReply(query) + var network string + switch question.Qtype { + case dns.TypeA: + network = "ip4" + case dns.TypeAAAA: + network = "ip6" + default: + // TODO: Handle other types - ips, err := net.LookupIP(domain) - if err != nil { - var dnsErr *net.DNSError - - switch { - case errors.As(err, &dnsErr): - resp.Rcode = dns.RcodeServerFailure - if dnsErr.IsNotFound { - // Pass through NXDOMAIN - resp.Rcode = dns.RcodeNameError - } - - if dnsErr.Server != "" { - log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err) - } else { - log.Warnf(errResolveFailed, domain, err) - } - default: - resp.Rcode = dns.RcodeServerFailure - log.Warnf(errResolveFailed, domain, err) - } - + resp.Rcode = dns.RcodeNotImplemented if err := w.WriteMsg(resp); err != nil { - log.Errorf("failed to write failure DNS response: %v", err) + log.Errorf("failed to write DNS response: %v", err) } return } + ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) + defer cancel() + ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain) + if err != nil { + f.handleDNSError(w, resp, domain, err) + return + } + + f.addIPsToResponse(resp, domain, ips) + + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write DNS response: %v", err) + } +} + +// handleDNSError processes DNS lookup errors and sends an appropriate error response +func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domain string, err error) { + var dnsErr *net.DNSError + + switch { + case errors.As(err, &dnsErr): + resp.Rcode = dns.RcodeServerFailure + if dnsErr.IsNotFound { + // Pass through NXDOMAIN + resp.Rcode = dns.RcodeNameError + } + + if dnsErr.Server != "" { + log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err) + } else { + log.Warnf(errResolveFailed, domain, err) + } + default: + resp.Rcode = dns.RcodeServerFailure + log.Warnf(errResolveFailed, domain, err) + } + + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write failure DNS response: %v", err) + } +} + +// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records +func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []netip.Addr) { for _, ip := range ips { var respRecord dns.RR - if ip.To4() == nil { + if ip.Is6() { log.Tracef("resolved domain=%s to IPv6=%s", domain, ip) rr := dns.AAAA{ - AAAA: ip, + AAAA: ip.AsSlice(), Hdr: dns.RR_Header{ Name: domain, Rrtype: dns.TypeAAAA, @@ -125,7 +159,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { } else { log.Tracef("resolved domain=%s to IPv4=%s", domain, ip) rr := dns.A{ - A: ip, + A: ip.AsSlice(), Hdr: dns.RR_Header{ Name: domain, Rrtype: dns.TypeA, @@ -137,10 +171,6 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { } resp.Answer = append(resp.Answer, respRecord) } - - if err := w.WriteMsg(resp); err != nil { - log.Errorf("failed to write DNS response: %v", err) - } } // filterDomains returns a list of normalized domains