mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 17:26:40 +00:00
[client] Fix dns forwarder handling of requested record types (#3615)
This commit is contained in:
@@ -4,6 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -12,6 +14,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
||||||
|
const upstreamTimeout = 15 * time.Second
|
||||||
|
|
||||||
type DNSForwarder struct {
|
type DNSForwarder struct {
|
||||||
listenAddress string
|
listenAddress string
|
||||||
@@ -79,41 +82,72 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
domain := question.Name
|
domain := question.Name
|
||||||
|
|
||||||
resp := query.SetReply(query)
|
resp := query.SetReply(query)
|
||||||
|
var network string
|
||||||
|
switch question.Qtype {
|
||||||
|
case dns.TypeA:
|
||||||
|
network = "ip4"
|
||||||
|
case dns.TypeAAAA:
|
||||||
|
network = "ip6"
|
||||||
|
default:
|
||||||
|
// TODO: Handle other types
|
||||||
|
|
||||||
ips, err := net.LookupIP(domain)
|
resp.Rcode = dns.RcodeNotImplemented
|
||||||
if err != nil {
|
|
||||||
var dnsErr *net.DNSError
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case errors.As(err, &dnsErr):
|
|
||||||
resp.Rcode = dns.RcodeServerFailure
|
|
||||||
if dnsErr.IsNotFound {
|
|
||||||
// Pass through NXDOMAIN
|
|
||||||
resp.Rcode = dns.RcodeNameError
|
|
||||||
}
|
|
||||||
|
|
||||||
if dnsErr.Server != "" {
|
|
||||||
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err)
|
|
||||||
} else {
|
|
||||||
log.Warnf(errResolveFailed, domain, err)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
resp.Rcode = dns.RcodeServerFailure
|
|
||||||
log.Warnf(errResolveFailed, domain, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
log.Errorf("failed to write failure DNS response: %v", err)
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||||
|
defer cancel()
|
||||||
|
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain)
|
||||||
|
if err != nil {
|
||||||
|
f.handleDNSError(w, resp, domain, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
f.addIPsToResponse(resp, domain, ips)
|
||||||
|
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
||||||
|
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domain string, err error) {
|
||||||
|
var dnsErr *net.DNSError
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case errors.As(err, &dnsErr):
|
||||||
|
resp.Rcode = dns.RcodeServerFailure
|
||||||
|
if dnsErr.IsNotFound {
|
||||||
|
// Pass through NXDOMAIN
|
||||||
|
resp.Rcode = dns.RcodeNameError
|
||||||
|
}
|
||||||
|
|
||||||
|
if dnsErr.Server != "" {
|
||||||
|
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err)
|
||||||
|
} else {
|
||||||
|
log.Warnf(errResolveFailed, domain, err)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
resp.Rcode = dns.RcodeServerFailure
|
||||||
|
log.Warnf(errResolveFailed, domain, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write failure DNS response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records
|
||||||
|
func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []netip.Addr) {
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
var respRecord dns.RR
|
var respRecord dns.RR
|
||||||
if ip.To4() == nil {
|
if ip.Is6() {
|
||||||
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
|
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
|
||||||
rr := dns.AAAA{
|
rr := dns.AAAA{
|
||||||
AAAA: ip,
|
AAAA: ip.AsSlice(),
|
||||||
Hdr: dns.RR_Header{
|
Hdr: dns.RR_Header{
|
||||||
Name: domain,
|
Name: domain,
|
||||||
Rrtype: dns.TypeAAAA,
|
Rrtype: dns.TypeAAAA,
|
||||||
@@ -125,7 +159,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
} else {
|
} else {
|
||||||
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
|
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
|
||||||
rr := dns.A{
|
rr := dns.A{
|
||||||
A: ip,
|
A: ip.AsSlice(),
|
||||||
Hdr: dns.RR_Header{
|
Hdr: dns.RR_Header{
|
||||||
Name: domain,
|
Name: domain,
|
||||||
Rrtype: dns.TypeA,
|
Rrtype: dns.TypeA,
|
||||||
@@ -137,10 +171,6 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
}
|
}
|
||||||
resp.Answer = append(resp.Answer, respRecord)
|
resp.Answer = append(resp.Answer, respRecord)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
|
||||||
log.Errorf("failed to write DNS response: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterDomains returns a list of normalized domains
|
// filterDomains returns a list of normalized domains
|
||||||
|
|||||||
Reference in New Issue
Block a user