mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-25 11:46:40 +00:00
Compare commits
6 Commits
revert-eas
...
fix-mgmt-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3cdfa11cb8 | ||
|
|
12d0edabc0 | ||
|
|
b0b52b6774 | ||
|
|
c89c30bb28 | ||
|
|
14be474e3d | ||
|
|
77ec25796e |
55
client/internal/dns/mgmt/bypass_resolver.go
Normal file
55
client/internal/dns/mgmt/bypass_resolver.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package mgmt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewBypassResolver builds a *net.Resolver that sends queries directly to
|
||||||
|
// the supplied nameservers through a socket that bypasses the NetBird
|
||||||
|
// overlay interface. This lets the mgmt cache refresh control-plane
|
||||||
|
// FQDNs (api/signal/relay/stun/turn) even when an exit-node default
|
||||||
|
// route is installed on the overlay before its peer is live.
|
||||||
|
//
|
||||||
|
// Returns nil if nameservers is empty. The caller must not pass
|
||||||
|
// loopback/overlay IPs (e.g. 127.0.0.1, the overlay listener address);
|
||||||
|
// those would defeat the purpose of bypassing.
|
||||||
|
func NewBypassResolver(nameservers []netip.Addr) *net.Resolver {
|
||||||
|
if len(nameservers) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
servers := make([]string, 0, len(nameservers))
|
||||||
|
for _, ns := range nameservers {
|
||||||
|
if !ns.IsValid() || ns.IsLoopback() || ns.IsUnspecified() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
servers = append(servers, netip.AddrPortFrom(ns, 53).String())
|
||||||
|
}
|
||||||
|
if len(servers) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &net.Resolver{
|
||||||
|
PreferGo: true,
|
||||||
|
Dial: func(ctx context.Context, network, _ string) (net.Conn, error) {
|
||||||
|
nbDialer := nbnet.NewDialer()
|
||||||
|
var lastErr error
|
||||||
|
for _, ns := range servers {
|
||||||
|
conn, err := nbDialer.DialContext(ctx, network, ns)
|
||||||
|
if err == nil {
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
|
}
|
||||||
|
if lastErr == nil {
|
||||||
|
return nil, fmt.Errorf("no bypass nameservers configured")
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("dial bypass nameservers: %w", lastErr)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -71,6 +72,14 @@ type Resolver struct {
|
|||||||
refreshing map[dns.Question]*atomic.Bool
|
refreshing map[dns.Question]*atomic.Bool
|
||||||
|
|
||||||
cacheTTL time.Duration
|
cacheTTL time.Duration
|
||||||
|
|
||||||
|
// bypassResolver, when non-nil, is used by osLookup instead of
|
||||||
|
// net.DefaultResolver. It is constructed by the caller to dial the
|
||||||
|
// original (pre-NetBird) system nameservers through a socket that
|
||||||
|
// bypasses the overlay interface (control-plane fwmark / bound iface),
|
||||||
|
// so that when an exit-node default route is installed before a peer
|
||||||
|
// is handshaked the refresh does not fail with ENOKEY.
|
||||||
|
bypassResolver *net.Resolver
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResolver creates a new management domains cache resolver.
|
// NewResolver creates a new management domains cache resolver.
|
||||||
@@ -98,8 +107,28 @@ func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
|
|||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetBypassResolver installs a resolver that osLookup uses instead of
|
||||||
|
// net.DefaultResolver. It is intended to dial the original (pre-NetBird)
|
||||||
|
// system nameservers through a socket that does not follow the overlay
|
||||||
|
// default route, so that a refresh initiated while an exit node is active
|
||||||
|
// but its WireGuard peer is not yet installed cannot deadlock on ENOKEY.
|
||||||
|
// Passing nil restores use of net.DefaultResolver.
|
||||||
|
func (m *Resolver) SetBypassResolver(r *net.Resolver) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
m.bypassResolver = r
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
// ServeDNS serves cached A/AAAA records. Stale entries are returned
|
// ServeDNS serves cached A/AAAA records. Stale entries are returned
|
||||||
// immediately and refreshed asynchronously (stale-while-revalidate).
|
// immediately and refreshed asynchronously (stale-while-revalidate).
|
||||||
|
//
|
||||||
|
// If the query name is not in the cache but falls under a pool-root
|
||||||
|
// domain (a domain the mgmt advertised in ServerDomains.Relay, whose
|
||||||
|
// instance subdomains like streamline-de-fra1-0.relay.netbird.io are
|
||||||
|
// part of the relay pool), resolve it on demand through the bypass
|
||||||
|
// resolver and cache the result. This is what lets the daemon reach
|
||||||
|
// a foreign relay FQDN after an exit-node default route has been
|
||||||
|
// installed on the overlay before its peer is live.
|
||||||
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
m.continueToNext(w, r)
|
m.continueToNext(w, r)
|
||||||
@@ -126,6 +155,10 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
|
if m.isUnderPoolRoot(question.Name) {
|
||||||
|
m.resolveOnDemand(w, r, question)
|
||||||
|
return
|
||||||
|
}
|
||||||
m.continueToNext(w, r)
|
m.continueToNext(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -155,12 +188,117 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MatchSubdomains returns false since this resolver only handles exact domain matches
|
// MatchSubdomains returns false by default: the bare resolver is registered
|
||||||
// for NetBird infrastructure domains (signal, relay, flow, etc.), not their subdomains.
|
// against exact domains. Pool-root domains (currently Relay entries from
|
||||||
|
// ServerDomains) are registered through a subdomain-matching wrapper at
|
||||||
|
// the call site instead, so instance subdomains hit this handler and get
|
||||||
|
// the on-demand resolve path in ServeDNS.
|
||||||
func (m *Resolver) MatchSubdomains() bool {
|
func (m *Resolver) MatchSubdomains() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isUnderPoolRoot reports whether fqdn is an instance subdomain under any
|
||||||
|
// pool-root domain advertised by the mgmt (currently ServerDomains.Relay),
|
||||||
|
// e.g. "streamline-de-fra1-0.relay.netbird.io." is under "relay.netbird.io".
|
||||||
|
// The pool-root itself is not considered a subdomain (it matches the exact
|
||||||
|
// cache entry populated by AddDomain instead).
|
||||||
|
//
|
||||||
|
// Canonicalization mirrors server.toZone — lowercase, strip trailing dot,
|
||||||
|
// and strip a leading "*." wildcard (via canonicalizePoolDomain) — so the
|
||||||
|
// membership check is consistent with the handler-chain registration that
|
||||||
|
// runs the same set through toZone. toZone itself lives in the parent dns
|
||||||
|
// package and cannot be imported from here without a cycle.
|
||||||
|
func (m *Resolver) isUnderPoolRoot(fqdn string) bool {
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
|
if m.serverDomains == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
fqdn = canonicalizePoolDomain(fqdn)
|
||||||
|
if fqdn == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, root := range m.serverDomains.Relay {
|
||||||
|
r := canonicalizePoolDomain(root.PunycodeString())
|
||||||
|
if r == "" || fqdn == r {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(fqdn, "."+r) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// canonicalizePoolDomain normalizes a domain for pool-root membership
|
||||||
|
// comparison: lowercase, trailing dot stripped, leading "*." wildcard
|
||||||
|
// stripped. Matches the transformation server.toZone applies on the
|
||||||
|
// handler-registration side (modulo trailing-dot orientation, which is
|
||||||
|
// self-consistent within this file).
|
||||||
|
func canonicalizePoolDomain(s string) string {
|
||||||
|
s = strings.ToLower(strings.TrimSuffix(s, "."))
|
||||||
|
s = strings.TrimPrefix(s, "*.")
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveOnDemand resolves an uncached pool-root subdomain (e.g. a relay
|
||||||
|
// instance FQDN) through the bypass resolver path, caches the result, and
|
||||||
|
// writes it back to w. Falls through to the next handler on error so the
|
||||||
|
// normal chain can still attempt the resolve.
|
||||||
|
func (m *Resolver) resolveOnDemand(w dns.ResponseWriter, r *dns.Msg, question dns.Question) {
|
||||||
|
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("on-demand resolve: parse domain %q: %v", question.Name, err)
|
||||||
|
m.continueToNext(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collapse concurrent on-demand lookups for the same (name, qtype) into
|
||||||
|
// a single upstream query via singleflight. A burst of parallel queries
|
||||||
|
// for a freshly-learned pool-root subdomain (e.g. multiple peer workers
|
||||||
|
// dialing the same foreign relay, or A + AAAA racing each other) would
|
||||||
|
// otherwise each hit the bypass resolver independently. The prefix
|
||||||
|
// namespaces this key off scheduleRefresh's keyspace so the two paths
|
||||||
|
// can coexist without collisions.
|
||||||
|
key := "ondemand:" + question.Name + ":" + strconv.Itoa(int(question.Qtype))
|
||||||
|
result, err, _ := m.refreshGroup.Do(key, func() (any, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||||
|
defer cancel()
|
||||||
|
return m.lookupRecords(ctx, d, question)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("on-demand resolve %s type=%s: %v",
|
||||||
|
d.SafeString(), dns.TypeToString[question.Qtype], err)
|
||||||
|
m.continueToNext(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
records, _ := result.([]dns.RR)
|
||||||
|
if len(records) == 0 {
|
||||||
|
m.continueToNext(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
m.mutex.Lock()
|
||||||
|
if _, exists := m.records[question]; !exists {
|
||||||
|
m.records[question] = &cachedRecord{records: records, cachedAt: now}
|
||||||
|
}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
|
||||||
|
resp := &dns.Msg{}
|
||||||
|
resp.SetReply(r)
|
||||||
|
resp.Authoritative = false
|
||||||
|
resp.RecursionAvailable = true
|
||||||
|
resp.Answer = cloneRecordsWithTTL(records, uint32(m.cacheTTL.Seconds()))
|
||||||
|
|
||||||
|
log.Debugf("on-demand resolved %d records for domain=%s", len(resp.Answer), question.Name)
|
||||||
|
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write on-demand response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// continueToNext signals the handler chain to continue to the next handler.
|
// continueToNext signals the handler chain to continue to the next handler.
|
||||||
func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
resp := &dns.Msg{}
|
resp := &dns.Msg{}
|
||||||
@@ -315,14 +453,29 @@ func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedReco
|
|||||||
return c.consecFailures
|
return c.consecFailures
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
|
// lookupBoth resolves A and AAAA via bypass resolver, chain, or OS.
|
||||||
// callers tell records, NODATA (nil err, no records), and failure apart.
|
// Per-family errors let callers tell records, NODATA (nil err, no records),
|
||||||
|
// and failure apart.
|
||||||
|
//
|
||||||
|
// Preference order:
|
||||||
|
// 1. bypassResolver (direct, overlay-bypassing dial to original system
|
||||||
|
// nameservers; immune to the exit-node ENOKEY race).
|
||||||
|
// 2. chain (handler chain; used when NetBird is the system resolver and
|
||||||
|
// no bypass resolver is installed).
|
||||||
|
// 3. net.DefaultResolver via osLookup (legacy fallback).
|
||||||
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
|
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
chain := m.chain
|
chain := m.chain
|
||||||
maxPriority := m.chainMaxPriority
|
maxPriority := m.chainMaxPriority
|
||||||
|
bypass := m.bypassResolver
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
if bypass != nil {
|
||||||
|
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
|
||||||
|
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
||||||
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
|
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
|
||||||
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
|
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
|
||||||
@@ -337,15 +490,22 @@ func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName stri
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookupRecords resolves a single record type via chain or OS. The OS branch
|
// lookupRecords resolves a single record type. See lookupBoth for the
|
||||||
// arms the loop detector for the duration of its call so that ServeDNS can
|
// preference order. The OS branch arms the loop detector for the duration
|
||||||
// spot the OS resolver routing the recursive query back to us.
|
// of its call so that ServeDNS can spot the OS resolver routing the
|
||||||
|
// recursive query back to us; the bypass branch skips the loop detector
|
||||||
|
// because its dial does not enter the system resolver.
|
||||||
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
|
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
chain := m.chain
|
chain := m.chain
|
||||||
maxPriority := m.chainMaxPriority
|
maxPriority := m.chainMaxPriority
|
||||||
|
bypass := m.bypassResolver
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
if bypass != nil {
|
||||||
|
return m.osLookup(ctx, d, q.Name, q.Qtype)
|
||||||
|
}
|
||||||
|
|
||||||
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
||||||
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
|
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
|
||||||
}
|
}
|
||||||
@@ -394,9 +554,9 @@ func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxP
|
|||||||
return filtered, nil
|
return filtered, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// osLookup resolves a single family via net.DefaultResolver using resutil,
|
// osLookup resolves a single family via the bypass resolver (if configured)
|
||||||
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
|
// or net.DefaultResolver using resutil, which disambiguates NODATA from
|
||||||
// returns (nil, nil).
|
// NXDOMAIN and Unmaps v4-mapped-v6. NODATA returns (nil, nil).
|
||||||
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
|
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
|
||||||
network := resutil.NetworkForQtype(qtype)
|
network := resutil.NetworkForQtype(qtype)
|
||||||
if network == "" {
|
if network == "" {
|
||||||
@@ -406,7 +566,14 @@ func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string
|
|||||||
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
||||||
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
||||||
|
|
||||||
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
|
m.mutex.RLock()
|
||||||
|
resolver := m.bypassResolver
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
if resolver == nil {
|
||||||
|
resolver = net.DefaultResolver
|
||||||
|
}
|
||||||
|
|
||||||
|
result := resutil.LookupIP(ctx, resolver, network, d.PunycodeString(), qtype)
|
||||||
if result.Rcode == dns.RcodeSuccess {
|
if result.Rcode == dns.RcodeSuccess {
|
||||||
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
|
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
|
||||||
}
|
}
|
||||||
@@ -467,6 +634,24 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPoolRootDomains returns the set of domains that should be registered
|
||||||
|
// with subdomain matching (currently the Relay entries from ServerDomains).
|
||||||
|
// Instance subdomains under these roots are resolved on demand in ServeDNS.
|
||||||
|
func (m *Resolver) GetPoolRootDomains() domain.List {
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
|
if m.serverDomains == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(domain.List, 0, len(m.serverDomains.Relay))
|
||||||
|
for _, d := range m.serverDomains.Relay {
|
||||||
|
if d != "" {
|
||||||
|
out = append(out, d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// GetCachedDomains returns a list of all cached domains.
|
// GetCachedDomains returns a list of all cached domains.
|
||||||
func (m *Resolver) GetCachedDomains() domain.List {
|
func (m *Resolver) GetCachedDomains() domain.List {
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
|
|||||||
@@ -31,6 +31,28 @@ import (
|
|||||||
|
|
||||||
const envSkipDNSProbe = "NB_SKIP_DNS_PROBE"
|
const envSkipDNSProbe = "NB_SKIP_DNS_PROBE"
|
||||||
|
|
||||||
|
// subdomainMatchHandler is a thin wrapper used to register a handler under
|
||||||
|
// a pool-root domain (e.g. a relay URL advertised by the mgmt) with
|
||||||
|
// subdomain matching enabled. The underlying handler's own MatchSubdomains
|
||||||
|
// is left untouched so that exact-match registrations keep their
|
||||||
|
// semantics.
|
||||||
|
type subdomainMatchHandler struct {
|
||||||
|
dns.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// MatchSubdomains lets the handler chain route any instance subdomain
|
||||||
|
// (e.g. streamline-de-fra1-0.relay.netbird.io) to the wrapped handler.
|
||||||
|
func (subdomainMatchHandler) MatchSubdomains() bool { return true }
|
||||||
|
|
||||||
|
// String returns a debug-friendly name; the chain uses fmt.Stringer for
|
||||||
|
// its "registering handler X" logs.
|
||||||
|
func (h subdomainMatchHandler) String() string {
|
||||||
|
if s, ok := h.Handler.(fmt.Stringer); ok {
|
||||||
|
return s.String() + "[subdomains]"
|
||||||
|
}
|
||||||
|
return "subdomainMatchHandler"
|
||||||
|
}
|
||||||
|
|
||||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||||
type ReadyListener interface {
|
type ReadyListener interface {
|
||||||
OnReady()
|
OnReady()
|
||||||
@@ -95,6 +117,11 @@ type DefaultServer struct {
|
|||||||
batchMode bool
|
batchMode bool
|
||||||
|
|
||||||
mgmtCacheResolver *mgmt.Resolver
|
mgmtCacheResolver *mgmt.Resolver
|
||||||
|
// mgmtPoolRoots tracks pool-root domains currently contributed to
|
||||||
|
// extraDomains by the mgmt cache, so the next UpdateServerConfig can
|
||||||
|
// decrement the old set before incrementing the new one without
|
||||||
|
// disturbing unrelated registerHandler callers.
|
||||||
|
mgmtPoolRoots map[domain.Domain]struct{}
|
||||||
|
|
||||||
// permanent related properties
|
// permanent related properties
|
||||||
permanent bool
|
permanent bool
|
||||||
@@ -229,6 +256,7 @@ func newDefaultServer(
|
|||||||
hostsDNSHolder: newHostsDNSHolder(),
|
hostsDNSHolder: newHostsDNSHolder(),
|
||||||
hostManager: &noopHostConfigurator{},
|
hostManager: &noopHostConfigurator{},
|
||||||
mgmtCacheResolver: mgmtCacheResolver,
|
mgmtCacheResolver: mgmtCacheResolver,
|
||||||
|
mgmtPoolRoots: make(map[domain.Domain]struct{}),
|
||||||
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
|
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -587,25 +615,92 @@ func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) erro
|
|||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
if s.mgmtCacheResolver != nil {
|
if s.mgmtCacheResolver == nil {
|
||||||
removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains)
|
return nil
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("update management cache resolver: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(removedDomains) > 0 {
|
|
||||||
s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache)
|
|
||||||
}
|
|
||||||
|
|
||||||
newDomains := s.mgmtCacheResolver.GetCachedDomains()
|
|
||||||
if len(newDomains) > 0 {
|
|
||||||
s.registerHandler(newDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update management cache resolver: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(removedDomains) > 0 {
|
||||||
|
s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache)
|
||||||
|
}
|
||||||
|
|
||||||
|
poolRoots := s.mgmtCacheResolver.GetPoolRootDomains()
|
||||||
|
s.registerMgmtCacheHandlers(poolRoots)
|
||||||
|
s.reconcileMgmtPoolRoots(poolRoots)
|
||||||
|
|
||||||
|
if !s.batchMode {
|
||||||
|
s.applyHostConfig()
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// registerMgmtCacheHandlers wires the mgmt cache resolver into the handler
|
||||||
|
// chain for the current set of cached domains. Pool-root domains (advertised
|
||||||
|
// by the mgmt as Relay URLs) go through a thin subdomain-matching wrapper so
|
||||||
|
// a query like "streamline-de-fra1-0.relay.netbird.io" routes to the mgmt
|
||||||
|
// cache resolver, which resolves it on demand through the bypass resolver
|
||||||
|
// instead of falling through to the overlay-routed upstream handler.
|
||||||
|
//
|
||||||
|
// Canonicalize with toZone on both sides of the pool-root membership check so
|
||||||
|
// the comparison is independent of each source's canonical form:
|
||||||
|
// GetPoolRootDomains returns what the extractor stored; GetCachedDomains
|
||||||
|
// strips the trailing dot from question names.
|
||||||
|
func (s *DefaultServer) registerMgmtCacheHandlers(poolRoots domain.List) {
|
||||||
|
poolRootSet := make(map[domain.Domain]struct{}, len(poolRoots))
|
||||||
|
for _, d := range poolRoots {
|
||||||
|
poolRootSet[toZone(d)] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(poolRoots) > 0 {
|
||||||
|
s.registerHandler(poolRoots.ToPunycodeList(), subdomainMatchHandler{Handler: s.mgmtCacheResolver}, PriorityMgmtCache)
|
||||||
|
}
|
||||||
|
|
||||||
|
var exactDomains domain.List
|
||||||
|
for _, d := range s.mgmtCacheResolver.GetCachedDomains() {
|
||||||
|
if _, isPool := poolRootSet[toZone(d)]; isPool {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
exactDomains = append(exactDomains, d)
|
||||||
|
}
|
||||||
|
if len(exactDomains) > 0 {
|
||||||
|
s.registerHandler(exactDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reconcileMgmtPoolRoots keeps extraDomains in sync with the current mgmt
|
||||||
|
// pool-root set. These entries show up as *match* domains for the host DNS
|
||||||
|
// manager (systemd-resolved, NetworkManager, etc.) so instance subdomain
|
||||||
|
// queries like streamline-* are delegated to the wt0 link where the daemon's
|
||||||
|
// DNS listener sits. Without this, systemd-resolved answers them from the
|
||||||
|
// host's global upstream, skipping our handler chain entirely.
|
||||||
|
//
|
||||||
|
// Uses s.mgmtPoolRoots as a dedicated tracking map so increments/decrements
|
||||||
|
// here don't collide with RegisterHandler's refcounting.
|
||||||
|
func (s *DefaultServer) reconcileMgmtPoolRoots(poolRoots domain.List) {
|
||||||
|
newPoolRoots := make(map[domain.Domain]struct{}, len(poolRoots))
|
||||||
|
for _, d := range poolRoots {
|
||||||
|
zone := toZone(d)
|
||||||
|
newPoolRoots[zone] = struct{}{}
|
||||||
|
if _, already := s.mgmtPoolRoots[zone]; !already {
|
||||||
|
s.extraDomains[zone]++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for zone := range s.mgmtPoolRoots {
|
||||||
|
if _, keep := newPoolRoots[zone]; keep {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.extraDomains[zone]--
|
||||||
|
if s.extraDomains[zone] <= 0 {
|
||||||
|
delete(s.extraDomains, zone)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.mgmtPoolRoots = newPoolRoots
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||||
// is the service should be Disabled, we stop the listener or fake resolver
|
// is the service should be Disabled, we stop the listener or fake resolver
|
||||||
if update.ServiceEnable {
|
if update.ServiceEnable {
|
||||||
@@ -759,6 +854,9 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
|||||||
originalNameservers := hostMgrWithNS.getOriginalNameservers()
|
originalNameservers := hostMgrWithNS.getOriginalNameservers()
|
||||||
if len(originalNameservers) == 0 {
|
if len(originalNameservers) == 0 {
|
||||||
s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback)
|
s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback)
|
||||||
|
if s.mgmtCacheResolver != nil {
|
||||||
|
s.mgmtCacheResolver.SetBypassResolver(nil)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -777,6 +875,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
|||||||
}
|
}
|
||||||
handler.routeMatch = s.routeMatch
|
handler.routeMatch = s.routeMatch
|
||||||
|
|
||||||
|
var bypassNameservers []netip.Addr
|
||||||
for _, ns := range originalNameservers {
|
for _, ns := range originalNameservers {
|
||||||
if ns == config.ServerIP {
|
if ns == config.ServerIP {
|
||||||
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP)
|
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP)
|
||||||
@@ -785,11 +884,22 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
|||||||
|
|
||||||
addrPort := netip.AddrPortFrom(ns, DefaultPort)
|
addrPort := netip.AddrPortFrom(ns, DefaultPort)
|
||||||
handler.upstreamServers = append(handler.upstreamServers, addrPort)
|
handler.upstreamServers = append(handler.upstreamServers, addrPort)
|
||||||
|
bypassNameservers = append(bypassNameservers, ns)
|
||||||
}
|
}
|
||||||
handler.deactivate = func(error) { /* always active */ }
|
handler.deactivate = func(error) { /* always active */ }
|
||||||
handler.reactivate = func() { /* always active */ }
|
handler.reactivate = func() { /* always active */ }
|
||||||
|
|
||||||
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback)
|
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback)
|
||||||
|
|
||||||
|
// Wire a bypass resolver into the mgmt cache so its refresh path dials
|
||||||
|
// the original nameservers directly over a fwmarked socket, avoiding
|
||||||
|
// the ENOKEY deadlock that occurs when an exit-node default route is
|
||||||
|
// installed on the overlay before its peer has handshaked. Scoped to
|
||||||
|
// the mgmt cache only: ordinary user DNS still flows through the
|
||||||
|
// normal upstream path.
|
||||||
|
if s.mgmtCacheResolver != nil {
|
||||||
|
s.mgmtCacheResolver.SetBypassResolver(mgmt.NewBypassResolver(bypassNameservers))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.CustomZone, error) {
|
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.CustomZone, error) {
|
||||||
|
|||||||
Reference in New Issue
Block a user