diff --git a/client/internal/dns/mgmt/bypass_resolver.go b/client/internal/dns/mgmt/bypass_resolver.go new file mode 100644 index 000000000..5a4c4442c --- /dev/null +++ b/client/internal/dns/mgmt/bypass_resolver.go @@ -0,0 +1,55 @@ +package mgmt + +import ( + "context" + "fmt" + "net" + "net/netip" + + nbnet "github.com/netbirdio/netbird/client/net" +) + +// NewBypassResolver builds a *net.Resolver that sends queries directly to +// the supplied nameservers through a socket that bypasses the NetBird +// overlay interface. This lets the mgmt cache refresh control-plane +// FQDNs (api/signal/relay/stun/turn) even when an exit-node default +// route is installed on the overlay before its peer is live. +// +// Returns nil if nameservers is empty. The caller must not pass +// loopback/overlay IPs (e.g. 127.0.0.1, the overlay listener address); +// those would defeat the purpose of bypassing. +func NewBypassResolver(nameservers []netip.Addr) *net.Resolver { + if len(nameservers) == 0 { + return nil + } + + servers := make([]string, 0, len(nameservers)) + for _, ns := range nameservers { + if !ns.IsValid() || ns.IsLoopback() || ns.IsUnspecified() { + continue + } + servers = append(servers, netip.AddrPortFrom(ns, 53).String()) + } + if len(servers) == 0 { + return nil + } + + return &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, _ string) (net.Conn, error) { + nbDialer := nbnet.NewDialer() + var lastErr error + for _, ns := range servers { + conn, err := nbDialer.DialContext(ctx, network, ns) + if err == nil { + return conn, nil + } + lastErr = err + } + if lastErr == nil { + return nil, fmt.Errorf("no bypass nameservers configured") + } + return nil, fmt.Errorf("dial bypass nameservers: %w", lastErr) + }, + } +} diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index 988e427fb..750c08466 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -71,6 +71,14 @@ type Resolver struct { refreshing map[dns.Question]*atomic.Bool cacheTTL time.Duration + + // bypassResolver, when non-nil, is used by osLookup instead of + // net.DefaultResolver. It is constructed by the caller to dial the + // original (pre-NetBird) system nameservers through a socket that + // bypasses the overlay interface (control-plane fwmark / bound iface), + // so that when an exit-node default route is installed before a peer + // is handshaked the refresh does not fail with ENOKEY. + bypassResolver *net.Resolver } // NewResolver creates a new management domains cache resolver. @@ -98,8 +106,28 @@ func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) { m.mutex.Unlock() } +// SetBypassResolver installs a resolver that osLookup uses instead of +// net.DefaultResolver. It is intended to dial the original (pre-NetBird) +// system nameservers through a socket that does not follow the overlay +// default route, so that a refresh initiated while an exit node is active +// but its WireGuard peer is not yet installed cannot deadlock on ENOKEY. +// Passing nil restores use of net.DefaultResolver. +func (m *Resolver) SetBypassResolver(r *net.Resolver) { + m.mutex.Lock() + m.bypassResolver = r + m.mutex.Unlock() +} + // ServeDNS serves cached A/AAAA records. Stale entries are returned // immediately and refreshed asynchronously (stale-while-revalidate). +// +// If the query name is not in the cache but falls under a pool-root +// domain (a domain the mgmt advertised in ServerDomains.Relay, whose +// instance subdomains like streamline-de-fra1-0.relay.netbird.io are +// part of the relay pool), resolve it on demand through the bypass +// resolver and cache the result. This is what lets the daemon reach +// a foreign relay FQDN after an exit-node default route has been +// installed on the overlay before its peer is live. func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 0 { m.continueToNext(w, r) @@ -126,6 +154,10 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { m.mutex.RUnlock() if !found { + if m.isUnderPoolRoot(question.Name) { + m.resolveOnDemand(w, r, question) + return + } m.continueToNext(w, r) return } @@ -155,12 +187,87 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } } -// MatchSubdomains returns false since this resolver only handles exact domain matches -// for NetBird infrastructure domains (signal, relay, flow, etc.), not their subdomains. +// MatchSubdomains returns false by default: the bare resolver is registered +// against exact domains. Pool-root domains (currently Relay entries from +// ServerDomains) are registered through a subdomain-matching wrapper at +// the call site instead, so instance subdomains hit this handler and get +// the on-demand resolve path in ServeDNS. func (m *Resolver) MatchSubdomains() bool { return false } +// isUnderPoolRoot reports whether fqdn is an instance subdomain under any +// pool-root domain advertised by the mgmt (currently ServerDomains.Relay), +// e.g. "streamline-de-fra1-0.relay.netbird.io." is under "relay.netbird.io". +// The pool-root itself is not considered a subdomain (it matches the exact +// cache entry populated by AddDomain instead). +func (m *Resolver) isUnderPoolRoot(fqdn string) bool { + m.mutex.RLock() + defer m.mutex.RUnlock() + if m.serverDomains == nil { + return false + } + fqdn = strings.ToLower(strings.TrimSuffix(fqdn, ".")) + for _, root := range m.serverDomains.Relay { + r := strings.ToLower(strings.TrimSuffix(root.PunycodeString(), ".")) + if r == "" || fqdn == r { + continue + } + if strings.HasSuffix(fqdn, "."+r) { + return true + } + } + return false +} + +// resolveOnDemand resolves an uncached pool-root subdomain (e.g. a relay +// instance FQDN) through the bypass resolver path, caches the result, and +// writes it back to w. Falls through to the next handler on error so the +// normal chain can still attempt the resolve. +func (m *Resolver) resolveOnDemand(w dns.ResponseWriter, r *dns.Msg, question dns.Question) { + d, err := domain.FromString(strings.TrimSuffix(question.Name, ".")) + if err != nil { + log.Debugf("on-demand resolve: parse domain %q: %v", question.Name, err) + m.continueToNext(w, r) + return + } + + ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) + defer cancel() + + records, err := m.lookupRecords(ctx, d, question) + if err != nil { + log.Debugf("on-demand resolve %s type=%s: %v", + d.SafeString(), dns.TypeToString[question.Qtype], err) + m.continueToNext(w, r) + return + } + if len(records) == 0 { + m.continueToNext(w, r) + return + } + + now := time.Now() + m.mutex.Lock() + if _, exists := m.records[question]; !exists { + m.records[question] = &cachedRecord{records: records, cachedAt: now} + } + m.mutex.Unlock() + + resp := &dns.Msg{} + resp.SetReply(r) + resp.Authoritative = false + resp.RecursionAvailable = true + resp.Answer = cloneRecordsWithTTL(records, uint32(m.cacheTTL.Seconds())) + + log.Debugf("on-demand resolved %d records for domain=%s", len(resp.Answer), question.Name) + + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write on-demand response: %v", err) + } +} + + // continueToNext signals the handler chain to continue to the next handler. func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) { resp := &dns.Msg{} @@ -315,14 +422,29 @@ func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedReco return c.consecFailures } -// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let -// callers tell records, NODATA (nil err, no records), and failure apart. +// lookupBoth resolves A and AAAA via bypass resolver, chain, or OS. +// Per-family errors let callers tell records, NODATA (nil err, no records), +// and failure apart. +// +// Preference order: +// 1. bypassResolver (direct, overlay-bypassing dial to original system +// nameservers; immune to the exit-node ENOKEY race). +// 2. chain (handler chain; used when NetBird is the system resolver and +// no bypass resolver is installed). +// 3. net.DefaultResolver via osLookup (legacy fallback). func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) { m.mutex.RLock() chain := m.chain maxPriority := m.chainMaxPriority + bypass := m.bypassResolver m.mutex.RUnlock() + if bypass != nil { + aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA) + aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA) + return + } + if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) { aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA) aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA) @@ -337,15 +459,22 @@ func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName stri return } -// lookupRecords resolves a single record type via chain or OS. The OS branch -// arms the loop detector for the duration of its call so that ServeDNS can -// spot the OS resolver routing the recursive query back to us. +// lookupRecords resolves a single record type. See lookupBoth for the +// preference order. The OS branch arms the loop detector for the duration +// of its call so that ServeDNS can spot the OS resolver routing the +// recursive query back to us; the bypass branch skips the loop detector +// because its dial does not enter the system resolver. func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) { m.mutex.RLock() chain := m.chain maxPriority := m.chainMaxPriority + bypass := m.bypassResolver m.mutex.RUnlock() + if bypass != nil { + return m.osLookup(ctx, d, q.Name, q.Qtype) + } + if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) { return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype) } @@ -394,9 +523,9 @@ func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxP return filtered, nil } -// osLookup resolves a single family via net.DefaultResolver using resutil, -// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA -// returns (nil, nil). +// osLookup resolves a single family via the bypass resolver (if configured) +// or net.DefaultResolver using resutil, which disambiguates NODATA from +// NXDOMAIN and Unmaps v4-mapped-v6. NODATA returns (nil, nil). func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) { network := resutil.NetworkForQtype(qtype) if network == "" { @@ -406,7 +535,14 @@ func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype]) defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype]) - result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype) + m.mutex.RLock() + resolver := m.bypassResolver + m.mutex.RUnlock() + if resolver == nil { + resolver = net.DefaultResolver + } + + result := resutil.LookupIP(ctx, resolver, network, d.PunycodeString(), qtype) if result.Rcode == dns.RcodeSuccess { return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil } @@ -467,6 +603,24 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error { return nil } +// GetPoolRootDomains returns the set of domains that should be registered +// with subdomain matching (currently the Relay entries from ServerDomains). +// Instance subdomains under these roots are resolved on demand in ServeDNS. +func (m *Resolver) GetPoolRootDomains() domain.List { + m.mutex.RLock() + defer m.mutex.RUnlock() + if m.serverDomains == nil { + return nil + } + out := make(domain.List, 0, len(m.serverDomains.Relay)) + for _, d := range m.serverDomains.Relay { + if d != "" { + out = append(out, d) + } + } + return out +} + // GetCachedDomains returns a list of all cached domains. func (m *Resolver) GetCachedDomains() domain.List { m.mutex.RLock() diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index d4f54dec5..36ebd6b45 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -31,6 +31,28 @@ import ( const envSkipDNSProbe = "NB_SKIP_DNS_PROBE" +// subdomainMatchHandler is a thin wrapper used to register a handler under +// a pool-root domain (e.g. a relay URL advertised by the mgmt) with +// subdomain matching enabled. The underlying handler's own MatchSubdomains +// is left untouched so that exact-match registrations keep their +// semantics. +type subdomainMatchHandler struct { + dns.Handler +} + +// MatchSubdomains lets the handler chain route any instance subdomain +// (e.g. streamline-de-fra1-0.relay.netbird.io) to the wrapped handler. +func (subdomainMatchHandler) MatchSubdomains() bool { return true } + +// String returns a debug-friendly name; the chain uses fmt.Stringer for +// its "registering handler X" logs. +func (h subdomainMatchHandler) String() string { + if s, ok := h.Handler.(fmt.Stringer); ok { + return s.String() + "[subdomains]" + } + return "subdomainMatchHandler" +} + // ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes type ReadyListener interface { OnReady() @@ -597,9 +619,32 @@ func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) erro s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache) } - newDomains := s.mgmtCacheResolver.GetCachedDomains() - if len(newDomains) > 0 { - s.registerHandler(newDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache) + // Pool-root domains (advertised by the mgmt as Relay URLs) own + // their instance subdomains. Register them through a thin + // subdomain-matching wrapper so a query like + // "streamline-de-fra1-0.relay.netbird.io" routes to the mgmt + // cache resolver, which resolves it on demand through the bypass + // resolver instead of falling through to the overlay-routed + // upstream handler. + poolRoots := s.mgmtCacheResolver.GetPoolRootDomains() + poolRootSet := make(map[domain.Domain]struct{}, len(poolRoots)) + for _, d := range poolRoots { + poolRootSet[d] = struct{}{} + } + + if len(poolRoots) > 0 { + s.registerHandler(poolRoots.ToPunycodeList(), subdomainMatchHandler{Handler: s.mgmtCacheResolver}, PriorityMgmtCache) + } + + var exactDomains domain.List + for _, d := range s.mgmtCacheResolver.GetCachedDomains() { + if _, isPool := poolRootSet[d]; isPool { + continue + } + exactDomains = append(exactDomains, d) + } + if len(exactDomains) > 0 { + s.registerHandler(exactDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache) } } @@ -759,6 +804,9 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) { originalNameservers := hostMgrWithNS.getOriginalNameservers() if len(originalNameservers) == 0 { s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback) + if s.mgmtCacheResolver != nil { + s.mgmtCacheResolver.SetBypassResolver(nil) + } return } @@ -777,6 +825,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) { } handler.routeMatch = s.routeMatch + var bypassNameservers []netip.Addr for _, ns := range originalNameservers { if ns == config.ServerIP { log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP) @@ -785,11 +834,22 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) { addrPort := netip.AddrPortFrom(ns, DefaultPort) handler.upstreamServers = append(handler.upstreamServers, addrPort) + bypassNameservers = append(bypassNameservers, ns) } handler.deactivate = func(error) { /* always active */ } handler.reactivate = func() { /* always active */ } s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback) + + // Wire a bypass resolver into the mgmt cache so its refresh path dials + // the original nameservers directly over a fwmarked socket, avoiding + // the ENOKEY deadlock that occurs when an exit-node default route is + // installed on the overlay before its peer has handshaked. Scoped to + // the mgmt cache only: ordinary user DNS still flows through the + // normal upstream path. + if s.mgmtCacheResolver != nil { + s.mgmtCacheResolver.SetBypassResolver(mgmt.NewBypassResolver(bypassNameservers)) + } } func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.CustomZone, error) {