mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
Per RFC 4592 section 2.2.1, wildcards should only match when the queried name does not exist in the zone. Previously, if host.example.com had an A record and *.example.com had an AAAA record, querying AAAA for host.example.com would incorrectly return the wildcard AAAA instead of NODATA. Now the resolver checks if the domain exists (with any record type) before falling back to wildcard matching, returning proper NODATA responses for existing names without the requested record type.
488 lines
14 KiB
Go
488 lines
14 KiB
Go
package local
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.org/x/exp/maps"
|
|
|
|
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
|
"github.com/netbirdio/netbird/client/internal/dns/types"
|
|
nbdns "github.com/netbirdio/netbird/dns"
|
|
"github.com/netbirdio/netbird/shared/management/domain"
|
|
)
|
|
|
|
const externalResolutionTimeout = 4 * time.Second
|
|
|
|
type resolver interface {
|
|
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
|
}
|
|
|
|
type Resolver struct {
|
|
mu sync.RWMutex
|
|
records map[dns.Question][]dns.RR
|
|
domains map[domain.Domain]struct{}
|
|
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
|
|
zones map[domain.Domain]bool
|
|
resolver resolver
|
|
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
func NewResolver() *Resolver {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
return &Resolver{
|
|
records: make(map[dns.Question][]dns.RR),
|
|
domains: make(map[domain.Domain]struct{}),
|
|
zones: make(map[domain.Domain]bool),
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
}
|
|
|
|
func (d *Resolver) MatchSubdomains() bool {
|
|
return true
|
|
}
|
|
|
|
// String returns a string representation of the local resolver
|
|
func (d *Resolver) String() string {
|
|
return fmt.Sprintf("LocalResolver [%d records]", len(d.records))
|
|
}
|
|
|
|
func (d *Resolver) Stop() {
|
|
if d.cancel != nil {
|
|
d.cancel()
|
|
}
|
|
|
|
d.mu.Lock()
|
|
defer d.mu.Unlock()
|
|
|
|
maps.Clear(d.records)
|
|
maps.Clear(d.domains)
|
|
maps.Clear(d.zones)
|
|
}
|
|
|
|
// ID returns the unique handler ID
|
|
func (d *Resolver) ID() types.HandlerID {
|
|
return "local-resolver"
|
|
}
|
|
|
|
func (d *Resolver) ProbeAvailability() {}
|
|
|
|
// ServeDNS handles a DNS request
|
|
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
logger := log.WithField("request_id", resutil.GetRequestID(w))
|
|
|
|
if len(r.Question) == 0 {
|
|
logger.Debug("received local resolver request with no question")
|
|
return
|
|
}
|
|
question := r.Question[0]
|
|
question.Name = strings.ToLower(dns.Fqdn(question.Name))
|
|
|
|
replyMessage := &dns.Msg{}
|
|
replyMessage.SetReply(r)
|
|
replyMessage.RecursionAvailable = true
|
|
|
|
result := d.lookupRecords(logger, question)
|
|
replyMessage.Authoritative = !result.hasExternalData
|
|
replyMessage.Answer = result.records
|
|
replyMessage.Rcode = d.determineRcode(question, result)
|
|
|
|
if replyMessage.Rcode == dns.RcodeNameError && d.shouldFallthrough(question.Name) {
|
|
d.continueToNext(logger, w, r)
|
|
return
|
|
}
|
|
|
|
if err := w.WriteMsg(replyMessage); err != nil {
|
|
logger.Warnf("failed to write the local resolver response: %v", err)
|
|
}
|
|
}
|
|
|
|
// determineRcode returns the appropriate DNS response code.
|
|
// Per RFC 6604, CNAME chains should return the rcode of the final target resolution,
|
|
// even if CNAME records are included in the answer.
|
|
func (d *Resolver) determineRcode(question dns.Question, result lookupResult) int {
|
|
// Use the rcode from lookup - this properly handles CNAME chains where
|
|
// the target may be NXDOMAIN or SERVFAIL even though we have CNAME records
|
|
if result.rcode != 0 {
|
|
return result.rcode
|
|
}
|
|
|
|
// No records found, but domain exists with different record types (NODATA)
|
|
if d.hasRecordsForDomain(domain.Domain(question.Name), question.Qtype) {
|
|
return dns.RcodeSuccess
|
|
}
|
|
|
|
return dns.RcodeNameError
|
|
}
|
|
|
|
// findZone finds the matching zone for a query name using reverse suffix lookup.
|
|
// Returns (nonAuthoritative, found). This is O(k) where k = number of labels in qname.
|
|
func (d *Resolver) findZone(qname string) (nonAuthoritative bool, found bool) {
|
|
qname = strings.ToLower(dns.Fqdn(qname))
|
|
for {
|
|
if nonAuth, ok := d.zones[domain.Domain(qname)]; ok {
|
|
return nonAuth, true
|
|
}
|
|
// Move to parent domain
|
|
idx := strings.Index(qname, ".")
|
|
if idx == -1 || idx == len(qname)-1 {
|
|
return false, false
|
|
}
|
|
qname = qname[idx+1:]
|
|
}
|
|
}
|
|
|
|
// shouldFallthrough checks if the query should fallthrough to the next handler.
|
|
// Returns true if the queried name belongs to a non-authoritative zone.
|
|
func (d *Resolver) shouldFallthrough(qname string) bool {
|
|
d.mu.RLock()
|
|
defer d.mu.RUnlock()
|
|
|
|
nonAuth, found := d.findZone(qname)
|
|
return found && nonAuth
|
|
}
|
|
|
|
func (d *Resolver) continueToNext(logger *log.Entry, w dns.ResponseWriter, r *dns.Msg) {
|
|
resp := &dns.Msg{}
|
|
resp.SetRcode(r, dns.RcodeNameError)
|
|
resp.MsgHdr.Zero = true
|
|
if err := w.WriteMsg(resp); err != nil {
|
|
logger.Warnf("failed to write continue signal: %v", err)
|
|
}
|
|
}
|
|
|
|
// hasRecordsForDomain checks if any records exist for the given domain name regardless of type
|
|
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain, qType uint16) bool {
|
|
d.mu.RLock()
|
|
defer d.mu.RUnlock()
|
|
|
|
_, exists := d.domains[domainName]
|
|
if !exists && supportsWildcard(qType) {
|
|
testWild := transformDomainToWildcard(string(domainName))
|
|
_, exists = d.domains[domain.Domain(testWild)]
|
|
}
|
|
return exists
|
|
}
|
|
|
|
// isInManagedZone checks if the given name falls within any of our managed zones.
|
|
// This is used to avoid unnecessary external resolution for CNAME targets that
|
|
// are within zones we manage - if we don't have a record for it, it doesn't exist.
|
|
// Caller must NOT hold the lock.
|
|
func (d *Resolver) isInManagedZone(name string) bool {
|
|
d.mu.RLock()
|
|
defer d.mu.RUnlock()
|
|
|
|
_, found := d.findZone(name)
|
|
return found
|
|
}
|
|
|
|
// lookupResult contains the result of a DNS lookup operation.
|
|
type lookupResult struct {
|
|
records []dns.RR
|
|
rcode int
|
|
hasExternalData bool
|
|
}
|
|
|
|
// lookupRecords fetches *all* DNS records matching the first question in r.
|
|
func (d *Resolver) lookupRecords(logger *log.Entry, question dns.Question) lookupResult {
|
|
d.mu.RLock()
|
|
records, found := d.records[question]
|
|
usingWildcard := false
|
|
wildQuestion := transformToWildcard(question)
|
|
// RFC 4592 section 2.2.1: wildcard only matches if the name does NOT exist in the zone.
|
|
// If the domain exists with any record type, return NODATA instead of wildcard match.
|
|
if !found && supportsWildcard(question.Qtype) {
|
|
if _, domainExists := d.domains[domain.Domain(question.Name)]; !domainExists {
|
|
records, found = d.records[wildQuestion]
|
|
usingWildcard = found
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
d.mu.RUnlock()
|
|
// alternatively check if we have a cname
|
|
if question.Qtype != dns.TypeCNAME {
|
|
cnameQuestion := dns.Question{
|
|
Name: question.Name,
|
|
Qtype: dns.TypeCNAME,
|
|
Qclass: question.Qclass,
|
|
}
|
|
return d.lookupCNAMEChain(logger, cnameQuestion, question.Qtype)
|
|
}
|
|
return lookupResult{rcode: dns.RcodeNameError}
|
|
}
|
|
|
|
recordsCopy := slices.Clone(records)
|
|
d.mu.RUnlock()
|
|
|
|
// if there's more than one record, rotate them (round-robin)
|
|
if len(recordsCopy) > 1 {
|
|
d.mu.Lock()
|
|
q := question
|
|
if usingWildcard {
|
|
q = wildQuestion
|
|
}
|
|
records = d.records[q]
|
|
if len(records) > 1 {
|
|
first := records[0]
|
|
records = append(records[1:], first)
|
|
d.records[q] = records
|
|
}
|
|
d.mu.Unlock()
|
|
}
|
|
|
|
if usingWildcard {
|
|
return responseFromWildRecords(question.Name, wildQuestion.Name, recordsCopy)
|
|
}
|
|
|
|
return lookupResult{records: recordsCopy, rcode: dns.RcodeSuccess}
|
|
}
|
|
|
|
func transformToWildcard(question dns.Question) dns.Question {
|
|
wildQuestion := question
|
|
wildQuestion.Name = transformDomainToWildcard(wildQuestion.Name)
|
|
return wildQuestion
|
|
}
|
|
|
|
func transformDomainToWildcard(domain string) string {
|
|
s := strings.Split(domain, ".")
|
|
s[0] = "*"
|
|
return strings.Join(s, ".")
|
|
}
|
|
|
|
func supportsWildcard(queryType uint16) bool {
|
|
return queryType != dns.TypeNS && queryType != dns.TypeSOA
|
|
}
|
|
|
|
func responseFromWildRecords(originalName, wildName string, wildRecords []dns.RR) lookupResult {
|
|
records := make([]dns.RR, len(wildRecords))
|
|
for i, record := range wildRecords {
|
|
copiedRecord := dns.Copy(record)
|
|
copiedRecord.Header().Name = originalName
|
|
records[i] = copiedRecord
|
|
}
|
|
|
|
return lookupResult{records: records, rcode: dns.RcodeSuccess}
|
|
}
|
|
|
|
// lookupCNAMEChain follows a CNAME chain and returns the CNAME records along with
|
|
// the final resolved record of the requested type. This is required for musl libc
|
|
// compatibility, which expects the full answer chain rather than just the CNAME.
|
|
func (d *Resolver) lookupCNAMEChain(logger *log.Entry, cnameQuestion dns.Question, targetType uint16) lookupResult {
|
|
const maxDepth = 8
|
|
var chain []dns.RR
|
|
|
|
for range maxDepth {
|
|
cnameRecords := d.getRecords(cnameQuestion)
|
|
if len(cnameRecords) == 0 && supportsWildcard(targetType) {
|
|
wildQuestion := transformToWildcard(cnameQuestion)
|
|
if wildRecords := d.getRecords(wildQuestion); len(wildRecords) > 0 {
|
|
cnameRecords = responseFromWildRecords(cnameQuestion.Name, wildQuestion.Name, wildRecords).records
|
|
}
|
|
}
|
|
|
|
if len(cnameRecords) == 0 {
|
|
break
|
|
}
|
|
|
|
chain = append(chain, cnameRecords...)
|
|
|
|
cname, ok := cnameRecords[0].(*dns.CNAME)
|
|
if !ok {
|
|
break
|
|
}
|
|
|
|
targetName := strings.ToLower(cname.Target)
|
|
targetResult := d.resolveCNAMETarget(logger, targetName, targetType, cnameQuestion.Qclass)
|
|
|
|
// keep following chain
|
|
if targetResult.rcode == -1 {
|
|
cnameQuestion = dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: cnameQuestion.Qclass}
|
|
continue
|
|
}
|
|
|
|
return d.buildChainResult(chain, targetResult)
|
|
}
|
|
|
|
if len(chain) > 0 {
|
|
return lookupResult{records: chain, rcode: dns.RcodeSuccess}
|
|
}
|
|
return lookupResult{rcode: dns.RcodeSuccess}
|
|
}
|
|
|
|
// buildChainResult combines CNAME chain records with the target resolution result.
|
|
// Per RFC 6604, the final rcode is propagated through the chain.
|
|
func (d *Resolver) buildChainResult(chain []dns.RR, target lookupResult) lookupResult {
|
|
records := chain
|
|
if len(target.records) > 0 {
|
|
records = append(records, target.records...)
|
|
}
|
|
|
|
// preserve hasExternalData for SERVFAIL so caller knows the error came from upstream
|
|
if target.hasExternalData && target.rcode == dns.RcodeServerFailure {
|
|
return lookupResult{
|
|
records: records,
|
|
rcode: dns.RcodeServerFailure,
|
|
hasExternalData: true,
|
|
}
|
|
}
|
|
|
|
return lookupResult{
|
|
records: records,
|
|
rcode: target.rcode,
|
|
hasExternalData: target.hasExternalData,
|
|
}
|
|
}
|
|
|
|
// resolveCNAMETarget attempts to resolve a CNAME target name.
|
|
// Returns rcode=-1 to signal "keep following the chain".
|
|
func (d *Resolver) resolveCNAMETarget(logger *log.Entry, targetName string, targetType uint16, qclass uint16) lookupResult {
|
|
if records := d.getRecords(dns.Question{Name: targetName, Qtype: targetType, Qclass: qclass}); len(records) > 0 {
|
|
return lookupResult{records: records, rcode: dns.RcodeSuccess}
|
|
}
|
|
|
|
// another CNAME, keep following
|
|
if d.hasRecord(dns.Question{Name: targetName, Qtype: dns.TypeCNAME, Qclass: qclass}) {
|
|
return lookupResult{rcode: -1}
|
|
}
|
|
|
|
// domain exists locally but not this record type (NODATA)
|
|
if d.hasRecordsForDomain(domain.Domain(targetName), targetType) {
|
|
return lookupResult{rcode: dns.RcodeSuccess}
|
|
}
|
|
|
|
// in our zone but doesn't exist (NXDOMAIN)
|
|
if d.isInManagedZone(targetName) {
|
|
return lookupResult{rcode: dns.RcodeNameError}
|
|
}
|
|
|
|
return d.resolveExternal(logger, targetName, targetType)
|
|
}
|
|
|
|
func (d *Resolver) getRecords(q dns.Question) []dns.RR {
|
|
d.mu.RLock()
|
|
defer d.mu.RUnlock()
|
|
return d.records[q]
|
|
}
|
|
|
|
func (d *Resolver) hasRecord(q dns.Question) bool {
|
|
d.mu.RLock()
|
|
defer d.mu.RUnlock()
|
|
_, ok := d.records[q]
|
|
return ok
|
|
}
|
|
|
|
// resolveExternal resolves a domain name using the system resolver.
|
|
// This is used to resolve CNAME targets that point outside our local zone,
|
|
// which is required for musl libc compatibility (musl expects complete answers).
|
|
func (d *Resolver) resolveExternal(logger *log.Entry, name string, qtype uint16) lookupResult {
|
|
network := resutil.NetworkForQtype(qtype)
|
|
if network == "" {
|
|
return lookupResult{rcode: dns.RcodeNotImplemented}
|
|
}
|
|
|
|
resolver := d.resolver
|
|
if resolver == nil {
|
|
resolver = net.DefaultResolver
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(d.ctx, externalResolutionTimeout)
|
|
defer cancel()
|
|
|
|
result := resutil.LookupIP(ctx, resolver, network, name, qtype)
|
|
if result.Err != nil {
|
|
d.logDNSError(logger, name, qtype, result.Err)
|
|
return lookupResult{rcode: result.Rcode, hasExternalData: true}
|
|
}
|
|
|
|
return lookupResult{
|
|
records: resutil.IPsToRRs(name, result.IPs, 60),
|
|
rcode: dns.RcodeSuccess,
|
|
hasExternalData: true,
|
|
}
|
|
}
|
|
|
|
// logDNSError logs DNS resolution errors for debugging.
|
|
func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16, err error) {
|
|
qtypeName := dns.TypeToString[qtype]
|
|
|
|
var dnsErr *net.DNSError
|
|
if !errors.As(err, &dnsErr) {
|
|
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
|
|
return
|
|
}
|
|
|
|
if dnsErr.IsNotFound {
|
|
logger.Tracef("DNS target not found: %s type %s", hostname, qtypeName)
|
|
return
|
|
}
|
|
|
|
if dnsErr.Server != "" {
|
|
logger.Debugf("DNS resolution failed for %s type %s server=%s: %v", hostname, qtypeName, dnsErr.Server, err)
|
|
} else {
|
|
logger.Debugf("DNS resolution failed for %s type %s: %v", hostname, qtypeName, err)
|
|
}
|
|
}
|
|
|
|
// Update replaces all zones and their records
|
|
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
|
|
d.mu.Lock()
|
|
defer d.mu.Unlock()
|
|
|
|
maps.Clear(d.records)
|
|
maps.Clear(d.domains)
|
|
maps.Clear(d.zones)
|
|
|
|
for _, zone := range customZones {
|
|
zoneDomain := domain.Domain(strings.ToLower(dns.Fqdn(zone.Domain)))
|
|
d.zones[zoneDomain] = zone.NonAuthoritative
|
|
|
|
for _, rec := range zone.Records {
|
|
if err := d.registerRecord(rec); err != nil {
|
|
log.Warnf("failed to register the record (%s): %v", rec, err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// RegisterRecord stores a new record by appending it to any existing list
|
|
func (d *Resolver) RegisterRecord(record nbdns.SimpleRecord) error {
|
|
d.mu.Lock()
|
|
defer d.mu.Unlock()
|
|
|
|
return d.registerRecord(record)
|
|
}
|
|
|
|
// registerRecord performs the registration with the lock already held
|
|
func (d *Resolver) registerRecord(record nbdns.SimpleRecord) error {
|
|
rr, err := dns.NewRR(record.String())
|
|
if err != nil {
|
|
return fmt.Errorf("register record: %w", err)
|
|
}
|
|
|
|
rr.Header().Rdlength = record.Len()
|
|
header := rr.Header()
|
|
q := dns.Question{
|
|
Name: strings.ToLower(dns.Fqdn(header.Name)),
|
|
Qtype: header.Rrtype,
|
|
Qclass: header.Class,
|
|
}
|
|
|
|
d.records[q] = append(d.records[q], rr)
|
|
d.domains[domain.Domain(q.Name)] = struct{}{}
|
|
|
|
return nil
|
|
}
|