From cdc1ff8fd2a6983cf2513179bd39488591a157e2 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 22 Apr 2026 11:40:54 +0200 Subject: [PATCH] Reject empty AddDomain results, filter chain answers by CNAME owner, dedup A/AAAA clone --- client/internal/dns/mgmt/mgmt.go | 114 +++++++++++++++++++------------ 1 file changed, 72 insertions(+), 42 deletions(-) diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index d6bb3d91a..03334334f 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -31,6 +31,15 @@ const ( envMgmtCacheTTL = "NB_MGMT_CACHE_TTL" ) +var cacheTTL = sync.OnceValue(func() time.Duration { + if v := os.Getenv(envMgmtCacheTTL); v != "" { + if d, err := time.ParseDuration(v); err == nil && d > 0 { + return d + } + } + return defaultTTL +}) + // ChainResolver lets the cache refresh stale entries through the DNS handler // chain instead of net.DefaultResolver, avoiding loopback when NetBird is the // system resolver. @@ -183,8 +192,10 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA)) } - // Dual NODATA: don't wipe existing entries, let the caller retry. - if errA == nil && errAAAA == nil && len(aRecords) == 0 && len(aaaaRecords) == 0 { + if len(aRecords) == 0 && len(aaaaRecords) == 0 { + if err := errors.Join(errA, errAAAA); err != nil { + return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err) + } return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString()) } @@ -551,36 +562,38 @@ func responseTTL(cachedAt time.Time) uint32 { return uint32((remaining + time.Second - 1) / time.Second) } -// cloneRecordsWithTTL deep-copies A/AAAA records and stamps ttl on each header -// so mutating the response never touches the cached slice. +// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non +// A/AAAA records return nil. +func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR { + switch r := rr.(type) { + case *dns.A: + cp := *r + cp.Hdr.Name = owner + cp.Hdr.Ttl = ttl + cp.A = slices.Clone(r.A) + return &cp + case *dns.AAAA: + cp := *r + cp.Hdr.Name = owner + cp.Hdr.Ttl = ttl + cp.AAAA = slices.Clone(r.AAAA) + return &cp + } + return nil +} + +// cloneRecordsWithTTL clones A/AAAA records preserving their owner and +// stamping ttl so the response shares no memory with the cached slice. func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR { out := make([]dns.RR, 0, len(records)) for _, rr := range records { - switch r := rr.(type) { - case *dns.A: - cp := *r - cp.Hdr.Ttl = ttl - cp.A = slices.Clone(r.A) - out = append(out, &cp) - case *dns.AAAA: - cp := *r - cp.Hdr.Ttl = ttl - cp.AAAA = slices.Clone(r.AAAA) - out = append(out, &cp) + if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil { + out = append(out, cp) } } return out } -var cacheTTL = sync.OnceValue(func() time.Duration { - if v := os.Getenv(envMgmtCacheTTL); v != "" { - if d, err := time.ParseDuration(v); err == nil && d > 0 { - return d - } - } - return defaultTTL -}) - // lookupViaChain resolves via the handler chain and rewrites each RR to use // dnsName as owner and cacheTTL() as TTL, so CNAME-backed domains don't cache // target-owned records or upstream TTLs. NODATA returns (nil, nil). @@ -601,33 +614,50 @@ func lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, d } ttl := uint32(cacheTTL().Seconds()) + owners := cnameOwners(dnsName, resp.Answer) var filtered []dns.RR for _, rr := range resp.Answer { - if rr.Header().Class != dns.ClassINET { + h := rr.Header() + if h.Class != dns.ClassINET || h.Rrtype != qtype { continue } - switch r := rr.(type) { - case *dns.A: - if qtype != dns.TypeA { - continue - } - filtered = append(filtered, &dns.A{ - Hdr: dns.RR_Header{Name: dnsName, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl}, - A: slices.Clone(r.A), - }) - case *dns.AAAA: - if qtype != dns.TypeAAAA { - continue - } - filtered = append(filtered, &dns.AAAA{ - Hdr: dns.RR_Header{Name: dnsName, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: ttl}, - AAAA: slices.Clone(r.AAAA), - }) + if !owners[strings.ToLower(dns.Fqdn(h.Name))] { + continue + } + if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil { + filtered = append(filtered, cp) } } return filtered, nil } +// cnameOwners returns dnsName plus every target reachable by following CNAMEs +// in answer, iterating until fixed point so out-of-order chains resolve. +func cnameOwners(dnsName string, answer []dns.RR) map[string]bool { + owners := map[string]bool{dnsName: true} + for { + added := false + for _, rr := range answer { + cname, ok := rr.(*dns.CNAME) + if !ok { + continue + } + name := strings.ToLower(dns.Fqdn(cname.Hdr.Name)) + if !owners[name] { + continue + } + target := strings.ToLower(dns.Fqdn(cname.Target)) + if !owners[target] { + owners[target] = true + added = true + } + } + if !added { + return owners + } + } +} + // osLookup resolves a single family via net.DefaultResolver using resutil, // which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA // returns (nil, nil).