mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-15 23:06:38 +00:00
[client] Chase CNAMEs in local resolver to ensure musl compatibility (#5046)
This commit is contained in:
@@ -3,11 +3,15 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -43,7 +47,23 @@ type HandlerChain struct {
|
|||||||
type ResponseWriterChain struct {
|
type ResponseWriterChain struct {
|
||||||
dns.ResponseWriter
|
dns.ResponseWriter
|
||||||
origPattern string
|
origPattern string
|
||||||
|
requestID string
|
||||||
shouldContinue bool
|
shouldContinue bool
|
||||||
|
response *dns.Msg
|
||||||
|
meta map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestID returns the request ID for tracing
|
||||||
|
func (w *ResponseWriterChain) RequestID() string {
|
||||||
|
return w.requestID
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMeta sets a metadata key-value pair for logging
|
||||||
|
func (w *ResponseWriterChain) SetMeta(key, value string) {
|
||||||
|
if w.meta == nil {
|
||||||
|
w.meta = make(map[string]string)
|
||||||
|
}
|
||||||
|
w.meta[key] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
||||||
@@ -52,6 +72,7 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
|
|||||||
w.shouldContinue = true
|
w.shouldContinue = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
w.response = m
|
||||||
return w.ResponseWriter.WriteMsg(m)
|
return w.ResponseWriter.WriteMsg(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,6 +122,8 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
|||||||
|
|
||||||
pos := c.findHandlerPosition(entry)
|
pos := c.findHandlerPosition(entry)
|
||||||
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
|
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
|
||||||
|
|
||||||
|
c.logHandlers()
|
||||||
}
|
}
|
||||||
|
|
||||||
// findHandlerPosition determines where to insert a new handler based on priority and specificity
|
// findHandlerPosition determines where to insert a new handler based on priority and specificity
|
||||||
@@ -140,68 +163,109 @@ func (c *HandlerChain) removeEntry(pattern string, priority int) {
|
|||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
entry := c.handlers[i]
|
entry := c.handlers[i]
|
||||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||||
|
log.Debugf("removing handler pattern: domain=%s priority=%d", entry.OrigPattern, priority)
|
||||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||||
|
c.logHandlers()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// logHandlers logs the current handler chain state. Caller must hold the lock.
|
||||||
|
func (c *HandlerChain) logHandlers() {
|
||||||
|
if !log.IsLevelEnabled(log.TraceLevel) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("handler chain (" + strconv.Itoa(len(c.handlers)) + "):\n")
|
||||||
|
for _, h := range c.handlers {
|
||||||
|
b.WriteString(" - pattern: domain=" + h.Pattern + " original: domain=" + h.OrigPattern +
|
||||||
|
" wildcard=" + strconv.FormatBool(h.IsWildcard) +
|
||||||
|
" match_subdomain=" + strconv.FormatBool(h.MatchSubdomains) +
|
||||||
|
" priority=" + strconv.Itoa(h.Priority) + "\n")
|
||||||
|
}
|
||||||
|
log.Trace(strings.TrimSuffix(b.String(), "\n"))
|
||||||
|
}
|
||||||
|
|
||||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
qname := strings.ToLower(r.Question[0].Name)
|
startTime := time.Now()
|
||||||
|
requestID := resutil.GenerateRequestID()
|
||||||
|
logger := log.WithFields(log.Fields{
|
||||||
|
"request_id": requestID,
|
||||||
|
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||||
|
})
|
||||||
|
|
||||||
|
question := r.Question[0]
|
||||||
|
qname := strings.ToLower(question.Name)
|
||||||
|
|
||||||
c.mu.RLock()
|
c.mu.RLock()
|
||||||
handlers := slices.Clone(c.handlers)
|
handlers := slices.Clone(c.handlers)
|
||||||
c.mu.RUnlock()
|
c.mu.RUnlock()
|
||||||
|
|
||||||
if log.IsLevelEnabled(log.TraceLevel) {
|
|
||||||
var b strings.Builder
|
|
||||||
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
|
|
||||||
for _, h := range handlers {
|
|
||||||
b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n",
|
|
||||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority))
|
|
||||||
}
|
|
||||||
log.Trace(strings.TrimSuffix(b.String(), "\n"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try handlers in priority order
|
// Try handlers in priority order
|
||||||
for _, entry := range handlers {
|
for _, entry := range handlers {
|
||||||
matched := c.isHandlerMatch(qname, entry)
|
if !c.isHandlerMatch(qname, entry) {
|
||||||
|
continue
|
||||||
if matched {
|
|
||||||
log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d",
|
|
||||||
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
|
||||||
|
|
||||||
chainWriter := &ResponseWriterChain{
|
|
||||||
ResponseWriter: w,
|
|
||||||
origPattern: entry.OrigPattern,
|
|
||||||
}
|
|
||||||
entry.Handler.ServeDNS(chainWriter, r)
|
|
||||||
|
|
||||||
// If handler wants to continue, try next handler
|
|
||||||
if chainWriter.shouldContinue {
|
|
||||||
// Only log continue for non-management cache handlers to reduce noise
|
|
||||||
if entry.Priority != PriorityMgmtCache {
|
|
||||||
log.Tracef("handler requested continue to next handler for domain=%s", qname)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
handlerName := entry.OrigPattern
|
||||||
|
if s, ok := entry.Handler.(interface{ String() string }); ok {
|
||||||
|
handlerName = s.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Tracef("question: domain=%s type=%s class=%s -> handler=%s pattern=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||||
|
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass],
|
||||||
|
handlerName, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
||||||
|
|
||||||
|
chainWriter := &ResponseWriterChain{
|
||||||
|
ResponseWriter: w,
|
||||||
|
origPattern: entry.OrigPattern,
|
||||||
|
requestID: requestID,
|
||||||
|
}
|
||||||
|
entry.Handler.ServeDNS(chainWriter, r)
|
||||||
|
|
||||||
|
// If handler wants to continue, try next handler
|
||||||
|
if chainWriter.shouldContinue {
|
||||||
|
if entry.Priority != PriorityMgmtCache {
|
||||||
|
logger.Tracef("handler requested continue for domain=%s", qname)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
c.logResponse(logger, chainWriter, qname, startTime)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// No handler matched or all handlers passed
|
// No handler matched or all handlers passed
|
||||||
log.Tracef("no handler found for domain=%s", qname)
|
logger.Tracef("no handler found for domain=%s type=%s class=%s",
|
||||||
|
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||||
resp := &dns.Msg{}
|
resp := &dns.Msg{}
|
||||||
resp.SetRcode(r, dns.RcodeRefused)
|
resp.SetRcode(r, dns.RcodeRefused)
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
log.Errorf("failed to write DNS response: %v", err)
|
logger.Errorf("failed to write DNS response: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, qname string, startTime time.Time) {
|
||||||
|
if cw.response == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var meta string
|
||||||
|
for k, v := range cw.meta {
|
||||||
|
meta += " " + k + "=" + v
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
|
||||||
|
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
|
||||||
|
meta, time.Since(startTime))
|
||||||
|
}
|
||||||
|
|
||||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||||
switch {
|
switch {
|
||||||
case entry.Pattern == ".":
|
case entry.Pattern == ".":
|
||||||
|
|||||||
@@ -1,30 +1,50 @@
|
|||||||
package local
|
package local
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const externalResolutionTimeout = 4 * time.Second
|
||||||
|
|
||||||
|
type resolver interface {
|
||||||
|
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||||
|
}
|
||||||
|
|
||||||
type Resolver struct {
|
type Resolver struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
records map[dns.Question][]dns.RR
|
records map[dns.Question][]dns.RR
|
||||||
domains map[domain.Domain]struct{}
|
domains map[domain.Domain]struct{}
|
||||||
|
zones []domain.Domain
|
||||||
|
resolver resolver
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewResolver() *Resolver {
|
func NewResolver() *Resolver {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
return &Resolver{
|
return &Resolver{
|
||||||
records: make(map[dns.Question][]dns.RR),
|
records: make(map[dns.Question][]dns.RR),
|
||||||
domains: make(map[domain.Domain]struct{}),
|
domains: make(map[domain.Domain]struct{}),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -37,7 +57,18 @@ func (d *Resolver) String() string {
|
|||||||
return fmt.Sprintf("LocalResolver [%d records]", len(d.records))
|
return fmt.Sprintf("LocalResolver [%d records]", len(d.records))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Resolver) Stop() {}
|
func (d *Resolver) Stop() {
|
||||||
|
if d.cancel != nil {
|
||||||
|
d.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
maps.Clear(d.records)
|
||||||
|
maps.Clear(d.domains)
|
||||||
|
d.zones = nil
|
||||||
|
}
|
||||||
|
|
||||||
// ID returns the unique handler ID
|
// ID returns the unique handler ID
|
||||||
func (d *Resolver) ID() types.HandlerID {
|
func (d *Resolver) ID() types.HandlerID {
|
||||||
@@ -48,38 +79,47 @@ func (d *Resolver) ProbeAvailability() {}
|
|||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
logger := log.WithField("request_id", resutil.GetRequestID(w))
|
||||||
|
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
log.Debugf("received local resolver request with no question")
|
logger.Debug("received local resolver request with no question")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
question := r.Question[0]
|
question := r.Question[0]
|
||||||
question.Name = strings.ToLower(dns.Fqdn(question.Name))
|
question.Name = strings.ToLower(dns.Fqdn(question.Name))
|
||||||
|
|
||||||
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
|
|
||||||
|
|
||||||
replyMessage := &dns.Msg{}
|
replyMessage := &dns.Msg{}
|
||||||
replyMessage.SetReply(r)
|
replyMessage.SetReply(r)
|
||||||
replyMessage.RecursionAvailable = true
|
replyMessage.RecursionAvailable = true
|
||||||
|
|
||||||
// lookup all records matching the question
|
result := d.lookupRecords(logger, question)
|
||||||
records := d.lookupRecords(question)
|
replyMessage.Authoritative = !result.hasExternalData
|
||||||
if len(records) > 0 {
|
replyMessage.Answer = result.records
|
||||||
replyMessage.Rcode = dns.RcodeSuccess
|
replyMessage.Rcode = d.determineRcode(question, result)
|
||||||
replyMessage.Answer = append(replyMessage.Answer, records...)
|
|
||||||
} else {
|
|
||||||
// Check if we have any records for this domain name with different types
|
|
||||||
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
|
|
||||||
replyMessage.Rcode = dns.RcodeSuccess // NOERROR with 0 records
|
|
||||||
} else {
|
|
||||||
replyMessage.Rcode = dns.RcodeNameError // NXDOMAIN
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := w.WriteMsg(replyMessage); err != nil {
|
if err := w.WriteMsg(replyMessage); err != nil {
|
||||||
log.Warnf("failed to write the local resolver response: %v", err)
|
logger.Warnf("failed to write the local resolver response: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// determineRcode returns the appropriate DNS response code.
|
||||||
|
// Per RFC 6604, CNAME chains should return the rcode of the final target resolution,
|
||||||
|
// even if CNAME records are included in the answer.
|
||||||
|
func (d *Resolver) determineRcode(question dns.Question, result lookupResult) int {
|
||||||
|
// Use the rcode from lookup - this properly handles CNAME chains where
|
||||||
|
// the target may be NXDOMAIN or SERVFAIL even though we have CNAME records
|
||||||
|
if result.rcode != 0 {
|
||||||
|
return result.rcode
|
||||||
|
}
|
||||||
|
|
||||||
|
// No records found, but domain exists with different record types (NODATA)
|
||||||
|
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
|
||||||
|
return dns.RcodeSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
return dns.RcodeNameError
|
||||||
|
}
|
||||||
|
|
||||||
// hasRecordsForDomain checks if any records exist for the given domain name regardless of type
|
// hasRecordsForDomain checks if any records exist for the given domain name regardless of type
|
||||||
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
|
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
|
||||||
d.mu.RLock()
|
d.mu.RLock()
|
||||||
@@ -89,8 +129,33 @@ func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
|
|||||||
return exists
|
return exists
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isInManagedZone checks if the given name falls within any of our managed zones.
|
||||||
|
// This is used to avoid unnecessary external resolution for CNAME targets that
|
||||||
|
// are within zones we manage - if we don't have a record for it, it doesn't exist.
|
||||||
|
// Caller must NOT hold the lock.
|
||||||
|
func (d *Resolver) isInManagedZone(name string) bool {
|
||||||
|
d.mu.RLock()
|
||||||
|
defer d.mu.RUnlock()
|
||||||
|
|
||||||
|
name = dns.Fqdn(name)
|
||||||
|
for _, zone := range d.zones {
|
||||||
|
zoneStr := dns.Fqdn(zone.PunycodeString())
|
||||||
|
if strings.EqualFold(name, zoneStr) || strings.HasSuffix(strings.ToLower(name), strings.ToLower("."+zoneStr)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupResult contains the result of a DNS lookup operation.
|
||||||
|
type lookupResult struct {
|
||||||
|
records []dns.RR
|
||||||
|
rcode int
|
||||||
|
hasExternalData bool
|
||||||
|
}
|
||||||
|
|
||||||
// lookupRecords fetches *all* DNS records matching the first question in r.
|
// lookupRecords fetches *all* DNS records matching the first question in r.
|
||||||
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
|
func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) lookupResult {
|
||||||
d.mu.RLock()
|
d.mu.RLock()
|
||||||
records, found := d.records[question]
|
records, found := d.records[question]
|
||||||
|
|
||||||
@@ -98,10 +163,14 @@ func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
|
|||||||
d.mu.RUnlock()
|
d.mu.RUnlock()
|
||||||
// alternatively check if we have a cname
|
// alternatively check if we have a cname
|
||||||
if question.Qtype != dns.TypeCNAME {
|
if question.Qtype != dns.TypeCNAME {
|
||||||
question.Qtype = dns.TypeCNAME
|
cnameQuestion := dns.Question{
|
||||||
return d.lookupRecords(question)
|
Name: question.Name,
|
||||||
|
Qtype: dns.TypeCNAME,
|
||||||
|
Qclass: question.Qclass,
|
||||||
|
}
|
||||||
|
return d.lookupCNAMEChain(logger, cnameQuestion, question.Qtype)
|
||||||
}
|
}
|
||||||
return nil
|
return lookupResult{rcode: dns.RcodeNameError}
|
||||||
}
|
}
|
||||||
|
|
||||||
recordsCopy := slices.Clone(records)
|
recordsCopy := slices.Clone(records)
|
||||||
@@ -119,16 +188,172 @@ func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
|
|||||||
d.mu.Unlock()
|
d.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
return recordsCopy
|
return lookupResult{records: recordsCopy, rcode: dns.RcodeSuccess}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Resolver) Update(update []nbdns.SimpleRecord) {
|
// lookupCNAMEChain follows a CNAME chain and returns the CNAME records along with
|
||||||
|
// the final resolved record of the requested type. This is required for musl libc
|
||||||
|
// compatibility, which expects the full answer chain rather than just the CNAME.
|
||||||
|
func (d *Resolver) lookupCNAMEChain(logger *log.Entry, cnameQuestion dns.Question, targetType uint16) lookupResult {
|
||||||
|
const maxDepth = 8
|
||||||
|
var chain []dns.RR
|
||||||
|
|
||||||
|
for range maxDepth {
|
||||||
|
cnameRecords := d.getRecords(cnameQuestion)
|
||||||
|
if len(cnameRecords) == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
chain = append(chain, cnameRecords...)
|
||||||
|
|
||||||
|
cname, ok := cnameRecords[0].(*dns.CNAME)
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
targetName := strings.ToLower(cname.Target)
|
||||||
|
targetResult := d.resolveCNAMETarget(logger, targetName, targetType, cnameQuestion.Qclass)
|
||||||
|
|
||||||
|
// keep following chain
|
||||||
|
if targetResult.rcode == -1 {
|
||||||
|
cnameQuestion = dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: cnameQuestion.Qclass}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.buildChainResult(chain, targetResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(chain) > 0 {
|
||||||
|
return lookupResult{records: chain, rcode: dns.RcodeSuccess}
|
||||||
|
}
|
||||||
|
return lookupResult{rcode: dns.RcodeSuccess}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildChainResult combines CNAME chain records with the target resolution result.
|
||||||
|
// Per RFC 6604, the final rcode is propagated through the chain.
|
||||||
|
func (d *Resolver) buildChainResult(chain []dns.RR, target lookupResult) lookupResult {
|
||||||
|
records := chain
|
||||||
|
if len(target.records) > 0 {
|
||||||
|
records = append(records, target.records...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// preserve hasExternalData for SERVFAIL so caller knows the error came from upstream
|
||||||
|
if target.hasExternalData && target.rcode == dns.RcodeServerFailure {
|
||||||
|
return lookupResult{
|
||||||
|
records: records,
|
||||||
|
rcode: dns.RcodeServerFailure,
|
||||||
|
hasExternalData: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return lookupResult{
|
||||||
|
records: records,
|
||||||
|
rcode: target.rcode,
|
||||||
|
hasExternalData: target.hasExternalData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveCNAMETarget attempts to resolve a CNAME target name.
|
||||||
|
// Returns rcode=-1 to signal "keep following the chain".
|
||||||
|
func (d *Resolver) resolveCNAMETarget(logger *log.Entry, targetName string, targetType uint16, qclass uint16) lookupResult {
|
||||||
|
if records := d.getRecords(dns.Question{Name: targetName, Qtype: targetType, Qclass: qclass}); len(records) > 0 {
|
||||||
|
return lookupResult{records: records, rcode: dns.RcodeSuccess}
|
||||||
|
}
|
||||||
|
|
||||||
|
// another CNAME, keep following
|
||||||
|
if d.hasRecord(dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: qclass}) {
|
||||||
|
return lookupResult{rcode: -1}
|
||||||
|
}
|
||||||
|
|
||||||
|
// domain exists locally but not this record type (NODATA)
|
||||||
|
if d.hasRecordsForDomain(domain.Domain(targetName)) {
|
||||||
|
return lookupResult{rcode: dns.RcodeSuccess}
|
||||||
|
}
|
||||||
|
|
||||||
|
// in our zone but doesn't exist (NXDOMAIN)
|
||||||
|
if d.isInManagedZone(targetName) {
|
||||||
|
return lookupResult{rcode: dns.RcodeNameError}
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.resolveExternal(logger, targetName, targetType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Resolver) getRecords(q dns.Question) []dns.RR {
|
||||||
|
d.mu.RLock()
|
||||||
|
defer d.mu.RUnlock()
|
||||||
|
return d.records[q]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Resolver) hasRecord(q dns.Question) bool {
|
||||||
|
d.mu.RLock()
|
||||||
|
defer d.mu.RUnlock()
|
||||||
|
_, ok := d.records[q]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveExternal resolves a domain name using the system resolver.
|
||||||
|
// This is used to resolve CNAME targets that point outside our local zone,
|
||||||
|
// which is required for musl libc compatibility (musl expects complete answers).
|
||||||
|
func (d *Resolver) resolveExternal(logger *log.Entry, name string, qtype uint16) lookupResult {
|
||||||
|
network := resutil.NetworkForQtype(qtype)
|
||||||
|
if network == "" {
|
||||||
|
return lookupResult{rcode: dns.RcodeNotImplemented}
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver := d.resolver
|
||||||
|
if resolver == nil {
|
||||||
|
resolver = net.DefaultResolver
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(d.ctx, externalResolutionTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result := resutil.LookupIP(ctx, resolver, network, name, qtype)
|
||||||
|
if result.Err != nil {
|
||||||
|
d.logDNSError(logger, name, qtype, result.Err)
|
||||||
|
return lookupResult{rcode: result.Rcode, hasExternalData: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
return lookupResult{
|
||||||
|
records: resutil.IPsToRRs(name, result.IPs, 60),
|
||||||
|
rcode: dns.RcodeSuccess,
|
||||||
|
hasExternalData: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// logDNSError logs DNS resolution errors for debugging.
|
||||||
|
func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16, err error) {
|
||||||
|
qtypeName := dns.TypeToString[qtype]
|
||||||
|
|
||||||
|
var dnsErr *net.DNSError
|
||||||
|
if !errors.As(err, &dnsErr) {
|
||||||
|
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if dnsErr.IsNotFound {
|
||||||
|
logger.Tracef("DNS target not found: %s type %s", hostname, qtypeName)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if dnsErr.Server != "" {
|
||||||
|
logger.Debugf("DNS resolution failed for %s type %s server=%s: %v", hostname, qtypeName, dnsErr.Server, err)
|
||||||
|
} else {
|
||||||
|
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update updates the resolver with new records and zone information.
|
||||||
|
// The zones parameter specifies which DNS zones this resolver manages.
|
||||||
|
func (d *Resolver) Update(update []nbdns.SimpleRecord, zones []domain.Domain) {
|
||||||
d.mu.Lock()
|
d.mu.Lock()
|
||||||
defer d.mu.Unlock()
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
maps.Clear(d.records)
|
maps.Clear(d.records)
|
||||||
maps.Clear(d.domains)
|
maps.Clear(d.domains)
|
||||||
|
|
||||||
|
d.zones = zones
|
||||||
|
|
||||||
for _, rec := range update {
|
for _, rec := range update {
|
||||||
if err := d.registerRecord(rec); err != nil {
|
if err := d.registerRecord(rec); err != nil {
|
||||||
log.Warnf("failed to register the record (%s): %v", rec, err)
|
log.Warnf("failed to register the record (%s): %v", rec, err)
|
||||||
|
|||||||
@@ -1,8 +1,14 @@
|
|||||||
package local
|
package local
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -10,8 +16,21 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// mockResolver implements resolver for testing
|
||||||
|
type mockResolver struct {
|
||||||
|
lookupFunc func(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
|
||||||
|
if m.lookupFunc != nil {
|
||||||
|
return m.lookupFunc(ctx, network, host)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func TestLocalResolver_ServeDNS(t *testing.T) {
|
func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||||
recordA := nbdns.SimpleRecord{
|
recordA := nbdns.SimpleRecord{
|
||||||
Name: "peera.netbird.cloud.",
|
Name: "peera.netbird.cloud.",
|
||||||
@@ -110,7 +129,7 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) {
|
|||||||
update2 := []nbdns.SimpleRecord{record2}
|
update2 := []nbdns.SimpleRecord{record2}
|
||||||
|
|
||||||
// Apply first update
|
// Apply first update
|
||||||
resolver.Update(update1)
|
resolver.Update(update1, nil)
|
||||||
|
|
||||||
// Verify first update
|
// Verify first update
|
||||||
resolver.mu.RLock()
|
resolver.mu.RLock()
|
||||||
@@ -122,7 +141,7 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) {
|
|||||||
assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData)
|
assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData)
|
||||||
|
|
||||||
// Apply second update
|
// Apply second update
|
||||||
resolver.Update(update2)
|
resolver.Update(update2, nil)
|
||||||
|
|
||||||
// Verify second update
|
// Verify second update
|
||||||
resolver.mu.RLock()
|
resolver.mu.RLock()
|
||||||
@@ -154,7 +173,7 @@ func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) {
|
|||||||
update := []nbdns.SimpleRecord{record1, record2}
|
update := []nbdns.SimpleRecord{record1, record2}
|
||||||
|
|
||||||
// Apply update with both records
|
// Apply update with both records
|
||||||
resolver.Update(update)
|
resolver.Update(update, nil)
|
||||||
|
|
||||||
// Create question that matches both records
|
// Create question that matches both records
|
||||||
question := dns.Question{
|
question := dns.Question{
|
||||||
@@ -198,7 +217,7 @@ func TestLocalResolver_RecordRotation(t *testing.T) {
|
|||||||
update := []nbdns.SimpleRecord{record1, record2, record3}
|
update := []nbdns.SimpleRecord{record1, record2, record3}
|
||||||
|
|
||||||
// Apply update with all three records
|
// Apply update with all three records
|
||||||
resolver.Update(update)
|
resolver.Update(update, nil)
|
||||||
|
|
||||||
msg := new(dns.Msg).SetQuestion(recordName, recordType)
|
msg := new(dns.Msg).SetQuestion(recordName, recordType)
|
||||||
|
|
||||||
@@ -264,7 +283,7 @@ func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update resolver with the records
|
// Update resolver with the records
|
||||||
resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord})
|
resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord}, nil)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -379,7 +398,7 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update resolver with both records
|
// Update resolver with both records
|
||||||
resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord})
|
resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord}, nil)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -476,6 +495,20 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
|
|||||||
// with 0 records instead of NXDOMAIN
|
// with 0 records instead of NXDOMAIN
|
||||||
func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
|
func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
|
||||||
resolver := NewResolver()
|
resolver := NewResolver()
|
||||||
|
// Mock external resolver for CNAME target resolution
|
||||||
|
resolver.resolver = &mockResolver{
|
||||||
|
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||||
|
if host == "target.example.com." {
|
||||||
|
if network == "ip4" {
|
||||||
|
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||||
|
}
|
||||||
|
if network == "ip6" {
|
||||||
|
return []netip.Addr{netip.MustParseAddr("2606:2800:220:1:248:1893:25c8:1946")}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, &net.DNSError{IsNotFound: true, Name: host}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
recordA := nbdns.SimpleRecord{
|
recordA := nbdns.SimpleRecord{
|
||||||
Name: "example.netbird.cloud.",
|
Name: "example.netbird.cloud.",
|
||||||
@@ -493,7 +526,7 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
|
|||||||
RData: "target.example.com.",
|
RData: "target.example.com.",
|
||||||
}
|
}
|
||||||
|
|
||||||
resolver.Update([]nbdns.SimpleRecord{recordA, recordCNAME})
|
resolver.Update([]nbdns.SimpleRecord{recordA, recordCNAME}, nil)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -582,3 +615,555 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestLocalResolver_CNAMEChainResolution tests comprehensive CNAME chain following
|
||||||
|
func TestLocalResolver_CNAMEChainResolution(t *testing.T) {
|
||||||
|
t.Run("simple internal CNAME chain", func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{
|
||||||
|
{Name: "alias.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."},
|
||||||
|
{Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1"},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
msg := new(dns.Msg).SetQuestion("alias.example.com.", dns.TypeA)
|
||||||
|
var resp *dns.Msg
|
||||||
|
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||||
|
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
require.Len(t, resp.Answer, 2)
|
||||||
|
|
||||||
|
cname, ok := resp.Answer[0].(*dns.CNAME)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "target.example.com.", cname.Target)
|
||||||
|
|
||||||
|
a, ok := resp.Answer[1].(*dns.A)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "192.168.1.1", a.A.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multi-hop CNAME chain", func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{
|
||||||
|
{Name: "hop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop2.test."},
|
||||||
|
{Name: "hop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop3.test."},
|
||||||
|
{Name: "hop3.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA)
|
||||||
|
var resp *dns.Msg
|
||||||
|
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||||
|
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
require.Len(t, resp.Answer, 3)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CNAME to non-existent internal target returns only CNAME", func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{
|
||||||
|
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.test."},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
|
||||||
|
var resp *dns.Msg
|
||||||
|
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||||
|
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
require.Len(t, resp.Answer, 1)
|
||||||
|
_, ok := resp.Answer[0].(*dns.CNAME)
|
||||||
|
assert.True(t, ok)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLocalResolver_CNAMEMaxDepth tests the maximum depth limit for CNAME chains
|
||||||
|
func TestLocalResolver_CNAMEMaxDepth(t *testing.T) {
|
||||||
|
t.Run("chain at max depth resolves", func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
var records []nbdns.SimpleRecord
|
||||||
|
// Create chain of 7 CNAMEs (under max of 8)
|
||||||
|
for i := 1; i <= 7; i++ {
|
||||||
|
records = append(records, nbdns.SimpleRecord{
|
||||||
|
Name: fmt.Sprintf("hop%d.test.", i),
|
||||||
|
Type: int(dns.TypeCNAME),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: fmt.Sprintf("hop%d.test.", i+1),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
records = append(records, nbdns.SimpleRecord{
|
||||||
|
Name: "hop8.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10",
|
||||||
|
})
|
||||||
|
|
||||||
|
resolver.Update(records, nil)
|
||||||
|
|
||||||
|
msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA)
|
||||||
|
var resp *dns.Msg
|
||||||
|
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||||
|
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
require.Len(t, resp.Answer, 8)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("chain exceeding max depth stops", func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
var records []nbdns.SimpleRecord
|
||||||
|
// Create chain of 10 CNAMEs (exceeds max of 8)
|
||||||
|
for i := 1; i <= 10; i++ {
|
||||||
|
records = append(records, nbdns.SimpleRecord{
|
||||||
|
Name: fmt.Sprintf("deep%d.test.", i),
|
||||||
|
Type: int(dns.TypeCNAME),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: 300,
|
||||||
|
RData: fmt.Sprintf("deep%d.test.", i+1),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
records = append(records, nbdns.SimpleRecord{
|
||||||
|
Name: "deep11.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10",
|
||||||
|
})
|
||||||
|
|
||||||
|
resolver.Update(records, nil)
|
||||||
|
|
||||||
|
msg := new(dns.Msg).SetQuestion("deep1.test.", dns.TypeA)
|
||||||
|
var resp *dns.Msg
|
||||||
|
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||||
|
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
// Should NOT have the final A record (chain too deep)
|
||||||
|
assert.LessOrEqual(t, len(resp.Answer), 8)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("circular CNAME is protected by max depth", func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{
|
||||||
|
{Name: "loop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop2.test."},
|
||||||
|
{Name: "loop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop1.test."},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
msg := new(dns.Msg).SetQuestion("loop1.test.", dns.TypeA)
|
||||||
|
var resp *dns.Msg
|
||||||
|
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||||
|
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
assert.LessOrEqual(t, len(resp.Answer), 8)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLocalResolver_ExternalCNAMEResolution tests CNAME resolution to external domains
|
||||||
|
func TestLocalResolver_ExternalCNAMEResolution(t *testing.T) {
|
||||||
|
t.Run("CNAME to external domain resolves via external resolver", func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
resolver.resolver = &mockResolver{
|
||||||
|
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||||
|
if host == "external.example.com." && network == "ip4" {
|
||||||
|
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{
|
||||||
|
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
|
||||||
|
var resp *dns.Msg
|
||||||
|
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||||
|
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
require.Len(t, resp.Answer, 2, "Should have CNAME + A record")
|
||||||
|
|
||||||
|
cname, ok := resp.Answer[0].(*dns.CNAME)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "external.example.com.", cname.Target)
|
||||||
|
|
||||||
|
a, ok := resp.Answer[1].(*dns.A)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "93.184.216.34", a.A.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CNAME to external domain resolves IPv6", func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
resolver.resolver = &mockResolver{
|
||||||
|
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||||
|
if host == "external.example.com." && network == "ip6" {
|
||||||
|
return []netip.Addr{netip.MustParseAddr("2606:2800:220:1:248:1893:25c8:1946")}, nil
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{
|
||||||
|
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA)
|
||||||
|
var resp *dns.Msg
|
||||||
|
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||||
|
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
require.Len(t, resp.Answer, 2, "Should have CNAME + AAAA record")
|
||||||
|
|
||||||
|
cname, ok := resp.Answer[0].(*dns.CNAME)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "external.example.com.", cname.Target)
|
||||||
|
|
||||||
|
aaaa, ok := resp.Answer[1].(*dns.AAAA)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "2606:2800:220:1:248:1893:25c8:1946", aaaa.AAAA.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("concurrent external resolution", func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
resolver.resolver = &mockResolver{
|
||||||
|
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||||
|
if host == "external.example.com." && network == "ip4" {
|
||||||
|
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{
|
||||||
|
{Name: "concurrent.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
results := make([]*dns.Msg, 10)
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(idx int) {
|
||||||
|
defer wg.Done()
|
||||||
|
msg := new(dns.Msg).SetQuestion("concurrent.test.", dns.TypeA)
|
||||||
|
var resp *dns.Msg
|
||||||
|
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||||
|
results[idx] = resp
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
for i, resp := range results {
|
||||||
|
require.NotNil(t, resp, "Response %d should not be nil", i)
|
||||||
|
require.Len(t, resp.Answer, 2, "Response %d should have CNAME + A", i)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLocalResolver_ZoneManagement tests zone-aware CNAME resolution
|
||||||
|
func TestLocalResolver_ZoneManagement(t *testing.T) {
|
||||||
|
t.Run("Update sets zones correctly", func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
|
||||||
|
zones := []domain.Domain{"example.com", "test.local"}
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{
|
||||||
|
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||||
|
}, zones)
|
||||||
|
|
||||||
|
assert.True(t, resolver.isInManagedZone("host.example.com."))
|
||||||
|
assert.True(t, resolver.isInManagedZone("other.example.com."))
|
||||||
|
assert.True(t, resolver.isInManagedZone("sub.test.local."))
|
||||||
|
assert.False(t, resolver.isInManagedZone("external.com."))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("isInManagedZone case insensitive", func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
resolver.Update(nil, []domain.Domain{"Example.COM"})
|
||||||
|
|
||||||
|
assert.True(t, resolver.isInManagedZone("host.example.com."))
|
||||||
|
assert.True(t, resolver.isInManagedZone("HOST.EXAMPLE.COM."))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Update clears zones", func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
resolver.Update(nil, []domain.Domain{"example.com"})
|
||||||
|
assert.True(t, resolver.isInManagedZone("host.example.com."))
|
||||||
|
|
||||||
|
resolver.Update(nil, nil)
|
||||||
|
assert.False(t, resolver.isInManagedZone("host.example.com."))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLocalResolver_CNAMEZoneAwareResolution tests CNAME resolution with zone awareness
|
||||||
|
func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) {
|
||||||
|
t.Run("CNAME target in managed zone returns NXDOMAIN per RFC 6604", func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{
|
||||||
|
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.myzone.test."},
|
||||||
|
}, []domain.Domain{"myzone.test"})
|
||||||
|
|
||||||
|
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA)
|
||||||
|
var resp *dns.Msg
|
||||||
|
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||||
|
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
assert.Equal(t, dns.RcodeNameError, resp.Rcode, "Should return NXDOMAIN")
|
||||||
|
require.Len(t, resp.Answer, 1, "Should include CNAME in answer")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CNAME to external domain skips zone check", func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
resolver.resolver = &mockResolver{
|
||||||
|
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||||
|
if host == "external.other.com." && network == "ip4" {
|
||||||
|
return []netip.Addr{netip.MustParseAddr("203.0.113.1")}, nil
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{
|
||||||
|
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.other.com."},
|
||||||
|
}, []domain.Domain{"myzone.test"})
|
||||||
|
|
||||||
|
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA)
|
||||||
|
var resp *dns.Msg
|
||||||
|
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||||
|
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
require.Len(t, resp.Answer, 2, "Should have CNAME + A from external resolution")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("CNAME target exists with different type returns NODATA not NXDOMAIN", func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
// CNAME points to target that has A but no AAAA - query for AAAA should be NODATA
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{
|
||||||
|
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.myzone.test."},
|
||||||
|
{Name: "target.myzone.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1"},
|
||||||
|
}, []domain.Domain{"myzone.test"})
|
||||||
|
|
||||||
|
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeAAAA)
|
||||||
|
var resp *dns.Msg
|
||||||
|
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||||
|
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success), not NXDOMAIN")
|
||||||
|
require.Len(t, resp.Answer, 1, "Should have only CNAME, no AAAA")
|
||||||
|
_, ok := resp.Answer[0].(*dns.CNAME)
|
||||||
|
assert.True(t, ok, "Answer should be CNAME record")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("external CNAME target exists but no AAAA records (NODATA)", func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
resolver.resolver = &mockResolver{
|
||||||
|
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||||
|
if host == "external.example.com." {
|
||||||
|
if network == "ip6" {
|
||||||
|
// No AAAA records
|
||||||
|
return nil, &net.DNSError{IsNotFound: true, Name: host}
|
||||||
|
}
|
||||||
|
if network == "ip4" {
|
||||||
|
// But A records exist - domain exists
|
||||||
|
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, &net.DNSError{IsNotFound: true, Name: host}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{
|
||||||
|
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA)
|
||||||
|
var resp *dns.Msg
|
||||||
|
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||||
|
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success), not NXDOMAIN")
|
||||||
|
require.Len(t, resp.Answer, 1, "Should have only CNAME")
|
||||||
|
_, ok := resp.Answer[0].(*dns.CNAME)
|
||||||
|
assert.True(t, ok, "Answer should be CNAME record")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Table-driven test for all external resolution outcomes
|
||||||
|
externalCases := []struct {
|
||||||
|
name string
|
||||||
|
lookupFunc func(context.Context, string, string) ([]netip.Addr, error)
|
||||||
|
expectedRcode int
|
||||||
|
expectedAnswer int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "external NXDOMAIN (both A and AAAA not found)",
|
||||||
|
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||||
|
return nil, &net.DNSError{IsNotFound: true, Name: host}
|
||||||
|
},
|
||||||
|
expectedRcode: dns.RcodeNameError,
|
||||||
|
expectedAnswer: 1, // CNAME only
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "external SERVFAIL (temporary error)",
|
||||||
|
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||||
|
return nil, &net.DNSError{IsTemporary: true, Name: host}
|
||||||
|
},
|
||||||
|
expectedRcode: dns.RcodeServerFailure,
|
||||||
|
expectedAnswer: 1, // CNAME only
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "external SERVFAIL (timeout)",
|
||||||
|
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||||
|
return nil, &net.DNSError{IsTimeout: true, Name: host}
|
||||||
|
},
|
||||||
|
expectedRcode: dns.RcodeServerFailure,
|
||||||
|
expectedAnswer: 1, // CNAME only
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "external SERVFAIL (generic error)",
|
||||||
|
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||||
|
return nil, fmt.Errorf("connection refused")
|
||||||
|
},
|
||||||
|
expectedRcode: dns.RcodeServerFailure,
|
||||||
|
expectedAnswer: 1, // CNAME only
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "external success with IPs",
|
||||||
|
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||||
|
if network == "ip4" {
|
||||||
|
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||||
|
}
|
||||||
|
return nil, &net.DNSError{IsNotFound: true, Name: host}
|
||||||
|
},
|
||||||
|
expectedRcode: dns.RcodeSuccess,
|
||||||
|
expectedAnswer: 2, // CNAME + A
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range externalCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
resolver.resolver = &mockResolver{lookupFunc: tc.lookupFunc}
|
||||||
|
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{
|
||||||
|
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
|
||||||
|
var resp *dns.Msg
|
||||||
|
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||||
|
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
assert.Equal(t, tc.expectedRcode, resp.Rcode, "rcode mismatch")
|
||||||
|
assert.Len(t, resp.Answer, tc.expectedAnswer, "answer count mismatch")
|
||||||
|
if tc.expectedAnswer > 0 {
|
||||||
|
_, ok := resp.Answer[0].(*dns.CNAME)
|
||||||
|
assert.True(t, ok, "first answer should be CNAME")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLocalResolver_AuthoritativeFlag tests the AA flag behavior
|
||||||
|
func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
|
||||||
|
t.Run("direct record lookup is authoritative", func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{
|
||||||
|
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||||
|
}, []domain.Domain{"example.com"})
|
||||||
|
|
||||||
|
msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA)
|
||||||
|
var resp *dns.Msg
|
||||||
|
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||||
|
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
assert.True(t, resp.Authoritative)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("external resolution is not authoritative", func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
resolver.resolver = &mockResolver{
|
||||||
|
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
|
||||||
|
if host == "external.example.com." && network == "ip4" {
|
||||||
|
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{
|
||||||
|
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
|
||||||
|
var resp *dns.Msg
|
||||||
|
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||||
|
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
require.Len(t, resp.Answer, 2)
|
||||||
|
assert.False(t, resp.Authoritative)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLocalResolver_Stop tests cleanup on Stop
|
||||||
|
func TestLocalResolver_Stop(t *testing.T) {
|
||||||
|
t.Run("Stop clears all state", func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{
|
||||||
|
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||||
|
}, []domain.Domain{"example.com"})
|
||||||
|
|
||||||
|
resolver.Stop()
|
||||||
|
|
||||||
|
msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA)
|
||||||
|
var resp *dns.Msg
|
||||||
|
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
|
||||||
|
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
assert.Len(t, resp.Answer, 0)
|
||||||
|
assert.False(t, resolver.isInManagedZone("host.example.com."))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{
|
||||||
|
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
|
||||||
|
}, []domain.Domain{"example.com"})
|
||||||
|
|
||||||
|
resolver.Stop()
|
||||||
|
resolver.Stop()
|
||||||
|
resolver.Stop()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
|
||||||
|
resolver := NewResolver()
|
||||||
|
|
||||||
|
lookupStarted := make(chan struct{})
|
||||||
|
lookupCtxCanceled := make(chan struct{})
|
||||||
|
|
||||||
|
resolver.resolver = &mockResolver{
|
||||||
|
lookupFunc: func(ctx context.Context, network, host string) ([]netip.Addr, error) {
|
||||||
|
close(lookupStarted)
|
||||||
|
<-ctx.Done()
|
||||||
|
close(lookupCtxCanceled)
|
||||||
|
return nil, ctx.Err()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver.Update([]nbdns.SimpleRecord{
|
||||||
|
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
|
||||||
|
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}, msg)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-lookupStarted
|
||||||
|
resolver.Stop()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-lookupCtxCanceled:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("external lookup context was not canceled")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("ServeDNS did not return after Stop")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
197
client/internal/dns/resutil/resolve.go
Normal file
197
client/internal/dns/resutil/resolve.go
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
// Package resutil provides shared DNS resolution utilities
|
||||||
|
package resutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateRequestID creates a random 8-character hex string for request tracing.
|
||||||
|
func GenerateRequestID() string {
|
||||||
|
bytes := make([]byte, 4)
|
||||||
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
|
log.Errorf("generate request ID: %v", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPsToRRs converts a slice of IP addresses to DNS resource records.
|
||||||
|
// IPv4 addresses become A records, IPv6 addresses become AAAA records.
|
||||||
|
func IPsToRRs(name string, ips []netip.Addr, ttl uint32) []dns.RR {
|
||||||
|
var result []dns.RR
|
||||||
|
|
||||||
|
for _, ip := range ips {
|
||||||
|
if ip.Is6() {
|
||||||
|
result = append(result, &dns.AAAA{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: name,
|
||||||
|
Rrtype: dns.TypeAAAA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: ttl,
|
||||||
|
},
|
||||||
|
AAAA: ip.AsSlice(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
result = append(result, &dns.A{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: name,
|
||||||
|
Rrtype: dns.TypeA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: ttl,
|
||||||
|
},
|
||||||
|
A: ip.AsSlice(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// NetworkForQtype returns the network string ("ip4" or "ip6") for a DNS query type.
|
||||||
|
// Returns empty string for unsupported types.
|
||||||
|
func NetworkForQtype(qtype uint16) string {
|
||||||
|
switch qtype {
|
||||||
|
case dns.TypeA:
|
||||||
|
return "ip4"
|
||||||
|
case dns.TypeAAAA:
|
||||||
|
return "ip6"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type resolver interface {
|
||||||
|
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// chainedWriter is implemented by ResponseWriters that carry request metadata
|
||||||
|
type chainedWriter interface {
|
||||||
|
RequestID() string
|
||||||
|
SetMeta(key, value string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRequestID extracts a request ID from the ResponseWriter if available,
|
||||||
|
// otherwise generates a new one.
|
||||||
|
func GetRequestID(w dns.ResponseWriter) string {
|
||||||
|
if cw, ok := w.(chainedWriter); ok {
|
||||||
|
if id := cw.RequestID(); id != "" {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return GenerateRequestID()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMeta sets metadata on the ResponseWriter if it supports it.
|
||||||
|
func SetMeta(w dns.ResponseWriter, key, value string) {
|
||||||
|
if cw, ok := w.(chainedWriter); ok {
|
||||||
|
cw.SetMeta(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LookupResult contains the result of an external DNS lookup
|
||||||
|
type LookupResult struct {
|
||||||
|
IPs []netip.Addr
|
||||||
|
Rcode int
|
||||||
|
Err error // Original error for caller's logging needs
|
||||||
|
}
|
||||||
|
|
||||||
|
// LookupIP performs a DNS lookup and determines the appropriate rcode.
|
||||||
|
func LookupIP(ctx context.Context, r resolver, network, host string, qtype uint16) LookupResult {
|
||||||
|
ips, err := r.LookupNetIP(ctx, network, host)
|
||||||
|
if err != nil {
|
||||||
|
return LookupResult{
|
||||||
|
Rcode: getRcodeForError(ctx, r, host, qtype, err),
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
|
||||||
|
for i, ip := range ips {
|
||||||
|
ips[i] = ip.Unmap()
|
||||||
|
}
|
||||||
|
|
||||||
|
return LookupResult{
|
||||||
|
IPs: ips,
|
||||||
|
Rcode: dns.RcodeSuccess,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRcodeForError(ctx context.Context, r resolver, host string, qtype uint16, err error) int {
|
||||||
|
var dnsErr *net.DNSError
|
||||||
|
if !errors.As(err, &dnsErr) {
|
||||||
|
return dns.RcodeServerFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
if dnsErr.IsNotFound {
|
||||||
|
return getRcodeForNotFound(ctx, r, host, qtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
return dns.RcodeServerFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRcodeForNotFound distinguishes between NXDOMAIN (domain doesn't exist) and NODATA
|
||||||
|
// (domain exists but no records of requested type) by checking the opposite record type.
|
||||||
|
//
|
||||||
|
// musl libc (the reason we need this distinction) only queries A/AAAA pairs in getaddrinfo,
|
||||||
|
// so checking the opposite A/AAAA type is sufficient. Other record types (MX, TXT, etc.)
|
||||||
|
// are not queried by musl and don't need this handling.
|
||||||
|
func getRcodeForNotFound(ctx context.Context, r resolver, domain string, originalQtype uint16) int {
|
||||||
|
// Try querying for a different record type to see if the domain exists
|
||||||
|
// If the original query was for AAAA, try A. If it was for A, try AAAA.
|
||||||
|
// This helps distinguish between NXDOMAIN and NODATA.
|
||||||
|
var alternativeNetwork string
|
||||||
|
switch originalQtype {
|
||||||
|
case dns.TypeAAAA:
|
||||||
|
alternativeNetwork = "ip4"
|
||||||
|
case dns.TypeA:
|
||||||
|
alternativeNetwork = "ip6"
|
||||||
|
default:
|
||||||
|
return dns.RcodeNameError
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
|
||||||
|
var dnsErr *net.DNSError
|
||||||
|
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
|
||||||
|
// Alternative query also returned not found - domain truly doesn't exist
|
||||||
|
return dns.RcodeNameError
|
||||||
|
}
|
||||||
|
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
|
||||||
|
return dns.RcodeSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
// Alternative query succeeded - domain exists but has no records of this type
|
||||||
|
return dns.RcodeSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatAnswers formats DNS resource records for logging.
|
||||||
|
func FormatAnswers(answers []dns.RR) string {
|
||||||
|
if len(answers) == 0 {
|
||||||
|
return "[]"
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := make([]string, 0, len(answers))
|
||||||
|
for _, rr := range answers {
|
||||||
|
switch r := rr.(type) {
|
||||||
|
case *dns.A:
|
||||||
|
parts = append(parts, r.A.String())
|
||||||
|
case *dns.AAAA:
|
||||||
|
parts = append(parts, r.AAAA.String())
|
||||||
|
case *dns.CNAME:
|
||||||
|
parts = append(parts, "CNAME:"+r.Target)
|
||||||
|
case *dns.PTR:
|
||||||
|
parts = append(parts, "PTR:"+r.Ptr)
|
||||||
|
default:
|
||||||
|
parts = append(parts, dns.TypeToString[rr.Header().Rrtype])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "[" + strings.Join(parts, ", ") + "]"
|
||||||
|
}
|
||||||
@@ -485,7 +485,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
localMuxUpdates, localRecords, localZones, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("local handler updater: %w", err)
|
return fmt.Errorf("local handler updater: %w", err)
|
||||||
}
|
}
|
||||||
@@ -499,7 +499,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
s.updateMux(muxUpdates)
|
s.updateMux(muxUpdates)
|
||||||
|
|
||||||
// register local records
|
// register local records
|
||||||
s.localResolver.Update(localRecords)
|
s.localResolver.Update(localRecords, localZones)
|
||||||
|
|
||||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||||
|
|
||||||
@@ -659,9 +659,10 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
|||||||
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback)
|
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) {
|
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, []domain.Domain, error) {
|
||||||
var muxUpdates []handlerWrapper
|
var muxUpdates []handlerWrapper
|
||||||
var localRecords []nbdns.SimpleRecord
|
var localRecords []nbdns.SimpleRecord
|
||||||
|
var zones []domain.Domain
|
||||||
|
|
||||||
for _, customZone := range customZones {
|
for _, customZone := range customZones {
|
||||||
if len(customZone.Records) == 0 {
|
if len(customZone.Records) == 0 {
|
||||||
@@ -675,6 +676,8 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
|||||||
priority: PriorityLocal,
|
priority: PriorityLocal,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
zones = append(zones, domain.Domain(customZone.Domain))
|
||||||
|
|
||||||
for _, record := range customZone.Records {
|
for _, record := range customZone.Records {
|
||||||
if record.Class != nbdns.DefaultClass {
|
if record.Class != nbdns.DefaultClass {
|
||||||
log.Warnf("received an invalid class type: %s", record.Class)
|
log.Warnf("received an invalid class type: %s", record.Class)
|
||||||
@@ -685,7 +688,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return muxUpdates, localRecords, nil
|
return muxUpdates, localRecords, zones, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) {
|
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) {
|
||||||
|
|||||||
@@ -385,7 +385,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||||
dnsServer.localResolver.Update(testCase.initLocalRecords)
|
dnsServer.localResolver.Update(testCase.initLocalRecords, nil)
|
||||||
dnsServer.updateSerial = testCase.initSerial
|
dnsServer.updateSerial = testCase.initSerial
|
||||||
|
|
||||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||||
@@ -511,7 +511,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
|
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
|
||||||
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}})
|
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, nil)
|
||||||
dnsServer.updateSerial = 0
|
dnsServer.updateSerial = 0
|
||||||
|
|
||||||
nameServers := []nbdns.NameServer{
|
nameServers := []nbdns.NameServer{
|
||||||
@@ -2013,7 +2013,7 @@ func TestLocalResolverPriorityInServer(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
|
localMuxUpdates, _, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
upstreamMuxUpdates, err := server.buildUpstreamHandlerUpdate(config.NameServerGroups)
|
upstreamMuxUpdates, err := server.buildUpstreamHandlerUpdate(config.NameServerGroups)
|
||||||
@@ -2074,7 +2074,7 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
|
localMuxUpdates, _, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Len(t, localMuxUpdates, 1)
|
assert.Len(t, localMuxUpdates, 1)
|
||||||
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
|
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -21,6 +20,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
@@ -113,10 +113,7 @@ func (u *upstreamResolverBase) Stop() {
|
|||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
requestID := GenerateRequestID()
|
logger := log.WithField("request_id", resutil.GetRequestID(w))
|
||||||
logger := log.WithField("request_id", requestID)
|
|
||||||
|
|
||||||
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
|
||||||
|
|
||||||
u.prepareRequest(r)
|
u.prepareRequest(r)
|
||||||
|
|
||||||
@@ -202,11 +199,14 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
|
|||||||
|
|
||||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
|
||||||
u.successCount.Add(1)
|
u.successCount.Add(1)
|
||||||
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain)
|
|
||||||
|
resutil.SetMeta(w, "upstream", upstream.String())
|
||||||
|
|
||||||
if err := w.WriteMsg(rm); err != nil {
|
if err := w.WriteMsg(rm); err != nil {
|
||||||
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
|
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -414,16 +414,6 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
|||||||
return rm, t, nil
|
return rm, t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateRequestID() string {
|
|
||||||
bytes := make([]byte, 4)
|
|
||||||
_, err := rand.Read(bytes)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to generate request ID: %v", err)
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return hex.EncodeToString(bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
|
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
|
||||||
func FormatPeerStatus(peerState *peer.State) string {
|
func FormatPeerStatus(peerState *peer.State) string {
|
||||||
isConnected := peerState.ConnStatus == peer.StatusConnected
|
isConnected := peerState.ConnStatus == peer.StatusConnected
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
@@ -189,29 +190,22 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
|
|||||||
return nberrors.FormatErrorOrNil(result)
|
return nberrors.FormatErrorOrNil(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
|
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
|
||||||
if len(query.Question) == 0 {
|
if len(query.Question) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
question := query.Question[0]
|
question := query.Question[0]
|
||||||
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
|
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
|
||||||
question.Name, question.Qtype, question.Qclass)
|
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||||
|
|
||||||
domain := strings.ToLower(question.Name)
|
domain := strings.ToLower(question.Name)
|
||||||
|
|
||||||
resp := query.SetReply(query)
|
resp := query.SetReply(query)
|
||||||
var network string
|
network := resutil.NetworkForQtype(question.Qtype)
|
||||||
switch question.Qtype {
|
if network == "" {
|
||||||
case dns.TypeA:
|
|
||||||
network = "ip4"
|
|
||||||
case dns.TypeAAAA:
|
|
||||||
network = "ip6"
|
|
||||||
default:
|
|
||||||
// TODO: Handle other types
|
|
||||||
|
|
||||||
resp.Rcode = dns.RcodeNotImplemented
|
resp.Rcode = dns.RcodeNotImplemented
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
log.Errorf("failed to write DNS response: %v", err)
|
logger.Errorf("failed to write DNS response: %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -221,33 +215,35 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
|||||||
if mostSpecificResId == "" {
|
if mostSpecificResId == "" {
|
||||||
resp.Rcode = dns.RcodeRefused
|
resp.Rcode = dns.RcodeRefused
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
log.Errorf("failed to write DNS response: %v", err)
|
logger.Errorf("failed to write DNS response: %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
|
|
||||||
if err != nil {
|
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
|
||||||
f.handleDNSError(ctx, w, question, resp, domain, err)
|
if result.Err != nil {
|
||||||
|
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
|
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
|
||||||
for i, ip := range ips {
|
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
|
||||||
ips[i] = ip.Unmap()
|
f.cache.set(domain, question.Qtype, result.IPs)
|
||||||
}
|
|
||||||
|
|
||||||
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
|
||||||
f.addIPsToResponse(resp, domain, ips)
|
|
||||||
f.cache.set(domain, question.Qtype, ips)
|
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
resp := f.handleDNSQuery(w, query)
|
startTime := time.Now()
|
||||||
|
logger := log.WithFields(log.Fields{
|
||||||
|
"request_id": resutil.GenerateRequestID(),
|
||||||
|
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||||
|
})
|
||||||
|
|
||||||
|
resp := f.handleDNSQuery(logger, w, query)
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -265,19 +261,33 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
log.Errorf("failed to write DNS response: %v", err)
|
logger.Errorf("failed to write DNS response: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||||
|
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
resp := f.handleDNSQuery(w, query)
|
startTime := time.Now()
|
||||||
|
logger := log.WithFields(log.Fields{
|
||||||
|
"request_id": resutil.GenerateRequestID(),
|
||||||
|
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||||
|
})
|
||||||
|
|
||||||
|
resp := f.handleDNSQuery(logger, w, query)
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
log.Errorf("failed to write DNS response: %v", err)
|
logger.Errorf("failed to write DNS response: %v", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||||
|
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
||||||
@@ -315,140 +325,64 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// setResponseCodeForNotFound determines and sets the appropriate response code when IsNotFound is true
|
|
||||||
// It distinguishes between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of requested type)
|
|
||||||
//
|
|
||||||
// LIMITATION: This function only checks A and AAAA record types to determine domain existence.
|
|
||||||
// If a domain has only other record types (MX, TXT, CNAME, etc.) but no A/AAAA records,
|
|
||||||
// it may incorrectly return NXDOMAIN instead of NODATA. This is acceptable since the forwarder
|
|
||||||
// only handles A/AAAA queries and returns NOTIMP for other types.
|
|
||||||
func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns.Msg, domain string, originalQtype uint16) {
|
|
||||||
// Try querying for a different record type to see if the domain exists
|
|
||||||
// If the original query was for AAAA, try A. If it was for A, try AAAA.
|
|
||||||
// This helps distinguish between NXDOMAIN and NODATA.
|
|
||||||
var alternativeNetwork string
|
|
||||||
switch originalQtype {
|
|
||||||
case dns.TypeAAAA:
|
|
||||||
alternativeNetwork = "ip4"
|
|
||||||
case dns.TypeA:
|
|
||||||
alternativeNetwork = "ip6"
|
|
||||||
default:
|
|
||||||
resp.Rcode = dns.RcodeNameError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := f.resolver.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
|
|
||||||
var dnsErr *net.DNSError
|
|
||||||
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
|
|
||||||
// Alternative query also returned not found - domain truly doesn't exist
|
|
||||||
resp.Rcode = dns.RcodeNameError
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
|
|
||||||
resp.Rcode = dns.RcodeSuccess
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Alternative query succeeded - domain exists but has no records of this type
|
|
||||||
resp.Rcode = dns.RcodeSuccess
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response.
|
// handleDNSError processes DNS lookup errors and sends an appropriate error response.
|
||||||
func (f *DNSForwarder) handleDNSError(
|
func (f *DNSForwarder) handleDNSError(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
logger *log.Entry,
|
||||||
w dns.ResponseWriter,
|
w dns.ResponseWriter,
|
||||||
question dns.Question,
|
question dns.Question,
|
||||||
resp *dns.Msg,
|
resp *dns.Msg,
|
||||||
domain string,
|
domain string,
|
||||||
err error,
|
result resutil.LookupResult,
|
||||||
) {
|
) {
|
||||||
// Default to SERVFAIL; override below when appropriate.
|
|
||||||
resp.Rcode = dns.RcodeServerFailure
|
|
||||||
|
|
||||||
qType := question.Qtype
|
qType := question.Qtype
|
||||||
qTypeName := dns.TypeToString[qType]
|
qTypeName := dns.TypeToString[qType]
|
||||||
|
|
||||||
// Prefer typed DNS errors; fall back to generic logging otherwise.
|
resp.Rcode = result.Rcode
|
||||||
var dnsErr *net.DNSError
|
|
||||||
if !errors.As(err, &dnsErr) {
|
|
||||||
log.Warnf(errResolveFailed, domain, err)
|
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
|
||||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// NotFound: set NXDOMAIN / appropriate code via helper.
|
// NotFound: cache negative result and respond
|
||||||
if dnsErr.IsNotFound {
|
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
|
||||||
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
|
||||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
f.cache.set(domain, question.Qtype, nil)
|
f.cache.set(domain, question.Qtype, nil)
|
||||||
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||||
|
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Upstream failed but we might have a cached answer—serve it if present.
|
// Upstream failed but we might have a cached answer—serve it if present.
|
||||||
if ips, ok := f.cache.get(domain, qType); ok {
|
if ips, ok := f.cache.get(domain, qType); ok {
|
||||||
if len(ips) > 0 {
|
if len(ips) > 0 {
|
||||||
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||||
f.addIPsToResponse(resp, domain, ips)
|
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
|
||||||
resp.Rcode = dns.RcodeSuccess
|
resp.Rcode = dns.RcodeSuccess
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||||
log.Errorf("failed to write cached DNS response: %v", writeErr)
|
logger.Errorf("failed to write cached DNS response: %v", writeErr)
|
||||||
}
|
|
||||||
} else { // send NXDOMAIN / appropriate code if cache is empty
|
|
||||||
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
|
||||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cached negative result - re-verify NXDOMAIN vs NODATA
|
||||||
|
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
|
||||||
|
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
|
||||||
|
resp.Rcode = verifyResult.Rcode
|
||||||
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||||
|
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// No cache. Log with or without the server field for more context.
|
// No cache or verification failed. Log with or without the server field for more context.
|
||||||
if dnsErr.Server != "" {
|
var dnsErr *net.DNSError
|
||||||
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
|
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
|
||||||
|
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
||||||
} else {
|
} else {
|
||||||
log.Warnf(errResolveFailed, domain, err)
|
logger.Warnf(errResolveFailed, domain, result.Err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write final failure response.
|
// Write final failure response.
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records
|
|
||||||
func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []netip.Addr) {
|
|
||||||
for _, ip := range ips {
|
|
||||||
var respRecord dns.RR
|
|
||||||
if ip.Is6() {
|
|
||||||
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
|
|
||||||
rr := dns.AAAA{
|
|
||||||
AAAA: ip.AsSlice(),
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: domain,
|
|
||||||
Rrtype: dns.TypeAAAA,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: f.ttl,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
respRecord = &rr
|
|
||||||
} else {
|
|
||||||
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
|
|
||||||
rr := dns.A{
|
|
||||||
A: ip.AsSlice(),
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: domain,
|
|
||||||
Rrtype: dns.TypeA,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: f.ttl,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
respRecord = &rr
|
|
||||||
}
|
|
||||||
resp.Answer = append(resp.Answer, respRecord)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -317,7 +318,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
|||||||
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(mockWriter, query)
|
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||||
|
|
||||||
if tt.shouldResolve {
|
if tt.shouldResolve {
|
||||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||||
@@ -465,7 +466,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
|||||||
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery)
|
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
|
||||||
|
|
||||||
// Verify response
|
// Verify response
|
||||||
if tt.shouldResolve {
|
if tt.shouldResolve {
|
||||||
@@ -527,7 +528,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
|||||||
query.SetQuestion("example.com.", dns.TypeA)
|
query.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(mockWriter, query)
|
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||||
|
|
||||||
// Verify response contains all IPs
|
// Verify response contains all IPs
|
||||||
require.NotNil(t, resp)
|
require.NotNil(t, resp)
|
||||||
@@ -604,7 +605,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = forwarder.handleDNSQuery(mockWriter, query)
|
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||||
|
|
||||||
// Check the response written to the writer
|
// Check the response written to the writer
|
||||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||||
@@ -674,7 +675,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
|||||||
q1 := &dns.Msg{}
|
q1 := &dns.Msg{}
|
||||||
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||||
w1 := &test.MockResponseWriter{}
|
w1 := &test.MockResponseWriter{}
|
||||||
resp1 := forwarder.handleDNSQuery(w1, q1)
|
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
||||||
require.NotNil(t, resp1)
|
require.NotNil(t, resp1)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||||
require.Len(t, resp1.Answer, 1)
|
require.Len(t, resp1.Answer, 1)
|
||||||
@@ -684,7 +685,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
|||||||
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||||
var writtenResp *dns.Msg
|
var writtenResp *dns.Msg
|
||||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||||
_ = forwarder.handleDNSQuery(w2, q2)
|
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
||||||
|
|
||||||
require.NotNil(t, writtenResp, "expected response to be written")
|
require.NotNil(t, writtenResp, "expected response to be written")
|
||||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||||
@@ -714,7 +715,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
|||||||
q1 := &dns.Msg{}
|
q1 := &dns.Msg{}
|
||||||
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
||||||
w1 := &test.MockResponseWriter{}
|
w1 := &test.MockResponseWriter{}
|
||||||
resp1 := forwarder.handleDNSQuery(w1, q1)
|
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
||||||
require.NotNil(t, resp1)
|
require.NotNil(t, resp1)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||||
require.Len(t, resp1.Answer, 1)
|
require.Len(t, resp1.Answer, 1)
|
||||||
@@ -728,7 +729,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
|||||||
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
||||||
var writtenResp *dns.Msg
|
var writtenResp *dns.Msg
|
||||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||||
_ = forwarder.handleDNSQuery(w2, q2)
|
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
||||||
|
|
||||||
require.NotNil(t, writtenResp)
|
require.NotNil(t, writtenResp)
|
||||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||||
@@ -783,7 +784,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
|||||||
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(mockWriter, query)
|
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||||
|
|
||||||
require.NotNil(t, resp)
|
require.NotNil(t, resp)
|
||||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
@@ -904,7 +905,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := forwarder.handleDNSQuery(mockWriter, query)
|
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||||
|
|
||||||
// If a response was returned, it means it should be written (happens in wrapper functions)
|
// If a response was returned, it means it should be written (happens in wrapper functions)
|
||||||
if resp != nil && writtenResp == nil {
|
if resp != nil && writtenResp == nil {
|
||||||
@@ -937,7 +938,7 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
resp := forwarder.handleDNSQuery(mockWriter, query)
|
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
||||||
|
|
||||||
assert.Nil(t, resp, "Should return nil for empty query")
|
assert.Nil(t, resp, "Should return nil for empty query")
|
||||||
assert.False(t, writeCalled, "Should not write response for empty query")
|
assert.False(t, writeCalled, "Should not write response for empty query")
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||||
@@ -219,14 +220,14 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
|
|||||||
|
|
||||||
// ServeDNS implements the dns.Handler interface
|
// ServeDNS implements the dns.Handler interface
|
||||||
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
requestID := nbdns.GenerateRequestID()
|
logger := log.WithFields(log.Fields{
|
||||||
logger := log.WithField("request_id", requestID)
|
"request_id": resutil.GetRequestID(w),
|
||||||
|
"dns_id": fmt.Sprintf("%04x", r.Id),
|
||||||
|
})
|
||||||
|
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger.Tracef("received DNS request for domain=%s type=%v class=%v",
|
|
||||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
|
||||||
|
|
||||||
// pass if non A/AAAA query
|
// pass if non A/AAAA query
|
||||||
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
|
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
|
||||||
@@ -280,15 +281,10 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var answer []dns.RR
|
resutil.SetMeta(w, "peer", peerKey)
|
||||||
if reply != nil {
|
|
||||||
answer = reply.Answer
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
|
|
||||||
|
|
||||||
reply.Id = r.Id
|
reply.Id = r.Id
|
||||||
if err := d.writeMsg(w, reply); err != nil {
|
if err := d.writeMsg(w, reply, logger); err != nil {
|
||||||
logger.Errorf("failed writing DNS response: %v", err)
|
logger.Errorf("failed writing DNS response: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -324,7 +320,7 @@ func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
|
|||||||
return peerAllowedIP, nil
|
return peerAllowedIP, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) error {
|
||||||
if r == nil {
|
if r == nil {
|
||||||
return fmt.Errorf("received nil DNS message")
|
return fmt.Errorf("received nil DNS message")
|
||||||
}
|
}
|
||||||
@@ -350,14 +346,14 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
|||||||
case *dns.A:
|
case *dns.A:
|
||||||
addr, ok := netip.AddrFromSlice(rr.A)
|
addr, ok := netip.AddrFromSlice(rr.A)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
|
logger.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ip = addr
|
ip = addr
|
||||||
case *dns.AAAA:
|
case *dns.AAAA:
|
||||||
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
addr, ok := netip.AddrFromSlice(rr.AAAA)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
|
logger.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ip = addr
|
ip = addr
|
||||||
@@ -370,11 +366,11 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(newPrefixes) > 0 {
|
if len(newPrefixes) > 0 {
|
||||||
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
|
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes, logger); err != nil {
|
||||||
log.Errorf("failed to update domain prefixes: %v", err)
|
logger.Errorf("failed to update domain prefixes: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
d.replaceIPsInDNSResponse(r, newPrefixes)
|
d.replaceIPsInDNSResponse(r, newPrefixes, logger)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -386,22 +382,22 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// logPrefixChanges handles the logging for prefix changes
|
// logPrefixChanges handles the logging for prefix changes
|
||||||
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) {
|
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix, logger *log.Entry) {
|
||||||
if len(toAdd) > 0 {
|
if len(toAdd) > 0 {
|
||||||
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
logger.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||||
resolvedDomain.SafeString(),
|
resolvedDomain.SafeString(),
|
||||||
originalDomain.SafeString(),
|
originalDomain.SafeString(),
|
||||||
toAdd)
|
toAdd)
|
||||||
}
|
}
|
||||||
if len(toRemove) > 0 && !d.route.KeepRoute {
|
if len(toRemove) > 0 && !d.route.KeepRoute {
|
||||||
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
logger.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
|
||||||
resolvedDomain.SafeString(),
|
resolvedDomain.SafeString(),
|
||||||
originalDomain.SafeString(),
|
originalDomain.SafeString(),
|
||||||
toRemove)
|
toRemove)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
|
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix, logger *log.Entry) error {
|
||||||
d.mu.Lock()
|
d.mu.Lock()
|
||||||
defer d.mu.Unlock()
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
@@ -418,9 +414,9 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
|
|||||||
realIP := prefix.Addr()
|
realIP := prefix.Addr()
|
||||||
if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil {
|
if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil {
|
||||||
dnatMappings[fakeIP] = realIP
|
dnatMappings[fakeIP] = realIP
|
||||||
log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
|
logger.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
|
||||||
} else {
|
} else {
|
||||||
log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err)
|
logger.Errorf("failed to allocate fake IP for %s: %v", realIP, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -432,7 +428,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
d.addDNATMappings(dnatMappings)
|
d.addDNATMappings(dnatMappings, logger)
|
||||||
|
|
||||||
if !d.route.KeepRoute {
|
if !d.route.KeepRoute {
|
||||||
// Remove old prefixes
|
// Remove old prefixes
|
||||||
@@ -448,7 +444,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
d.removeDNATMappings(toRemove)
|
d.removeDNATMappings(toRemove, logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update domain prefixes using resolved domain as key - store real IPs
|
// Update domain prefixes using resolved domain as key - store real IPs
|
||||||
@@ -463,14 +459,14 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
|
|||||||
// Store real IPs for status (user-facing), not fake IPs
|
// Store real IPs for status (user-facing), not fake IPs
|
||||||
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
|
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
|
||||||
|
|
||||||
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove)
|
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove, logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes
|
// removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes
|
||||||
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
|
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix, logger *log.Entry) {
|
||||||
if len(realPrefixes) == 0 {
|
if len(realPrefixes) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -484,9 +480,9 @@ func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
|
|||||||
realIP := prefix.Addr()
|
realIP := prefix.Addr()
|
||||||
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||||
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
|
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
|
||||||
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err)
|
logger.Errorf("failed to remove DNAT mapping for %s: %v", fakeIP, err)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP)
|
logger.Debugf("removed DNAT mapping: %s -> %s", fakeIP, realIP)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -502,7 +498,7 @@ func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addDNATMappings adds DNAT mappings to the firewall
|
// addDNATMappings adds DNAT mappings to the firewall
|
||||||
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
|
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr, logger *log.Entry) {
|
||||||
if len(mappings) == 0 {
|
if len(mappings) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -514,9 +510,9 @@ func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
|
|||||||
|
|
||||||
for fakeIP, realIP := range mappings {
|
for fakeIP, realIP := range mappings {
|
||||||
if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil {
|
if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil {
|
||||||
log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
|
logger.Errorf("failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP)
|
logger.Debugf("added DNAT mapping: %s -> %s", fakeIP, realIP)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -528,12 +524,12 @@ func (d *DnsInterceptor) cleanupDNATMappings() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, prefixes := range d.interceptedDomains {
|
for _, prefixes := range d.interceptedDomains {
|
||||||
d.removeDNATMappings(prefixes)
|
d.removeDNATMappings(prefixes, log.NewEntry(log.StandardLogger()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response
|
// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response
|
||||||
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) {
|
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix, logger *log.Entry) {
|
||||||
if _, ok := d.internalDnatFw(); !ok {
|
if _, ok := d.internalDnatFw(); !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -549,7 +545,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []
|
|||||||
|
|
||||||
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||||
rr.A = fakeIP.AsSlice()
|
rr.A = fakeIP.AsSlice()
|
||||||
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
case *dns.AAAA:
|
case *dns.AAAA:
|
||||||
@@ -560,7 +556,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []
|
|||||||
|
|
||||||
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
|
||||||
rr.AAAA = fakeIP.AsSlice()
|
rr.AAAA = fakeIP.AsSlice()
|
||||||
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user