mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-07 01:10:03 +00:00
Reject empty AddDomain results, filter chain answers by CNAME owner, dedup A/AAAA clone
This commit is contained in:
@@ -31,6 +31,15 @@ const (
|
||||
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
|
||||
)
|
||||
|
||||
var cacheTTL = sync.OnceValue(func() time.Duration {
|
||||
if v := os.Getenv(envMgmtCacheTTL); v != "" {
|
||||
if d, err := time.ParseDuration(v); err == nil && d > 0 {
|
||||
return d
|
||||
}
|
||||
}
|
||||
return defaultTTL
|
||||
})
|
||||
|
||||
// ChainResolver lets the cache refresh stale entries through the DNS handler
|
||||
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
|
||||
// system resolver.
|
||||
@@ -183,8 +192,10 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
||||
}
|
||||
|
||||
// Dual NODATA: don't wipe existing entries, let the caller retry.
|
||||
if errA == nil && errAAAA == nil && len(aRecords) == 0 && len(aaaaRecords) == 0 {
|
||||
if len(aRecords) == 0 && len(aaaaRecords) == 0 {
|
||||
if err := errors.Join(errA, errAAAA); err != nil {
|
||||
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
|
||||
}
|
||||
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
|
||||
}
|
||||
|
||||
@@ -551,36 +562,38 @@ func responseTTL(cachedAt time.Time) uint32 {
|
||||
return uint32((remaining + time.Second - 1) / time.Second)
|
||||
}
|
||||
|
||||
// cloneRecordsWithTTL deep-copies A/AAAA records and stamps ttl on each header
|
||||
// so mutating the response never touches the cached slice.
|
||||
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
|
||||
// A/AAAA records return nil.
|
||||
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
|
||||
switch r := rr.(type) {
|
||||
case *dns.A:
|
||||
cp := *r
|
||||
cp.Hdr.Name = owner
|
||||
cp.Hdr.Ttl = ttl
|
||||
cp.A = slices.Clone(r.A)
|
||||
return &cp
|
||||
case *dns.AAAA:
|
||||
cp := *r
|
||||
cp.Hdr.Name = owner
|
||||
cp.Hdr.Ttl = ttl
|
||||
cp.AAAA = slices.Clone(r.AAAA)
|
||||
return &cp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
|
||||
// stamping ttl so the response shares no memory with the cached slice.
|
||||
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
|
||||
out := make([]dns.RR, 0, len(records))
|
||||
for _, rr := range records {
|
||||
switch r := rr.(type) {
|
||||
case *dns.A:
|
||||
cp := *r
|
||||
cp.Hdr.Ttl = ttl
|
||||
cp.A = slices.Clone(r.A)
|
||||
out = append(out, &cp)
|
||||
case *dns.AAAA:
|
||||
cp := *r
|
||||
cp.Hdr.Ttl = ttl
|
||||
cp.AAAA = slices.Clone(r.AAAA)
|
||||
out = append(out, &cp)
|
||||
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
|
||||
out = append(out, cp)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
var cacheTTL = sync.OnceValue(func() time.Duration {
|
||||
if v := os.Getenv(envMgmtCacheTTL); v != "" {
|
||||
if d, err := time.ParseDuration(v); err == nil && d > 0 {
|
||||
return d
|
||||
}
|
||||
}
|
||||
return defaultTTL
|
||||
})
|
||||
|
||||
// lookupViaChain resolves via the handler chain and rewrites each RR to use
|
||||
// dnsName as owner and cacheTTL() as TTL, so CNAME-backed domains don't cache
|
||||
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
|
||||
@@ -601,33 +614,50 @@ func lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, d
|
||||
}
|
||||
|
||||
ttl := uint32(cacheTTL().Seconds())
|
||||
owners := cnameOwners(dnsName, resp.Answer)
|
||||
var filtered []dns.RR
|
||||
for _, rr := range resp.Answer {
|
||||
if rr.Header().Class != dns.ClassINET {
|
||||
h := rr.Header()
|
||||
if h.Class != dns.ClassINET || h.Rrtype != qtype {
|
||||
continue
|
||||
}
|
||||
switch r := rr.(type) {
|
||||
case *dns.A:
|
||||
if qtype != dns.TypeA {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: dnsName, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl},
|
||||
A: slices.Clone(r.A),
|
||||
})
|
||||
case *dns.AAAA:
|
||||
if qtype != dns.TypeAAAA {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, &dns.AAAA{
|
||||
Hdr: dns.RR_Header{Name: dnsName, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: ttl},
|
||||
AAAA: slices.Clone(r.AAAA),
|
||||
})
|
||||
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
|
||||
continue
|
||||
}
|
||||
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
|
||||
filtered = append(filtered, cp)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
|
||||
// in answer, iterating until fixed point so out-of-order chains resolve.
|
||||
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
|
||||
owners := map[string]bool{dnsName: true}
|
||||
for {
|
||||
added := false
|
||||
for _, rr := range answer {
|
||||
cname, ok := rr.(*dns.CNAME)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
|
||||
if !owners[name] {
|
||||
continue
|
||||
}
|
||||
target := strings.ToLower(dns.Fqdn(cname.Target))
|
||||
if !owners[target] {
|
||||
owners[target] = true
|
||||
added = true
|
||||
}
|
||||
}
|
||||
if !added {
|
||||
return owners
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// osLookup resolves a single family via net.DefaultResolver using resutil,
|
||||
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
|
||||
// returns (nil, nil).
|
||||
|
||||
Reference in New Issue
Block a user