mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-24 19:26:39 +00:00
When an exit-node peer's network-map installs a 0.0.0.0/0 default route
on the overlay interface before that peer's WireGuard key material is
active, any UDP socket dialing an off-link address is routed into wt0
and the kernel returns ENOKEY.
Two places needed fixing:
1. The mgmt cache refresh path. It reactively refreshes the
control-plane FQDNs advertised by the mgmt (api/signal/stun/turn/
the Relay pool root) after the daemon has installed its own
resolv.conf pointing at the overlay listener. Previously the
refresh dial followed the chain's upstream handler, which followed
the overlay default route and deadlocked on ENOKEY.
2. Foreign relay FQDN resolution. When a remote peer is homed on a
different relay instance than us, we need to resolve a streamline-*
subdomain that is not in the cache. That lookup went through the
same overlay-routed upstream and failed identically, deadlocking
the exit-node test whenever the relay LB put the two peers on
different instances.
Fix both by giving the mgmt cache a dedicated net.Resolver that dials
the original pre-NetBird system nameservers through nbnet.NewDialer.
The dialer marks the socket as control-plane (SO_MARK on Linux,
IP_BOUND_IF on darwin, IP_UNICAST_IF on Windows); the routemanager's
policy rules keep those sockets on the underlay regardless of the
overlay default.
Pool-root domains (the Relay entries in ServerDomains) now register
through a subdomain-matching wrapper so that instance subdomains like
streamline-de-fra1-0.relay.netbird.io also hit the mgmt cache handler.
On cache miss under a pool root, ServeDNS resolves the FQDN on demand
through the bypass resolver, caches the result, and returns it.
Pool-root membership is derived dynamically from mgmt-advertised
ServerDomains.Relay[] — no hardcoded domain lists, no protocol change.
No hardcoded fallback nameservers: if the host had no original system
resolver at all, the bypass resolver stays nil and the stale-while-
revalidate cache keeps serving. The general upstream forwarder and
the user DNS path are unchanged.
846 lines
27 KiB
Go
846 lines
27 KiB
Go
package mgmt
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/url"
|
|
"os"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.org/x/sync/singleflight"
|
|
|
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
|
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
|
"github.com/netbirdio/netbird/shared/management/domain"
|
|
)
|
|
|
|
const (
|
|
dnsTimeout = 5 * time.Second
|
|
defaultTTL = 300 * time.Second
|
|
refreshBackoff = 30 * time.Second
|
|
|
|
// envMgmtCacheTTL overrides defaultTTL for integration/dev testing.
|
|
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
|
|
)
|
|
|
|
// ChainResolver lets the cache refresh stale entries through the DNS handler
|
|
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
|
|
// system resolver.
|
|
type ChainResolver interface {
|
|
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
|
|
HasRootHandlerAtOrBelow(maxPriority int) bool
|
|
}
|
|
|
|
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
|
|
// records and cachedAt are set at construction and treated as immutable;
|
|
// lastFailedRefresh and consecFailures are mutable and must be accessed under
|
|
// Resolver.mutex.
|
|
type cachedRecord struct {
|
|
records []dns.RR
|
|
cachedAt time.Time
|
|
lastFailedRefresh time.Time
|
|
consecFailures int
|
|
}
|
|
|
|
// Resolver caches critical NetBird infrastructure domains.
|
|
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
|
type Resolver struct {
|
|
records map[dns.Question]*cachedRecord
|
|
mgmtDomain *domain.Domain
|
|
serverDomains *dnsconfig.ServerDomains
|
|
mutex sync.RWMutex
|
|
|
|
chain ChainResolver
|
|
chainMaxPriority int
|
|
refreshGroup singleflight.Group
|
|
|
|
// refreshing tracks questions whose refresh is running via the OS
|
|
// fallback path. A ServeDNS hit for a question in this map indicates
|
|
// the OS resolver routed the recursive query back to us (loop). Only
|
|
// the OS path arms this so chain-path refreshes don't produce false
|
|
// positives. The atomic bool is CAS-flipped once per refresh to
|
|
// throttle the warning log.
|
|
refreshing map[dns.Question]*atomic.Bool
|
|
|
|
cacheTTL time.Duration
|
|
|
|
// bypassResolver, when non-nil, is used by osLookup instead of
|
|
// net.DefaultResolver. It is constructed by the caller to dial the
|
|
// original (pre-NetBird) system nameservers through a socket that
|
|
// bypasses the overlay interface (control-plane fwmark / bound iface),
|
|
// so that when an exit-node default route is installed before a peer
|
|
// is handshaked the refresh does not fail with ENOKEY.
|
|
bypassResolver *net.Resolver
|
|
}
|
|
|
|
// NewResolver creates a new management domains cache resolver.
|
|
func NewResolver() *Resolver {
|
|
return &Resolver{
|
|
records: make(map[dns.Question]*cachedRecord),
|
|
refreshing: make(map[dns.Question]*atomic.Bool),
|
|
cacheTTL: resolveCacheTTL(),
|
|
}
|
|
}
|
|
|
|
// String returns a string representation of the resolver.
|
|
func (m *Resolver) String() string {
|
|
return "MgmtCacheResolver"
|
|
}
|
|
|
|
// SetChainResolver wires the handler chain used to refresh stale cache entries.
|
|
// maxPriority caps which handlers may answer refresh queries (typically
|
|
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
|
|
// mgmt/route/local handlers are skipped).
|
|
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
|
|
m.mutex.Lock()
|
|
m.chain = chain
|
|
m.chainMaxPriority = maxPriority
|
|
m.mutex.Unlock()
|
|
}
|
|
|
|
// SetBypassResolver installs a resolver that osLookup uses instead of
|
|
// net.DefaultResolver. It is intended to dial the original (pre-NetBird)
|
|
// system nameservers through a socket that does not follow the overlay
|
|
// default route, so that a refresh initiated while an exit node is active
|
|
// but its WireGuard peer is not yet installed cannot deadlock on ENOKEY.
|
|
// Passing nil restores use of net.DefaultResolver.
|
|
func (m *Resolver) SetBypassResolver(r *net.Resolver) {
|
|
m.mutex.Lock()
|
|
m.bypassResolver = r
|
|
m.mutex.Unlock()
|
|
}
|
|
|
|
// ServeDNS serves cached A/AAAA records. Stale entries are returned
|
|
// immediately and refreshed asynchronously (stale-while-revalidate).
|
|
//
|
|
// If the query name is not in the cache but falls under a pool-root
|
|
// domain (a domain the mgmt advertised in ServerDomains.Relay, whose
|
|
// instance subdomains like streamline-de-fra1-0.relay.netbird.io are
|
|
// part of the relay pool), resolve it on demand through the bypass
|
|
// resolver and cache the result. This is what lets the daemon reach
|
|
// a foreign relay FQDN after an exit-node default route has been
|
|
// installed on the overlay before its peer is live.
|
|
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
if len(r.Question) == 0 {
|
|
m.continueToNext(w, r)
|
|
return
|
|
}
|
|
|
|
question := r.Question[0]
|
|
question.Name = strings.ToLower(dns.Fqdn(question.Name))
|
|
|
|
if question.Qtype != dns.TypeA && question.Qtype != dns.TypeAAAA {
|
|
m.continueToNext(w, r)
|
|
return
|
|
}
|
|
|
|
m.mutex.RLock()
|
|
cached, found := m.records[question]
|
|
inflight := m.refreshing[question]
|
|
var shouldRefresh bool
|
|
if found {
|
|
stale := time.Since(cached.cachedAt) > m.cacheTTL
|
|
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
|
|
shouldRefresh = stale && !inBackoff
|
|
}
|
|
m.mutex.RUnlock()
|
|
|
|
if !found {
|
|
if m.isUnderPoolRoot(question.Name) {
|
|
m.resolveOnDemand(w, r, question)
|
|
return
|
|
}
|
|
m.continueToNext(w, r)
|
|
return
|
|
}
|
|
|
|
if inflight != nil && inflight.CompareAndSwap(false, true) {
|
|
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
|
|
question.Name)
|
|
}
|
|
|
|
// Skip scheduling a refresh goroutine if one is already inflight for
|
|
// this question; singleflight would dedup anyway but skipping avoids
|
|
// a parked goroutine per stale hit under bursty load.
|
|
if shouldRefresh && inflight == nil {
|
|
m.scheduleRefresh(question, cached)
|
|
}
|
|
|
|
resp := &dns.Msg{}
|
|
resp.SetReply(r)
|
|
resp.Authoritative = false
|
|
resp.RecursionAvailable = true
|
|
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
|
|
|
|
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
|
|
|
|
if err := w.WriteMsg(resp); err != nil {
|
|
log.Errorf("failed to write response: %v", err)
|
|
}
|
|
}
|
|
|
|
// MatchSubdomains returns false by default: the bare resolver is registered
|
|
// against exact domains. Pool-root domains (currently Relay entries from
|
|
// ServerDomains) are registered through a subdomain-matching wrapper at
|
|
// the call site instead, so instance subdomains hit this handler and get
|
|
// the on-demand resolve path in ServeDNS.
|
|
func (m *Resolver) MatchSubdomains() bool {
|
|
return false
|
|
}
|
|
|
|
// isUnderPoolRoot reports whether fqdn is an instance subdomain under any
|
|
// pool-root domain advertised by the mgmt (currently ServerDomains.Relay),
|
|
// e.g. "streamline-de-fra1-0.relay.netbird.io." is under "relay.netbird.io".
|
|
// The pool-root itself is not considered a subdomain (it matches the exact
|
|
// cache entry populated by AddDomain instead).
|
|
func (m *Resolver) isUnderPoolRoot(fqdn string) bool {
|
|
m.mutex.RLock()
|
|
defer m.mutex.RUnlock()
|
|
if m.serverDomains == nil {
|
|
return false
|
|
}
|
|
fqdn = strings.ToLower(strings.TrimSuffix(fqdn, "."))
|
|
for _, root := range m.serverDomains.Relay {
|
|
r := strings.ToLower(strings.TrimSuffix(root.PunycodeString(), "."))
|
|
if r == "" || fqdn == r {
|
|
continue
|
|
}
|
|
if strings.HasSuffix(fqdn, "."+r) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// resolveOnDemand resolves an uncached pool-root subdomain (e.g. a relay
|
|
// instance FQDN) through the bypass resolver path, caches the result, and
|
|
// writes it back to w. Falls through to the next handler on error so the
|
|
// normal chain can still attempt the resolve.
|
|
func (m *Resolver) resolveOnDemand(w dns.ResponseWriter, r *dns.Msg, question dns.Question) {
|
|
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
|
|
if err != nil {
|
|
log.Debugf("on-demand resolve: parse domain %q: %v", question.Name, err)
|
|
m.continueToNext(w, r)
|
|
return
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
|
defer cancel()
|
|
|
|
records, err := m.lookupRecords(ctx, d, question)
|
|
if err != nil {
|
|
log.Debugf("on-demand resolve %s type=%s: %v",
|
|
d.SafeString(), dns.TypeToString[question.Qtype], err)
|
|
m.continueToNext(w, r)
|
|
return
|
|
}
|
|
if len(records) == 0 {
|
|
m.continueToNext(w, r)
|
|
return
|
|
}
|
|
|
|
now := time.Now()
|
|
m.mutex.Lock()
|
|
if _, exists := m.records[question]; !exists {
|
|
m.records[question] = &cachedRecord{records: records, cachedAt: now}
|
|
}
|
|
m.mutex.Unlock()
|
|
|
|
resp := &dns.Msg{}
|
|
resp.SetReply(r)
|
|
resp.Authoritative = false
|
|
resp.RecursionAvailable = true
|
|
resp.Answer = cloneRecordsWithTTL(records, uint32(m.cacheTTL.Seconds()))
|
|
|
|
log.Debugf("on-demand resolved %d records for domain=%s", len(resp.Answer), question.Name)
|
|
|
|
if err := w.WriteMsg(resp); err != nil {
|
|
log.Errorf("failed to write on-demand response: %v", err)
|
|
}
|
|
}
|
|
|
|
|
|
// continueToNext signals the handler chain to continue to the next handler.
|
|
func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
|
resp := &dns.Msg{}
|
|
resp.SetRcode(r, dns.RcodeNameError)
|
|
resp.MsgHdr.Zero = true
|
|
if err := w.WriteMsg(resp); err != nil {
|
|
log.Errorf("failed to write continue signal: %v", err)
|
|
}
|
|
}
|
|
|
|
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
|
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
|
// entry for that qtype.
|
|
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
|
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
|
defer cancel()
|
|
|
|
aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName)
|
|
|
|
if errA != nil && errAAAA != nil {
|
|
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
|
}
|
|
|
|
if len(aRecords) == 0 && len(aaaaRecords) == 0 {
|
|
if err := errors.Join(errA, errAAAA); err != nil {
|
|
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
|
|
}
|
|
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
|
|
}
|
|
|
|
now := time.Now()
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
|
|
m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now)
|
|
m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now)
|
|
|
|
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
|
d.SafeString(), len(aRecords), len(aaaaRecords))
|
|
|
|
return nil
|
|
}
|
|
|
|
// applyFamilyRecords writes records, evicts on NODATA, leaves the cache
|
|
// untouched on error. Caller holds m.mutex.
|
|
func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) {
|
|
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
|
switch {
|
|
case len(records) > 0:
|
|
m.records[q] = &cachedRecord{records: records, cachedAt: now}
|
|
case err == nil:
|
|
delete(m.records, q)
|
|
}
|
|
}
|
|
|
|
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
|
|
// unique in-flight key; bursty stale hits share its channel. expected is the
|
|
// cachedRecord pointer observed by the caller; the refresh only mutates the
|
|
// cache if that pointer is still the one stored, so a stale in-flight refresh
|
|
// can't clobber a newer entry written by AddDomain or a competing refresh.
|
|
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
|
|
key := question.Name + "|" + dns.TypeToString[question.Qtype]
|
|
_ = m.refreshGroup.DoChan(key, func() (any, error) {
|
|
return nil, m.refreshQuestion(question, expected)
|
|
})
|
|
}
|
|
|
|
// refreshQuestion replaces the cached records on success, or marks the entry
|
|
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
|
|
// a resolver loop by spotting a query for this same question arriving on us.
|
|
// expected pins the cache entry observed at schedule time; mutations only apply
|
|
// if m.records[question] still points at it.
|
|
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
|
defer cancel()
|
|
|
|
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
|
|
if err != nil {
|
|
m.markRefreshFailed(question, expected)
|
|
return fmt.Errorf("parse domain: %w", err)
|
|
}
|
|
|
|
records, err := m.lookupRecords(ctx, d, question)
|
|
if err != nil {
|
|
fails := m.markRefreshFailed(question, expected)
|
|
logf := log.Warnf
|
|
if fails == 0 || fails > 1 {
|
|
logf = log.Debugf
|
|
}
|
|
logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)",
|
|
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
|
|
return err
|
|
}
|
|
|
|
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
|
|
if len(records) == 0 {
|
|
m.mutex.Lock()
|
|
if m.records[question] == expected {
|
|
delete(m.records, question)
|
|
m.mutex.Unlock()
|
|
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
|
|
d.SafeString(), dns.TypeToString[question.Qtype])
|
|
return nil
|
|
}
|
|
m.mutex.Unlock()
|
|
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
|
|
d.SafeString(), dns.TypeToString[question.Qtype])
|
|
return nil
|
|
}
|
|
|
|
now := time.Now()
|
|
m.mutex.Lock()
|
|
if m.records[question] != expected {
|
|
m.mutex.Unlock()
|
|
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
|
|
d.SafeString(), dns.TypeToString[question.Qtype])
|
|
return nil
|
|
}
|
|
m.records[question] = &cachedRecord{records: records, cachedAt: now}
|
|
m.mutex.Unlock()
|
|
|
|
log.Infof("refreshed mgmt cache domain=%s type=%s",
|
|
d.SafeString(), dns.TypeToString[question.Qtype])
|
|
return nil
|
|
}
|
|
|
|
func (m *Resolver) markRefreshing(question dns.Question) {
|
|
m.mutex.Lock()
|
|
m.refreshing[question] = &atomic.Bool{}
|
|
m.mutex.Unlock()
|
|
}
|
|
|
|
func (m *Resolver) clearRefreshing(question dns.Question) {
|
|
m.mutex.Lock()
|
|
delete(m.refreshing, question)
|
|
m.mutex.Unlock()
|
|
}
|
|
|
|
// markRefreshFailed arms the backoff and returns the new consecutive-failure
|
|
// count so callers can downgrade subsequent failure logs to debug.
|
|
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
c, ok := m.records[question]
|
|
if !ok || c != expected {
|
|
return 0
|
|
}
|
|
c.lastFailedRefresh = time.Now()
|
|
c.consecFailures++
|
|
return c.consecFailures
|
|
}
|
|
|
|
// lookupBoth resolves A and AAAA via bypass resolver, chain, or OS.
|
|
// Per-family errors let callers tell records, NODATA (nil err, no records),
|
|
// and failure apart.
|
|
//
|
|
// Preference order:
|
|
// 1. bypassResolver (direct, overlay-bypassing dial to original system
|
|
// nameservers; immune to the exit-node ENOKEY race).
|
|
// 2. chain (handler chain; used when NetBird is the system resolver and
|
|
// no bypass resolver is installed).
|
|
// 3. net.DefaultResolver via osLookup (legacy fallback).
|
|
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
|
|
m.mutex.RLock()
|
|
chain := m.chain
|
|
maxPriority := m.chainMaxPriority
|
|
bypass := m.bypassResolver
|
|
m.mutex.RUnlock()
|
|
|
|
if bypass != nil {
|
|
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
|
|
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
|
|
return
|
|
}
|
|
|
|
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
|
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
|
|
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
|
|
return
|
|
}
|
|
|
|
// TODO: drop once every supported OS registers a fallback resolver. Safe
|
|
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
|
|
// not the system resolver, so net.DefaultResolver will not loop back.
|
|
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
|
|
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
|
|
return
|
|
}
|
|
|
|
// lookupRecords resolves a single record type. See lookupBoth for the
|
|
// preference order. The OS branch arms the loop detector for the duration
|
|
// of its call so that ServeDNS can spot the OS resolver routing the
|
|
// recursive query back to us; the bypass branch skips the loop detector
|
|
// because its dial does not enter the system resolver.
|
|
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
|
|
m.mutex.RLock()
|
|
chain := m.chain
|
|
maxPriority := m.chainMaxPriority
|
|
bypass := m.bypassResolver
|
|
m.mutex.RUnlock()
|
|
|
|
if bypass != nil {
|
|
return m.osLookup(ctx, d, q.Name, q.Qtype)
|
|
}
|
|
|
|
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
|
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
|
|
}
|
|
|
|
// TODO: drop once every supported OS registers a fallback resolver.
|
|
m.markRefreshing(q)
|
|
defer m.clearRefreshing(q)
|
|
|
|
return m.osLookup(ctx, d, q.Name, q.Qtype)
|
|
}
|
|
|
|
// lookupViaChain resolves via the handler chain and rewrites each RR to use
|
|
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
|
|
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
|
|
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
|
|
msg := &dns.Msg{}
|
|
msg.SetQuestion(dnsName, qtype)
|
|
msg.RecursionDesired = true
|
|
|
|
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("chain resolve: %w", err)
|
|
}
|
|
if resp == nil {
|
|
return nil, fmt.Errorf("chain resolve returned nil response")
|
|
}
|
|
if resp.Rcode != dns.RcodeSuccess {
|
|
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
|
|
}
|
|
|
|
ttl := uint32(m.cacheTTL.Seconds())
|
|
owners := cnameOwners(dnsName, resp.Answer)
|
|
var filtered []dns.RR
|
|
for _, rr := range resp.Answer {
|
|
h := rr.Header()
|
|
if h.Class != dns.ClassINET || h.Rrtype != qtype {
|
|
continue
|
|
}
|
|
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
|
|
continue
|
|
}
|
|
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
|
|
filtered = append(filtered, cp)
|
|
}
|
|
}
|
|
return filtered, nil
|
|
}
|
|
|
|
// osLookup resolves a single family via the bypass resolver (if configured)
|
|
// or net.DefaultResolver using resutil, which disambiguates NODATA from
|
|
// NXDOMAIN and Unmaps v4-mapped-v6. NODATA returns (nil, nil).
|
|
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
|
|
network := resutil.NetworkForQtype(qtype)
|
|
if network == "" {
|
|
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
|
|
}
|
|
|
|
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
|
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
|
|
|
m.mutex.RLock()
|
|
resolver := m.bypassResolver
|
|
m.mutex.RUnlock()
|
|
if resolver == nil {
|
|
resolver = net.DefaultResolver
|
|
}
|
|
|
|
result := resutil.LookupIP(ctx, resolver, network, d.PunycodeString(), qtype)
|
|
if result.Rcode == dns.RcodeSuccess {
|
|
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
|
|
}
|
|
|
|
if result.Err != nil {
|
|
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
|
|
}
|
|
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
|
|
}
|
|
|
|
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
|
|
// so downstream resolvers don't cache an answer for longer than we will.
|
|
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
|
|
remaining := m.cacheTTL - time.Since(cachedAt)
|
|
if remaining <= 0 {
|
|
return 0
|
|
}
|
|
return uint32((remaining + time.Second - 1) / time.Second)
|
|
}
|
|
|
|
// PopulateFromConfig extracts and caches domains from the client configuration.
|
|
func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
|
|
if mgmtURL == nil {
|
|
return nil
|
|
}
|
|
|
|
d, err := dnsconfig.ExtractValidDomain(mgmtURL.String())
|
|
if err != nil {
|
|
return fmt.Errorf("extract domain from URL: %w", err)
|
|
}
|
|
|
|
m.mutex.Lock()
|
|
m.mgmtDomain = &d
|
|
m.mutex.Unlock()
|
|
|
|
if err := m.AddDomain(ctx, d); err != nil {
|
|
return fmt.Errorf("add domain: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// RemoveDomain removes a domain from the cache.
|
|
func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
|
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
|
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
|
|
qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
|
|
delete(m.records, qA)
|
|
delete(m.records, qAAAA)
|
|
delete(m.refreshing, qA)
|
|
delete(m.refreshing, qAAAA)
|
|
|
|
log.Debugf("removed domain=%s from cache", d.SafeString())
|
|
return nil
|
|
}
|
|
|
|
// GetPoolRootDomains returns the set of domains that should be registered
|
|
// with subdomain matching (currently the Relay entries from ServerDomains).
|
|
// Instance subdomains under these roots are resolved on demand in ServeDNS.
|
|
func (m *Resolver) GetPoolRootDomains() domain.List {
|
|
m.mutex.RLock()
|
|
defer m.mutex.RUnlock()
|
|
if m.serverDomains == nil {
|
|
return nil
|
|
}
|
|
out := make(domain.List, 0, len(m.serverDomains.Relay))
|
|
for _, d := range m.serverDomains.Relay {
|
|
if d != "" {
|
|
out = append(out, d)
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
// GetCachedDomains returns a list of all cached domains.
|
|
func (m *Resolver) GetCachedDomains() domain.List {
|
|
m.mutex.RLock()
|
|
defer m.mutex.RUnlock()
|
|
|
|
domainSet := make(map[domain.Domain]struct{})
|
|
for question := range m.records {
|
|
domainName := strings.TrimSuffix(question.Name, ".")
|
|
domainSet[domain.Domain(domainName)] = struct{}{}
|
|
}
|
|
|
|
domains := make(domain.List, 0, len(domainSet))
|
|
for d := range domainSet {
|
|
domains = append(domains, d)
|
|
}
|
|
|
|
return domains
|
|
}
|
|
|
|
// UpdateFromServerDomains updates the cache with server domains from network configuration.
|
|
// It merges new domains with existing ones, replacing entire domain types when updated.
|
|
// Empty updates are ignored to prevent clearing infrastructure domains during partial updates.
|
|
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) (domain.List, error) {
|
|
newDomains := m.extractDomainsFromServerDomains(serverDomains)
|
|
var removedDomains domain.List
|
|
|
|
if len(newDomains) > 0 {
|
|
m.mutex.Lock()
|
|
if m.serverDomains == nil {
|
|
m.serverDomains = &dnsconfig.ServerDomains{}
|
|
}
|
|
updatedServerDomains := m.mergeServerDomains(*m.serverDomains, serverDomains)
|
|
m.serverDomains = &updatedServerDomains
|
|
m.mutex.Unlock()
|
|
|
|
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
|
|
currentDomains := m.GetCachedDomains()
|
|
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
|
|
}
|
|
|
|
m.addNewDomains(ctx, newDomains)
|
|
|
|
return removedDomains, nil
|
|
}
|
|
|
|
// removeStaleDomains removes cached domains not present in the target domain list.
|
|
// Management domains are preserved and never removed during server domain updates.
|
|
func (m *Resolver) removeStaleDomains(currentDomains, newDomains domain.List) domain.List {
|
|
var removedDomains domain.List
|
|
|
|
for _, currentDomain := range currentDomains {
|
|
if m.isDomainInList(currentDomain, newDomains) {
|
|
continue
|
|
}
|
|
|
|
if m.isManagementDomain(currentDomain) {
|
|
continue
|
|
}
|
|
|
|
removedDomains = append(removedDomains, currentDomain)
|
|
if err := m.RemoveDomain(currentDomain); err != nil {
|
|
log.Warnf("failed to remove domain=%s: %v", currentDomain.SafeString(), err)
|
|
}
|
|
}
|
|
|
|
return removedDomains
|
|
}
|
|
|
|
// mergeServerDomains merges new server domains with existing ones.
|
|
// When a domain type is provided in the new domains, it completely replaces that type.
|
|
func (m *Resolver) mergeServerDomains(existing, incoming dnsconfig.ServerDomains) dnsconfig.ServerDomains {
|
|
merged := existing
|
|
|
|
if incoming.Signal != "" {
|
|
merged.Signal = incoming.Signal
|
|
}
|
|
if len(incoming.Relay) > 0 {
|
|
merged.Relay = incoming.Relay
|
|
}
|
|
if incoming.Flow != "" {
|
|
merged.Flow = incoming.Flow
|
|
}
|
|
if len(incoming.Stuns) > 0 {
|
|
merged.Stuns = incoming.Stuns
|
|
}
|
|
if len(incoming.Turns) > 0 {
|
|
merged.Turns = incoming.Turns
|
|
}
|
|
|
|
return merged
|
|
}
|
|
|
|
// isDomainInList checks if domain exists in the list
|
|
func (m *Resolver) isDomainInList(domain domain.Domain, list domain.List) bool {
|
|
for _, d := range list {
|
|
if domain.SafeString() == d.SafeString() {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// isManagementDomain checks if domain is the protected management domain
|
|
func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
|
|
m.mutex.RLock()
|
|
defer m.mutex.RUnlock()
|
|
|
|
return m.mgmtDomain != nil && domain == *m.mgmtDomain
|
|
}
|
|
|
|
// addNewDomains resolves and caches all domains from the update
|
|
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
|
|
for _, newDomain := range newDomains {
|
|
if err := m.AddDomain(ctx, newDomain); err != nil {
|
|
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
|
} else {
|
|
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) domain.List {
|
|
var domains domain.List
|
|
|
|
if serverDomains.Signal != "" {
|
|
domains = append(domains, serverDomains.Signal)
|
|
}
|
|
|
|
for _, relay := range serverDomains.Relay {
|
|
if relay != "" {
|
|
domains = append(domains, relay)
|
|
}
|
|
}
|
|
|
|
// Flow receiver domain is intentionally excluded from caching.
|
|
// Cloud providers may rotate the IP behind this domain; a stale cached record
|
|
// causes TLS certificate verification failures on reconnect.
|
|
|
|
for _, stun := range serverDomains.Stuns {
|
|
if stun != "" {
|
|
domains = append(domains, stun)
|
|
}
|
|
}
|
|
|
|
for _, turn := range serverDomains.Turns {
|
|
if turn != "" {
|
|
domains = append(domains, turn)
|
|
}
|
|
}
|
|
|
|
return domains
|
|
}
|
|
|
|
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
|
|
// A/AAAA records return nil.
|
|
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
|
|
switch r := rr.(type) {
|
|
case *dns.A:
|
|
cp := *r
|
|
cp.Hdr.Name = owner
|
|
cp.Hdr.Ttl = ttl
|
|
cp.A = slices.Clone(r.A)
|
|
return &cp
|
|
case *dns.AAAA:
|
|
cp := *r
|
|
cp.Hdr.Name = owner
|
|
cp.Hdr.Ttl = ttl
|
|
cp.AAAA = slices.Clone(r.AAAA)
|
|
return &cp
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
|
|
// stamping ttl so the response shares no memory with the cached slice.
|
|
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
|
|
out := make([]dns.RR, 0, len(records))
|
|
for _, rr := range records {
|
|
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
|
|
out = append(out, cp)
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
|
|
// in answer, iterating until fixed point so out-of-order chains resolve.
|
|
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
|
|
owners := map[string]bool{dnsName: true}
|
|
for {
|
|
added := false
|
|
for _, rr := range answer {
|
|
cname, ok := rr.(*dns.CNAME)
|
|
if !ok {
|
|
continue
|
|
}
|
|
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
|
|
if !owners[name] {
|
|
continue
|
|
}
|
|
target := strings.ToLower(dns.Fqdn(cname.Target))
|
|
if !owners[target] {
|
|
owners[target] = true
|
|
added = true
|
|
}
|
|
}
|
|
if !added {
|
|
return owners
|
|
}
|
|
}
|
|
}
|
|
|
|
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
|
|
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
|
|
func resolveCacheTTL() time.Duration {
|
|
if v := os.Getenv(envMgmtCacheTTL); v != "" {
|
|
if d, err := time.ParseDuration(v); err == nil && d > 0 {
|
|
return d
|
|
}
|
|
}
|
|
return defaultTTL
|
|
}
|