diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 2e54bffd9..7e3eb6d1f 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -3,11 +3,15 @@ package dns import ( "fmt" "slices" + "strconv" "strings" "sync" + "time" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/dns/resutil" ) const ( @@ -43,7 +47,23 @@ type HandlerChain struct { type ResponseWriterChain struct { dns.ResponseWriter origPattern string + requestID string shouldContinue bool + response *dns.Msg + meta map[string]string +} + +// RequestID returns the request ID for tracing +func (w *ResponseWriterChain) RequestID() string { + return w.requestID +} + +// SetMeta sets a metadata key-value pair for logging +func (w *ResponseWriterChain) SetMeta(key, value string) { + if w.meta == nil { + w.meta = make(map[string]string) + } + w.meta[key] = value } func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error { @@ -52,6 +72,7 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error { w.shouldContinue = true return nil } + w.response = m return w.ResponseWriter.WriteMsg(m) } @@ -101,6 +122,8 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority pos := c.findHandlerPosition(entry) c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...) + + c.logHandlers() } // findHandlerPosition determines where to insert a new handler based on priority and specificity @@ -140,68 +163,109 @@ func (c *HandlerChain) removeEntry(pattern string, priority int) { for i := len(c.handlers) - 1; i >= 0; i-- { entry := c.handlers[i] if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority { + log.Debugf("removing handler pattern: domain=%s priority=%d", entry.OrigPattern, priority) c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) + c.logHandlers() break } } } +// logHandlers logs the current handler chain state. Caller must hold the lock. +func (c *HandlerChain) logHandlers() { + if !log.IsLevelEnabled(log.TraceLevel) { + return + } + + var b strings.Builder + b.WriteString("handler chain (" + strconv.Itoa(len(c.handlers)) + "):\n") + for _, h := range c.handlers { + b.WriteString(" - pattern: domain=" + h.Pattern + " original: domain=" + h.OrigPattern + + " wildcard=" + strconv.FormatBool(h.IsWildcard) + + " match_subdomain=" + strconv.FormatBool(h.MatchSubdomains) + + " priority=" + strconv.Itoa(h.Priority) + "\n") + } + log.Trace(strings.TrimSuffix(b.String(), "\n")) +} + func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 0 { return } - qname := strings.ToLower(r.Question[0].Name) + startTime := time.Now() + requestID := resutil.GenerateRequestID() + logger := log.WithFields(log.Fields{ + "request_id": requestID, + "dns_id": fmt.Sprintf("%04x", r.Id), + }) + + question := r.Question[0] + qname := strings.ToLower(question.Name) c.mu.RLock() handlers := slices.Clone(c.handlers) c.mu.RUnlock() - if log.IsLevelEnabled(log.TraceLevel) { - var b strings.Builder - b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers))) - for _, h := range handlers { - b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n", - h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority)) - } - log.Trace(strings.TrimSuffix(b.String(), "\n")) - } - // Try handlers in priority order for _, entry := range handlers { - matched := c.isHandlerMatch(qname, entry) - - if matched { - log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d", - qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority) - - chainWriter := &ResponseWriterChain{ - ResponseWriter: w, - origPattern: entry.OrigPattern, - } - entry.Handler.ServeDNS(chainWriter, r) - - // If handler wants to continue, try next handler - if chainWriter.shouldContinue { - // Only log continue for non-management cache handlers to reduce noise - if entry.Priority != PriorityMgmtCache { - log.Tracef("handler requested continue to next handler for domain=%s", qname) - } - continue - } - return + if !c.isHandlerMatch(qname, entry) { + continue } + + handlerName := entry.OrigPattern + if s, ok := entry.Handler.(interface{ String() string }); ok { + handlerName = s.String() + } + + logger.Tracef("question: domain=%s type=%s class=%s -> handler=%s pattern=%s wildcard=%v match_subdomain=%v priority=%d", + qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass], + handlerName, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority) + + chainWriter := &ResponseWriterChain{ + ResponseWriter: w, + origPattern: entry.OrigPattern, + requestID: requestID, + } + entry.Handler.ServeDNS(chainWriter, r) + + // If handler wants to continue, try next handler + if chainWriter.shouldContinue { + if entry.Priority != PriorityMgmtCache { + logger.Tracef("handler requested continue for domain=%s", qname) + } + continue + } + + c.logResponse(logger, chainWriter, qname, startTime) + return } // No handler matched or all handlers passed - log.Tracef("no handler found for domain=%s", qname) + logger.Tracef("no handler found for domain=%s type=%s class=%s", + qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass]) resp := &dns.Msg{} resp.SetRcode(r, dns.RcodeRefused) if err := w.WriteMsg(resp); err != nil { - log.Errorf("failed to write DNS response: %v", err) + logger.Errorf("failed to write DNS response: %v", err) } } +func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, qname string, startTime time.Time) { + if cw.response == nil { + return + } + + var meta string + for k, v := range cw.meta { + meta += " " + k + "=" + v + } + + logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s", + qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer), + meta, time.Since(startTime)) +} + func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { switch { case entry.Pattern == ".": diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go index bac7875ec..cb1fa5293 100644 --- a/client/internal/dns/local/local.go +++ b/client/internal/dns/local/local.go @@ -1,30 +1,50 @@ package local import ( + "context" + "errors" "fmt" + "net" + "net/netip" "slices" "strings" "sync" + "time" "github.com/miekg/dns" log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/client/internal/dns/types" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/shared/management/domain" ) +const externalResolutionTimeout = 4 * time.Second + +type resolver interface { + LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) +} + type Resolver struct { - mu sync.RWMutex - records map[dns.Question][]dns.RR - domains map[domain.Domain]struct{} + mu sync.RWMutex + records map[dns.Question][]dns.RR + domains map[domain.Domain]struct{} + zones []domain.Domain + resolver resolver + + ctx context.Context + cancel context.CancelFunc } func NewResolver() *Resolver { + ctx, cancel := context.WithCancel(context.Background()) return &Resolver{ records: make(map[dns.Question][]dns.RR), domains: make(map[domain.Domain]struct{}), + ctx: ctx, + cancel: cancel, } } @@ -37,7 +57,18 @@ func (d *Resolver) String() string { return fmt.Sprintf("LocalResolver [%d records]", len(d.records)) } -func (d *Resolver) Stop() {} +func (d *Resolver) Stop() { + if d.cancel != nil { + d.cancel() + } + + d.mu.Lock() + defer d.mu.Unlock() + + maps.Clear(d.records) + maps.Clear(d.domains) + d.zones = nil +} // ID returns the unique handler ID func (d *Resolver) ID() types.HandlerID { @@ -48,38 +79,47 @@ func (d *Resolver) ProbeAvailability() {} // ServeDNS handles a DNS request func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + logger := log.WithField("request_id", resutil.GetRequestID(w)) + if len(r.Question) == 0 { - log.Debugf("received local resolver request with no question") + logger.Debug("received local resolver request with no question") return } question := r.Question[0] question.Name = strings.ToLower(dns.Fqdn(question.Name)) - log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass) - replyMessage := &dns.Msg{} replyMessage.SetReply(r) replyMessage.RecursionAvailable = true - // lookup all records matching the question - records := d.lookupRecords(question) - if len(records) > 0 { - replyMessage.Rcode = dns.RcodeSuccess - replyMessage.Answer = append(replyMessage.Answer, records...) - } else { - // Check if we have any records for this domain name with different types - if d.hasRecordsForDomain(domain.Domain(question.Name)) { - replyMessage.Rcode = dns.RcodeSuccess // NOERROR with 0 records - } else { - replyMessage.Rcode = dns.RcodeNameError // NXDOMAIN - } - } + result := d.lookupRecords(logger, question) + replyMessage.Authoritative = !result.hasExternalData + replyMessage.Answer = result.records + replyMessage.Rcode = d.determineRcode(question, result) if err := w.WriteMsg(replyMessage); err != nil { - log.Warnf("failed to write the local resolver response: %v", err) + logger.Warnf("failed to write the local resolver response: %v", err) } } +// determineRcode returns the appropriate DNS response code. +// Per RFC 6604, CNAME chains should return the rcode of the final target resolution, +// even if CNAME records are included in the answer. +func (d *Resolver) determineRcode(question dns.Question, result lookupResult) int { + // Use the rcode from lookup - this properly handles CNAME chains where + // the target may be NXDOMAIN or SERVFAIL even though we have CNAME records + if result.rcode != 0 { + return result.rcode + } + + // No records found, but domain exists with different record types (NODATA) + if d.hasRecordsForDomain(domain.Domain(question.Name)) { + return dns.RcodeSuccess + } + + return dns.RcodeNameError +} + // hasRecordsForDomain checks if any records exist for the given domain name regardless of type func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool { d.mu.RLock() @@ -89,8 +129,33 @@ func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool { return exists } +// isInManagedZone checks if the given name falls within any of our managed zones. +// This is used to avoid unnecessary external resolution for CNAME targets that +// are within zones we manage - if we don't have a record for it, it doesn't exist. +// Caller must NOT hold the lock. +func (d *Resolver) isInManagedZone(name string) bool { + d.mu.RLock() + defer d.mu.RUnlock() + + name = dns.Fqdn(name) + for _, zone := range d.zones { + zoneStr := dns.Fqdn(zone.PunycodeString()) + if strings.EqualFold(name, zoneStr) || strings.HasSuffix(strings.ToLower(name), strings.ToLower("."+zoneStr)) { + return true + } + } + return false +} + +// lookupResult contains the result of a DNS lookup operation. +type lookupResult struct { + records []dns.RR + rcode int + hasExternalData bool +} + // lookupRecords fetches *all* DNS records matching the first question in r. -func (d *Resolver) lookupRecords(question dns.Question) []dns.RR { +func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) lookupResult { d.mu.RLock() records, found := d.records[question] @@ -98,10 +163,14 @@ func (d *Resolver) lookupRecords(question dns.Question) []dns.RR { d.mu.RUnlock() // alternatively check if we have a cname if question.Qtype != dns.TypeCNAME { - question.Qtype = dns.TypeCNAME - return d.lookupRecords(question) + cnameQuestion := dns.Question{ + Name: question.Name, + Qtype: dns.TypeCNAME, + Qclass: question.Qclass, + } + return d.lookupCNAMEChain(logger, cnameQuestion, question.Qtype) } - return nil + return lookupResult{rcode: dns.RcodeNameError} } recordsCopy := slices.Clone(records) @@ -119,16 +188,172 @@ func (d *Resolver) lookupRecords(question dns.Question) []dns.RR { d.mu.Unlock() } - return recordsCopy + return lookupResult{records: recordsCopy, rcode: dns.RcodeSuccess} } -func (d *Resolver) Update(update []nbdns.SimpleRecord) { +// lookupCNAMEChain follows a CNAME chain and returns the CNAME records along with +// the final resolved record of the requested type. This is required for musl libc +// compatibility, which expects the full answer chain rather than just the CNAME. +func (d *Resolver) lookupCNAMEChain(logger *log.Entry, cnameQuestion dns.Question, targetType uint16) lookupResult { + const maxDepth = 8 + var chain []dns.RR + + for range maxDepth { + cnameRecords := d.getRecords(cnameQuestion) + if len(cnameRecords) == 0 { + break + } + + chain = append(chain, cnameRecords...) + + cname, ok := cnameRecords[0].(*dns.CNAME) + if !ok { + break + } + + targetName := strings.ToLower(cname.Target) + targetResult := d.resolveCNAMETarget(logger, targetName, targetType, cnameQuestion.Qclass) + + // keep following chain + if targetResult.rcode == -1 { + cnameQuestion = dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: cnameQuestion.Qclass} + continue + } + + return d.buildChainResult(chain, targetResult) + } + + if len(chain) > 0 { + return lookupResult{records: chain, rcode: dns.RcodeSuccess} + } + return lookupResult{rcode: dns.RcodeSuccess} +} + +// buildChainResult combines CNAME chain records with the target resolution result. +// Per RFC 6604, the final rcode is propagated through the chain. +func (d *Resolver) buildChainResult(chain []dns.RR, target lookupResult) lookupResult { + records := chain + if len(target.records) > 0 { + records = append(records, target.records...) + } + + // preserve hasExternalData for SERVFAIL so caller knows the error came from upstream + if target.hasExternalData && target.rcode == dns.RcodeServerFailure { + return lookupResult{ + records: records, + rcode: dns.RcodeServerFailure, + hasExternalData: true, + } + } + + return lookupResult{ + records: records, + rcode: target.rcode, + hasExternalData: target.hasExternalData, + } +} + +// resolveCNAMETarget attempts to resolve a CNAME target name. +// Returns rcode=-1 to signal "keep following the chain". +func (d *Resolver) resolveCNAMETarget(logger *log.Entry, targetName string, targetType uint16, qclass uint16) lookupResult { + if records := d.getRecords(dns.Question{Name: targetName, Qtype: targetType, Qclass: qclass}); len(records) > 0 { + return lookupResult{records: records, rcode: dns.RcodeSuccess} + } + + // another CNAME, keep following + if d.hasRecord(dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: qclass}) { + return lookupResult{rcode: -1} + } + + // domain exists locally but not this record type (NODATA) + if d.hasRecordsForDomain(domain.Domain(targetName)) { + return lookupResult{rcode: dns.RcodeSuccess} + } + + // in our zone but doesn't exist (NXDOMAIN) + if d.isInManagedZone(targetName) { + return lookupResult{rcode: dns.RcodeNameError} + } + + return d.resolveExternal(logger, targetName, targetType) +} + +func (d *Resolver) getRecords(q dns.Question) []dns.RR { + d.mu.RLock() + defer d.mu.RUnlock() + return d.records[q] +} + +func (d *Resolver) hasRecord(q dns.Question) bool { + d.mu.RLock() + defer d.mu.RUnlock() + _, ok := d.records[q] + return ok +} + +// resolveExternal resolves a domain name using the system resolver. +// This is used to resolve CNAME targets that point outside our local zone, +// which is required for musl libc compatibility (musl expects complete answers). +func (d *Resolver) resolveExternal(logger *log.Entry, name string, qtype uint16) lookupResult { + network := resutil.NetworkForQtype(qtype) + if network == "" { + return lookupResult{rcode: dns.RcodeNotImplemented} + } + + resolver := d.resolver + if resolver == nil { + resolver = net.DefaultResolver + } + + ctx, cancel := context.WithTimeout(d.ctx, externalResolutionTimeout) + defer cancel() + + result := resutil.LookupIP(ctx, resolver, network, name, qtype) + if result.Err != nil { + d.logDNSError(logger, name, qtype, result.Err) + return lookupResult{rcode: result.Rcode, hasExternalData: true} + } + + return lookupResult{ + records: resutil.IPsToRRs(name, result.IPs, 60), + rcode: dns.RcodeSuccess, + hasExternalData: true, + } +} + +// logDNSError logs DNS resolution errors for debugging. +func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16, err error) { + qtypeName := dns.TypeToString[qtype] + + var dnsErr *net.DNSError + if !errors.As(err, &dnsErr) { + logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err) + return + } + + if dnsErr.IsNotFound { + logger.Tracef("DNS target not found: %s type %s", hostname, qtypeName) + return + } + + if dnsErr.Server != "" { + logger.Debugf("DNS resolution failed for %s type %s server=%s: %v", hostname, qtypeName, dnsErr.Server, err) + } else { + logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err) + } +} + +// Update updates the resolver with new records and zone information. +// The zones parameter specifies which DNS zones this resolver manages. +func (d *Resolver) Update(update []nbdns.SimpleRecord, zones []domain.Domain) { d.mu.Lock() defer d.mu.Unlock() maps.Clear(d.records) maps.Clear(d.domains) + d.zones = zones + for _, rec := range update { if err := d.registerRecord(rec); err != nil { log.Warnf("failed to register the record (%s): %v", rec, err) diff --git a/client/internal/dns/local/local_test.go b/client/internal/dns/local/local_test.go index 8b13b69ff..2f8e08b1a 100644 --- a/client/internal/dns/local/local_test.go +++ b/client/internal/dns/local/local_test.go @@ -1,8 +1,14 @@ package local import ( + "context" + "fmt" + "net" + "net/netip" "strings" + "sync" "testing" + "time" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -10,8 +16,21 @@ import ( "github.com/netbirdio/netbird/client/internal/dns/test" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/shared/management/domain" ) +// mockResolver implements resolver for testing +type mockResolver struct { + lookupFunc func(ctx context.Context, network, host string) ([]netip.Addr, error) +} + +func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) { + if m.lookupFunc != nil { + return m.lookupFunc(ctx, network, host) + } + return nil, nil +} + func TestLocalResolver_ServeDNS(t *testing.T) { recordA := nbdns.SimpleRecord{ Name: "peera.netbird.cloud.", @@ -110,7 +129,7 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) { update2 := []nbdns.SimpleRecord{record2} // Apply first update - resolver.Update(update1) + resolver.Update(update1, nil) // Verify first update resolver.mu.RLock() @@ -122,7 +141,7 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) { assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData) // Apply second update - resolver.Update(update2) + resolver.Update(update2, nil) // Verify second update resolver.mu.RLock() @@ -154,7 +173,7 @@ func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) { update := []nbdns.SimpleRecord{record1, record2} // Apply update with both records - resolver.Update(update) + resolver.Update(update, nil) // Create question that matches both records question := dns.Question{ @@ -198,7 +217,7 @@ func TestLocalResolver_RecordRotation(t *testing.T) { update := []nbdns.SimpleRecord{record1, record2, record3} // Apply update with all three records - resolver.Update(update) + resolver.Update(update, nil) msg := new(dns.Msg).SetQuestion(recordName, recordType) @@ -264,7 +283,7 @@ func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) { } // Update resolver with the records - resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord}) + resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord}, nil) testCases := []struct { name string @@ -379,7 +398,7 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) { } // Update resolver with both records - resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord}) + resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord}, nil) testCases := []struct { name string @@ -476,6 +495,20 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) { // with 0 records instead of NXDOMAIN func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) { resolver := NewResolver() + // Mock external resolver for CNAME target resolution + resolver.resolver = &mockResolver{ + lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) { + if host == "target.example.com." { + if network == "ip4" { + return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil + } + if network == "ip6" { + return []netip.Addr{netip.MustParseAddr("2606:2800:220:1:248:1893:25c8:1946")}, nil + } + } + return nil, &net.DNSError{IsNotFound: true, Name: host} + }, + } recordA := nbdns.SimpleRecord{ Name: "example.netbird.cloud.", @@ -493,7 +526,7 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) { RData: "target.example.com.", } - resolver.Update([]nbdns.SimpleRecord{recordA, recordCNAME}) + resolver.Update([]nbdns.SimpleRecord{recordA, recordCNAME}, nil) testCases := []struct { name string @@ -582,3 +615,555 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) { }) } } + +// TestLocalResolver_CNAMEChainResolution tests comprehensive CNAME chain following +func TestLocalResolver_CNAMEChainResolution(t *testing.T) { + t.Run("simple internal CNAME chain", func(t *testing.T) { + resolver := NewResolver() + resolver.Update([]nbdns.SimpleRecord{ + {Name: "alias.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."}, + {Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1"}, + }, nil) + + msg := new(dns.Msg).SetQuestion("alias.example.com.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 2) + + cname, ok := resp.Answer[0].(*dns.CNAME) + require.True(t, ok) + assert.Equal(t, "target.example.com.", cname.Target) + + a, ok := resp.Answer[1].(*dns.A) + require.True(t, ok) + assert.Equal(t, "192.168.1.1", a.A.String()) + }) + + t.Run("multi-hop CNAME chain", func(t *testing.T) { + resolver := NewResolver() + resolver.Update([]nbdns.SimpleRecord{ + {Name: "hop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop2.test."}, + {Name: "hop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop3.test."}, + {Name: "hop3.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, nil) + + msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 3) + }) + + t.Run("CNAME to non-existent internal target returns only CNAME", func(t *testing.T) { + resolver := NewResolver() + resolver.Update([]nbdns.SimpleRecord{ + {Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.test."}, + }, nil) + + msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Len(t, resp.Answer, 1) + _, ok := resp.Answer[0].(*dns.CNAME) + assert.True(t, ok) + }) +} + +// TestLocalResolver_CNAMEMaxDepth tests the maximum depth limit for CNAME chains +func TestLocalResolver_CNAMEMaxDepth(t *testing.T) { + t.Run("chain at max depth resolves", func(t *testing.T) { + resolver := NewResolver() + var records []nbdns.SimpleRecord + // Create chain of 7 CNAMEs (under max of 8) + for i := 1; i <= 7; i++ { + records = append(records, nbdns.SimpleRecord{ + Name: fmt.Sprintf("hop%d.test.", i), + Type: int(dns.TypeCNAME), + Class: nbdns.DefaultClass, + TTL: 300, + RData: fmt.Sprintf("hop%d.test.", i+1), + }) + } + records = append(records, nbdns.SimpleRecord{ + Name: "hop8.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10", + }) + + resolver.Update(records, nil) + + msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 8) + }) + + t.Run("chain exceeding max depth stops", func(t *testing.T) { + resolver := NewResolver() + var records []nbdns.SimpleRecord + // Create chain of 10 CNAMEs (exceeds max of 8) + for i := 1; i <= 10; i++ { + records = append(records, nbdns.SimpleRecord{ + Name: fmt.Sprintf("deep%d.test.", i), + Type: int(dns.TypeCNAME), + Class: nbdns.DefaultClass, + TTL: 300, + RData: fmt.Sprintf("deep%d.test.", i+1), + }) + } + records = append(records, nbdns.SimpleRecord{ + Name: "deep11.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10", + }) + + resolver.Update(records, nil) + + msg := new(dns.Msg).SetQuestion("deep1.test.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + // Should NOT have the final A record (chain too deep) + assert.LessOrEqual(t, len(resp.Answer), 8) + }) + + t.Run("circular CNAME is protected by max depth", func(t *testing.T) { + resolver := NewResolver() + resolver.Update([]nbdns.SimpleRecord{ + {Name: "loop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop2.test."}, + {Name: "loop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop1.test."}, + }, nil) + + msg := new(dns.Msg).SetQuestion("loop1.test.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + assert.LessOrEqual(t, len(resp.Answer), 8) + }) +} + +// TestLocalResolver_ExternalCNAMEResolution tests CNAME resolution to external domains +func TestLocalResolver_ExternalCNAMEResolution(t *testing.T) { + t.Run("CNAME to external domain resolves via external resolver", func(t *testing.T) { + resolver := NewResolver() + resolver.resolver = &mockResolver{ + lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) { + if host == "external.example.com." && network == "ip4" { + return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil + } + return nil, nil + }, + } + + resolver.Update([]nbdns.SimpleRecord{ + {Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."}, + }, nil) + + msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Len(t, resp.Answer, 2, "Should have CNAME + A record") + + cname, ok := resp.Answer[0].(*dns.CNAME) + require.True(t, ok) + assert.Equal(t, "external.example.com.", cname.Target) + + a, ok := resp.Answer[1].(*dns.A) + require.True(t, ok) + assert.Equal(t, "93.184.216.34", a.A.String()) + }) + + t.Run("CNAME to external domain resolves IPv6", func(t *testing.T) { + resolver := NewResolver() + resolver.resolver = &mockResolver{ + lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) { + if host == "external.example.com." && network == "ip6" { + return []netip.Addr{netip.MustParseAddr("2606:2800:220:1:248:1893:25c8:1946")}, nil + } + return nil, nil + }, + } + + resolver.Update([]nbdns.SimpleRecord{ + {Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."}, + }, nil) + + msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Len(t, resp.Answer, 2, "Should have CNAME + AAAA record") + + cname, ok := resp.Answer[0].(*dns.CNAME) + require.True(t, ok) + assert.Equal(t, "external.example.com.", cname.Target) + + aaaa, ok := resp.Answer[1].(*dns.AAAA) + require.True(t, ok) + assert.Equal(t, "2606:2800:220:1:248:1893:25c8:1946", aaaa.AAAA.String()) + }) + + t.Run("concurrent external resolution", func(t *testing.T) { + resolver := NewResolver() + resolver.resolver = &mockResolver{ + lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) { + if host == "external.example.com." && network == "ip4" { + return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil + } + return nil, nil + }, + } + + resolver.Update([]nbdns.SimpleRecord{ + {Name: "concurrent.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."}, + }, nil) + + var wg sync.WaitGroup + results := make([]*dns.Msg, 10) + + for i := 0; i < 10; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + msg := new(dns.Msg).SetQuestion("concurrent.test.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + results[idx] = resp + }(i) + } + wg.Wait() + + for i, resp := range results { + require.NotNil(t, resp, "Response %d should not be nil", i) + require.Len(t, resp.Answer, 2, "Response %d should have CNAME + A", i) + } + }) +} + +// TestLocalResolver_ZoneManagement tests zone-aware CNAME resolution +func TestLocalResolver_ZoneManagement(t *testing.T) { + t.Run("Update sets zones correctly", func(t *testing.T) { + resolver := NewResolver() + + zones := []domain.Domain{"example.com", "test.local"} + resolver.Update([]nbdns.SimpleRecord{ + {Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, zones) + + assert.True(t, resolver.isInManagedZone("host.example.com.")) + assert.True(t, resolver.isInManagedZone("other.example.com.")) + assert.True(t, resolver.isInManagedZone("sub.test.local.")) + assert.False(t, resolver.isInManagedZone("external.com.")) + }) + + t.Run("isInManagedZone case insensitive", func(t *testing.T) { + resolver := NewResolver() + resolver.Update(nil, []domain.Domain{"Example.COM"}) + + assert.True(t, resolver.isInManagedZone("host.example.com.")) + assert.True(t, resolver.isInManagedZone("HOST.EXAMPLE.COM.")) + }) + + t.Run("Update clears zones", func(t *testing.T) { + resolver := NewResolver() + resolver.Update(nil, []domain.Domain{"example.com"}) + assert.True(t, resolver.isInManagedZone("host.example.com.")) + + resolver.Update(nil, nil) + assert.False(t, resolver.isInManagedZone("host.example.com.")) + }) +} + +// TestLocalResolver_CNAMEZoneAwareResolution tests CNAME resolution with zone awareness +func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) { + t.Run("CNAME target in managed zone returns NXDOMAIN per RFC 6604", func(t *testing.T) { + resolver := NewResolver() + resolver.Update([]nbdns.SimpleRecord{ + {Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.myzone.test."}, + }, []domain.Domain{"myzone.test"}) + + msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + assert.Equal(t, dns.RcodeNameError, resp.Rcode, "Should return NXDOMAIN") + require.Len(t, resp.Answer, 1, "Should include CNAME in answer") + }) + + t.Run("CNAME to external domain skips zone check", func(t *testing.T) { + resolver := NewResolver() + resolver.resolver = &mockResolver{ + lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) { + if host == "external.other.com." && network == "ip4" { + return []netip.Addr{netip.MustParseAddr("203.0.113.1")}, nil + } + return nil, nil + }, + } + + resolver.Update([]nbdns.SimpleRecord{ + {Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.other.com."}, + }, []domain.Domain{"myzone.test"}) + + msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + assert.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 2, "Should have CNAME + A from external resolution") + }) + + t.Run("CNAME target exists with different type returns NODATA not NXDOMAIN", func(t *testing.T) { + resolver := NewResolver() + // CNAME points to target that has A but no AAAA - query for AAAA should be NODATA + resolver.Update([]nbdns.SimpleRecord{ + {Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.myzone.test."}, + {Name: "target.myzone.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1"}, + }, []domain.Domain{"myzone.test"}) + + msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeAAAA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success), not NXDOMAIN") + require.Len(t, resp.Answer, 1, "Should have only CNAME, no AAAA") + _, ok := resp.Answer[0].(*dns.CNAME) + assert.True(t, ok, "Answer should be CNAME record") + }) + + t.Run("external CNAME target exists but no AAAA records (NODATA)", func(t *testing.T) { + resolver := NewResolver() + resolver.resolver = &mockResolver{ + lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) { + if host == "external.example.com." { + if network == "ip6" { + // No AAAA records + return nil, &net.DNSError{IsNotFound: true, Name: host} + } + if network == "ip4" { + // But A records exist - domain exists + return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil + } + } + return nil, &net.DNSError{IsNotFound: true, Name: host} + }, + } + + resolver.Update([]nbdns.SimpleRecord{ + {Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."}, + }, nil) + + msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "Should return NODATA (success), not NXDOMAIN") + require.Len(t, resp.Answer, 1, "Should have only CNAME") + _, ok := resp.Answer[0].(*dns.CNAME) + assert.True(t, ok, "Answer should be CNAME record") + }) + + // Table-driven test for all external resolution outcomes + externalCases := []struct { + name string + lookupFunc func(context.Context, string, string) ([]netip.Addr, error) + expectedRcode int + expectedAnswer int + }{ + { + name: "external NXDOMAIN (both A and AAAA not found)", + lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) { + return nil, &net.DNSError{IsNotFound: true, Name: host} + }, + expectedRcode: dns.RcodeNameError, + expectedAnswer: 1, // CNAME only + }, + { + name: "external SERVFAIL (temporary error)", + lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) { + return nil, &net.DNSError{IsTemporary: true, Name: host} + }, + expectedRcode: dns.RcodeServerFailure, + expectedAnswer: 1, // CNAME only + }, + { + name: "external SERVFAIL (timeout)", + lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) { + return nil, &net.DNSError{IsTimeout: true, Name: host} + }, + expectedRcode: dns.RcodeServerFailure, + expectedAnswer: 1, // CNAME only + }, + { + name: "external SERVFAIL (generic error)", + lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) { + return nil, fmt.Errorf("connection refused") + }, + expectedRcode: dns.RcodeServerFailure, + expectedAnswer: 1, // CNAME only + }, + { + name: "external success with IPs", + lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) { + if network == "ip4" { + return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil + } + return nil, &net.DNSError{IsNotFound: true, Name: host} + }, + expectedRcode: dns.RcodeSuccess, + expectedAnswer: 2, // CNAME + A + }, + } + + for _, tc := range externalCases { + t.Run(tc.name, func(t *testing.T) { + resolver := NewResolver() + resolver.resolver = &mockResolver{lookupFunc: tc.lookupFunc} + + resolver.Update([]nbdns.SimpleRecord{ + {Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."}, + }, nil) + + msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + assert.Equal(t, tc.expectedRcode, resp.Rcode, "rcode mismatch") + assert.Len(t, resp.Answer, tc.expectedAnswer, "answer count mismatch") + if tc.expectedAnswer > 0 { + _, ok := resp.Answer[0].(*dns.CNAME) + assert.True(t, ok, "first answer should be CNAME") + } + }) + } +} + +// TestLocalResolver_AuthoritativeFlag tests the AA flag behavior +func TestLocalResolver_AuthoritativeFlag(t *testing.T) { + t.Run("direct record lookup is authoritative", func(t *testing.T) { + resolver := NewResolver() + resolver.Update([]nbdns.SimpleRecord{ + {Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, []domain.Domain{"example.com"}) + + msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + assert.True(t, resp.Authoritative) + }) + + t.Run("external resolution is not authoritative", func(t *testing.T) { + resolver := NewResolver() + resolver.resolver = &mockResolver{ + lookupFunc: func(_ context.Context, network, host string) ([]netip.Addr, error) { + if host == "external.example.com." && network == "ip4" { + return []netip.Addr{netip.MustParseAddr("93.184.216.34")}, nil + } + return nil, nil + }, + } + + resolver.Update([]nbdns.SimpleRecord{ + {Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."}, + }, nil) + + msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + require.Len(t, resp.Answer, 2) + assert.False(t, resp.Authoritative) + }) +} + +// TestLocalResolver_Stop tests cleanup on Stop +func TestLocalResolver_Stop(t *testing.T) { + t.Run("Stop clears all state", func(t *testing.T) { + resolver := NewResolver() + resolver.Update([]nbdns.SimpleRecord{ + {Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, []domain.Domain{"example.com"}) + + resolver.Stop() + + msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA) + var resp *dns.Msg + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { resp = m; return nil }}, msg) + + require.NotNil(t, resp) + assert.Len(t, resp.Answer, 0) + assert.False(t, resolver.isInManagedZone("host.example.com.")) + }) + + t.Run("Stop is safe to call multiple times", func(t *testing.T) { + resolver := NewResolver() + resolver.Update([]nbdns.SimpleRecord{ + {Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, []domain.Domain{"example.com"}) + + resolver.Stop() + resolver.Stop() + resolver.Stop() + }) + + t.Run("Stop cancels in-flight external resolution", func(t *testing.T) { + resolver := NewResolver() + + lookupStarted := make(chan struct{}) + lookupCtxCanceled := make(chan struct{}) + + resolver.resolver = &mockResolver{ + lookupFunc: func(ctx context.Context, network, host string) ([]netip.Addr, error) { + close(lookupStarted) + <-ctx.Done() + close(lookupCtxCanceled) + return nil, ctx.Err() + }, + } + + resolver.Update([]nbdns.SimpleRecord{ + {Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."}, + }, nil) + + done := make(chan struct{}) + go func() { + msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA) + resolver.ServeDNS(&test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}, msg) + close(done) + }() + + <-lookupStarted + resolver.Stop() + + select { + case <-lookupCtxCanceled: + case <-time.After(time.Second): + t.Fatal("external lookup context was not canceled") + } + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("ServeDNS did not return after Stop") + } + }) +} diff --git a/client/internal/dns/resutil/resolve.go b/client/internal/dns/resutil/resolve.go new file mode 100644 index 000000000..5a3744719 --- /dev/null +++ b/client/internal/dns/resutil/resolve.go @@ -0,0 +1,197 @@ +// Package resutil provides shared DNS resolution utilities +package resutil + +import ( + "context" + "crypto/rand" + "encoding/hex" + "errors" + "net" + "net/netip" + "strings" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" +) + +// GenerateRequestID creates a random 8-character hex string for request tracing. +func GenerateRequestID() string { + bytes := make([]byte, 4) + if _, err := rand.Read(bytes); err != nil { + log.Errorf("generate request ID: %v", err) + return "" + } + return hex.EncodeToString(bytes) +} + +// IPsToRRs converts a slice of IP addresses to DNS resource records. +// IPv4 addresses become A records, IPv6 addresses become AAAA records. +func IPsToRRs(name string, ips []netip.Addr, ttl uint32) []dns.RR { + var result []dns.RR + + for _, ip := range ips { + if ip.Is6() { + result = append(result, &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: ttl, + }, + AAAA: ip.AsSlice(), + }) + } else { + result = append(result, &dns.A{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: ttl, + }, + A: ip.AsSlice(), + }) + } + } + + return result +} + +// NetworkForQtype returns the network string ("ip4" or "ip6") for a DNS query type. +// Returns empty string for unsupported types. +func NetworkForQtype(qtype uint16) string { + switch qtype { + case dns.TypeA: + return "ip4" + case dns.TypeAAAA: + return "ip6" + default: + return "" + } +} + +type resolver interface { + LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) +} + +// chainedWriter is implemented by ResponseWriters that carry request metadata +type chainedWriter interface { + RequestID() string + SetMeta(key, value string) +} + +// GetRequestID extracts a request ID from the ResponseWriter if available, +// otherwise generates a new one. +func GetRequestID(w dns.ResponseWriter) string { + if cw, ok := w.(chainedWriter); ok { + if id := cw.RequestID(); id != "" { + return id + } + } + return GenerateRequestID() +} + +// SetMeta sets metadata on the ResponseWriter if it supports it. +func SetMeta(w dns.ResponseWriter, key, value string) { + if cw, ok := w.(chainedWriter); ok { + cw.SetMeta(key, value) + } +} + +// LookupResult contains the result of an external DNS lookup +type LookupResult struct { + IPs []netip.Addr + Rcode int + Err error // Original error for caller's logging needs +} + +// LookupIP performs a DNS lookup and determines the appropriate rcode. +func LookupIP(ctx context.Context, r resolver, network, host string, qtype uint16) LookupResult { + ips, err := r.LookupNetIP(ctx, network, host) + if err != nil { + return LookupResult{ + Rcode: getRcodeForError(ctx, r, host, qtype, err), + Err: err, + } + } + + // Unmap IPv4-mapped IPv6 addresses that some resolvers may return + for i, ip := range ips { + ips[i] = ip.Unmap() + } + + return LookupResult{ + IPs: ips, + Rcode: dns.RcodeSuccess, + } +} + +func getRcodeForError(ctx context.Context, r resolver, host string, qtype uint16, err error) int { + var dnsErr *net.DNSError + if !errors.As(err, &dnsErr) { + return dns.RcodeServerFailure + } + + if dnsErr.IsNotFound { + return getRcodeForNotFound(ctx, r, host, qtype) + } + + return dns.RcodeServerFailure +} + +// getRcodeForNotFound distinguishes between NXDOMAIN (domain doesn't exist) and NODATA +// (domain exists but no records of requested type) by checking the opposite record type. +// +// musl libc (the reason we need this distinction) only queries A/AAAA pairs in getaddrinfo, +// so checking the opposite A/AAAA type is sufficient. Other record types (MX, TXT, etc.) +// are not queried by musl and don't need this handling. +func getRcodeForNotFound(ctx context.Context, r resolver, domain string, originalQtype uint16) int { + // Try querying for a different record type to see if the domain exists + // If the original query was for AAAA, try A. If it was for A, try AAAA. + // This helps distinguish between NXDOMAIN and NODATA. + var alternativeNetwork string + switch originalQtype { + case dns.TypeAAAA: + alternativeNetwork = "ip4" + case dns.TypeA: + alternativeNetwork = "ip6" + default: + return dns.RcodeNameError + } + + if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil { + var dnsErr *net.DNSError + if errors.As(err, &dnsErr) && dnsErr.IsNotFound { + // Alternative query also returned not found - domain truly doesn't exist + return dns.RcodeNameError + } + // Some other error (timeout, server failure, etc.) - can't determine, assume domain exists + return dns.RcodeSuccess + } + + // Alternative query succeeded - domain exists but has no records of this type + return dns.RcodeSuccess +} + +// FormatAnswers formats DNS resource records for logging. +func FormatAnswers(answers []dns.RR) string { + if len(answers) == 0 { + return "[]" + } + + parts := make([]string, 0, len(answers)) + for _, rr := range answers { + switch r := rr.(type) { + case *dns.A: + parts = append(parts, r.A.String()) + case *dns.AAAA: + parts = append(parts, r.AAAA.String()) + case *dns.CNAME: + parts = append(parts, "CNAME:"+r.Target) + case *dns.PTR: + parts = append(parts, "PTR:"+r.Ptr) + default: + parts = append(parts, dns.TypeToString[rr.Header().Rrtype]) + } + } + return "[" + strings.Join(parts, ", ") + "]" +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 94945b55a..0a56b92a1 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -485,7 +485,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { } } - localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones) + localMuxUpdates, localRecords, localZones, err := s.buildLocalHandlerUpdate(update.CustomZones) if err != nil { return fmt.Errorf("local handler updater: %w", err) } @@ -499,7 +499,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { s.updateMux(muxUpdates) // register local records - s.localResolver.Update(localRecords) + s.localResolver.Update(localRecords, localZones) s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort()) @@ -659,9 +659,10 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) { s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback) } -func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, error) { +func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, []domain.Domain, error) { var muxUpdates []handlerWrapper var localRecords []nbdns.SimpleRecord + var zones []domain.Domain for _, customZone := range customZones { if len(customZone.Records) == 0 { @@ -675,6 +676,8 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) priority: PriorityLocal, }) + zones = append(zones, domain.Domain(customZone.Domain)) + for _, record := range customZone.Records { if record.Class != nbdns.DefaultClass { log.Warnf("received an invalid class type: %s", record.Class) @@ -685,7 +688,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) } } - return muxUpdates, localRecords, nil + return muxUpdates, localRecords, zones, nil } func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) { diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index fe1f67f66..2b5b460b4 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -385,7 +385,7 @@ func TestUpdateDNSServer(t *testing.T) { }() dnsServer.dnsMuxMap = testCase.initUpstreamMap - dnsServer.localResolver.Update(testCase.initLocalRecords) + dnsServer.localResolver.Update(testCase.initLocalRecords, nil) dnsServer.updateSerial = testCase.initSerial err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate) @@ -511,7 +511,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { }, } //dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}} - dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}) + dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, nil) dnsServer.updateSerial = 0 nameServers := []nbdns.NameServer{ @@ -2013,7 +2013,7 @@ func TestLocalResolverPriorityInServer(t *testing.T) { }, } - localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones) + localMuxUpdates, _, _, err := server.buildLocalHandlerUpdate(config.CustomZones) assert.NoError(t, err) upstreamMuxUpdates, err := server.buildUpstreamHandlerUpdate(config.NameServerGroups) @@ -2074,7 +2074,7 @@ func TestLocalResolverPriorityConstants(t *testing.T) { }, } - localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones) + localMuxUpdates, _, _, err := server.buildLocalHandlerUpdate(config.CustomZones) assert.NoError(t, err) assert.Len(t, localMuxUpdates, 1) assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal") diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 2a92fd6d8..6b52010fb 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -2,7 +2,6 @@ package dns import ( "context" - "crypto/rand" "crypto/sha256" "encoding/hex" "errors" @@ -21,6 +20,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/proto" @@ -113,10 +113,7 @@ func (u *upstreamResolverBase) Stop() { // ServeDNS handles a DNS request func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - requestID := GenerateRequestID() - logger := log.WithField("request_id", requestID) - - logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + logger := log.WithField("request_id", resutil.GetRequestID(w)) u.prepareRequest(r) @@ -202,11 +199,14 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool { u.successCount.Add(1) - logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain) + + resutil.SetMeta(w, "upstream", upstream.String()) if err := w.WriteMsg(rm); err != nil { logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err) + return true } + return true } @@ -414,16 +414,6 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u return rm, t, nil } -func GenerateRequestID() string { - bytes := make([]byte, 4) - _, err := rand.Read(bytes) - if err != nil { - log.Errorf("failed to generate request ID: %v", err) - return "" - } - return hex.EncodeToString(bytes) -} - // FormatPeerStatus formats peer connection status information for debugging DNS timeouts func FormatPeerStatus(peerState *peer.State) string { isConnected := peerState.ConnStatus == peer.StatusConnected diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 6b8042ccb..1230a4e46 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -18,6 +18,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/route" ) @@ -189,29 +190,22 @@ func (f *DNSForwarder) Close(ctx context.Context) error { return nberrors.FormatErrorOrNil(result) } -func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns.Msg { +func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg { if len(query.Question) == 0 { return nil } question := query.Question[0] - log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v", - question.Name, question.Qtype, question.Qclass) + logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s", + question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass]) domain := strings.ToLower(question.Name) resp := query.SetReply(query) - var network string - switch question.Qtype { - case dns.TypeA: - network = "ip4" - case dns.TypeAAAA: - network = "ip6" - default: - // TODO: Handle other types - + network := resutil.NetworkForQtype(question.Qtype) + if network == "" { resp.Rcode = dns.RcodeNotImplemented if err := w.WriteMsg(resp); err != nil { - log.Errorf("failed to write DNS response: %v", err) + logger.Errorf("failed to write DNS response: %v", err) } return nil } @@ -221,33 +215,35 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns if mostSpecificResId == "" { resp.Rcode = dns.RcodeRefused if err := w.WriteMsg(resp); err != nil { - log.Errorf("failed to write DNS response: %v", err) + logger.Errorf("failed to write DNS response: %v", err) } return nil } ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) defer cancel() - ips, err := f.resolver.LookupNetIP(ctx, network, domain) - if err != nil { - f.handleDNSError(ctx, w, question, resp, domain, err) + + result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype) + if result.Err != nil { + f.handleDNSError(ctx, logger, w, question, resp, domain, result) return nil } - // Unmap IPv4-mapped IPv6 addresses that some resolvers may return - for i, ip := range ips { - ips[i] = ip.Unmap() - } - - f.updateInternalState(ips, mostSpecificResId, matchingEntries) - f.addIPsToResponse(resp, domain, ips) - f.cache.set(domain, question.Qtype, ips) + f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries) + resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...) + f.cache.set(domain, question.Qtype, result.IPs) return resp } func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { - resp := f.handleDNSQuery(w, query) + startTime := time.Now() + logger := log.WithFields(log.Fields{ + "request_id": resutil.GenerateRequestID(), + "dns_id": fmt.Sprintf("%04x", query.Id), + }) + + resp := f.handleDNSQuery(logger, w, query) if resp == nil { return } @@ -265,19 +261,33 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { } if err := w.WriteMsg(resp); err != nil { - log.Errorf("failed to write DNS response: %v", err) + logger.Errorf("failed to write DNS response: %v", err) + return } + + logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s", + query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime)) } func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) { - resp := f.handleDNSQuery(w, query) + startTime := time.Now() + logger := log.WithFields(log.Fields{ + "request_id": resutil.GenerateRequestID(), + "dns_id": fmt.Sprintf("%04x", query.Id), + }) + + resp := f.handleDNSQuery(logger, w, query) if resp == nil { return } if err := w.WriteMsg(resp); err != nil { - log.Errorf("failed to write DNS response: %v", err) + logger.Errorf("failed to write DNS response: %v", err) + return } + + logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s", + query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime)) } func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) { @@ -315,140 +325,64 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe } } -// setResponseCodeForNotFound determines and sets the appropriate response code when IsNotFound is true -// It distinguishes between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of requested type) -// -// LIMITATION: This function only checks A and AAAA record types to determine domain existence. -// If a domain has only other record types (MX, TXT, CNAME, etc.) but no A/AAAA records, -// it may incorrectly return NXDOMAIN instead of NODATA. This is acceptable since the forwarder -// only handles A/AAAA queries and returns NOTIMP for other types. -func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns.Msg, domain string, originalQtype uint16) { - // Try querying for a different record type to see if the domain exists - // If the original query was for AAAA, try A. If it was for A, try AAAA. - // This helps distinguish between NXDOMAIN and NODATA. - var alternativeNetwork string - switch originalQtype { - case dns.TypeAAAA: - alternativeNetwork = "ip4" - case dns.TypeA: - alternativeNetwork = "ip6" - default: - resp.Rcode = dns.RcodeNameError - return - } - - if _, err := f.resolver.LookupNetIP(ctx, alternativeNetwork, domain); err != nil { - var dnsErr *net.DNSError - if errors.As(err, &dnsErr) && dnsErr.IsNotFound { - // Alternative query also returned not found - domain truly doesn't exist - resp.Rcode = dns.RcodeNameError - return - } - // Some other error (timeout, server failure, etc.) - can't determine, assume domain exists - resp.Rcode = dns.RcodeSuccess - return - } - - // Alternative query succeeded - domain exists but has no records of this type - resp.Rcode = dns.RcodeSuccess -} - // handleDNSError processes DNS lookup errors and sends an appropriate error response. func (f *DNSForwarder) handleDNSError( ctx context.Context, + logger *log.Entry, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, - err error, + result resutil.LookupResult, ) { - // Default to SERVFAIL; override below when appropriate. - resp.Rcode = dns.RcodeServerFailure - qType := question.Qtype qTypeName := dns.TypeToString[qType] - // Prefer typed DNS errors; fall back to generic logging otherwise. - var dnsErr *net.DNSError - if !errors.As(err, &dnsErr) { - log.Warnf(errResolveFailed, domain, err) - if writeErr := w.WriteMsg(resp); writeErr != nil { - log.Errorf("failed to write failure DNS response: %v", writeErr) - } - return - } + resp.Rcode = result.Rcode - // NotFound: set NXDOMAIN / appropriate code via helper. - if dnsErr.IsNotFound { - f.setResponseCodeForNotFound(ctx, resp, domain, qType) - if writeErr := w.WriteMsg(resp); writeErr != nil { - log.Errorf("failed to write failure DNS response: %v", writeErr) - } + // NotFound: cache negative result and respond + if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess { f.cache.set(domain, question.Qtype, nil) + if writeErr := w.WriteMsg(resp); writeErr != nil { + logger.Errorf("failed to write failure DNS response: %v", writeErr) + } return } // Upstream failed but we might have a cached answer—serve it if present. if ips, ok := f.cache.get(domain, qType); ok { if len(ips) > 0 { - log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName) - f.addIPsToResponse(resp, domain, ips) + logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName) + resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...) resp.Rcode = dns.RcodeSuccess if writeErr := w.WriteMsg(resp); writeErr != nil { - log.Errorf("failed to write cached DNS response: %v", writeErr) - } - } else { // send NXDOMAIN / appropriate code if cache is empty - f.setResponseCodeForNotFound(ctx, resp, domain, qType) - if writeErr := w.WriteMsg(resp); writeErr != nil { - log.Errorf("failed to write failure DNS response: %v", writeErr) + logger.Errorf("failed to write cached DNS response: %v", writeErr) } + return + } + + // Cached negative result - re-verify NXDOMAIN vs NODATA + verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType) + if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess { + resp.Rcode = verifyResult.Rcode + if writeErr := w.WriteMsg(resp); writeErr != nil { + logger.Errorf("failed to write failure DNS response: %v", writeErr) + } + return } - return } - // No cache. Log with or without the server field for more context. - if dnsErr.Server != "" { - log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err) + // No cache or verification failed. Log with or without the server field for more context. + var dnsErr *net.DNSError + if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" { + logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err) } else { - log.Warnf(errResolveFailed, domain, err) + logger.Warnf(errResolveFailed, domain, result.Err) } // Write final failure response. if writeErr := w.WriteMsg(resp); writeErr != nil { - log.Errorf("failed to write failure DNS response: %v", writeErr) - } -} - -// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records -func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []netip.Addr) { - for _, ip := range ips { - var respRecord dns.RR - if ip.Is6() { - log.Tracef("resolved domain=%s to IPv6=%s", domain, ip) - rr := dns.AAAA{ - AAAA: ip.AsSlice(), - Hdr: dns.RR_Header{ - Name: domain, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: f.ttl, - }, - } - respRecord = &rr - } else { - log.Tracef("resolved domain=%s to IPv4=%s", domain, ip) - rr := dns.A{ - A: ip.AsSlice(), - Hdr: dns.RR_Header{ - Name: domain, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: f.ttl, - }, - } - respRecord = &rr - } - resp.Answer = append(resp.Answer, respRecord) + logger.Errorf("failed to write failure DNS response: %v", writeErr) } } diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index 4d0b96a75..6416c2f21 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/miekg/dns" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -317,7 +318,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) { query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA) mockWriter := &test.MockResponseWriter{} - resp := forwarder.handleDNSQuery(mockWriter, query) + resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) if tt.shouldResolve { require.NotNil(t, resp, "Expected response for authorized domain") @@ -465,7 +466,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) { dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA) mockWriter := &test.MockResponseWriter{} - resp := forwarder.handleDNSQuery(mockWriter, dnsQuery) + resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery) // Verify response if tt.shouldResolve { @@ -527,7 +528,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) { query.SetQuestion("example.com.", dns.TypeA) mockWriter := &test.MockResponseWriter{} - resp := forwarder.handleDNSQuery(mockWriter, query) + resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) // Verify response contains all IPs require.NotNil(t, resp) @@ -604,7 +605,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) { }, } - _ = forwarder.handleDNSQuery(mockWriter, query) + _ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) // Check the response written to the writer require.NotNil(t, writtenResp, "Expected response to be written") @@ -674,7 +675,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { q1 := &dns.Msg{} q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) w1 := &test.MockResponseWriter{} - resp1 := forwarder.handleDNSQuery(w1, q1) + resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1) require.NotNil(t, resp1) require.Equal(t, dns.RcodeSuccess, resp1.Rcode) require.Len(t, resp1.Answer, 1) @@ -684,7 +685,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) var writtenResp *dns.Msg w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} - _ = forwarder.handleDNSQuery(w2, q2) + _ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2) require.NotNil(t, writtenResp, "expected response to be written") require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) @@ -714,7 +715,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { q1 := &dns.Msg{} q1.SetQuestion(mixedQuery+".", dns.TypeA) w1 := &test.MockResponseWriter{} - resp1 := forwarder.handleDNSQuery(w1, q1) + resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1) require.NotNil(t, resp1) require.Equal(t, dns.RcodeSuccess, resp1.Rcode) require.Len(t, resp1.Answer, 1) @@ -728,7 +729,7 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { q2.SetQuestion("EXAMPLE.COM", dns.TypeA) var writtenResp *dns.Msg w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} - _ = forwarder.handleDNSQuery(w2, q2) + _ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2) require.NotNil(t, writtenResp) require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) @@ -783,7 +784,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { query.SetQuestion("smtp.mail.example.com.", dns.TypeA) mockWriter := &test.MockResponseWriter{} - resp := forwarder.handleDNSQuery(mockWriter, query) + resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) require.NotNil(t, resp) assert.Equal(t, dns.RcodeSuccess, resp.Rcode) @@ -904,7 +905,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) { }, } - resp := forwarder.handleDNSQuery(mockWriter, query) + resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) // If a response was returned, it means it should be written (happens in wrapper functions) if resp != nil && writtenResp == nil { @@ -937,7 +938,7 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) { return nil }, } - resp := forwarder.handleDNSQuery(mockWriter, query) + resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) assert.Nil(t, resp, "Should return nil for empty query") assert.False(t, writeCalled, "Should not write response for empty query") diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 348338dac..928b85acb 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -19,6 +19,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" nbdns "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/common" @@ -219,14 +220,14 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error { // ServeDNS implements the dns.Handler interface func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - requestID := nbdns.GenerateRequestID() - logger := log.WithField("request_id", requestID) + logger := log.WithFields(log.Fields{ + "request_id": resutil.GetRequestID(w), + "dns_id": fmt.Sprintf("%04x", r.Id), + }) if len(r.Question) == 0 { return } - logger.Tracef("received DNS request for domain=%s type=%v class=%v", - r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) // pass if non A/AAAA query if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA { @@ -280,15 +281,10 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } - var answer []dns.RR - if reply != nil { - answer = reply.Answer - } - - logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer) + resutil.SetMeta(w, "peer", peerKey) reply.Id = r.Id - if err := d.writeMsg(w, reply); err != nil { + if err := d.writeMsg(w, reply, logger); err != nil { logger.Errorf("failed writing DNS response: %v", err) } } @@ -324,7 +320,7 @@ func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) { return peerAllowedIP, nil } -func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { +func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) error { if r == nil { return fmt.Errorf("received nil DNS message") } @@ -350,14 +346,14 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { case *dns.A: addr, ok := netip.AddrFromSlice(rr.A) if !ok { - log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A) + logger.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A) continue } ip = addr case *dns.AAAA: addr, ok := netip.AddrFromSlice(rr.AAAA) if !ok { - log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA) + logger.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA) continue } ip = addr @@ -370,11 +366,11 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { } if len(newPrefixes) > 0 { - if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil { - log.Errorf("failed to update domain prefixes: %v", err) + if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes, logger); err != nil { + logger.Errorf("failed to update domain prefixes: %v", err) } - d.replaceIPsInDNSResponse(r, newPrefixes) + d.replaceIPsInDNSResponse(r, newPrefixes, logger) } } @@ -386,22 +382,22 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { } // logPrefixChanges handles the logging for prefix changes -func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) { +func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix, logger *log.Entry) { if len(toAdd) > 0 { - log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s", + logger.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s", resolvedDomain.SafeString(), originalDomain.SafeString(), toAdd) } if len(toRemove) > 0 && !d.route.KeepRoute { - log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s", + logger.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s", resolvedDomain.SafeString(), originalDomain.SafeString(), toRemove) } } -func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error { +func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix, logger *log.Entry) error { d.mu.Lock() defer d.mu.Unlock() @@ -418,9 +414,9 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom realIP := prefix.Addr() if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil { dnatMappings[fakeIP] = realIP - log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP) + logger.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP) } else { - log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err) + logger.Errorf("failed to allocate fake IP for %s: %v", realIP, err) } } } @@ -432,7 +428,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom } } - d.addDNATMappings(dnatMappings) + d.addDNATMappings(dnatMappings, logger) if !d.route.KeepRoute { // Remove old prefixes @@ -448,7 +444,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom } } - d.removeDNATMappings(toRemove) + d.removeDNATMappings(toRemove, logger) } // Update domain prefixes using resolved domain as key - store real IPs @@ -463,14 +459,14 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom // Store real IPs for status (user-facing), not fake IPs d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID()) - d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove) + d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove, logger) } return nberrors.FormatErrorOrNil(merr) } // removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes -func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) { +func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix, logger *log.Entry) { if len(realPrefixes) == 0 { return } @@ -484,9 +480,9 @@ func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) { realIP := prefix.Addr() if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists { if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil { - log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err) + logger.Errorf("failed to remove DNAT mapping for %s: %v", fakeIP, err) } else { - log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP) + logger.Debugf("removed DNAT mapping: %s -> %s", fakeIP, realIP) } } } @@ -502,7 +498,7 @@ func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) { } // addDNATMappings adds DNAT mappings to the firewall -func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) { +func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr, logger *log.Entry) { if len(mappings) == 0 { return } @@ -514,9 +510,9 @@ func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) { for fakeIP, realIP := range mappings { if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil { - log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err) + logger.Errorf("failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err) } else { - log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP) + logger.Debugf("added DNAT mapping: %s -> %s", fakeIP, realIP) } } } @@ -528,12 +524,12 @@ func (d *DnsInterceptor) cleanupDNATMappings() { } for _, prefixes := range d.interceptedDomains { - d.removeDNATMappings(prefixes) + d.removeDNATMappings(prefixes, log.NewEntry(log.StandardLogger())) } } // replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response -func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) { +func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix, logger *log.Entry) { if _, ok := d.internalDnatFw(); !ok { return } @@ -549,7 +545,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes [] if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists { rr.A = fakeIP.AsSlice() - log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP) + logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP) } case *dns.AAAA: @@ -560,7 +556,7 @@ func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes [] if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists { rr.AAAA = fakeIP.AsSlice() - log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP) + logger.Tracef("replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP) } } }