mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-25 03:36:41 +00:00
692 lines
22 KiB
Go
692 lines
22 KiB
Go
package mgmt
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/url"
|
|
"os"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.org/x/sync/singleflight"
|
|
|
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
|
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
|
"github.com/netbirdio/netbird/shared/management/domain"
|
|
)
|
|
|
|
const (
|
|
dnsTimeout = 5 * time.Second
|
|
defaultTTL = 300 * time.Second
|
|
refreshBackoff = 30 * time.Second
|
|
|
|
// envMgmtCacheTTL overrides defaultTTL for integration/dev testing.
|
|
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
|
|
)
|
|
|
|
// ChainResolver lets the cache refresh stale entries through the DNS handler
|
|
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
|
|
// system resolver.
|
|
type ChainResolver interface {
|
|
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
|
|
HasRootHandlerAtOrBelow(maxPriority int) bool
|
|
}
|
|
|
|
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
|
|
// records and cachedAt are set at construction and treated as immutable;
|
|
// lastFailedRefresh and consecFailures are mutable and must be accessed under
|
|
// Resolver.mutex.
|
|
type cachedRecord struct {
|
|
records []dns.RR
|
|
cachedAt time.Time
|
|
lastFailedRefresh time.Time
|
|
consecFailures int
|
|
}
|
|
|
|
// Resolver caches critical NetBird infrastructure domains.
|
|
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
|
type Resolver struct {
|
|
records map[dns.Question]*cachedRecord
|
|
mgmtDomain *domain.Domain
|
|
serverDomains *dnsconfig.ServerDomains
|
|
mutex sync.RWMutex
|
|
|
|
chain ChainResolver
|
|
chainMaxPriority int
|
|
refreshGroup singleflight.Group
|
|
|
|
// refreshing tracks questions whose refresh is running via the OS
|
|
// fallback path. A ServeDNS hit for a question in this map indicates
|
|
// the OS resolver routed the recursive query back to us (loop). Only
|
|
// the OS path arms this so chain-path refreshes don't produce false
|
|
// positives. The atomic bool is CAS-flipped once per refresh to
|
|
// throttle the warning log.
|
|
refreshing map[dns.Question]*atomic.Bool
|
|
|
|
cacheTTL time.Duration
|
|
}
|
|
|
|
// NewResolver creates a new management domains cache resolver.
|
|
func NewResolver() *Resolver {
|
|
return &Resolver{
|
|
records: make(map[dns.Question]*cachedRecord),
|
|
refreshing: make(map[dns.Question]*atomic.Bool),
|
|
cacheTTL: resolveCacheTTL(),
|
|
}
|
|
}
|
|
|
|
// String returns a string representation of the resolver.
|
|
func (m *Resolver) String() string {
|
|
return "MgmtCacheResolver"
|
|
}
|
|
|
|
// SetChainResolver wires the handler chain used to refresh stale cache entries.
|
|
// maxPriority caps which handlers may answer refresh queries (typically
|
|
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
|
|
// mgmt/route/local handlers are skipped).
|
|
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
|
|
m.mutex.Lock()
|
|
m.chain = chain
|
|
m.chainMaxPriority = maxPriority
|
|
m.mutex.Unlock()
|
|
}
|
|
|
|
// ServeDNS serves cached A/AAAA records. Stale entries are returned
|
|
// immediately and refreshed asynchronously (stale-while-revalidate).
|
|
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
if len(r.Question) == 0 {
|
|
m.continueToNext(w, r)
|
|
return
|
|
}
|
|
|
|
question := r.Question[0]
|
|
question.Name = strings.ToLower(dns.Fqdn(question.Name))
|
|
|
|
if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA {
|
|
m.continueToNext(w, r)
|
|
return
|
|
}
|
|
|
|
m.mutex.RLock()
|
|
cached, found := m.records[question]
|
|
inflight := m.refreshing[question]
|
|
var shouldRefresh bool
|
|
if found {
|
|
stale := time.Since(cached.cachedAt) > m.cacheTTL
|
|
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
|
|
shouldRefresh = stale && !inBackoff
|
|
}
|
|
m.mutex.RUnlock()
|
|
|
|
if !found {
|
|
m.continueToNext(w, r)
|
|
return
|
|
}
|
|
|
|
if inflight != nil && inflight.CompareAndSwap(false, true) {
|
|
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
|
|
question.Name)
|
|
}
|
|
|
|
// Skip scheduling a refresh goroutine if one is already inflight for
|
|
// this question; singleflight would dedup anyway but skipping avoids
|
|
// a parked goroutine per stale hit under bursty load.
|
|
if shouldRefresh && inflight == nil {
|
|
m.scheduleRefresh(question, cached)
|
|
}
|
|
|
|
resp := &dns.Msg{}
|
|
resp.SetReply(r)
|
|
resp.Authoritative = false
|
|
resp.RecursionAvailable = true
|
|
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
|
|
|
|
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
|
|
|
|
if err := w.WriteMsg(resp); err != nil {
|
|
log.Errorf("failed to write response: %v", err)
|
|
}
|
|
}
|
|
|
|
// MatchSubdomains returns false since this resolver only handles exact domain matches
|
|
// for NetBird infrastructure domains (signal, relay, flow, etc.), not their subdomains.
|
|
func (m *Resolver) MatchSubdomains() bool {
|
|
return false
|
|
}
|
|
|
|
// continueToNext signals the handler chain to continue to the next handler.
|
|
func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
|
resp := &dns.Msg{}
|
|
resp.SetRcode(r, dns.RcodeNameError)
|
|
resp.MsgHdr.Zero = true
|
|
if err := w.WriteMsg(resp); err != nil {
|
|
log.Errorf("failed to write continue signal: %v", err)
|
|
}
|
|
}
|
|
|
|
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
|
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
|
// entry for that qtype.
|
|
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
|
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
|
defer cancel()
|
|
|
|
aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName)
|
|
|
|
if errA != nil && errAAAA != nil {
|
|
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
|
}
|
|
|
|
if len(aRecords) == 0 && len(aaaaRecords) == 0 {
|
|
if err := errors.Join(errA, errAAAA); err != nil {
|
|
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
|
|
}
|
|
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
|
|
}
|
|
|
|
now := time.Now()
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
|
|
m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now)
|
|
m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now)
|
|
|
|
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
|
d.SafeString(), len(aRecords), len(aaaaRecords))
|
|
|
|
return nil
|
|
}
|
|
|
|
// applyFamilyRecords writes records, evicts on NODATA, leaves the cache
|
|
// untouched on error. Caller holds m.mutex.
|
|
func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) {
|
|
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
|
switch {
|
|
case len(records) > 0:
|
|
m.records[q] = &cachedRecord{records: records, cachedAt: now}
|
|
case err == nil:
|
|
delete(m.records, q)
|
|
}
|
|
}
|
|
|
|
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
|
|
// unique in-flight key; bursty stale hits share its channel. expected is the
|
|
// cachedRecord pointer observed by the caller; the refresh only mutates the
|
|
// cache if that pointer is still the one stored, so a stale in-flight refresh
|
|
// can't clobber a newer entry written by AddDomain or a competing refresh.
|
|
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
|
|
key := question.Name + "|" + dns.TypeToString[question.Qtype]
|
|
_ = m.refreshGroup.DoChan(key, func() (any, error) {
|
|
return nil, m.refreshQuestion(question, expected)
|
|
})
|
|
}
|
|
|
|
// refreshQuestion replaces the cached records on success, or marks the entry
|
|
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
|
|
// a resolver loop by spotting a query for this same question arriving on us.
|
|
// expected pins the cache entry observed at schedule time; mutations only apply
|
|
// if m.records[question] still points at it.
|
|
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
|
defer cancel()
|
|
|
|
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
|
|
if err != nil {
|
|
m.markRefreshFailed(question, expected)
|
|
return fmt.Errorf("parse domain: %w", err)
|
|
}
|
|
|
|
records, err := m.lookupRecords(ctx, d, question)
|
|
if err != nil {
|
|
fails := m.markRefreshFailed(question, expected)
|
|
logf := log.Warnf
|
|
if fails == 0 || fails > 1 {
|
|
logf = log.Debugf
|
|
}
|
|
logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)",
|
|
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
|
|
return err
|
|
}
|
|
|
|
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
|
|
if len(records) == 0 {
|
|
m.mutex.Lock()
|
|
if m.records[question] == expected {
|
|
delete(m.records, question)
|
|
m.mutex.Unlock()
|
|
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
|
|
d.SafeString(), dns.TypeToString[question.Qtype])
|
|
return nil
|
|
}
|
|
m.mutex.Unlock()
|
|
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
|
|
d.SafeString(), dns.TypeToString[question.Qtype])
|
|
return nil
|
|
}
|
|
|
|
now := time.Now()
|
|
m.mutex.Lock()
|
|
if m.records[question] != expected {
|
|
m.mutex.Unlock()
|
|
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
|
|
d.SafeString(), dns.TypeToString[question.Qtype])
|
|
return nil
|
|
}
|
|
m.records[question] = &cachedRecord{records: records, cachedAt: now}
|
|
m.mutex.Unlock()
|
|
|
|
log.Infof("refreshed mgmt cache domain=%s type=%s",
|
|
d.SafeString(), dns.TypeToString[question.Qtype])
|
|
return nil
|
|
}
|
|
|
|
func (m *Resolver) markRefreshing(question dns.Question) {
|
|
m.mutex.Lock()
|
|
m.refreshing[question] = &atomic.Bool{}
|
|
m.mutex.Unlock()
|
|
}
|
|
|
|
func (m *Resolver) clearRefreshing(question dns.Question) {
|
|
m.mutex.Lock()
|
|
delete(m.refreshing, question)
|
|
m.mutex.Unlock()
|
|
}
|
|
|
|
// markRefreshFailed arms the backoff and returns the new consecutive-failure
|
|
// count so callers can downgrade subsequent failure logs to debug.
|
|
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
c, ok := m.records[question]
|
|
if !ok || c != expected {
|
|
return 0
|
|
}
|
|
c.lastFailedRefresh = time.Now()
|
|
c.consecFailures++
|
|
return c.consecFailures
|
|
}
|
|
|
|
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
|
|
// callers tell records, NODATA (nil err, no records), and failure apart.
|
|
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
|
|
m.mutex.RLock()
|
|
chain := m.chain
|
|
maxPriority := m.chainMaxPriority
|
|
m.mutex.RUnlock()
|
|
|
|
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
|
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
|
|
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
|
|
return
|
|
}
|
|
|
|
// TODO: drop once every supported OS registers a fallback resolver. Safe
|
|
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
|
|
// not the system resolver, so net.DefaultResolver will not loop back.
|
|
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
|
|
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
|
|
return
|
|
}
|
|
|
|
// lookupRecords resolves a single record type via chain or OS. The OS branch
|
|
// arms the loop detector for the duration of its call so that ServeDNS can
|
|
// spot the OS resolver routing the recursive query back to us.
|
|
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
|
|
m.mutex.RLock()
|
|
chain := m.chain
|
|
maxPriority := m.chainMaxPriority
|
|
m.mutex.RUnlock()
|
|
|
|
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
|
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
|
|
}
|
|
|
|
// TODO: drop once every supported OS registers a fallback resolver.
|
|
m.markRefreshing(q)
|
|
defer m.clearRefreshing(q)
|
|
|
|
return m.osLookup(ctx, d, q.Name, q.Qtype)
|
|
}
|
|
|
|
// lookupViaChain resolves via the handler chain and rewrites each RR to use
|
|
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
|
|
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
|
|
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
|
|
msg := &dns.Msg{}
|
|
msg.SetQuestion(dnsName, qtype)
|
|
msg.RecursionDesired = true
|
|
|
|
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("chain resolve: %w", err)
|
|
}
|
|
if resp == nil {
|
|
return nil, fmt.Errorf("chain resolve returned nil response")
|
|
}
|
|
if resp.Rcode != dns.RcodeSuccess {
|
|
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
|
|
}
|
|
|
|
ttl := uint32(m.cacheTTL.Seconds())
|
|
owners := cnameOwners(dnsName, resp.Answer)
|
|
var filtered []dns.RR
|
|
for _, rr := range resp.Answer {
|
|
h := rr.Header()
|
|
if h.Class != dns.ClassINET || h.Rrtype != qtype {
|
|
continue
|
|
}
|
|
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
|
|
continue
|
|
}
|
|
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
|
|
filtered = append(filtered, cp)
|
|
}
|
|
}
|
|
return filtered, nil
|
|
}
|
|
|
|
// osLookup resolves a single family via net.DefaultResolver using resutil,
|
|
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
|
|
// returns (nil, nil).
|
|
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
|
|
network := resutil.NetworkForQtype(qtype)
|
|
if network == "" {
|
|
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
|
|
}
|
|
|
|
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
|
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
|
|
|
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
|
|
if result.Rcode == dns.RcodeSuccess {
|
|
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
|
|
}
|
|
|
|
if result.Err != nil {
|
|
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
|
|
}
|
|
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
|
|
}
|
|
|
|
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
|
|
// so downstream resolvers don't cache an answer for longer than we will.
|
|
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
|
|
remaining := m.cacheTTL - time.Since(cachedAt)
|
|
if remaining <= 0 {
|
|
return 0
|
|
}
|
|
return uint32((remaining + time.Second - 1) / time.Second)
|
|
}
|
|
|
|
// PopulateFromConfig extracts and caches domains from the client configuration.
|
|
func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
|
|
if mgmtURL == nil {
|
|
return nil
|
|
}
|
|
|
|
d, err := dnsconfig.ExtractValidDomain(mgmtURL.String())
|
|
if err != nil {
|
|
return fmt.Errorf("extract domain from URL: %w", err)
|
|
}
|
|
|
|
m.mutex.Lock()
|
|
m.mgmtDomain = &d
|
|
m.mutex.Unlock()
|
|
|
|
if err := m.AddDomain(ctx, d); err != nil {
|
|
return fmt.Errorf("add domain: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// RemoveDomain removes a domain from the cache.
|
|
func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
|
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
|
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
|
|
qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
|
|
delete(m.records, qA)
|
|
delete(m.records, qAAAA)
|
|
delete(m.refreshing, qA)
|
|
delete(m.refreshing, qAAAA)
|
|
|
|
log.Debugf("removed domain=%s from cache", d.SafeString())
|
|
return nil
|
|
}
|
|
|
|
// GetCachedDomains returns a list of all cached domains.
|
|
func (m *Resolver) GetCachedDomains() domain.List {
|
|
m.mutex.RLock()
|
|
defer m.mutex.RUnlock()
|
|
|
|
domainSet := make(map[domain.Domain]struct{})
|
|
for question := range m.records {
|
|
domainName := strings.TrimSuffix(question.Name, ".")
|
|
domainSet[domain.Domain(domainName)] = struct{}{}
|
|
}
|
|
|
|
domains := make(domain.List, 0, len(domainSet))
|
|
for d := range domainSet {
|
|
domains = append(domains, d)
|
|
}
|
|
|
|
return domains
|
|
}
|
|
|
|
// UpdateFromServerDomains updates the cache with server domains from network configuration.
|
|
// It merges new domains with existing ones, replacing entire domain types when updated.
|
|
// Empty updates are ignored to prevent clearing infrastructure domains during partial updates.
|
|
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) (domain.List, error) {
|
|
newDomains := m.extractDomainsFromServerDomains(serverDomains)
|
|
var removedDomains domain.List
|
|
|
|
if len(newDomains) > 0 {
|
|
m.mutex.Lock()
|
|
if m.serverDomains == nil {
|
|
m.serverDomains = &dnsconfig.ServerDomains{}
|
|
}
|
|
updatedServerDomains := m.mergeServerDomains(*m.serverDomains, serverDomains)
|
|
m.serverDomains = &updatedServerDomains
|
|
m.mutex.Unlock()
|
|
|
|
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
|
|
currentDomains := m.GetCachedDomains()
|
|
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
|
|
}
|
|
|
|
m.addNewDomains(ctx, newDomains)
|
|
|
|
return removedDomains, nil
|
|
}
|
|
|
|
// removeStaleDomains removes cached domains not present in the target domain list.
|
|
// Management domains are preserved and never removed during server domain updates.
|
|
func (m *Resolver) removeStaleDomains(currentDomains, newDomains domain.List) domain.List {
|
|
var removedDomains domain.List
|
|
|
|
for _, currentDomain := range currentDomains {
|
|
if m.isDomainInList(currentDomain, newDomains) {
|
|
continue
|
|
}
|
|
|
|
if m.isManagementDomain(currentDomain) {
|
|
continue
|
|
}
|
|
|
|
removedDomains = append(removedDomains, currentDomain)
|
|
if err := m.RemoveDomain(currentDomain); err != nil {
|
|
log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err)
|
|
}
|
|
}
|
|
|
|
return removedDomains
|
|
}
|
|
|
|
// mergeServerDomains merges new server domains with existing ones.
|
|
// When a domain type is provided in the new domains, it completely replaces that type.
|
|
func (m *Resolver) mergeServerDomains(existing, incoming dnsconfig.ServerDomains) dnsconfig.ServerDomains {
|
|
merged := existing
|
|
|
|
if incoming.Signal != "" {
|
|
merged.Signal = incoming.Signal
|
|
}
|
|
if len(incoming.Relay) > 0 {
|
|
merged.Relay = incoming.Relay
|
|
}
|
|
if incoming.Flow != "" {
|
|
merged.Flow = incoming.Flow
|
|
}
|
|
if len(incoming.Stuns) > 0 {
|
|
merged.Stuns = incoming.Stuns
|
|
}
|
|
if len(incoming.Turns) > 0 {
|
|
merged.Turns = incoming.Turns
|
|
}
|
|
|
|
return merged
|
|
}
|
|
|
|
// isDomainInList checks if domain exists in the list
|
|
func (m *Resolver) isDomainInList(domain domain.Domain, list domain.List) bool {
|
|
for _, d := range list {
|
|
if domain.SafeString() == d.SafeString() {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// isManagementDomain checks if domain is the protected management domain
|
|
func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
|
|
m.mutex.RLock()
|
|
defer m.mutex.RUnlock()
|
|
|
|
return m.mgmtDomain != nil && domain == *m.mgmtDomain
|
|
}
|
|
|
|
// addNewDomains resolves and caches all domains from the update
|
|
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
|
|
for _, newDomain := range newDomains {
|
|
if err := m.AddDomain(ctx, newDomain); err != nil {
|
|
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
|
} else {
|
|
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) domain.List {
|
|
var domains domain.List
|
|
|
|
if serverDomains.Signal != "" {
|
|
domains = append(domains, serverDomains.Signal)
|
|
}
|
|
|
|
for _, relay := range serverDomains.Relay {
|
|
if relay != "" {
|
|
domains = append(domains, relay)
|
|
}
|
|
}
|
|
|
|
// Flow receiver domain is intentionally excluded from caching.
|
|
// Cloud providers may rotate the IP behind this domain; a stale cached record
|
|
// causes TLS certificate verification failures on reconnect.
|
|
|
|
for _, stun := range serverDomains.Stuns {
|
|
if stun != "" {
|
|
domains = append(domains, stun)
|
|
}
|
|
}
|
|
|
|
for _, turn := range serverDomains.Turns {
|
|
if turn != "" {
|
|
domains = append(domains, turn)
|
|
}
|
|
}
|
|
|
|
return domains
|
|
}
|
|
|
|
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
|
|
// A/AAAA records return nil.
|
|
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
|
|
switch r := rr.(type) {
|
|
case *dns.A:
|
|
cp := *r
|
|
cp.Hdr.Name = owner
|
|
cp.Hdr.Ttl = ttl
|
|
cp.A = slices.Clone(r.A)
|
|
return &cp
|
|
case *dns.AAAA:
|
|
cp := *r
|
|
cp.Hdr.Name = owner
|
|
cp.Hdr.Ttl = ttl
|
|
cp.AAAA = slices.Clone(r.AAAA)
|
|
return &cp
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
|
|
// stamping ttl so the response shares no memory with the cached slice.
|
|
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
|
|
out := make([]dns.RR, 0, len(records))
|
|
for _, rr := range records {
|
|
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
|
|
out = append(out, cp)
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
|
|
// in answer, iterating until fixed point so out-of-order chains resolve.
|
|
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
|
|
owners := map[string]bool{dnsName: true}
|
|
for {
|
|
added := false
|
|
for _, rr := range answer {
|
|
cname, ok := rr.(*dns.CNAME)
|
|
if !ok {
|
|
continue
|
|
}
|
|
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
|
|
if !owners[name] {
|
|
continue
|
|
}
|
|
target := strings.ToLower(dns.Fqdn(cname.Target))
|
|
if !owners[target] {
|
|
owners[target] = true
|
|
added = true
|
|
}
|
|
}
|
|
if !added {
|
|
return owners
|
|
}
|
|
}
|
|
}
|
|
|
|
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
|
|
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
|
|
func resolveCacheTTL() time.Duration {
|
|
if v := os.Getenv(envMgmtCacheTTL); v != "" {
|
|
if d, err := time.ParseDuration(v); err == nil && d > 0 {
|
|
return d
|
|
}
|
|
}
|
|
return defaultTTL
|
|
}
|