Reject empty AddDomain results, filter chain answers by CNAME owner, dedup A/AAAA clone

This commit is contained in:
Viktor Liu
2026-04-22 11:40:54 +02:00
parent 4c25ac674a
commit cdc1ff8fd2

View File

@@ -31,6 +31,15 @@ const (
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
)
var cacheTTL = sync.OnceValue(func() time.Duration {
if v := os.Getenv(envMgmtCacheTTL); v != "" {
if d, err := time.ParseDuration(v); err == nil && d > 0 {
return d
}
}
return defaultTTL
})
// ChainResolver lets the cache refresh stale entries through the DNS handler
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
// system resolver.
@@ -183,8 +192,10 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
}
// Dual NODATA: don't wipe existing entries, let the caller retry.
if errA == nil && errAAAA == nil && len(aRecords) == 0 && len(aaaaRecords) == 0 {
if len(aRecords) == 0 && len(aaaaRecords) == 0 {
if err := errors.Join(errA, errAAAA); err != nil {
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
}
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
}
@@ -551,36 +562,38 @@ func responseTTL(cachedAt time.Time) uint32 {
return uint32((remaining + time.Second - 1) / time.Second)
}
// cloneRecordsWithTTL deep-copies A/AAAA records and stamps ttl on each header
// so mutating the response never touches the cached slice.
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
// A/AAAA records return nil.
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
switch r := rr.(type) {
case *dns.A:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.A = slices.Clone(r.A)
return &cp
case *dns.AAAA:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.AAAA = slices.Clone(r.AAAA)
return &cp
}
return nil
}
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
// stamping ttl so the response shares no memory with the cached slice.
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
out := make([]dns.RR, 0, len(records))
for _, rr := range records {
switch r := rr.(type) {
case *dns.A:
cp := *r
cp.Hdr.Ttl = ttl
cp.A = slices.Clone(r.A)
out = append(out, &cp)
case *dns.AAAA:
cp := *r
cp.Hdr.Ttl = ttl
cp.AAAA = slices.Clone(r.AAAA)
out = append(out, &cp)
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
out = append(out, cp)
}
}
return out
}
var cacheTTL = sync.OnceValue(func() time.Duration {
if v := os.Getenv(envMgmtCacheTTL); v != "" {
if d, err := time.ParseDuration(v); err == nil && d > 0 {
return d
}
}
return defaultTTL
})
// lookupViaChain resolves via the handler chain and rewrites each RR to use
// dnsName as owner and cacheTTL() as TTL, so CNAME-backed domains don't cache
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
@@ -601,33 +614,50 @@ func lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, d
}
ttl := uint32(cacheTTL().Seconds())
owners := cnameOwners(dnsName, resp.Answer)
var filtered []dns.RR
for _, rr := range resp.Answer {
if rr.Header().Class != dns.ClassINET {
h := rr.Header()
if h.Class != dns.ClassINET || h.Rrtype != qtype {
continue
}
switch r := rr.(type) {
case *dns.A:
if qtype != dns.TypeA {
continue
}
filtered = append(filtered, &dns.A{
Hdr: dns.RR_Header{Name: dnsName, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl},
A: slices.Clone(r.A),
})
case *dns.AAAA:
if qtype != dns.TypeAAAA {
continue
}
filtered = append(filtered, &dns.AAAA{
Hdr: dns.RR_Header{Name: dnsName, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: ttl},
AAAA: slices.Clone(r.AAAA),
})
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
continue
}
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
filtered = append(filtered, cp)
}
}
return filtered, nil
}
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
// in answer, iterating until fixed point so out-of-order chains resolve.
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
owners := map[string]bool{dnsName: true}
for {
added := false
for _, rr := range answer {
cname, ok := rr.(*dns.CNAME)
if !ok {
continue
}
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
if !owners[name] {
continue
}
target := strings.ToLower(dns.Fqdn(cname.Target))
if !owners[target] {
owners[target] = true
added = true
}
}
if !added {
return owners
}
}
}
// osLookup resolves a single family via net.DefaultResolver using resutil,
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
// returns (nil, nil).