[client] Chase CNAMEs in local resolver to ensure musl compatibility (#5046)

This commit is contained in:
Viktor Liu
2026-01-12 19:35:38 +08:00
committed by GitHub
parent 614e7d5b90
commit 394ad19507
10 changed files with 1267 additions and 272 deletions

View File

@@ -3,11 +3,15 @@ package dns
import ( import (
"fmt" "fmt"
"slices" "slices"
"strconv"
"strings" "strings"
"sync" "sync"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
) )
const ( const (
@@ -43,7 +47,23 @@ type HandlerChain struct {
type ResponseWriterChain struct { type ResponseWriterChain struct {
dns.ResponseWriter dns.ResponseWriter
origPattern string origPattern string
requestID string
shouldContinue bool shouldContinue bool
response *dns.Msg
meta map[string]string
}
// RequestID returns the request ID for tracing
func (w *ResponseWriterChain) RequestID() string {
return w.requestID
}
// SetMeta sets a metadata key-value pair for logging
func (w *ResponseWriterChain) SetMeta(key, value string) {
if w.meta == nil {
w.meta = make(map[string]string)
}
w.meta[key] = value
} }
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error { func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
@@ -52,6 +72,7 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
w.shouldContinue = true w.shouldContinue = true
return nil return nil
} }
w.response = m
return w.ResponseWriter.WriteMsg(m) return w.ResponseWriter.WriteMsg(m)
} }
@@ -101,6 +122,8 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
pos := c.findHandlerPosition(entry) pos := c.findHandlerPosition(entry)
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...) c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
c.logHandlers()
} }
// findHandlerPosition determines where to insert a new handler based on priority and specificity // findHandlerPosition determines where to insert a new handler based on priority and specificity
@@ -140,68 +163,109 @@ func (c *HandlerChain) removeEntry(pattern string, priority int) {
for i := len(c.handlers) - 1; i >= 0; i-- { for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i] entry := c.handlers[i]
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority { if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
log.Debugf("removing handler pattern: domain=%s priority=%d", entry.OrigPattern, priority)
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
c.logHandlers()
break break
} }
} }
} }
// logHandlers logs the current handler chain state. Caller must hold the lock.
func (c *HandlerChain) logHandlers() {
if !log.IsLevelEnabled(log.TraceLevel) {
return
}
var b strings.Builder
b.WriteString("handler chain (" + strconv.Itoa(len(c.handlers)) + "):\n")
for _, h := range c.handlers {
b.WriteString(" - pattern: domain=" + h.Pattern + " original: domain=" + h.OrigPattern +
" wildcard=" + strconv.FormatBool(h.IsWildcard) +
" match_subdomain=" + strconv.FormatBool(h.MatchSubdomains) +
" priority=" + strconv.Itoa(h.Priority) + "\n")
}
log.Trace(strings.TrimSuffix(b.String(), "\n"))
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 { if len(r.Question) == 0 {
return return
} }
qname := strings.ToLower(r.Question[0].Name) startTime := time.Now()
requestID := resutil.GenerateRequestID()
logger := log.WithFields(log.Fields{
"request_id": requestID,
"dns_id": fmt.Sprintf("%04x", r.Id),
})
question := r.Question[0]
qname := strings.ToLower(question.Name)
c.mu.RLock() c.mu.RLock()
handlers := slices.Clone(c.handlers) handlers := slices.Clone(c.handlers)
c.mu.RUnlock() c.mu.RUnlock()
if log.IsLevelEnabled(log.TraceLevel) {
var b strings.Builder
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
for _, h := range handlers {
b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n",
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority))
}
log.Trace(strings.TrimSuffix(b.String(), "\n"))
}
// Try handlers in priority order // Try handlers in priority order
for _, entry := range handlers { for _, entry := range handlers {
matched := c.isHandlerMatch(qname, entry) if !c.isHandlerMatch(qname, entry) {
continue
if matched {
log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
// Only log continue for non-management cache handlers to reduce noise
if entry.Priority != PriorityMgmtCache {
log.Tracef("handler requested continue to next handler for domain=%s", qname)
}
continue
}
return
} }
handlerName := entry.OrigPattern
if s, ok := entry.Handler.(interface{ String() string }); ok {
handlerName = s.String()
}
logger.Tracef("question: domain=%s type=%s class=%s -> handler=%s pattern=%s wildcard=%v match_subdomain=%v priority=%d",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass],
handlerName, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
requestID: requestID,
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
if entry.Priority != PriorityMgmtCache {
logger.Tracef("handler requested continue for domain=%s", qname)
}
continue
}
c.logResponse(logger, chainWriter, qname, startTime)
return
} }
// No handler matched or all handlers passed // No handler matched or all handlers passed
log.Tracef("no handler found for domain=%s", qname) logger.Tracef("no handler found for domain=%s type=%s class=%s",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
resp := &dns.Msg{} resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeRefused) resp.SetRcode(r, dns.RcodeRefused)
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err) logger.Errorf("failed to write DNS response: %v", err)
} }
} }
func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, qname string, startTime time.Time) {
if cw.response == nil {
return
}
var meta string
for k, v := range cw.meta {
meta += " " + k + "=" + v
}
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
meta, time.Since(startTime))
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch { switch {
case entry.Pattern == ".": case entry.Pattern == ".":

View File

@@ -1,30 +1,50 @@
package local package local
import ( import (
"context"
"errors"
"fmt" "fmt"
"net"
"net/netip"
"slices" "slices"
"strings" "strings"
"sync" "sync"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/dns/types"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
) )
const externalResolutionTimeout = 4 * time.Second
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
type Resolver struct { type Resolver struct {
mu sync.RWMutex mu sync.RWMutex
records map[dns.Question][]dns.RR records map[dns.Question][]dns.RR
domains map[domain.Domain]struct{} domains map[domain.Domain]struct{}
zones []domain.Domain
resolver resolver
ctx context.Context
cancel context.CancelFunc
} }
func NewResolver() *Resolver { func NewResolver() *Resolver {
ctx, cancel := context.WithCancel(context.Background())
return &Resolver{ return &Resolver{
records: make(map[dns.Question][]dns.RR), records: make(map[dns.Question][]dns.RR),
domains: make(map[domain.Domain]struct{}), domains: make(map[domain.Domain]struct{}),
ctx: ctx,
cancel: cancel,
} }
} }
@@ -37,7 +57,18 @@ func (d *Resolver) String() string {
return fmt.Sprintf("LocalResolver [%d records]", len(d.records)) return fmt.Sprintf("LocalResolver [%d records]", len(d.records))
} }
func (d *Resolver) Stop() {} func (d *Resolver) Stop() {
if d.cancel != nil {
d.cancel()
}
d.mu.Lock()
defer d.mu.Unlock()
maps.Clear(d.records)
maps.Clear(d.domains)
d.zones = nil
}
// ID returns the unique handler ID // ID returns the unique handler ID
func (d *Resolver) ID() types.HandlerID { func (d *Resolver) ID() types.HandlerID {
@@ -48,38 +79,47 @@ func (d *Resolver) ProbeAvailability() {}
// ServeDNS handles a DNS request // ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
logger := log.WithField("request_id", resutil.GetRequestID(w))
if len(r.Question) == 0 { if len(r.Question) == 0 {
log.Debugf("received local resolver request with no question") logger.Debug("received local resolver request with no question")
return return
} }
question := r.Question[0] question := r.Question[0]
question.Name = strings.ToLower(dns.Fqdn(question.Name)) question.Name = strings.ToLower(dns.Fqdn(question.Name))
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
replyMessage := &dns.Msg{} replyMessage := &dns.Msg{}
replyMessage.SetReply(r) replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true replyMessage.RecursionAvailable = true
// lookup all records matching the question result := d.lookupRecords(logger, question)
records := d.lookupRecords(question) replyMessage.Authoritative = !result.hasExternalData
if len(records) > 0 { replyMessage.Answer = result.records
replyMessage.Rcode = dns.RcodeSuccess replyMessage.Rcode = d.determineRcode(question, result)
replyMessage.Answer = append(replyMessage.Answer, records...)
} else {
// Check if we have any records for this domain name with different types
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
replyMessage.Rcode = dns.RcodeSuccess // NOERROR with 0 records
} else {
replyMessage.Rcode = dns.RcodeNameError // NXDOMAIN
}
}
if err := w.WriteMsg(replyMessage); err != nil { if err := w.WriteMsg(replyMessage); err != nil {
log.Warnf("failed to write the local resolver response: %v", err) logger.Warnf("failed to write the local resolver response: %v", err)
} }
} }
// determineRcode returns the appropriate DNS response code.
// Per RFC 6604, CNAME chains should return the rcode of the final target resolution,
// even if CNAME records are included in the answer.
func (d *Resolver) determineRcode(question dns.Question, result lookupResult) int {
// Use the rcode from lookup - this properly handles CNAME chains where
// the target may be NXDOMAIN or SERVFAIL even though we have CNAME records
if result.rcode != 0 {
return result.rcode
}
// No records found, but domain exists with different record types (NODATA)
if d.hasRecordsForDomain(domain.Domain(question.Name)) {
return dns.RcodeSuccess
}
return dns.RcodeNameError
}
// hasRecordsForDomain checks if any records exist for the given domain name regardless of type // hasRecordsForDomain checks if any records exist for the given domain name regardless of type
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool { func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
d.mu.RLock() d.mu.RLock()
@@ -89,8 +129,33 @@ func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
return exists return exists
} }
// isInManagedZone checks if the given name falls within any of our managed zones.
// This is used to avoid unnecessary external resolution for CNAME targets that
// are within zones we manage - if we don't have a record for it, it doesn't exist.
// Caller must NOT hold the lock.
func (d *Resolver) isInManagedZone(name string) bool {
d.mu.RLock()
defer d.mu.RUnlock()
name = dns.Fqdn(name)
for _, zone := range d.zones {
zoneStr := dns.Fqdn(zone.PunycodeString())
if strings.EqualFold(name, zoneStr) || strings.HasSuffix(strings.ToLower(name), strings.ToLower("."+zoneStr)) {
return true
}
}
return false
}
// lookupResult contains the result of a DNS lookup operation.
type lookupResult struct {
records []dns.RR
rcode int
hasExternalData bool
}
// lookupRecords fetches *all* DNS records matching the first question in r. // lookupRecords fetches *all* DNS records matching the first question in r.
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR { func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) lookupResult {
d.mu.RLock() d.mu.RLock()
records, found := d.records[question] records, found := d.records[question]
@@ -98,10 +163,14 @@ func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
d.mu.RUnlock() d.mu.RUnlock()
// alternatively check if we have a cname // alternatively check if we have a cname
if question.Qtype != dns.TypeCNAME { if question.Qtype != dns.TypeCNAME {
question.Qtype = dns.TypeCNAME cnameQuestion := dns.Question{
return d.lookupRecords(question) Name: question.Name,
Qtype: dns.TypeCNAME,
Qclass: question.Qclass,
}
return d.lookupCNAMEChain(logger, cnameQuestion, question.Qtype)
} }
return nil return lookupResult{rcode: dns.RcodeNameError}
} }
recordsCopy := slices.Clone(records) recordsCopy := slices.Clone(records)
@@ -119,16 +188,172 @@ func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
d.mu.Unlock() d.mu.Unlock()
} }
return recordsCopy return lookupResult{records: recordsCopy, rcode: dns.RcodeSuccess}
} }
func (d *Resolver) Update(update []nbdns.SimpleRecord) { // lookupCNAMEChain follows a CNAME chain and returns the CNAME records along with
// the final resolved record of the requested type. This is required for musl libc
// compatibility, which expects the full answer chain rather than just the CNAME.
func (d *Resolver) lookupCNAMEChain(logger *log.Entry, cnameQuestion dns.Question, targetType uint16) lookupResult {
const maxDepth = 8
var chain []dns.RR
for range maxDepth {
cnameRecords := d.getRecords(cnameQuestion)
if len(cnameRecords) == 0 {
break
}
chain = append(chain, cnameRecords...)
cname, ok := cnameRecords[0].(*dns.CNAME)
if !ok {
break
}
targetName := strings.ToLower(cname.Target)
targetResult := d.resolveCNAMETarget(logger, targetName, targetType, cnameQuestion.Qclass)
// keep following chain
if targetResult.rcode == -1 {
cnameQuestion = dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: cnameQuestion.Qclass}
continue
}
return d.buildChainResult(chain, targetResult)
}
if len(chain) > 0 {
return lookupResult{records: chain, rcode: dns.RcodeSuccess}
}
return lookupResult{rcode: dns.RcodeSuccess}
}
// buildChainResult combines CNAME chain records with the target resolution result.
// Per RFC 6604, the final rcode is propagated through the chain.
func (d *Resolver) buildChainResult(chain []dns.RR, target lookupResult) lookupResult {
records := chain
if len(target.records) > 0 {
records = append(records, target.records...)
}
// preserve hasExternalData for SERVFAIL so caller knows the error came from upstream
if target.hasExternalData && target.rcode == dns.RcodeServerFailure {
return lookupResult{
records: records,
rcode: dns.RcodeServerFailure,
hasExternalData: true,
}
}
return lookupResult{
records: records,
rcode: target.rcode,
hasExternalData: target.hasExternalData,
}
}
// resolveCNAMETarget attempts to resolve a CNAME target name.
// Returns rcode=-1 to signal "keep following the chain".
func (d *Resolver) resolveCNAMETarget(logger *log.Entry, targetName string, targetType uint16, qclass uint16) lookupResult {
if records := d.getRecords(dns.Question{Name: targetName, Qtype: targetType, Qclass: qclass}); len(records) > 0 {
return lookupResult{records: records, rcode: dns.RcodeSuccess}
}
// another CNAME, keep following
if d.hasRecord(dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: qclass}) {
return lookupResult{rcode: -1}
}
// domain exists locally but not this record type (NODATA)
if d.hasRecordsForDomain(domain.Domain(targetName)) {
return lookupResult{rcode: dns.RcodeSuccess}
}
// in our zone but doesn't exist (NXDOMAIN)
if d.isInManagedZone(targetName) {
return lookupResult{rcode: dns.RcodeNameError}
}
return d.resolveExternal(logger, targetName, targetType)
}
func (d *Resolver) getRecords(q dns.Question) []dns.RR {
d.mu.RLock()
defer d.mu.RUnlock()
return d.records[q]
}
func (d *Resolver) hasRecord(q dns.Question) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, ok := d.records[q]
return ok
}
// resolveExternal resolves a domain name using the system resolver.
// This is used to resolve CNAME targets that point outside our local zone,
// which is required for musl libc compatibility (musl expects complete answers).
func (d *Resolver) resolveExternal(logger *log.Entry, name string, qtype uint16) lookupResult {
network := resutil.NetworkForQtype(qtype)
if network == "" {
return lookupResult{rcode: dns.RcodeNotImplemented}
}
resolver := d.resolver
if resolver == nil {
resolver = net.DefaultResolver
}
ctx, cancel := context.WithTimeout(d.ctx, externalResolutionTimeout)
defer cancel()
result := resutil.LookupIP(ctx, resolver, network, name, qtype)
if result.Err != nil {
d.logDNSError(logger, name, qtype, result.Err)
return lookupResult{rcode: result.Rcode, hasExternalData: true}
}
return lookupResult{
records: resutil.IPsToRRs(name, result.IPs, 60),
rcode: dns.RcodeSuccess,
hasExternalData: true,
}
}
// logDNSError logs DNS resolution errors for debugging.
func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16, err error) {
qtypeName := dns.TypeToString[qtype]
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
return
}
if dnsErr.IsNotFound {
logger.Tracef("DNS target not found: %s type %s", hostname, qtypeName)
return
}
if dnsErr.Server != "" {
logger.Debugf("DNS resolution failed for %s type %s server=%s: %v", hostname, qtypeName, dnsErr.Server, err)
} else {
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
}
}
// Update updates the resolver with new records and zone information.
// The zones parameter specifies which DNS zones this resolver manages.
func (d *Resolver) Update(update []nbdns.SimpleRecord, zones []domain.Domain) {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
maps.Clear(d.records) maps.Clear(d.records)
maps.Clear(d.domains) maps.Clear(d.domains)
d.zones = zones
for _, rec := range update { for _, rec := range update {
if err := d.registerRecord(rec); err != nil { if err := d.registerRecord(rec); err != nil {
log.Warnf("failed to register the record (%s): %v", rec, err) log.Warnf("failed to register the record (%s): %v", rec, err)

View File

@@ -1,8 +1,14 @@
package local package local
import ( import (
"context"
"fmt"
"net"
"net/netip"
"strings" "strings"
"sync"
"testing" "testing"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -10,8 +16,21 @@ import (
"github.com/netbirdio/netbird/client/internal/dns/test" "github.com/netbirdio/netbird/client/internal/dns/test"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
) )
// mockResolver implements resolver for testing
type mockResolver struct {
lookupFunc func(ctx context.Context, network, host string) ([]netip.Addr, error)
}
func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
if m.lookupFunc != nil {
return m.lookupFunc(ctx, network, host)
}
return nil, nil
}
func TestLocalResolver_ServeDNS(t *testing.T) { func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{ recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.", Name: "peera.netbird.cloud.",
@@ -110,7 +129,7 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) {
update2 := []nbdns.SimpleRecord{record2} update2 := []nbdns.SimpleRecord{record2}
// Apply first update // Apply first update
resolver.Update(update1) resolver.Update(update1, nil)
// Verify first update // Verify first update
resolver.mu.RLock() resolver.mu.RLock()
@@ -122,7 +141,7 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) {
assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData) assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData)
// Apply second update // Apply second update
resolver.Update(update2) resolver.Update(update2, nil)
// Verify second update // Verify second update
resolver.mu.RLock() resolver.mu.RLock()
@@ -154,7 +173,7 @@ func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) {
update := []nbdns.SimpleRecord{record1, record2} update := []nbdns.SimpleRecord{record1, record2}
// Apply update with both records // Apply update with both records
resolver.Update(update) resolver.Update(update, nil)
// Create question that matches both records // Create question that matches both records
question := dns.Question{ question := dns.Question{
@@ -198,7 +217,7 @@ func TestLocalResolver_RecordRotation(t *testing.T) {
update := []nbdns.SimpleRecord{record1, record2, record3} update := []nbdns.SimpleRecord{record1, record2, record3}
// Apply update with all three records // Apply update with all three records
resolver.Update(update) resolver.Update(update, nil)
msg := new(dns.Msg).SetQuestion(recordName, recordType) msg := new(dns.Msg).SetQuestion(recordName, recordType)
@@ -264,7 +283,7 @@ func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) {
} }
// Update resolver with the records // Update resolver with the records
resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord}) resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord}, nil)
testCases := []struct { testCases := []struct {
name string name string
@@ -379,7 +398,7 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
} }
// Update resolver with both records // Update resolver with both records
resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord}) resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord}, nil)
testCases := []struct { testCases := []struct {
name string name string
@@ -476,6 +495,20 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
// with 0 records instead of NXDOMAIN // with 0 records instead of NXDOMAIN
func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) { func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
resolver := NewResolver() resolver := NewResolver()
// Mock external resolver for CNAME target resolution
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "target.example.com." {
if network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
if network == "ip6" {
return []netip.Addr{netip.MustParseAddr("2606:2800:220:1:248:1893:25c8:1946")}, nil
}
}
return nil, &net.DNSError{IsNotFound: true, Name: host}
},
}
recordA := nbdns.SimpleRecord{ recordA := nbdns.SimpleRecord{
Name: "example.netbird.cloud.", Name: "example.netbird.cloud.",
@@ -493,7 +526,7 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
RData: "target.example.com.", RData: "target.example.com.",
} }
resolver.Update([]nbdns.SimpleRecord{recordA, recordCNAME}) resolver.Update([]nbdns.SimpleRecord{recordA, recordCNAME}, nil)
testCases := []struct { testCases := []struct {
name string name string
@@ -582,3 +615,555 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
}) })
} }
} }
// TestLocalResolver_CNAMEChainResolution tests comprehensive CNAME chain following
func TestLocalResolver_CNAMEChainResolution(t *testing.T) {
t.Run("simple internal CNAME chain", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."},
{Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1"},
}, nil)
msg := new(dns.Msg).SetQuestion("alias.example.com.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 2)
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok)
assert.Equal(t, "target.example.com.", cname.Target)
a, ok := resp.Answer[1].(*dns.A)
require.True(t, ok)
assert.Equal(t, "192.168.1.1", a.A.String())
})
t.Run("multi-hop CNAME chain", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.SimpleRecord{
{Name: "hop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop2.test."},
{Name: "hop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop3.test."},
{Name: "hop3.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
}, nil)
msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 3)
})
t.Run("CNAME to non-existent internal target returns only CNAME", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.test."},
}, nil)
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 1)
_, ok := resp.Answer[0].(*dns.CNAME)
assert.True(t, ok)
})
}
// TestLocalResolver_CNAMEMaxDepth tests the maximum depth limit for CNAME chains
func TestLocalResolver_CNAMEMaxDepth(t *testing.T) {
t.Run("chain at max depth resolves", func(t *testing.T) {
resolver := NewResolver()
var records []nbdns.SimpleRecord
// Create chain of 7 CNAMEs (under max of 8)
for i := 1; i <= 7; i++ {
records = append(records, nbdns.SimpleRecord{
Name: fmt.Sprintf("hop%d.test.", i),
Type: int(dns.TypeCNAME),
Class: nbdns.DefaultClass,
TTL: 300,
RData: fmt.Sprintf("hop%d.test.", i+1),
})
}
records = append(records, nbdns.SimpleRecord{
Name: "hop8.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10",
})
resolver.Update(records, nil)
msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 8)
})
t.Run("chain exceeding max depth stops", func(t *testing.T) {
resolver := NewResolver()
var records []nbdns.SimpleRecord
// Create chain of 10 CNAMEs (exceeds max of 8)
for i := 1; i <= 10; i++ {
records = append(records, nbdns.SimpleRecord{
Name: fmt.Sprintf("deep%d.test.", i),
Type: int(dns.TypeCNAME),
Class: nbdns.DefaultClass,
TTL: 300,
RData: fmt.Sprintf("deep%d.test.", i+1),
})
}
records = append(records, nbdns.SimpleRecord{
Name: "deep11.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10",
})
resolver.Update(records, nil)
msg := new(dns.Msg).SetQuestion("deep1.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
// Should NOT have the final A record (chain too deep)
assert.LessOrEqual(t, len(resp.Answer), 8)
})
t.Run("circular CNAME is protected by max depth", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.SimpleRecord{
{Name: "loop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop2.test."},
{Name: "loop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop1.test."},
}, nil)
msg := new(dns.Msg).SetQuestion("loop1.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.LessOrEqual(t, len(resp.Answer), 8)
})
}
// TestLocalResolver_ExternalCNAMEResolution tests CNAME resolution to external domains
func TestLocalResolver_ExternalCNAMEResolution(t *testing.T) {
t.Run("CNAME to external domain resolves via external resolver", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." && network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
}, nil)
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 2, "Should have CNAME + A record")
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok)
assert.Equal(t, "external.example.com.", cname.Target)
a, ok := resp.Answer[1].(*dns.A)
require.True(t, ok)
assert.Equal(t, "93.184.216.34", a.A.String())
})
t.Run("CNAME to external domain resolves IPv6", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." && network == "ip6" {
return []netip.Addr{netip.MustParseAddr("2606:2800:220:1:248:1893:25c8:1946")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
}, nil)
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 2, "Should have CNAME + AAAA record")
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok)
assert.Equal(t, "external.example.com.", cname.Target)
aaaa, ok := resp.Answer[1].(*dns.AAAA)
require.True(t, ok)
assert.Equal(t, "2606:2800:220:1:248:1893:25c8:1946", aaaa.AAAA.String())
})
t.Run("concurrent external resolution", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." && network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.SimpleRecord{
{Name: "concurrent.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
}, nil)
var wg sync.WaitGroup
results := make([]*dns.Msg, 10)
for i := 0; i < 10; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
msg := new(dns.Msg).SetQuestion("concurrent.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
results[idx] = resp
}(i)
}
wg.Wait()
for i, resp := range results {
require.NotNil(t, resp, "Response %d should not be nil", i)
require.Len(t, resp.Answer, 2, "Response %d should have CNAME + A", i)
}
})
}
// TestLocalResolver_ZoneManagement tests zone-aware CNAME resolution
func TestLocalResolver_ZoneManagement(t *testing.T) {
t.Run("Update sets zones correctly", func(t *testing.T) {
resolver := NewResolver()
zones := []domain.Domain{"example.com", "test.local"}
resolver.Update([]nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
}, zones)
assert.True(t, resolver.isInManagedZone("host.example.com."))
assert.True(t, resolver.isInManagedZone("other.example.com."))
assert.True(t, resolver.isInManagedZone("sub.test.local."))
assert.False(t, resolver.isInManagedZone("external.com."))
})
t.Run("isInManagedZone case insensitive", func(t *testing.T) {
resolver := NewResolver()
resolver.Update(nil, []domain.Domain{"Example.COM"})
assert.True(t, resolver.isInManagedZone("host.example.com."))
assert.True(t, resolver.isInManagedZone("HOST.EXAMPLE.COM."))
})
t.Run("Update clears zones", func(t *testing.T) {
resolver := NewResolver()
resolver.Update(nil, []domain.Domain{"example.com"})
assert.True(t, resolver.isInManagedZone("host.example.com."))
resolver.Update(nil, nil)
assert.False(t, resolver.isInManagedZone("host.example.com."))
})
}
// TestLocalResolver_CNAMEZoneAwareResolution tests CNAME resolution with zone awareness
func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) {
t.Run("CNAME target in managed zone returns NXDOMAIN per RFC 6604", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.myzone.test."},
}, []domain.Domain{"myzone.test"})
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeNameError, resp.Rcode, "Should return NXDOMAIN")
require.Len(t, resp.Answer, 1, "Should include CNAME in answer")
})
t.Run("CNAME to external domain skips zone check", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.other.com." && network == "ip4" {
return []netip.Addr{netip.MustParseAddr("203.0.113.1")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.other.com."},
}, []domain.Domain{"myzone.test"})
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 2, "Should have CNAME + A from external resolution")
})
t.Run("CNAME target exists with different type returns NODATA not NXDOMAIN", func(t *testing.T) {
resolver := NewResolver()
// CNAME points to target that has A but no AAAA - query for AAAA should be NODATA
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.myzone.test."},
{Name: "target.myzone.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1"},
}, []domain.Domain{"myzone.test"})
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeAAAA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success), not NXDOMAIN")
require.Len(t, resp.Answer, 1, "Should have only CNAME, no AAAA")
_, ok := resp.Answer[0].(*dns.CNAME)
assert.True(t, ok, "Answer should be CNAME record")
})
t.Run("external CNAME target exists but no AAAA records (NODATA)", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." {
if network == "ip6" {
// No AAAA records
return nil, &net.DNSError{IsNotFound: true, Name: host}
}
if network == "ip4" {
// But A records exist - domain exists
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
}
return nil, &net.DNSError{IsNotFound: true, Name: host}
},
}
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
}, nil)
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success), not NXDOMAIN")
require.Len(t, resp.Answer, 1, "Should have only CNAME")
_, ok := resp.Answer[0].(*dns.CNAME)
assert.True(t, ok, "Answer should be CNAME record")
})
// Table-driven test for all external resolution outcomes
externalCases := []struct {
name string
lookupFunc func(context.Context, string, string) ([]netip.Addr, error)
expectedRcode int
expectedAnswer int
}{
{
name: "external NXDOMAIN (both A and AAAA not found)",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, &net.DNSError{IsNotFound: true, Name: host}
},
expectedRcode: dns.RcodeNameError,
expectedAnswer: 1, // CNAME only
},
{
name: "external SERVFAIL (temporary error)",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, &net.DNSError{IsTemporary: true, Name: host}
},
expectedRcode: dns.RcodeServerFailure,
expectedAnswer: 1, // CNAME only
},
{
name: "external SERVFAIL (timeout)",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, &net.DNSError{IsTimeout: true, Name: host}
},
expectedRcode: dns.RcodeServerFailure,
expectedAnswer: 1, // CNAME only
},
{
name: "external SERVFAIL (generic error)",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
return nil, fmt.Errorf("connection refused")
},
expectedRcode: dns.RcodeServerFailure,
expectedAnswer: 1, // CNAME only
},
{
name: "external success with IPs",
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
return nil, &net.DNSError{IsNotFound: true, Name: host}
},
expectedRcode: dns.RcodeSuccess,
expectedAnswer: 2, // CNAME + A
},
}
for _, tc := range externalCases {
t.Run(tc.name, func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{lookupFunc: tc.lookupFunc}
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
}, nil)
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Equal(t, tc.expectedRcode, resp.Rcode, "rcode mismatch")
assert.Len(t, resp.Answer, tc.expectedAnswer, "answer count mismatch")
if tc.expectedAnswer > 0 {
_, ok := resp.Answer[0].(*dns.CNAME)
assert.True(t, ok, "first answer should be CNAME")
}
})
}
}
// TestLocalResolver_AuthoritativeFlag tests the AA flag behavior
func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
t.Run("direct record lookup is authoritative", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
}, []domain.Domain{"example.com"})
msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.True(t, resp.Authoritative)
})
t.Run("external resolution is not authoritative", func(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{
lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) {
if host == "external.example.com." && network == "ip4" {
return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil
}
return nil, nil
},
}
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
}, nil)
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 2)
assert.False(t, resp.Authoritative)
})
}
// TestLocalResolver_Stop tests cleanup on Stop
func TestLocalResolver_Stop(t *testing.T) {
t.Run("Stop clears all state", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
}, []domain.Domain{"example.com"})
resolver.Stop()
msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA)
var resp *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg)
require.NotNil(t, resp)
assert.Len(t, resp.Answer, 0)
assert.False(t, resolver.isInManagedZone("host.example.com."))
})
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
}, []domain.Domain{"example.com"})
resolver.Stop()
resolver.Stop()
resolver.Stop()
})
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
resolver := NewResolver()
lookupStarted := make(chan struct{})
lookupCtxCanceled := make(chan struct{})
resolver.resolver = &mockResolver{
lookupFunc: func(ctx context.Context, network, host string) ([]netip.Addr, error) {
close(lookupStarted)
<-ctx.Done()
close(lookupCtxCanceled)
return nil, ctx.Err()
},
}
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
}, nil)
done := make(chan struct{})
go func() {
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}, msg)
close(done)
}()
<-lookupStarted
resolver.Stop()
select {
case <-lookupCtxCanceled:
case <-time.After(time.Second):
t.Fatal("external lookup context was not canceled")
}
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("ServeDNS did not return after Stop")
}
})
}

View File

@@ -0,0 +1,197 @@
// Package resutil provides shared DNS resolution utilities
package resutil
import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"net"
"net/netip"
"strings"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
// GenerateRequestID creates a random 8-character hex string for request tracing.
func GenerateRequestID() string {
bytes := make([]byte, 4)
if _, err := rand.Read(bytes); err != nil {
log.Errorf("generate request ID: %v", err)
return ""
}
return hex.EncodeToString(bytes)
}
// IPsToRRs converts a slice of IP addresses to DNS resource records.
// IPv4 addresses become A records, IPv6 addresses become AAAA records.
func IPsToRRs(name string, ips []netip.Addr, ttl uint32) []dns.RR {
var result []dns.RR
for _, ip := range ips {
if ip.Is6() {
result = append(result, &dns.AAAA{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: ttl,
},
AAAA: ip.AsSlice(),
})
} else {
result = append(result, &dns.A{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: ttl,
},
A: ip.AsSlice(),
})
}
}
return result
}
// NetworkForQtype returns the network string ("ip4" or "ip6") for a DNS query type.
// Returns empty string for unsupported types.
func NetworkForQtype(qtype uint16) string {
switch qtype {
case dns.TypeA:
return "ip4"
case dns.TypeAAAA:
return "ip6"
default:
return ""
}
}
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
// chainedWriter is implemented by ResponseWriters that carry request metadata
type chainedWriter interface {
RequestID() string
SetMeta(key, value string)
}
// GetRequestID extracts a request ID from the ResponseWriter if available,
// otherwise generates a new one.
func GetRequestID(w dns.ResponseWriter) string {
if cw, ok := w.(chainedWriter); ok {
if id := cw.RequestID(); id != "" {
return id
}
}
return GenerateRequestID()
}
// SetMeta sets metadata on the ResponseWriter if it supports it.
func SetMeta(w dns.ResponseWriter, key, value string) {
if cw, ok := w.(chainedWriter); ok {
cw.SetMeta(key, value)
}
}
// LookupResult contains the result of an external DNS lookup
type LookupResult struct {
IPs []netip.Addr
Rcode int
Err error // Original error for caller's logging needs
}
// LookupIP performs a DNS lookup and determines the appropriate rcode.
func LookupIP(ctx context.Context, r resolver, network, host string, qtype uint16) LookupResult {
ips, err := r.LookupNetIP(ctx, network, host)
if err != nil {
return LookupResult{
Rcode: getRcodeForError(ctx, r, host, qtype, err),
Err: err,
}
}
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return
for i, ip := range ips {
ips[i] = ip.Unmap()
}
return LookupResult{
IPs: ips,
Rcode: dns.RcodeSuccess,
}
}
func getRcodeForError(ctx context.Context, r resolver, host string, qtype uint16, err error) int {
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
return dns.RcodeServerFailure
}
if dnsErr.IsNotFound {
return getRcodeForNotFound(ctx, r, host, qtype)
}
return dns.RcodeServerFailure
}
// getRcodeForNotFound distinguishes between NXDOMAIN (domain doesn't exist) and NODATA
// (domain exists but no records of requested type) by checking the opposite record type.
//
// musl libc (the reason we need this distinction) only queries A/AAAA pairs in getaddrinfo,
// so checking the opposite A/AAAA type is sufficient. Other record types (MX, TXT, etc.)
// are not queried by musl and don't need this handling.
func getRcodeForNotFound(ctx context.Context, r resolver, domain string, originalQtype uint16) int {
// Try querying for a different record type to see if the domain exists
// If the original query was for AAAA, try A. If it was for A, try AAAA.
// This helps distinguish between NXDOMAIN and NODATA.
var alternativeNetwork string
switch originalQtype {
case dns.TypeAAAA:
alternativeNetwork = "ip4"
case dns.TypeA:
alternativeNetwork = "ip6"
default:
return dns.RcodeNameError
}
if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
// Alternative query also returned not found - domain truly doesn't exist
return dns.RcodeNameError
}
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
return dns.RcodeSuccess
}
// Alternative query succeeded - domain exists but has no records of this type
return dns.RcodeSuccess
}
// FormatAnswers formats DNS resource records for logging.
func FormatAnswers(answers []dns.RR) string {
if len(answers) == 0 {
return "[]"
}
parts := make([]string, 0, len(answers))
for _, rr := range answers {
switch r := rr.(type) {
case *dns.A:
parts = append(parts, r.A.String())
case *dns.AAAA:
parts = append(parts, r.AAAA.String())
case *dns.CNAME:
parts = append(parts, "CNAME:"+r.Target)
case *dns.PTR:
parts = append(parts, "PTR:"+r.Ptr)
default:
parts = append(parts, dns.TypeToString[rr.Header().Rrtype])
}
}
return "[" + strings.Join(parts, ", ") + "]"
}

View File

@@ -485,7 +485,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
} }
} }
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones) localMuxUpdates, localRecords, localZones, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil { if err != nil {
return fmt.Errorf("local handler updater: %w", err) return fmt.Errorf("local handler updater: %w", err)
} }
@@ -499,7 +499,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.updateMux(muxUpdates) s.updateMux(muxUpdates)
// register local records // register local records
s.localResolver.Update(localRecords) s.localResolver.Update(localRecords, localZones)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort()) s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
@@ -659,9 +659,10 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback) s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback)
} }
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) { func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, []domain.Domain, error) {
var muxUpdates []handlerWrapper var muxUpdates []handlerWrapper
var localRecords []nbdns.SimpleRecord var localRecords []nbdns.SimpleRecord
var zones []domain.Domain
for _, customZone := range customZones { for _, customZone := range customZones {
if len(customZone.Records) == 0 { if len(customZone.Records) == 0 {
@@ -675,6 +676,8 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
priority: PriorityLocal, priority: PriorityLocal,
}) })
zones = append(zones, domain.Domain(customZone.Domain))
for _, record := range customZone.Records { for _, record := range customZone.Records {
if record.Class != nbdns.DefaultClass { if record.Class != nbdns.DefaultClass {
log.Warnf("received an invalid class type: %s", record.Class) log.Warnf("received an invalid class type: %s", record.Class)
@@ -685,7 +688,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
} }
} }
return muxUpdates, localRecords, nil return muxUpdates, localRecords, zones, nil
} }
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) { func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) {

View File

@@ -385,7 +385,7 @@ func TestUpdateDNSServer(t *testing.T) {
}() }()
dnsServer.dnsMuxMap = testCase.initUpstreamMap dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalRecords) dnsServer.localResolver.Update(testCase.initLocalRecords, nil)
dnsServer.updateSerial = testCase.initSerial dnsServer.updateSerial = testCase.initSerial
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate) err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
@@ -511,7 +511,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
}, },
} }
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}} //dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}) dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, nil)
dnsServer.updateSerial = 0 dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{ nameServers := []nbdns.NameServer{
@@ -2013,7 +2013,7 @@ func TestLocalResolverPriorityInServer(t *testing.T) {
}, },
} }
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones) localMuxUpdates, _, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
assert.NoError(t, err) assert.NoError(t, err)
upstreamMuxUpdates, err := server.buildUpstreamHandlerUpdate(config.NameServerGroups) upstreamMuxUpdates, err := server.buildUpstreamHandlerUpdate(config.NameServerGroups)
@@ -2074,7 +2074,7 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
}, },
} }
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones) localMuxUpdates, _, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, localMuxUpdates, 1) assert.Len(t, localMuxUpdates, 1)
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal") assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")

View File

@@ -2,7 +2,6 @@ package dns
import ( import (
"context" "context"
"crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"errors" "errors"
@@ -21,6 +20,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
@@ -113,10 +113,7 @@ func (u *upstreamResolverBase) Stop() {
// ServeDNS handles a DNS request // ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID() logger := log.WithField("request_id", resutil.GetRequestID(w))
logger := log.WithField("request_id", requestID)
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
u.prepareRequest(r) u.prepareRequest(r)
@@ -202,11 +199,14 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool { func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
u.successCount.Add(1) u.successCount.Add(1)
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain)
resutil.SetMeta(w, "upstream", upstream.String())
if err := w.WriteMsg(rm); err != nil { if err := w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err) logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
return true
} }
return true return true
} }
@@ -414,16 +414,6 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil return rm, t, nil
} }
func GenerateRequestID() string {
bytes := make([]byte, 4)
_, err := rand.Read(bytes)
if err != nil {
log.Errorf("failed to generate request ID: %v", err)
return ""
}
return hex.EncodeToString(bytes)
}
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts // FormatPeerStatus formats peer connection status information for debugging DNS timeouts
func FormatPeerStatus(peerState *peer.State) string { func FormatPeerStatus(peerState *peer.State) string {
isConnected := peerState.ConnStatus == peer.StatusConnected isConnected := peerState.ConnStatus == peer.StatusConnected

View File

@@ -18,6 +18,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -189,29 +190,22 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
return nberrors.FormatErrorOrNil(result) return nberrors.FormatErrorOrNil(result)
} }
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns.Msg { func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
if len(query.Question) == 0 { if len(query.Question) == 0 {
return nil return nil
} }
question := query.Question[0] question := query.Question[0]
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v", logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
question.Name, question.Qtype, question.Qclass) question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
domain := strings.ToLower(question.Name) domain := strings.ToLower(question.Name)
resp := query.SetReply(query) resp := query.SetReply(query)
var network string network := resutil.NetworkForQtype(question.Qtype)
switch question.Qtype { if network == "" {
case dns.TypeA:
network = "ip4"
case dns.TypeAAAA:
network = "ip6"
default:
// TODO: Handle other types
resp.Rcode = dns.RcodeNotImplemented resp.Rcode = dns.RcodeNotImplemented
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err) logger.Errorf("failed to write DNS response: %v", err)
} }
return nil return nil
} }
@@ -221,33 +215,35 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
if mostSpecificResId == "" { if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused resp.Rcode = dns.RcodeRefused
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err) logger.Errorf("failed to write DNS response: %v", err)
} }
return nil return nil
} }
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel() defer cancel()
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
if err != nil { result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
f.handleDNSError(ctx, w, question, resp, domain, err) if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
return nil return nil
} }
// Unmap IPv4-mapped IPv6 addresses that some resolvers may return f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
for i, ip := range ips { resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
ips[i] = ip.Unmap() f.cache.set(domain, question.Qtype, result.IPs)
}
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips)
f.cache.set(domain, question.Qtype, ips)
return resp return resp
} }
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query) startTime := time.Now()
logger := log.WithFields(log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
})
resp := f.handleDNSQuery(logger, w, query)
if resp == nil { if resp == nil {
return return
} }
@@ -265,19 +261,33 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
} }
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err) logger.Errorf("failed to write DNS response: %v", err)
return
} }
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
} }
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) { func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query) startTime := time.Now()
logger := log.WithFields(log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
})
resp := f.handleDNSQuery(logger, w, query)
if resp == nil { if resp == nil {
return return
} }
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err) logger.Errorf("failed to write DNS response: %v", err)
return
} }
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
} }
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) { func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
@@ -315,140 +325,64 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
} }
} }
// setResponseCodeForNotFound determines and sets the appropriate response code when IsNotFound is true
// It distinguishes between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of requested type)
//
// LIMITATION: This function only checks A and AAAA record types to determine domain existence.
// If a domain has only other record types (MX, TXT, CNAME, etc.) but no A/AAAA records,
// it may incorrectly return NXDOMAIN instead of NODATA. This is acceptable since the forwarder
// only handles A/AAAA queries and returns NOTIMP for other types.
func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns.Msg, domain string, originalQtype uint16) {
// Try querying for a different record type to see if the domain exists
// If the original query was for AAAA, try A. If it was for A, try AAAA.
// This helps distinguish between NXDOMAIN and NODATA.
var alternativeNetwork string
switch originalQtype {
case dns.TypeAAAA:
alternativeNetwork = "ip4"
case dns.TypeA:
alternativeNetwork = "ip6"
default:
resp.Rcode = dns.RcodeNameError
return
}
if _, err := f.resolver.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
// Alternative query also returned not found - domain truly doesn't exist
resp.Rcode = dns.RcodeNameError
return
}
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
resp.Rcode = dns.RcodeSuccess
return
}
// Alternative query succeeded - domain exists but has no records of this type
resp.Rcode = dns.RcodeSuccess
}
// handleDNSError processes DNS lookup errors and sends an appropriate error response. // handleDNSError processes DNS lookup errors and sends an appropriate error response.
func (f *DNSForwarder) handleDNSError( func (f *DNSForwarder) handleDNSError(
ctx context.Context, ctx context.Context,
logger *log.Entry,
w dns.ResponseWriter, w dns.ResponseWriter,
question dns.Question, question dns.Question,
resp *dns.Msg, resp *dns.Msg,
domain string, domain string,
err error, result resutil.LookupResult,
) { ) {
// Default to SERVFAIL; override below when appropriate.
resp.Rcode = dns.RcodeServerFailure
qType := question.Qtype qType := question.Qtype
qTypeName := dns.TypeToString[qType] qTypeName := dns.TypeToString[qType]
// Prefer typed DNS errors; fall back to generic logging otherwise. resp.Rcode = result.Rcode
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
log.Warnf(errResolveFailed, domain, err)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
// NotFound: set NXDOMAIN / appropriate code via helper. // NotFound: cache negative result and respond
if dnsErr.IsNotFound { if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
f.cache.set(domain, question.Qtype, nil) f.cache.set(domain, question.Qtype, nil)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
return return
} }
// Upstream failed but we might have a cached answer—serve it if present. // Upstream failed but we might have a cached answer—serve it if present.
if ips, ok := f.cache.get(domain, qType); ok { if ips, ok := f.cache.get(domain, qType); ok {
if len(ips) > 0 { if len(ips) > 0 {
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName) logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
f.addIPsToResponse(resp, domain, ips) resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
resp.Rcode = dns.RcodeSuccess resp.Rcode = dns.RcodeSuccess
if writeErr := w.WriteMsg(resp); writeErr != nil { if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write cached DNS response: %v", writeErr) logger.Errorf("failed to write cached DNS response: %v", writeErr)
}
} else { // send NXDOMAIN / appropriate code if cache is empty
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
} }
return
}
// Cached negative result - re-verify NXDOMAIN vs NODATA
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
resp.Rcode = verifyResult.Rcode
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
} }
return
} }
// No cache. Log with or without the server field for more context. // No cache or verification failed. Log with or without the server field for more context.
if dnsErr.Server != "" { var dnsErr *net.DNSError
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err) if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
} else { } else {
log.Warnf(errResolveFailed, domain, err) logger.Warnf(errResolveFailed, domain, result.Err)
} }
// Write final failure response. // Write final failure response.
if writeErr := w.WriteMsg(resp); writeErr != nil { if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr) logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
}
// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records
func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []netip.Addr) {
for _, ip := range ips {
var respRecord dns.RR
if ip.Is6() {
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
rr := dns.AAAA{
AAAA: ip.AsSlice(),
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: f.ttl,
},
}
respRecord = &rr
} else {
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
rr := dns.A{
A: ip.AsSlice(),
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: f.ttl,
},
}
respRecord = &rr
}
resp.Answer = append(resp.Answer, respRecord)
} }
} }

View File

@@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -317,7 +318,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA) query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{} mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query) resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
if tt.shouldResolve { if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain") require.NotNil(t, resp, "Expected response for authorized domain")
@@ -465,7 +466,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA) dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{} mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery) resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
// Verify response // Verify response
if tt.shouldResolve { if tt.shouldResolve {
@@ -527,7 +528,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
query.SetQuestion("example.com.", dns.TypeA) query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{} mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query) resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// Verify response contains all IPs // Verify response contains all IPs
require.NotNil(t, resp) require.NotNil(t, resp)
@@ -604,7 +605,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
}, },
} }
_ = forwarder.handleDNSQuery(mockWriter, query) _ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// Check the response written to the writer // Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written") require.NotNil(t, writtenResp, "Expected response to be written")
@@ -674,7 +675,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
q1 := &dns.Msg{} q1 := &dns.Msg{}
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w1 := &test.MockResponseWriter{} w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(w1, q1) resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
require.NotNil(t, resp1) require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode) require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1) require.Len(t, resp1.Answer, 1)
@@ -684,7 +685,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
var writtenResp *dns.Msg var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(w2, q2) _ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
require.NotNil(t, writtenResp, "expected response to be written") require.NotNil(t, writtenResp, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
@@ -714,7 +715,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q1 := &dns.Msg{} q1 := &dns.Msg{}
q1.SetQuestion(mixedQuery+".", dns.TypeA) q1.SetQuestion(mixedQuery+".", dns.TypeA)
w1 := &test.MockResponseWriter{} w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(w1, q1) resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
require.NotNil(t, resp1) require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode) require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1) require.Len(t, resp1.Answer, 1)
@@ -728,7 +729,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q2.SetQuestion("EXAMPLE.COM", dns.TypeA) q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
var writtenResp *dns.Msg var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(w2, q2) _ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
require.NotNil(t, writtenResp) require.NotNil(t, writtenResp)
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
@@ -783,7 +784,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
query.SetQuestion("smtp.mail.example.com.", dns.TypeA) query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{} mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query) resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
require.NotNil(t, resp) require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode) assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
@@ -904,7 +905,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
}, },
} }
resp := forwarder.handleDNSQuery(mockWriter, query) resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// If a response was returned, it means it should be written (happens in wrapper functions) // If a response was returned, it means it should be written (happens in wrapper functions)
if resp != nil && writtenResp == nil { if resp != nil && writtenResp == nil {
@@ -937,7 +938,7 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
return nil return nil
}, },
} }
resp := forwarder.handleDNSQuery(mockWriter, query) resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
assert.Nil(t, resp, "Should return nil for empty query") assert.Nil(t, resp, "Should return nil for empty query")
assert.False(t, writeCalled, "Should not write response for empty query") assert.False(t, writeCalled, "Should not write response for empty query")

View File

@@ -19,6 +19,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/common"
@@ -219,14 +220,14 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
// ServeDNS implements the dns.Handler interface // ServeDNS implements the dns.Handler interface
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := nbdns.GenerateRequestID() logger := log.WithFields(log.Fields{
logger := log.WithField("request_id", requestID) "request_id": resutil.GetRequestID(w),
"dns_id": fmt.Sprintf("%04x", r.Id),
})
if len(r.Question) == 0 { if len(r.Question) == 0 {
return return
} }
logger.Tracef("received DNS request for domain=%s type=%v class=%v",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// pass if non A/AAAA query // pass if non A/AAAA query
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA { if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
@@ -280,15 +281,10 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
var answer []dns.RR resutil.SetMeta(w, "peer", peerKey)
if reply != nil {
answer = reply.Answer
}
logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
reply.Id = r.Id reply.Id = r.Id
if err := d.writeMsg(w, reply); err != nil { if err := d.writeMsg(w, reply, logger); err != nil {
logger.Errorf("failed writing DNS response: %v", err) logger.Errorf("failed writing DNS response: %v", err)
} }
} }
@@ -324,7 +320,7 @@ func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
return peerAllowedIP, nil return peerAllowedIP, nil
} }
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) error {
if r == nil { if r == nil {
return fmt.Errorf("received nil DNS message") return fmt.Errorf("received nil DNS message")
} }
@@ -350,14 +346,14 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
case *dns.A: case *dns.A:
addr, ok := netip.AddrFromSlice(rr.A) addr, ok := netip.AddrFromSlice(rr.A)
if !ok { if !ok {
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A) logger.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
continue continue
} }
ip = addr ip = addr
case *dns.AAAA: case *dns.AAAA:
addr, ok := netip.AddrFromSlice(rr.AAAA) addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok { if !ok {
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA) logger.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
continue continue
} }
ip = addr ip = addr
@@ -370,11 +366,11 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
} }
if len(newPrefixes) > 0 { if len(newPrefixes) > 0 {
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil { if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes, logger); err != nil {
log.Errorf("failed to update domain prefixes: %v", err) logger.Errorf("failed to update domain prefixes: %v", err)
} }
d.replaceIPsInDNSResponse(r, newPrefixes) d.replaceIPsInDNSResponse(r, newPrefixes, logger)
} }
} }
@@ -386,22 +382,22 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
} }
// logPrefixChanges handles the logging for prefix changes // logPrefixChanges handles the logging for prefix changes
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) { func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix, logger *log.Entry) {
if len(toAdd) > 0 { if len(toAdd) > 0 {
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s", logger.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(), resolvedDomain.SafeString(),
originalDomain.SafeString(), originalDomain.SafeString(),
toAdd) toAdd)
} }
if len(toRemove) > 0 && !d.route.KeepRoute { if len(toRemove) > 0 && !d.route.KeepRoute {
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s", logger.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(), resolvedDomain.SafeString(),
originalDomain.SafeString(), originalDomain.SafeString(),
toRemove) toRemove)
} }
} }
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error { func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix, logger *log.Entry) error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
@@ -418,9 +414,9 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
realIP := prefix.Addr() realIP := prefix.Addr()
if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil { if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil {
dnatMappings[fakeIP] = realIP dnatMappings[fakeIP] = realIP
log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP) logger.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
} else { } else {
log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err) logger.Errorf("failed to allocate fake IP for %s: %v", realIP, err)
} }
} }
} }
@@ -432,7 +428,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
} }
} }
d.addDNATMappings(dnatMappings) d.addDNATMappings(dnatMappings, logger)
if !d.route.KeepRoute { if !d.route.KeepRoute {
// Remove old prefixes // Remove old prefixes
@@ -448,7 +444,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
} }
} }
d.removeDNATMappings(toRemove) d.removeDNATMappings(toRemove, logger)
} }
// Update domain prefixes using resolved domain as key - store real IPs // Update domain prefixes using resolved domain as key - store real IPs
@@ -463,14 +459,14 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
// Store real IPs for status (user-facing), not fake IPs // Store real IPs for status (user-facing), not fake IPs
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID()) d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove) d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove, logger)
} }
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
// removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes // removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) { func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix, logger *log.Entry) {
if len(realPrefixes) == 0 { if len(realPrefixes) == 0 {
return return
} }
@@ -484,9 +480,9 @@ func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
realIP := prefix.Addr() realIP := prefix.Addr()
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists { if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil { if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err) logger.Errorf("failed to remove DNAT mapping for %s: %v", fakeIP, err)
} else { } else {
log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP) logger.Debugf("removed DNAT mapping: %s -> %s", fakeIP, realIP)
} }
} }
} }
@@ -502,7 +498,7 @@ func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) {
} }
// addDNATMappings adds DNAT mappings to the firewall // addDNATMappings adds DNAT mappings to the firewall
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) { func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr, logger *log.Entry) {
if len(mappings) == 0 { if len(mappings) == 0 {
return return
} }
@@ -514,9 +510,9 @@ func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
for fakeIP, realIP := range mappings { for fakeIP, realIP := range mappings {
if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil { if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil {
log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err) logger.Errorf("failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
} else { } else {
log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP) logger.Debugf("added DNAT mapping: %s -> %s", fakeIP, realIP)
} }
} }
} }
@@ -528,12 +524,12 @@ func (d *DnsInterceptor) cleanupDNATMappings() {
} }
for _, prefixes := range d.interceptedDomains { for _, prefixes := range d.interceptedDomains {
d.removeDNATMappings(prefixes) d.removeDNATMappings(prefixes, log.NewEntry(log.StandardLogger()))
} }
} }
// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response // replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) { func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix, logger *log.Entry) {
if _, ok := d.internalDnatFw(); !ok { if _, ok := d.internalDnatFw(); !ok {
return return
} }
@@ -549,7 +545,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists { if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.A = fakeIP.AsSlice() rr.A = fakeIP.AsSlice()
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP) logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
} }
case *dns.AAAA: case *dns.AAAA:
@@ -560,7 +556,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists { if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.AAAA = fakeIP.AsSlice() rr.AAAA = fakeIP.AsSlice()
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP) logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
} }
} }
} }