diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go index a67a23945..3229190f2 100644 --- a/client/internal/dns/local/local.go +++ b/client/internal/dns/local/local.go @@ -77,8 +77,6 @@ func (d *Resolver) ID() types.HandlerID { return "local-resolver" } -func (d *Resolver) ProbeAvailability(context.Context) {} - // ServeDNS handles a DNS request func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { logger := log.WithFields(log.Fields{ diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index 548b1f54f..31fedd9e5 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -9,6 +9,7 @@ import ( dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" ) @@ -70,10 +71,6 @@ func (m *MockServer) SearchDomains() []string { return make([]string, 0) } -// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface -func (m *MockServer) ProbeAvailability() { -} - func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { if m.UpdateServerConfigFunc != nil { return m.UpdateServerConfigFunc(domains) @@ -85,8 +82,8 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error { return nil } -// SetRouteChecker mock implementation of SetRouteChecker from Server interface -func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) { +// SetRouteSources mock implementation of SetRouteSources from Server interface +func (m *MockServer) SetRouteSources(selected, active func() route.HAMap) { // Mock implementation - no-op } diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index d4f54dec5..46e07f98d 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -6,11 +6,9 @@ import ( "fmt" "net/netip" "net/url" - "os" - "runtime" - "strconv" "strings" "sync" + "time" "github.com/miekg/dns" "github.com/mitchellh/hashstructure/v2" @@ -25,11 +23,26 @@ import ( "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/proto" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" ) -const envSkipDNSProbe = "NB_SKIP_DNS_PROBE" +const ( + // healthLookback must exceed the upstream query timeout so one + // query per refresh cycle is enough to keep a group marked healthy. + healthLookback = 60 * time.Second + nsGroupHealthRefreshInterval = 10 * time.Second + // defaultWarningDelayBase is the starting grace window before a + // "Nameserver group unreachable" event fires for a group that's + // never been healthy and only has overlay upstreams with no + // Connected peer. Per-server and overridable; see warningDelayFor. + defaultWarningDelayBase = 30 * time.Second + // warningDelayBonusCap caps the route-count bonus added to the + // base grace window. See warningDelayFor. + warningDelayBonusCap = 30 * time.Second +) // ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes type ReadyListener interface { @@ -54,10 +67,9 @@ type Server interface { UpdateDNSServer(serial uint64, update nbdns.Config) error OnUpdatedHostDNSServer(addrs []netip.AddrPort) SearchDomains() []string - ProbeAvailability() UpdateServerConfig(domains dnsconfig.ServerDomains) error PopulateManagementDomain(mgmtURL *url.URL) error - SetRouteChecker(func(netip.Addr) bool) + SetRouteSources(selected, active func() route.HAMap) SetFirewall(Firewall) } @@ -66,6 +78,47 @@ type nsGroupsByDomain struct { groups []*nbdns.NameServerGroup } +// nsGroupID identifies a nameserver group by the tuple (server list, domain +// list) so config updates produce stable IDs across recomputations. +type nsGroupID string + +// nsHealthSnapshot is the input to projectNSGroupHealth, captured under +// s.mux so projection runs lock-free. +type nsHealthSnapshot struct { + groups []*nbdns.NameServerGroup + merged map[netip.AddrPort]UpstreamHealth + selected route.HAMap + active route.HAMap +} + +// nsGroupProj holds per-group state for the emission rules. +type nsGroupProj struct { + // unhealthySince is the start of the current Unhealthy streak, + // zero when the group is not currently Unhealthy. + unhealthySince time.Time + // everHealthy is sticky: once the group has been Healthy at least + // once this session, subsequent failures skip warningDelay. + everHealthy bool + // warningActive tracks whether we've already published a warning + // for the current streak, so recovery emits iff a warning did. + warningActive bool +} + +// nsGroupVerdict is the outcome of evaluateNSGroupHealth. +type nsGroupVerdict int + +const ( + // nsVerdictUndecided means no upstream has a fresh observation + // (startup before first query, or records aged past healthLookback). + nsVerdictUndecided nsGroupVerdict = iota + // nsVerdictHealthy means at least one upstream's most-recent + // in-lookback observation is a success. + nsVerdictHealthy + // nsVerdictUnhealthy means at least one upstream has a recent + // failure and none has a fresher success. + nsVerdictUnhealthy +) + // hostManagerWithOriginalNS extends the basic hostManager interface type hostManagerWithOriginalNS interface { hostManager @@ -106,20 +159,35 @@ type DefaultServer struct { statusRecorder *peer.Status stateManager *statemanager.Manager - routeMatch func(netip.Addr) bool + // selectedRoutes returns admin-enabled client routes. + selectedRoutes func() route.HAMap + // activeRoutes returns the subset whose peer is in StatusConnected. + activeRoutes func() route.HAMap - probeMu sync.Mutex - probeCancel context.CancelFunc - probeWg sync.WaitGroup + nsGroups []*nbdns.NameServerGroup + healthProjectMu sync.Mutex + // nsGroupProj is the per-group state used by the emission rules. + // Accessed only under healthProjectMu. + nsGroupProj map[nsGroupID]*nsGroupProj + // warningDelayBase is the base grace window for health projection. + // Set at construction, mutated only by tests. Read by the + // refresher goroutine so never change it while one is running. + warningDelayBase time.Duration + // healthRefresh is buffered=1; writers coalesce, senders never block. + // See refreshHealth for the lock-order rationale. + healthRefresh chan struct{} } type handlerWithStop interface { dns.Handler Stop() - ProbeAvailability(context.Context) ID() types.HandlerID } +type upstreamHealthReporter interface { + UpstreamHealth() map[netip.AddrPort]UpstreamHealth +} + type handlerWrapper struct { domain string handler handlerWithStop @@ -230,6 +298,8 @@ func newDefaultServer( hostManager: &noopHostConfigurator{}, mgmtCacheResolver: mgmtCacheResolver, currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied + warningDelayBase: defaultWarningDelayBase, + healthRefresh: make(chan struct{}, 1), } // register with root zone, handler chain takes care of the routing @@ -238,12 +308,13 @@ func newDefaultServer( return defaultServer } -// SetRouteChecker sets the function used by upstream resolvers to determine -// whether an IP is routed through the tunnel. -func (s *DefaultServer) SetRouteChecker(f func(netip.Addr) bool) { +// SetRouteSources wires the route-manager accessors used by health +// projection to classify each upstream for emission timing. +func (s *DefaultServer) SetRouteSources(selected, active func() route.HAMap) { s.mux.Lock() defer s.mux.Unlock() - s.routeMatch = f + s.selectedRoutes = selected + s.activeRoutes = active } // RegisterHandler registers a handler for the given domains with the given priority. @@ -256,7 +327,6 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler // TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain for _, domain := range domains { - // convert to zone with simple ref counter s.extraDomains[toZone(domain)]++ } if !s.batchMode { @@ -357,6 +427,8 @@ func (s *DefaultServer) Initialize() (err error) { s.stateManager.RegisterState(&ShutdownState{}) + s.startHealthRefresher() + // Keep using noop host manager if dns off requested or running in netstack mode. // Netstack mode currently doesn't have a way to receive DNS requests. // TODO: Use listener on localhost in netstack mode when running as root. @@ -394,13 +466,7 @@ func (s *DefaultServer) SetFirewall(fw Firewall) { // Stop stops the server func (s *DefaultServer) Stop() { - s.probeMu.Lock() - if s.probeCancel != nil { - s.probeCancel() - } s.ctxCancel() - s.probeMu.Unlock() - s.probeWg.Wait() s.shutdownWg.Wait() s.mux.Lock() @@ -411,6 +477,13 @@ func (s *DefaultServer) Stop() { } maps.Clear(s.extraDomains) + + // Clear health projection state so a subsequent Start doesn't + // inherit sticky flags (notably everHealthy) that would bypass + // the grace window during the next peer handshake. + s.healthProjectMu.Lock() + s.nsGroupProj = nil + s.healthProjectMu.Unlock() } func (s *DefaultServer) disableDNS() (retErr error) { @@ -446,7 +519,6 @@ func (s *DefaultServer) disableDNS() (retErr error) { func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []netip.AddrPort) { s.hostsDNSHolder.set(hostsDnsList) - // Check if there's any root handler var hasRootHandler bool for _, handler := range s.dnsMuxMap { if handler.domain == nbdns.RootZone { @@ -520,69 +592,6 @@ func (s *DefaultServer) SearchDomains() []string { return searchDomains } -// ProbeAvailability tests each upstream group's servers for availability -// and deactivates the group if no server responds. -// If a previous probe is still running, it will be cancelled before starting a new one. -func (s *DefaultServer) ProbeAvailability() { - if val := os.Getenv(envSkipDNSProbe); val != "" { - skipProbe, err := strconv.ParseBool(val) - if err != nil { - log.Warnf("failed to parse %s: %v", envSkipDNSProbe, err) - } - if skipProbe { - log.Infof("skipping DNS probe due to %s", envSkipDNSProbe) - return - } - } - - s.probeMu.Lock() - - // don't start probes on a stopped server - if s.ctx.Err() != nil { - s.probeMu.Unlock() - return - } - - // cancel any running probe - if s.probeCancel != nil { - s.probeCancel() - s.probeCancel = nil - } - - // wait for the previous probe goroutines to finish while holding - // the mutex so no other caller can start a new probe concurrently - s.probeWg.Wait() - - // start a new probe - probeCtx, probeCancel := context.WithCancel(s.ctx) - s.probeCancel = probeCancel - - s.probeWg.Add(1) - defer s.probeWg.Done() - - // Snapshot handlers under s.mux to avoid racing with updateMux/dnsMuxMap writers. - s.mux.Lock() - handlers := make([]handlerWithStop, 0, len(s.dnsMuxMap)) - for _, mux := range s.dnsMuxMap { - handlers = append(handlers, mux.handler) - } - s.mux.Unlock() - - var wg sync.WaitGroup - for _, handler := range handlers { - wg.Add(1) - go func(h handlerWithStop) { - defer wg.Done() - h.ProbeAvailability(probeCtx) - }(handler) - } - - s.probeMu.Unlock() - - wg.Wait() - probeCancel() -} - func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { s.mux.Lock() defer s.mux.Unlock() @@ -769,25 +778,23 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) { s.wgInterface, s.statusRecorder, s.hostsDNSHolder, - nbdns.RootZone, + domain.Domain(nbdns.RootZone), ) if err != nil { log.Errorf("failed to create upstream resolver for original nameservers: %v", err) return } - handler.routeMatch = s.routeMatch + handler.selectedRoutes = s.selectedRoutes + var servers []netip.AddrPort for _, ns := range originalNameservers { if ns == config.ServerIP { log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, config.ServerIP) continue } - - addrPort := netip.AddrPortFrom(ns, DefaultPort) - handler.upstreamServers = append(handler.upstreamServers, addrPort) + servers = append(servers, netip.AddrPortFrom(ns, DefaultPort)) } - handler.deactivate = func(error) { /* always active */ } - handler.reactivate = func() { /* always active */ } + handler.addRace(servers) s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback) } @@ -847,100 +854,77 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam groupedNS := groupNSGroupsByDomain(nameServerGroups) for _, domainGroup := range groupedNS { - basePriority := PriorityUpstream + priority := PriorityUpstream if domainGroup.domain == nbdns.RootZone { - basePriority = PriorityDefault + priority = PriorityDefault } - updates, err := s.createHandlersForDomainGroup(domainGroup, basePriority) + update, err := s.buildMergedDomainHandler(domainGroup, priority) if err != nil { return nil, err } - muxUpdates = append(muxUpdates, updates...) + if update != nil { + muxUpdates = append(muxUpdates, *update) + } } return muxUpdates, nil } -func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomain, basePriority int) ([]handlerWrapper, error) { - var muxUpdates []handlerWrapper +// buildMergedDomainHandler merges every nameserver group that targets the +// same domain into one handler whose inner groups are raced in parallel. +func (s *DefaultServer) buildMergedDomainHandler(domainGroup nsGroupsByDomain, priority int) (*handlerWrapper, error) { + handler, err := newUpstreamResolver( + s.ctx, + s.wgInterface, + s.statusRecorder, + s.hostsDNSHolder, + domain.Domain(domainGroup.domain), + ) + if err != nil { + return nil, fmt.Errorf("create upstream resolver: %v", err) + } + handler.selectedRoutes = s.selectedRoutes - for i, nsGroup := range domainGroup.groups { - // Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts - priority := basePriority - i - - // Check if we're about to overlap with the next priority tier - if s.leaksPriority(domainGroup, basePriority, priority) { - break - } - - log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority) - handler, err := newUpstreamResolver( - s.ctx, - s.wgInterface, - s.statusRecorder, - s.hostsDNSHolder, - domainGroup.domain, - ) - if err != nil { - return nil, fmt.Errorf("create upstream resolver: %v", err) - } - handler.routeMatch = s.routeMatch - - for _, ns := range nsGroup.NameServers { - if ns.NSType != nbdns.UDPNameServerType { - log.Warnf("skipping nameserver %s with type %s, this peer supports only %s", - ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String()) - continue - } - - if ns.IP == s.service.RuntimeIP() { - log.Warnf("skipping nameserver %s as it matches our DNS server IP, preventing potential loop", ns.IP) - continue - } - - handler.upstreamServers = append(handler.upstreamServers, ns.AddrPort()) - } - - if len(handler.upstreamServers) == 0 { - handler.Stop() - log.Errorf("received a nameserver group with an invalid nameserver list") + for _, nsGroup := range domainGroup.groups { + servers := s.filterNameServers(nsGroup.NameServers) + if len(servers) == 0 { + log.Warnf("nameserver group for domain=%s yielded no usable servers, skipping", domainGroup.domain) continue } - - // when upstream fails to resolve domain several times over all it servers - // it will calls this hook to exclude self from the configuration and - // reapply DNS settings, but it not touch the original configuration and serial number - // because it is temporal deactivation until next try - // - // after some period defined by upstream it tries to reactivate self by calling this hook - // everything we need here is just to re-apply current configuration because it already - // contains this upstream settings (temporal deactivation not removed it) - handler.deactivate, handler.reactivate = s.upstreamCallbacks(nsGroup, handler, priority) - - muxUpdates = append(muxUpdates, handlerWrapper{ - domain: domainGroup.domain, - handler: handler, - priority: priority, - }) + handler.addRace(servers) } - return muxUpdates, nil + if len(handler.upstreamServers) == 0 { + handler.Stop() + log.Errorf("no usable nameservers for domain=%s", domainGroup.domain) + return nil, nil + } + + log.Debugf("creating merged handler for domain=%s with %d group(s) priority=%d", domainGroup.domain, len(handler.upstreamServers), priority) + + return &handlerWrapper{ + domain: domainGroup.domain, + handler: handler, + priority: priority, + }, nil } -func (s *DefaultServer) leaksPriority(domainGroup nsGroupsByDomain, basePriority int, priority int) bool { - if basePriority == PriorityUpstream && priority <= PriorityDefault { - log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers", - domainGroup.domain, PriorityUpstream-PriorityDefault) - return true +func (s *DefaultServer) filterNameServers(nameServers []nbdns.NameServer) []netip.AddrPort { + var out []netip.AddrPort + for _, ns := range nameServers { + if ns.NSType != nbdns.UDPNameServerType { + log.Warnf("skipping nameserver %s with type %s, this peer supports only %s", + ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String()) + continue + } + if ns.IP == s.service.RuntimeIP() { + log.Warnf("skipping nameserver %s as it matches our DNS server IP, preventing potential loop", ns.IP) + continue + } + out = append(out, ns.AddrPort()) } - if basePriority == PriorityDefault && priority <= PriorityFallback { - log.Warnf("too many handlers for domain=%s, would overlap with fallback priority tier (diff=%d). Skipping remaining handlers", - domainGroup.domain, PriorityDefault-PriorityFallback) - return true - } - - return false + return out } func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { @@ -974,84 +958,6 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { s.dnsMuxMap = muxUpdateMap } -// upstreamCallbacks returns two functions, the first one is used to deactivate -// the upstream resolver from the configuration, the second one is used to -// reactivate it. Not allowed to call reactivate before deactivate. -func (s *DefaultServer) upstreamCallbacks( - nsGroup *nbdns.NameServerGroup, - handler dns.Handler, - priority int, -) (deactivate func(error), reactivate func()) { - var removeIndex map[string]int - deactivate = func(err error) { - s.mux.Lock() - defer s.mux.Unlock() - - l := log.WithField("nameservers", nsGroup.NameServers) - l.Info("Temporarily deactivating nameservers group due to timeout") - - removeIndex = make(map[string]int) - for _, domain := range nsGroup.Domains { - removeIndex[domain] = -1 - } - if nsGroup.Primary { - removeIndex[nbdns.RootZone] = -1 - s.currentConfig.RouteAll = false - s.deregisterHandler([]string{nbdns.RootZone}, priority) - } - - for i, item := range s.currentConfig.Domains { - if _, found := removeIndex[item.Domain]; found { - s.currentConfig.Domains[i].Disabled = true - s.deregisterHandler([]string{item.Domain}, priority) - removeIndex[item.Domain] = i - } - } - - // Always apply host config when nameserver goes down, regardless of batch mode - s.applyHostConfig() - - go func() { - if err := s.stateManager.PersistState(s.ctx); err != nil { - l.Errorf("Failed to persist dns state: %v", err) - } - }() - - if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { - s.addHostRootZone() - } - - s.updateNSState(nsGroup, err, false) - } - - reactivate = func() { - s.mux.Lock() - defer s.mux.Unlock() - - for domain, i := range removeIndex { - if i == -1 || i >= len(s.currentConfig.Domains) || s.currentConfig.Domains[i].Domain != domain { - continue - } - s.currentConfig.Domains[i].Disabled = false - s.registerHandler([]string{domain}, handler, priority) - } - - l := log.WithField("nameservers", nsGroup.NameServers) - l.Debug("reactivate temporary disabled nameserver group") - - if nsGroup.Primary { - s.currentConfig.RouteAll = true - s.registerHandler([]string{nbdns.RootZone}, handler, priority) - } - - // Always apply host config when nameserver reactivates, regardless of batch mode - s.applyHostConfig() - - s.updateNSState(nsGroup, nil, true) - } - return -} - func (s *DefaultServer) addHostRootZone() { hostDNSServers := s.hostsDNSHolder.get() if len(hostDNSServers) == 0 { @@ -1070,56 +976,343 @@ func (s *DefaultServer) addHostRootZone() { log.Errorf("unable to create a new upstream resolver, error: %v", err) return } - handler.routeMatch = s.routeMatch + handler.selectedRoutes = s.selectedRoutes - handler.upstreamServers = maps.Keys(hostDNSServers) - handler.deactivate = func(error) {} - handler.reactivate = func() {} + handler.addRace(maps.Keys(hostDNSServers)) s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault) } +// updateNSGroupStates records the new group set and pokes the refresher. +// Must hold s.mux; projection runs async (see refreshHealth for why). func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) { - var states []peer.NSGroupState + s.nsGroups = groups + select { + case s.healthRefresh <- struct{}{}: + default: + } +} - for _, group := range groups { - var servers []netip.AddrPort - for _, ns := range group.NameServers { - servers = append(servers, ns.AddrPort()) +// refreshHealth runs one projection cycle. Must not be called while +// holding s.mux: the route callbacks re-enter routemanager's lock. +func (s *DefaultServer) refreshHealth() { + s.mux.Lock() + groups := s.nsGroups + merged := s.collectUpstreamHealth() + selFn := s.selectedRoutes + actFn := s.activeRoutes + s.mux.Unlock() + + var selected, active route.HAMap + if selFn != nil { + selected = selFn() + } + if actFn != nil { + active = actFn() + } + + s.projectNSGroupHealth(nsHealthSnapshot{ + groups: groups, + merged: merged, + selected: selected, + active: active, + }) +} + +// projectNSGroupHealth applies the emission rules to the snapshot and +// publishes the resulting NSGroupStates. Serialized by healthProjectMu, +// lock-free wrt s.mux. +// +// Rules: +// - Healthy: emit recovery iff warningActive; set everHealthy. +// - Unhealthy: stamp unhealthySince on streak start; emit warning +// iff any of immediate / everHealthy / elapsed >= effective delay. +// - Undecided: no-op. +// +// "Immediate" means the group has at least one upstream that's public +// or overlay+Connected: no peer-startup race to wait out. +func (s *DefaultServer) projectNSGroupHealth(snap nsHealthSnapshot) { + if s.statusRecorder == nil { + return + } + + s.healthProjectMu.Lock() + defer s.healthProjectMu.Unlock() + + if s.nsGroupProj == nil { + s.nsGroupProj = make(map[nsGroupID]*nsGroupProj) + } + + now := time.Now() + delay := s.warningDelay(len(snap.selected)) + states := make([]peer.NSGroupState, 0, len(snap.groups)) + seen := make(map[nsGroupID]struct{}, len(snap.groups)) + for _, group := range snap.groups { + servers := nameServerAddrPorts(group.NameServers) + verdict, groupErr := evaluateNSGroupHealth(snap.merged, servers, now) + id := generateGroupKey(group) + seen[id] = struct{}{} + + immediate := s.groupHasImmediateUpstream(servers, snap) + + p, known := s.nsGroupProj[id] + if !known { + p = &nsGroupProj{} + s.nsGroupProj[id] = p } - state := peer.NSGroupState{ - ID: generateGroupKey(group), + enabled := true + switch verdict { + case nsVerdictHealthy: + enabled = s.projectHealthy(p, servers) + case nsVerdictUnhealthy: + enabled = s.projectUnhealthy(p, servers, immediate, now, delay) + case nsVerdictUndecided: + // Stay Available until evidence says otherwise, unless a + // warning is already active for this group. + enabled = !p.warningActive + groupErr = nil + } + + states = append(states, peer.NSGroupState{ + ID: string(id), Servers: servers, Domains: group.Domains, - // The probe will determine the state, default enabled - Enabled: true, - Error: nil, - } - states = append(states, state) + Enabled: enabled, + Error: groupErr, + }) } - s.statusRecorder.UpdateDNSStates(states) -} - -func (s *DefaultServer) updateNSState(nsGroup *nbdns.NameServerGroup, err error, enabled bool) { - states := s.statusRecorder.GetDNSStates() - id := generateGroupKey(nsGroup) - for i, state := range states { - if state.ID == id { - states[i].Enabled = enabled - states[i].Error = err - break + for id := range s.nsGroupProj { + if _, ok := seen[id]; !ok { + delete(s.nsGroupProj, id) } } s.statusRecorder.UpdateDNSStates(states) } -func generateGroupKey(nsGroup *nbdns.NameServerGroup) string { +// projectHealthy records a healthy tick on p and publishes a recovery +// event iff a warning was active for the current streak. Returns the +// Enabled flag to record in NSGroupState. +func (s *DefaultServer) projectHealthy(p *nsGroupProj, servers []netip.AddrPort) bool { + p.everHealthy = true + p.unhealthySince = time.Time{} + if !p.warningActive { + return true + } + log.Debugf("DNS health: group [%s] recovered, emitting event", joinAddrPorts(servers)) + s.statusRecorder.PublishEvent( + proto.SystemEvent_INFO, + proto.SystemEvent_DNS, + "Nameserver group recovered", + "DNS servers are reachable again.", + map[string]string{"upstreams": joinAddrPorts(servers)}, + ) + p.warningActive = false + return true +} + +// projectUnhealthy records an unhealthy tick on p, publishes the +// warning when the emission rules fire, and returns the Enabled flag +// to record in NSGroupState. +func (s *DefaultServer) projectUnhealthy(p *nsGroupProj, servers []netip.AddrPort, immediate bool, now time.Time, delay time.Duration) bool { + streakStart := p.unhealthySince.IsZero() + if streakStart { + p.unhealthySince = now + } + reason := unhealthyEmitReason(immediate, p.everHealthy, now.Sub(p.unhealthySince), delay) + switch { + case reason != "" && !p.warningActive: + log.Debugf("DNS health: group [%s] unreachable, emitting event (reason=%s)", joinAddrPorts(servers), reason) + s.statusRecorder.PublishEvent( + proto.SystemEvent_WARNING, + proto.SystemEvent_DNS, + "Nameserver group unreachable", + "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", + map[string]string{"upstreams": joinAddrPorts(servers)}, + ) + p.warningActive = true + case streakStart && reason == "": + // One line per streak, not per tick. + log.Debugf("DNS health: group [%s] unreachable but holding warning for up to %v (overlay-routed, no connected peer)", joinAddrPorts(servers), delay) + } + return false +} + +// warningDelay returns the grace window for the given selected-route +// count. Scales gently: +1s per 100 routes, capped by +// warningDelayBonusCap. Parallel handshakes mean handshake time grows +// much slower than route count, so linear scaling would overcorrect. +func (s *DefaultServer) warningDelay(routeCount int) time.Duration { + bonus := time.Duration(routeCount/100) * time.Second + if bonus > warningDelayBonusCap { + bonus = warningDelayBonusCap + } + return s.warningDelayBase + bonus +} + +// groupHasImmediateUpstream reports whether the group has at least one +// upstream in a classification that bypasses the grace window: public +// (outside the overlay range and not routed), or overlay/routed with a +// Connected peer. +// +// TODO(ipv6): include the v6 overlay prefix once it's plumbed in. +func (s *DefaultServer) groupHasImmediateUpstream(servers []netip.AddrPort, snap nsHealthSnapshot) bool { + var overlayV4 netip.Prefix + if s.wgInterface != nil { + overlayV4 = s.wgInterface.Address().Network + } + for _, srv := range servers { + addr := srv.Addr().Unmap() + overlay := overlayV4.IsValid() && overlayV4.Contains(addr) + routed := haMapContains(snap.selected, addr) + if !overlay && !routed { + return true + } + if haMapContains(snap.active, addr) { + return true + } + } + return false +} + +// collectUpstreamHealth merges health snapshots across handlers, keeping +// the most recent success and failure per upstream when an address appears +// in more than one handler. +func (s *DefaultServer) collectUpstreamHealth() map[netip.AddrPort]UpstreamHealth { + merged := make(map[netip.AddrPort]UpstreamHealth) + for _, entry := range s.dnsMuxMap { + reporter, ok := entry.handler.(upstreamHealthReporter) + if !ok { + continue + } + for addr, h := range reporter.UpstreamHealth() { + existing, have := merged[addr] + if !have { + merged[addr] = h + continue + } + if h.LastOk.After(existing.LastOk) { + existing.LastOk = h.LastOk + } + if h.LastFail.After(existing.LastFail) { + existing.LastFail = h.LastFail + existing.LastErr = h.LastErr + } + merged[addr] = existing + } + } + return merged +} + +func (s *DefaultServer) startHealthRefresher() { + s.shutdownWg.Add(1) + go func() { + defer s.shutdownWg.Done() + ticker := time.NewTicker(nsGroupHealthRefreshInterval) + defer ticker.Stop() + for { + select { + case <-s.ctx.Done(): + return + case <-ticker.C: + case <-s.healthRefresh: + } + s.refreshHealth() + } + }() +} + +// evaluateNSGroupHealth decides a group's verdict from query records +// alone. Per upstream, the most-recent-in-lookback observation wins. +// Group is Healthy if any upstream is fresh-working, Unhealthy if any +// is fresh-broken with no fresh-working sibling, Undecided otherwise. +func evaluateNSGroupHealth(merged map[netip.AddrPort]UpstreamHealth, servers []netip.AddrPort, now time.Time) (nsGroupVerdict, error) { + anyWorking := false + anyBroken := false + var mostRecentFail time.Time + var mostRecentErr string + + for _, srv := range servers { + h, ok := merged[srv] + if !ok { + continue + } + switch classifyUpstreamHealth(h, now) { + case upstreamFresh: + anyWorking = true + case upstreamBroken: + anyBroken = true + if h.LastFail.After(mostRecentFail) { + mostRecentFail = h.LastFail + mostRecentErr = h.LastErr + } + } + } + + if anyWorking { + return nsVerdictHealthy, nil + } + if anyBroken { + if mostRecentErr == "" { + return nsVerdictUnhealthy, nil + } + return nsVerdictUnhealthy, errors.New(mostRecentErr) + } + return nsVerdictUndecided, nil +} + +// upstreamClassification is the per-upstream verdict within healthLookback. +type upstreamClassification int + +const ( + upstreamStale upstreamClassification = iota + upstreamFresh + upstreamBroken +) + +// classifyUpstreamHealth compares the last ok and last fail timestamps +// against healthLookback and returns which one (if any) counts. Fresh +// wins when both are in-window and ok is newer; broken otherwise. +func classifyUpstreamHealth(h UpstreamHealth, now time.Time) upstreamClassification { + okRecent := !h.LastOk.IsZero() && now.Sub(h.LastOk) <= healthLookback + failRecent := !h.LastFail.IsZero() && now.Sub(h.LastFail) <= healthLookback + switch { + case okRecent && failRecent: + if h.LastOk.After(h.LastFail) { + return upstreamFresh + } + return upstreamBroken + case okRecent: + return upstreamFresh + case failRecent: + return upstreamBroken + } + return upstreamStale +} + +// nameServerAddrPorts flattens a NameServer list to AddrPorts. +func nameServerAddrPorts(ns []nbdns.NameServer) []netip.AddrPort { + out := make([]netip.AddrPort, 0, len(ns)) + for _, n := range ns { + out = append(out, n.AddrPort()) + } + return out +} + +func joinAddrPorts(servers []netip.AddrPort) string { + parts := make([]string, 0, len(servers)) + for _, s := range servers { + parts = append(parts, s.String()) + } + return strings.Join(parts, ", ") +} + +func generateGroupKey(nsGroup *nbdns.NameServerGroup) nsGroupID { var servers []string for _, ns := range nsGroup.NameServers { servers = append(servers, ns.AddrPort().String()) } - return fmt.Sprintf("%v_%v", servers, nsGroup.Domains) + return nsGroupID(fmt.Sprintf("%v_%v", servers, nsGroup.Domains)) } // groupNSGroupsByDomain groups nameserver groups by their match domains @@ -1161,6 +1354,21 @@ func toZone(d domain.Domain) domain.Domain { ) } +// unhealthyEmitReason returns the tag of the rule that fires the +// warning now, or "" if the group is still inside its grace window. +func unhealthyEmitReason(immediate, everHealthy bool, elapsed, delay time.Duration) string { + switch { + case immediate: + return "immediate" + case everHealthy: + return "ever-healthy" + case elapsed >= delay: + return "grace-elapsed" + default: + return "" + } +} + // PopulateManagementDomain populates the DNS cache with management domain func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error { if s.mgmtCacheResolver != nil { diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index f77f6e898..7b596d3fd 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -6,7 +6,6 @@ import ( "net" "net/netip" "os" - "strings" "testing" "time" @@ -15,6 +14,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -31,8 +31,10 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/netbird/client/proto" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/formatter" + "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" ) @@ -101,16 +103,17 @@ func init() { formatter.SetTextFormatter(log.StandardLogger()) } -func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase { +func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase { var srvs []netip.AddrPort for _, srv := range servers { srvs = append(srvs, srv.AddrPort()) } - return &upstreamResolverBase{ - domain: domain, - upstreamServers: srvs, - cancel: func() {}, + u := &upstreamResolverBase{ + domain: domain.Domain(d), + cancel: func() {}, } + u.addRace(srvs) + return u } func TestUpdateDNSServer(t *testing.T) { @@ -653,73 +656,6 @@ func TestDNSServerStartStop(t *testing.T) { } } -func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { - hostManager := &mockHostConfigurator{} - server := DefaultServer{ - ctx: context.Background(), - service: NewServiceViaMemory(&mocWGIface{}), - localResolver: local.NewResolver(), - handlerChain: NewHandlerChain(), - hostManager: hostManager, - currentConfig: HostDNSConfig{ - Domains: []DomainConfig{ - {false, "domain0", false}, - {false, "domain1", false}, - {false, "domain2", false}, - }, - }, - statusRecorder: peer.NewRecorder("mgm"), - } - - var domainsUpdate string - hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error { - domains := []string{} - for _, item := range config.Domains { - if item.Disabled { - continue - } - domains = append(domains, item.Domain) - } - domainsUpdate = strings.Join(domains, ",") - return nil - } - - deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{ - Domains: []string{"domain1"}, - NameServers: []nbdns.NameServer{ - {IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53}, - }, - }, nil, 0) - - deactivate(nil) - expected := "domain0,domain2" - domains := []string{} - for _, item := range server.currentConfig.Domains { - if item.Disabled { - continue - } - domains = append(domains, item.Domain) - } - got := strings.Join(domains, ",") - if expected != got { - t.Errorf("expected domains list: %q, got %q", expected, got) - } - - reactivate() - expected = "domain0,domain1,domain2" - domains = []string{} - for _, item := range server.currentConfig.Domains { - if item.Disabled { - continue - } - domains = append(domains, item.Domain) - } - got = strings.Join(domains, ",") - if expected != got { - t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate) - } -} - func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) { wgIFace, err := createWgInterfaceWithBind(t) if err != nil { @@ -2085,6 +2021,598 @@ func TestLocalResolverPriorityConstants(t *testing.T) { assert.Equal(t, "local.example.com", localMuxUpdates[0].domain) } +// TestBuildUpstreamHandler_MergesGroupsPerDomain verifies that multiple +// admin-defined nameserver groups targeting the same domain collapse into a +// single handler with each group preserved as a sequential inner list. +func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) { + wgInterface := &mocWGIface{} + service := NewServiceViaMemory(wgInterface) + server := &DefaultServer{ + ctx: context.Background(), + wgInterface: wgInterface, + service: service, + localResolver: local.NewResolver(), + handlerChain: NewHandlerChain(), + hostManager: &noopHostConfigurator{}, + dnsMuxMap: make(registeredHandlerMap), + } + + groups := []*nbdns.NameServerGroup{ + { + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("192.0.2.1"), NSType: nbdns.UDPNameServerType, Port: 53}, + }, + Domains: []string{"example.com"}, + }, + { + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("192.0.2.2"), NSType: nbdns.UDPNameServerType, Port: 53}, + {IP: netip.MustParseAddr("192.0.2.3"), NSType: nbdns.UDPNameServerType, Port: 53}, + }, + Domains: []string{"example.com"}, + }, + } + + muxUpdates, err := server.buildUpstreamHandlerUpdate(groups) + require.NoError(t, err) + require.Len(t, muxUpdates, 1, "same-domain groups should merge into one handler") + assert.Equal(t, "example.com", muxUpdates[0].domain) + assert.Equal(t, PriorityUpstream, muxUpdates[0].priority) + + handler := muxUpdates[0].handler.(*upstreamResolver) + require.Len(t, handler.upstreamServers, 2, "handler should have two groups") + assert.Equal(t, upstreamRace{netip.MustParseAddrPort("192.0.2.1:53")}, handler.upstreamServers[0]) + assert.Equal(t, upstreamRace{ + netip.MustParseAddrPort("192.0.2.2:53"), + netip.MustParseAddrPort("192.0.2.3:53"), + }, handler.upstreamServers[1]) +} + +// TestEvaluateNSGroupHealth covers the records-only verdict. The gate +// (overlay route selected-but-no-active-peer) is intentionally NOT an +// input to the evaluator anymore: the verdict drives the Enabled flag, +// which must always reflect what we actually observed. Gate-aware event +// suppression is tested separately in the projection test. +// +// Matrix per upstream: {no record, fresh Ok, fresh Fail, stale Fail, +// stale Ok, Ok newer than Fail, Fail newer than Ok}. +// Group verdict: any fresh-working → Healthy; any fresh-broken with no +// fresh-working → Unhealthy; otherwise Undecided. +func TestEvaluateNSGroupHealth(t *testing.T) { + now := time.Now() + a := netip.MustParseAddrPort("192.0.2.1:53") + b := netip.MustParseAddrPort("192.0.2.2:53") + + recentOk := UpstreamHealth{LastOk: now.Add(-2 * time.Second)} + recentFail := UpstreamHealth{LastFail: now.Add(-1 * time.Second), LastErr: "timeout"} + staleOk := UpstreamHealth{LastOk: now.Add(-10 * time.Minute)} + staleFail := UpstreamHealth{LastFail: now.Add(-10 * time.Minute), LastErr: "timeout"} + okThenFail := UpstreamHealth{ + LastOk: now.Add(-10 * time.Second), + LastFail: now.Add(-1 * time.Second), + LastErr: "timeout", + } + failThenOk := UpstreamHealth{ + LastOk: now.Add(-1 * time.Second), + LastFail: now.Add(-10 * time.Second), + LastErr: "timeout", + } + + tests := []struct { + name string + health map[netip.AddrPort]UpstreamHealth + servers []netip.AddrPort + wantVerdict nsGroupVerdict + wantErrSubst string + }{ + { + name: "no record, undecided", + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictUndecided, + }, + { + name: "fresh success, healthy", + health: map[netip.AddrPort]UpstreamHealth{a: recentOk}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictHealthy, + }, + { + name: "fresh failure, unhealthy", + health: map[netip.AddrPort]UpstreamHealth{a: recentFail}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictUnhealthy, + wantErrSubst: "timeout", + }, + { + name: "only stale success, undecided", + health: map[netip.AddrPort]UpstreamHealth{a: staleOk}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictUndecided, + }, + { + name: "only stale failure, undecided", + health: map[netip.AddrPort]UpstreamHealth{a: staleFail}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictUndecided, + }, + { + name: "both fresh, fail newer, unhealthy", + health: map[netip.AddrPort]UpstreamHealth{a: okThenFail}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictUnhealthy, + wantErrSubst: "timeout", + }, + { + name: "both fresh, ok newer, healthy", + health: map[netip.AddrPort]UpstreamHealth{a: failThenOk}, + servers: []netip.AddrPort{a}, + wantVerdict: nsVerdictHealthy, + }, + { + name: "two upstreams, one success wins", + health: map[netip.AddrPort]UpstreamHealth{ + a: recentFail, + b: recentOk, + }, + servers: []netip.AddrPort{a, b}, + wantVerdict: nsVerdictHealthy, + }, + { + name: "two upstreams, one fail one unseen, unhealthy", + health: map[netip.AddrPort]UpstreamHealth{ + a: recentFail, + }, + servers: []netip.AddrPort{a, b}, + wantVerdict: nsVerdictUnhealthy, + wantErrSubst: "timeout", + }, + { + name: "two upstreams, all recent failures, unhealthy", + health: map[netip.AddrPort]UpstreamHealth{ + a: {LastFail: now.Add(-5 * time.Second), LastErr: "timeout"}, + b: {LastFail: now.Add(-1 * time.Second), LastErr: "SERVFAIL"}, + }, + servers: []netip.AddrPort{a, b}, + wantVerdict: nsVerdictUnhealthy, + wantErrSubst: "SERVFAIL", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + verdict, err := evaluateNSGroupHealth(tc.health, tc.servers, now) + assert.Equal(t, tc.wantVerdict, verdict, "verdict mismatch") + if tc.wantErrSubst != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErrSubst) + } else { + assert.NoError(t, err) + } + }) + } +} + +// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed +// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates +// without spinning up real handlers. +type healthStubHandler struct { + health map[netip.AddrPort]UpstreamHealth +} + +func (h *healthStubHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {} +func (h *healthStubHandler) Stop() {} +func (h *healthStubHandler) ID() types.HandlerID { return "health-stub" } +func (h *healthStubHandler) UpstreamHealth() map[netip.AddrPort]UpstreamHealth { + return h.health +} + +// TestProjection_SteadyStateIsSilent guards against duplicate events: +// while a group stays Unhealthy tick after tick, only the first +// Unhealthy transition may emit. Same for staying Healthy. +func TestProjection_SteadyStateIsSilent(t *testing.T) { + fx := newProjTestFixture(t) + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "first fail emits warning") + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.tick() + fx.expectNoEvent("staying unhealthy must not re-emit") + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + fx.expectEvent("recovered", "recovery on transition") + + fx.tick() + fx.tick() + fx.expectNoEvent("staying healthy must not re-emit") +} + +// projTestFixture is the common setup for the projection tests: a +// single-upstream group whose route classification the test can flip by +// assigning to selected/active. Callers drive failures/successes by +// mutating stub.health and calling refreshHealth. +type projTestFixture struct { + t *testing.T + recorder *peer.Status + events <-chan *proto.SystemEvent + server *DefaultServer + stub *healthStubHandler + group *nbdns.NameServerGroup + srv netip.AddrPort + selected route.HAMap + active route.HAMap +} + +func newProjTestFixture(t *testing.T) *projTestFixture { + t.Helper() + recorder := peer.NewRecorder("mgm") + sub := recorder.SubscribeToEvents() + t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) }) + + srv := netip.MustParseAddrPort("100.64.0.1:53") + fx := &projTestFixture{ + t: t, + recorder: recorder, + events: sub.Events(), + stub: &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{}}, + srv: srv, + group: &nbdns.NameServerGroup{ + Domains: []string{"example.com"}, + NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}}, + }, + } + fx.server = &DefaultServer{ + ctx: context.Background(), + wgInterface: &mocWGIface{}, + statusRecorder: recorder, + dnsMuxMap: make(registeredHandlerMap), + selectedRoutes: func() route.HAMap { return fx.selected }, + activeRoutes: func() route.HAMap { return fx.active }, + warningDelayBase: defaultWarningDelayBase, + } + fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream} + + fx.server.mux.Lock() + fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group}) + fx.server.mux.Unlock() + return fx +} + +func (f *projTestFixture) setHealth(h UpstreamHealth) { + f.stub.health = map[netip.AddrPort]UpstreamHealth{f.srv: h} +} + +func (f *projTestFixture) tick() []peer.NSGroupState { + f.server.refreshHealth() + return f.recorder.GetDNSStates() +} + +func (f *projTestFixture) expectNoEvent(why string) { + f.t.Helper() + select { + case evt := <-f.events: + f.t.Fatalf("unexpected event (%s): %+v", why, evt) + case <-time.After(100 * time.Millisecond): + } +} + +func (f *projTestFixture) expectEvent(substr, why string) *proto.SystemEvent { + f.t.Helper() + select { + case evt := <-f.events: + assert.Contains(f.t, evt.Message, substr, why) + return evt + case <-time.After(time.Second): + f.t.Fatalf("expected event (%s) with %q", why, substr) + return nil + } +} + +var overlayNetForTest = netip.MustParsePrefix("100.64.0.0/16") +var overlayMapForTest = route.HAMap{"overlay": {{Network: overlayNetForTest}}} + +// TestProjection_PublicFailEmitsImmediately covers rule 1: an upstream +// that is not inside any selected route (public DNS) fires the warning +// on the first Unhealthy tick, no grace period. +func TestProjection_PublicFailEmitsImmediately(t *testing.T) { + fx := newProjTestFixture(t) + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + states := fx.tick() + require.Len(t, states, 1) + assert.False(t, states[0].Enabled) + fx.expectEvent("unreachable", "public DNS failure") +} + +// TestProjection_OverlayConnectedFailEmitsImmediately covers rule 2: +// the upstream is inside a selected route AND the route has a Connected +// peer. Tunnel is up, failure is real, emit immediately. +func TestProjection_OverlayConnectedFailEmitsImmediately(t *testing.T) { + fx := newProjTestFixture(t) + fx.selected = overlayMapForTest + fx.active = overlayMapForTest + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + states := fx.tick() + require.Len(t, states, 1) + assert.False(t, states[0].Enabled) + fx.expectEvent("unreachable", "overlay + connected failure") +} + +// TestProjection_OverlayNotConnectedDelaysWarning covers rule 3: the +// upstream is routed but no peer is Connected (Connecting/Idle/missing). +// First tick: Unhealthy display, no warning. After the grace window +// elapses with no recovery, the warning fires. +func TestProjection_OverlayNotConnectedDelaysWarning(t *testing.T) { + grace := 50 * time.Millisecond + fx := newProjTestFixture(t) + fx.server.warningDelayBase = grace + fx.selected = overlayMapForTest + // active stays nil: routed but not connected. + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + states := fx.tick() + require.Len(t, states, 1) + assert.False(t, states[0].Enabled, "display must reflect failure even during grace window") + fx.expectNoEvent("first fail tick within grace window") + + time.Sleep(grace + 10*time.Millisecond) + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "warning after grace window") +} + +// TestProjection_OverlayAddrNoRouteDelaysWarning covers an upstream +// whose address is inside the WireGuard overlay range but is not +// covered by any selected route (peer-to-peer DNS without an explicit +// route). Until a peer reports Connected for that address, startup +// failures must be held just like the routed case. +func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) { + recorder := peer.NewRecorder("mgm") + sub := recorder.SubscribeToEvents() + t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) }) + + overlayPeer := netip.MustParseAddrPort("100.66.100.5:53") + server := &DefaultServer{ + ctx: context.Background(), + wgInterface: &mocWGIface{}, + statusRecorder: recorder, + dnsMuxMap: make(registeredHandlerMap), + selectedRoutes: func() route.HAMap { return nil }, + activeRoutes: func() route.HAMap { return nil }, + warningDelayBase: 50 * time.Millisecond, + } + group := &nbdns.NameServerGroup{ + Domains: []string{"example.com"}, + NameServers: []nbdns.NameServer{{IP: overlayPeer.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlayPeer.Port())}}, + } + stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{ + overlayPeer: {LastFail: time.Now(), LastErr: "timeout"}, + }} + server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream} + + server.mux.Lock() + server.updateNSGroupStates([]*nbdns.NameServerGroup{group}) + server.mux.Unlock() + server.refreshHealth() + + select { + case evt := <-sub.Events(): + t.Fatalf("unexpected event during grace window: %+v", evt) + case <-time.After(100 * time.Millisecond): + } + + time.Sleep(60 * time.Millisecond) + stub.health = map[netip.AddrPort]UpstreamHealth{overlayPeer: {LastFail: time.Now(), LastErr: "timeout"}} + server.refreshHealth() + + select { + case evt := <-sub.Events(): + assert.Contains(t, evt.Message, "unreachable") + case <-time.After(time.Second): + t.Fatal("expected warning after grace window") + } +} + +// TestProjection_StopClearsHealthState verifies that Stop wipes the +// per-group projection state so a subsequent Start doesn't inherit +// sticky flags (notably everHealthy) that would bypass the grace +// window during the next peer handshake. +func TestProjection_StopClearsHealthState(t *testing.T) { + wgIface := &mocWGIface{} + server := &DefaultServer{ + ctx: context.Background(), + wgInterface: wgIface, + service: NewServiceViaMemory(wgIface), + hostManager: &noopHostConfigurator{}, + extraDomains: map[domain.Domain]int{}, + dnsMuxMap: make(registeredHandlerMap), + statusRecorder: peer.NewRecorder("mgm"), + selectedRoutes: func() route.HAMap { return nil }, + activeRoutes: func() route.HAMap { return nil }, + warningDelayBase: defaultWarningDelayBase, + currentConfigHash: ^uint64(0), + } + server.ctx, server.ctxCancel = context.WithCancel(context.Background()) + + srv := netip.MustParseAddrPort("8.8.8.8:53") + group := &nbdns.NameServerGroup{ + Domains: []string{"example.com"}, + NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}}, + } + stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}} + server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream} + + server.mux.Lock() + server.updateNSGroupStates([]*nbdns.NameServerGroup{group}) + server.mux.Unlock() + server.refreshHealth() + + server.healthProjectMu.Lock() + p, ok := server.nsGroupProj[generateGroupKey(group)] + server.healthProjectMu.Unlock() + require.True(t, ok, "projection state should exist after tick") + require.True(t, p.everHealthy, "tick with success must set everHealthy") + + server.Stop() + + server.healthProjectMu.Lock() + cleared := server.nsGroupProj == nil + server.healthProjectMu.Unlock() + assert.True(t, cleared, "Stop must clear nsGroupProj") +} + +// TestProjection_OverlayRecoversDuringGrace covers the happy path of +// rule 3: startup failures while the peer is handshaking, then the peer +// comes up and a query succeeds before the grace window elapses. No +// warning should ever have fired, and no recovery either. +func TestProjection_OverlayRecoversDuringGrace(t *testing.T) { + fx := newProjTestFixture(t) + fx.server.warningDelayBase = 200 * time.Millisecond + fx.selected = overlayMapForTest + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectNoEvent("fail within grace, warning suppressed") + + fx.active = overlayMapForTest + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + states := fx.tick() + require.Len(t, states, 1) + assert.True(t, states[0].Enabled) + fx.expectNoEvent("recovery without prior warning must not emit") +} + +// TestProjection_RecoveryOnlyAfterWarning enforces the invariant the +// whole design leans on: recovery events only appear when a warning +// event was actually emitted for the current streak. A Healthy verdict +// without a prior warning is silent, so the user never sees "recovered" +// out of thin air. +func TestProjection_RecoveryOnlyAfterWarning(t *testing.T) { + fx := newProjTestFixture(t) + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + states := fx.tick() + require.Len(t, states, 1) + assert.True(t, states[0].Enabled) + fx.expectNoEvent("first healthy tick should not recover anything") + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "public fail emits immediately") + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + fx.expectEvent("recovered", "recovery follows real warning") + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "second cycle warning") + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + fx.expectEvent("recovered", "second cycle recovery") +} + +// TestProjection_EverHealthyOverridesDelay covers rule 4: once a group +// has ever been Healthy, subsequent failures skip the grace window even +// if classification says "routed + not connected". The system has +// proved it can work, so any new failure is real. +func TestProjection_EverHealthyOverridesDelay(t *testing.T) { + fx := newProjTestFixture(t) + // Large base so any emission must come from the everHealthy bypass, not elapsed time. + fx.server.warningDelayBase = time.Hour + fx.selected = overlayMapForTest + fx.active = overlayMapForTest + + // Establish "ever healthy". + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + fx.expectNoEvent("first healthy tick") + + // Peer drops. Query fails. Routed + not connected → normally grace, + // but everHealthy flag bypasses it. + fx.active = nil + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "failure after ever-healthy must be immediate") +} + +// TestProjection_ReconnectBlipEmitsPair covers the explicit tradeoff +// from the design discussion: once a group has been healthy, a brief +// reconnect that produces a failing tick will fire warning + recovery. +// This is by design: user-visible blips are accurate signal, not noise. +func TestProjection_ReconnectBlipEmitsPair(t *testing.T) { + fx := newProjTestFixture(t) + fx.selected = overlayMapForTest + fx.active = overlayMapForTest + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + + fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"}) + fx.tick() + fx.expectEvent("unreachable", "blip warning") + + fx.setHealth(UpstreamHealth{LastOk: time.Now()}) + fx.tick() + fx.expectEvent("recovered", "blip recovery") +} + +// TestProjection_MixedGroupEmitsImmediately covers the multi-upstream +// rule: a group with at least one public upstream is in the "immediate" +// category regardless of the other upstreams' routing, because the +// public one has no peer-startup excuse. Prevents public-DNS failures +// from being hidden behind a routed sibling. +func TestProjection_MixedGroupEmitsImmediately(t *testing.T) { + recorder := peer.NewRecorder("mgm") + sub := recorder.SubscribeToEvents() + t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) }) + events := sub.Events() + + public := netip.MustParseAddrPort("8.8.8.8:53") + overlay := netip.MustParseAddrPort("100.64.0.1:53") + overlayMap := route.HAMap{"overlay": {{Network: netip.MustParsePrefix("100.64.0.0/16")}}} + + server := &DefaultServer{ + ctx: context.Background(), + statusRecorder: recorder, + dnsMuxMap: make(registeredHandlerMap), + selectedRoutes: func() route.HAMap { return overlayMap }, + activeRoutes: func() route.HAMap { return nil }, + warningDelayBase: time.Hour, + } + group := &nbdns.NameServerGroup{ + Domains: []string{"example.com"}, + NameServers: []nbdns.NameServer{ + {IP: public.Addr(), NSType: nbdns.UDPNameServerType, Port: int(public.Port())}, + {IP: overlay.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlay.Port())}, + }, + } + stub := &healthStubHandler{ + health: map[netip.AddrPort]UpstreamHealth{ + public: {LastFail: time.Now(), LastErr: "servfail"}, + overlay: {LastFail: time.Now(), LastErr: "timeout"}, + }, + } + server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream} + + server.mux.Lock() + server.updateNSGroupStates([]*nbdns.NameServerGroup{group}) + server.mux.Unlock() + server.refreshHealth() + + select { + case evt := <-events: + assert.Contains(t, evt.Message, "unreachable") + case <-time.After(time.Second): + t.Fatal("expected immediate warning because group contains a public upstream") + } +} + func TestDNSLoopPrevention(t *testing.T) { wgInterface := &mocWGIface{} service := NewServiceViaMemory(wgInterface) @@ -2183,17 +2711,18 @@ func TestDNSLoopPrevention(t *testing.T) { if tt.expectedHandlers > 0 { handler := muxUpdates[0].handler.(*upstreamResolver) - assert.Len(t, handler.upstreamServers, len(tt.expectedServers)) + flat := handler.flatUpstreams() + assert.Len(t, flat, len(tt.expectedServers)) if tt.shouldFilterOwnIP { - for _, upstream := range handler.upstreamServers { + for _, upstream := range flat { assert.NotEqual(t, dnsServerIP, upstream.Addr()) } } for _, expected := range tt.expectedServers { found := false - for _, upstream := range handler.upstreamServers { + for _, upstream := range flat { if upstream.Addr() == expected { found = true break diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 746b73ca7..3df69517a 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -1,3 +1,32 @@ +// Package dns implements the client-side DNS stack: listener/service on the +// peer's tunnel address, handler chain that routes questions by domain and +// priority, and upstream resolvers that forward what remains to configured +// nameservers. +// +// # Upstream resolution and the race model +// +// When two or more nameserver groups target the same domain, DefaultServer +// merges them into one upstream handler whose state is: +// +// upstreamResolverBase +// └── upstreamServers []upstreamRace // one entry per source NS group +// └── []netip.AddrPort // primary, fallback, ... +// +// Each source nameserver group contributes one upstreamRace. Within a race +// upstreams are tried in order: the next is used only on failure (timeout, +// SERVFAIL, REFUSED, no response). NXDOMAIN is a valid answer and stops +// the walk. When more than one race exists, ServeDNS fans out one +// goroutine per race and returns the first valid answer, cancelling the +// rest. A handler with a single race skips the fan-out. +// +// # Health projection +// +// Query outcomes are recorded per-upstream in UpstreamHealth. The server +// periodically merges these snapshots across handlers and projects them +// into peer.NSGroupState. There is no active probing: a group is marked +// unhealthy only when every seen upstream has a recent failure and none +// has a recent success. Healthy→unhealthy fires a single +// SystemEvent_WARNING; steady-state refreshes do not duplicate it. package dns import ( @@ -11,11 +40,8 @@ import ( "slices" "strings" "sync" - "sync/atomic" "time" - "github.com/cenkalti/backoff/v4" - "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/tun/netstack" @@ -24,7 +50,8 @@ import ( "github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) var currentMTU uint16 = iface.DefaultMTU @@ -39,15 +66,17 @@ const ( // Set longer than UpstreamTimeout to ensure context timeout takes precedence ClientTimeout = 5 * time.Second - reactivatePeriod = 30 * time.Second - probeTimeout = 2 * time.Second - // ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP // payload from the tunnel MTU. ipUDPHeaderSize = 60 + 8 -) -const testRecord = "com." + // raceMaxTotalTimeout caps the combined time spent walking all upstreams + // within one race, so a slow primary can't eat the whole race budget. + raceMaxTotalTimeout = 5 * time.Second + // raceMinPerUpstreamTimeout is the floor applied when dividing + // raceMaxTotalTimeout across upstreams within a race. + raceMinPerUpstreamTimeout = 2 * time.Second +) const ( protoUDP = "udp" @@ -56,6 +85,68 @@ const ( type dnsProtocolKey struct{} +type upstreamProtocolKey struct{} + +// upstreamProtocolResult holds the protocol used for the upstream exchange. +// Stored as a pointer in context so the exchange function can set it. +type upstreamProtocolResult struct { + protocol string +} + +type upstreamClient interface { + exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) +} + +type UpstreamResolver interface { + serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error) + upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) +} + +// upstreamRace is an ordered list of upstreams derived from one configured +// nameserver group. Order matters: the first upstream is tried first, the +// second only on failure, and so on. Multiple upstreamRace values coexist +// inside one resolver when overlapping nameserver groups target the same +// domain; those races run in parallel and the first valid answer wins. +type upstreamRace []netip.AddrPort + +// UpstreamHealth is the last query-path outcome for a single upstream, +// consumed by nameserver-group status projection. +type UpstreamHealth struct { + LastOk time.Time + LastFail time.Time + LastErr string +} + +type upstreamResolverBase struct { + ctx context.Context + cancel context.CancelFunc + upstreamClient upstreamClient + upstreamServers []upstreamRace + domain domain.Domain + upstreamTimeout time.Duration + + healthMu sync.RWMutex + health map[netip.AddrPort]*UpstreamHealth + + statusRecorder *peer.Status + // selectedRoutes returns the current set of client routes the admin + // has enabled. Called lazily from the query hot path when an upstream + // might need a tunnel-bound client (iOS) and from health projection. + selectedRoutes func() route.HAMap +} + +type upstreamFailure struct { + upstream netip.AddrPort + reason string +} + +type raceResult struct { + msg *dns.Msg + upstream netip.AddrPort + protocol string + failures []upstreamFailure +} + // contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context. func contextWithDNSProtocol(ctx context.Context, network string) context.Context { return context.WithValue(ctx, dnsProtocolKey{}, network) @@ -72,14 +163,6 @@ func dnsProtocolFromContext(ctx context.Context) string { return "" } -type upstreamProtocolKey struct{} - -// upstreamProtocolResult holds the protocol used for the upstream exchange. -// Stored as a pointer in context so the exchange function can set it. -type upstreamProtocolResult struct { - protocol string -} - // contextWithupstreamProtocolResult stores a mutable result holder in the context. func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) { r := &upstreamProtocolResult{} @@ -96,64 +179,30 @@ func setUpstreamProtocol(ctx context.Context, protocol string) { } } -type upstreamClient interface { - exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) -} - -type UpstreamResolver interface { - serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error) - upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) -} - -type upstreamResolverBase struct { - ctx context.Context - cancel context.CancelFunc - upstreamClient upstreamClient - upstreamServers []netip.AddrPort - domain string - disabled bool - successCount atomic.Int32 - mutex sync.Mutex - reactivatePeriod time.Duration - upstreamTimeout time.Duration - wg sync.WaitGroup - - deactivate func(error) - reactivate func() - statusRecorder *peer.Status - routeMatch func(netip.Addr) bool -} - -type upstreamFailure struct { - upstream netip.AddrPort - reason string -} - -func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase { +func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d domain.Domain) *upstreamResolverBase { ctx, cancel := context.WithCancel(ctx) return &upstreamResolverBase{ - ctx: ctx, - cancel: cancel, - domain: domain, - upstreamTimeout: UpstreamTimeout, - reactivatePeriod: reactivatePeriod, - statusRecorder: statusRecorder, + ctx: ctx, + cancel: cancel, + domain: d, + upstreamTimeout: UpstreamTimeout, + statusRecorder: statusRecorder, } } // String returns a string representation of the upstream resolver func (u *upstreamResolverBase) String() string { - return fmt.Sprintf("Upstream %s", u.upstreamServers) + return fmt.Sprintf("Upstream %s", u.flatUpstreams()) } // ID returns the unique handler ID func (u *upstreamResolverBase) ID() types.HandlerID { - servers := slices.Clone(u.upstreamServers) + servers := u.flatUpstreams() slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) }) hash := sha256.New() - hash.Write([]byte(u.domain + ":")) + hash.Write([]byte(u.domain.PunycodeString() + ":")) for _, s := range servers { hash.Write([]byte(s.String())) hash.Write([]byte("|")) @@ -166,13 +215,33 @@ func (u *upstreamResolverBase) MatchSubdomains() bool { } func (u *upstreamResolverBase) Stop() { - log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) + log.Debugf("stopping serving DNS for upstreams %s", u.flatUpstreams()) u.cancel() +} - u.mutex.Lock() - u.wg.Wait() - u.mutex.Unlock() +// flatUpstreams is for logging and ID hashing only, not for dispatch. +func (u *upstreamResolverBase) flatUpstreams() []netip.AddrPort { + var out []netip.AddrPort + for _, g := range u.upstreamServers { + out = append(out, g...) + } + return out +} +// isRouted reports whether ip falls inside any client route the admin +// has selected. +func (u *upstreamResolverBase) isRouted(ip netip.Addr) bool { + if u.selectedRoutes == nil { + return false + } + return haMapContains(u.selectedRoutes(), ip) +} + +func (u *upstreamResolverBase) addRace(servers []netip.AddrPort) { + if len(servers) == 0 { + return + } + u.upstreamServers = append(u.upstreamServers, slices.Clone(servers)) } // ServeDNS handles a DNS request @@ -214,59 +283,152 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) { } func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) { - timeout := u.upstreamTimeout - if len(u.upstreamServers) > 1 { - maxTotal := 5 * time.Second - minPerUpstream := 2 * time.Second - scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers)) - if scaledTimeout > minPerUpstream { - timeout = scaledTimeout - } else { - timeout = minPerUpstream - } + groups := u.upstreamServers + switch len(groups) { + case 0: + return false, nil + case 1: + return u.tryOnlyRace(ctx, w, r, groups[0], logger) + default: + return u.raceAll(ctx, w, r, groups, logger) + } +} + +func (u *upstreamResolverBase) tryOnlyRace(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, group upstreamRace, logger *log.Entry) (bool, []upstreamFailure) { + res := u.tryRace(ctx, r, group) + if res.msg == nil { + return false, res.failures + } + u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger) + return true, res.failures +} + +// raceAll runs one worker per group in parallel, taking the first valid +// answer and cancelling the rest. +func (u *upstreamResolverBase) raceAll(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, groups []upstreamRace, logger *log.Entry) (bool, []upstreamFailure) { + raceCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Buffer sized to len(groups) so workers never block on send, even + // after the coordinator has returned. + results := make(chan raceResult, len(groups)) + for _, g := range groups { + go func(g upstreamRace) { + results <- u.tryRace(raceCtx, r, g) + }(g) } var failures []upstreamFailure - for _, upstream := range u.upstreamServers { - if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil { - failures = append(failures, *failure) - } else { - return true, failures + for range groups { + select { + case res := <-results: + failures = append(failures, res.failures...) + if res.msg != nil { + u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger) + return true, failures + } + case <-ctx.Done(): + return false, failures } } return false, failures } -// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream. -func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure { - var rm *dns.Msg - var t time.Duration - var err error +func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group upstreamRace) raceResult { + timeout := u.upstreamTimeout + if len(group) > 1 { + timeout = max(raceMaxTotalTimeout/time.Duration(len(group)), raceMinPerUpstreamTimeout) + } - var startTime time.Time - var upstreamProto *upstreamProtocolResult - func() { - ctx, cancel := context.WithTimeout(parentCtx, timeout) - defer cancel() - ctx, upstreamProto = contextWithupstreamProtocolResult(ctx) - startTime = time.Now() - rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r) - }() + var failures []upstreamFailure + for _, upstream := range group { + if ctx.Err() != nil { + return raceResult{failures: failures} + } + msg, proto, failure := u.queryUpstream(ctx, r, upstream, timeout) + if failure != nil { + failures = append(failures, *failure) + continue + } + return raceResult{msg: msg, upstream: upstream, protocol: proto, failures: failures} + } + return raceResult{failures: failures} +} + +func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration) (*dns.Msg, string, *upstreamFailure) { + ctx, cancel := context.WithTimeout(parentCtx, timeout) + defer cancel() + ctx, upstreamProto := contextWithupstreamProtocolResult(ctx) + + startTime := time.Now() + rm, _, err := u.upstreamClient.exchange(ctx, upstream.String(), r) if err != nil { - return u.handleUpstreamError(err, upstream, startTime) + failure := u.handleUpstreamError(err, upstream, startTime) + u.markUpstreamFail(upstream, failure.reason) + return nil, "", failure } if rm == nil || !rm.Response { - return &upstreamFailure{upstream: upstream, reason: "no response"} + u.markUpstreamFail(upstream, "no response") + return nil, "", &upstreamFailure{upstream: upstream, reason: "no response"} } if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused { - return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]} + reason := dns.RcodeToString[rm.Rcode] + u.markUpstreamFail(upstream, reason) + return nil, "", &upstreamFailure{upstream: upstream, reason: reason} } - u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger) - return nil + u.markUpstreamOk(upstream) + + proto := "" + if upstreamProto != nil { + proto = upstreamProto.protocol + } + return rm, proto, nil +} + +// healthEntry returns the mutable health record for addr, lazily creating +// the map and the entry. Caller must hold u.healthMu. +func (u *upstreamResolverBase) healthEntry(addr netip.AddrPort) *UpstreamHealth { + if u.health == nil { + u.health = make(map[netip.AddrPort]*UpstreamHealth) + } + h := u.health[addr] + if h == nil { + h = &UpstreamHealth{} + u.health[addr] = h + } + return h +} + +func (u *upstreamResolverBase) markUpstreamOk(addr netip.AddrPort) { + u.healthMu.Lock() + defer u.healthMu.Unlock() + h := u.healthEntry(addr) + h.LastOk = time.Now() + h.LastFail = time.Time{} + h.LastErr = "" +} + +func (u *upstreamResolverBase) markUpstreamFail(addr netip.AddrPort, reason string) { + u.healthMu.Lock() + defer u.healthMu.Unlock() + h := u.healthEntry(addr) + h.LastFail = time.Now() + h.LastErr = reason +} + +// UpstreamHealth returns a snapshot of per-upstream query outcomes. +func (u *upstreamResolverBase) UpstreamHealth() map[netip.AddrPort]UpstreamHealth { + u.healthMu.RLock() + defer u.healthMu.RUnlock() + out := make(map[netip.AddrPort]UpstreamHealth, len(u.health)) + for k, v := range u.health { + out[k] = *v + } + return out } func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure { @@ -282,12 +444,23 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add return &upstreamFailure{upstream: upstream, reason: reason} } -func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool { - u.successCount.Add(1) +func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string { + if u.statusRecorder == nil { + return "" + } + peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder) + if peerInfo == nil { + return "" + } + + return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo)) +} + +func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, proto string, logger *log.Entry) { resutil.SetMeta(w, "upstream", upstream.String()) - if upstreamProto != nil && upstreamProto.protocol != "" { - resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol) + if proto != "" { + resutil.SetMeta(w, "upstream_protocol", proto) } // Clear Zero bit from external responses to prevent upstream servers from @@ -296,14 +469,11 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn if err := w.WriteMsg(rm); err != nil { logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err) - return true } - - return true } func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) { - totalUpstreams := len(u.upstreamServers) + totalUpstreams := len(u.flatUpstreams()) failedCount := len(failures) failureSummary := formatFailures(failures) @@ -330,119 +500,6 @@ func formatFailures(failures []upstreamFailure) string { return strings.Join(parts, ", ") } -// ProbeAvailability tests all upstream servers simultaneously and -// disables the resolver if none work -func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) { - u.mutex.Lock() - defer u.mutex.Unlock() - - // avoid probe if upstreams could resolve at least one query - if u.successCount.Load() > 0 { - return - } - - var success bool - var mu sync.Mutex - var wg sync.WaitGroup - - var errs *multierror.Error - for _, upstream := range u.upstreamServers { - wg.Add(1) - go func(upstream netip.AddrPort) { - defer wg.Done() - err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond) - if err != nil { - mu.Lock() - errs = multierror.Append(errs, err) - mu.Unlock() - log.Warnf("probing upstream nameserver %s: %s", upstream, err) - return - } - - mu.Lock() - success = true - mu.Unlock() - }(upstream) - } - - wg.Wait() - - select { - case <-ctx.Done(): - return - case <-u.ctx.Done(): - return - default: - } - - // didn't find a working upstream server, let's disable and try later - if !success { - u.disable(errs.ErrorOrNil()) - - if u.statusRecorder == nil { - return - } - - u.statusRecorder.PublishEvent( - proto.SystemEvent_WARNING, - proto.SystemEvent_DNS, - "All upstream servers failed (probe failed)", - "Unable to reach one or more DNS servers. This might affect your ability to connect to some services.", - map[string]string{"upstreams": u.upstreamServersString()}, - ) - } -} - -// waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response -func (u *upstreamResolverBase) waitUntilResponse() { - exponentialBackOff := &backoff.ExponentialBackOff{ - InitialInterval: 500 * time.Millisecond, - RandomizationFactor: 0.5, - Multiplier: 1.1, - MaxInterval: u.reactivatePeriod, - MaxElapsedTime: 0, - Stop: backoff.Stop, - Clock: backoff.SystemClock, - } - - operation := func() error { - select { - case <-u.ctx.Done(): - return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString())) - default: - } - - for _, upstream := range u.upstreamServers { - if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil { - log.Tracef("upstream check for %s: %s", upstream, err) - } else { - // at least one upstream server is available, stop probing - return nil - } - } - - log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff()) - return fmt.Errorf("upstream check call error") - } - - err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx)) - if err != nil { - if errors.Is(err, context.Canceled) { - log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString()) - } else { - log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err) - } - return - } - - log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString()) - u.successCount.Add(1) - u.reactivate() - u.mutex.Lock() - u.disabled = false - u.mutex.Unlock() -} - // isTimeout returns true if the given error is a network timeout error. // // Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout @@ -454,45 +511,6 @@ func isTimeout(err error) bool { return false } -func (u *upstreamResolverBase) disable(err error) { - if u.disabled { - return - } - - log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod) - u.successCount.Store(0) - u.deactivate(err) - u.disabled = true - u.wg.Add(1) - go func() { - defer u.wg.Done() - u.waitUntilResponse() - }() -} - -func (u *upstreamResolverBase) upstreamServersString() string { - var servers []string - for _, server := range u.upstreamServers { - servers = append(servers, server.String()) - } - return strings.Join(servers, ", ") -} - -func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error { - mergedCtx, cancel := context.WithTimeout(baseCtx, timeout) - defer cancel() - - if externalCtx != nil { - stop2 := context.AfterFunc(externalCtx, cancel) - defer stop2() - } - - r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA) - - _, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r) - return err -} - // clientUDPMaxSize returns the maximum UDP response size the client accepts. func clientUDPMaxSize(r *dns.Msg) int { if opt := r.IsEdns0(); opt != nil { @@ -718,15 +736,22 @@ func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State { return bestMatch } -func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string { - if u.statusRecorder == nil { - return "" +// haMapContains reports whether any route in the map contains ip. +// +// Gap: dynamic (domain-based) routes carry a placeholder Network that +// never matches a real address, so an upstream reached via a dynamic +// route is classified as "not routed" here. The DNS health path then +// emits failure events immediately for such upstreams instead of +// applying the startup grace window. Rare (DNS servers are usually +// designated by IP, not by domain) but worth revisiting if DoT/DoH-style +// upstreams or /etc/hosts-style domain routing to DNS become supported. +func haMapContains(hm route.HAMap, ip netip.Addr) bool { + for _, routes := range hm { + for _, r := range routes { + if r.Network.Contains(ip) { + return true + } + } } - - peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder) - if peerInfo == nil { - return "" - } - - return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo)) + return false } diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index ee1ca42fe..0072b1134 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -11,6 +11,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" nbnet "github.com/netbirdio/netbird/client/net" + "github.com/netbirdio/netbird/shared/management/domain" ) type upstreamResolver struct { @@ -26,9 +27,9 @@ func newUpstreamResolver( _ WGIface, statusRecorder *peer.Status, hostsDNSHolder *hostsDNSHolder, - domain string, + d domain.Domain, ) (*upstreamResolver, error) { - upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) + upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d) c := &upstreamResolver{ upstreamResolverBase: upstreamResolverBase, hostsDNSHolder: hostsDNSHolder, diff --git a/client/internal/dns/upstream_general.go b/client/internal/dns/upstream_general.go index 1143b6c51..5f8f369a4 100644 --- a/client/internal/dns/upstream_general.go +++ b/client/internal/dns/upstream_general.go @@ -12,6 +12,7 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/shared/management/domain" ) type upstreamResolver struct { @@ -24,9 +25,9 @@ func newUpstreamResolver( wgIface WGIface, statusRecorder *peer.Status, _ *hostsDNSHolder, - domain string, + d domain.Domain, ) (*upstreamResolver, error) { - upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) + upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d) nonIOS := &upstreamResolver{ upstreamResolverBase: upstreamResolverBase, nsNet: wgIface.GetNet(), diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 02c11173b..450152b2e 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -15,6 +15,7 @@ import ( "golang.org/x/sys/unix" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/shared/management/domain" ) type upstreamResolverIOS struct { @@ -29,9 +30,9 @@ func newUpstreamResolver( wgIface WGIface, statusRecorder *peer.Status, _ *hostsDNSHolder, - domain string, + d domain.Domain, ) (*upstreamResolverIOS, error) { - upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) + upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d) ios := &upstreamResolverIOS{ upstreamResolverBase: upstreamResolverBase, @@ -65,8 +66,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * } else { upstreamIP = upstreamIP.Unmap() } - needsPrivate := u.lNet.Contains(upstreamIP) || - (u.routeMatch != nil && u.routeMatch(upstreamIP)) + needsPrivate := u.lNet.Contains(upstreamIP) || u.isRouted(upstreamIP) if needsPrivate { log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream) client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout) diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 1797fdad8..b0c510b18 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -73,7 +73,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())) } } - resolver.upstreamServers = servers + resolver.addRace(servers) resolver.upstreamTimeout = testCase.timeout if testCase.cancelCTX { cancel() @@ -160,58 +160,6 @@ func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream stri return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream) } -func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { - mockClient := &mockUpstreamResolver{ - err: dns.ErrTime, - r: new(dns.Msg), - rtt: time.Millisecond, - } - - resolver := &upstreamResolverBase{ - ctx: context.TODO(), - upstreamClient: mockClient, - upstreamTimeout: UpstreamTimeout, - reactivatePeriod: time.Microsecond * 100, - } - addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection - resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())} - - failed := false - resolver.deactivate = func(error) { - failed = true - // After deactivation, make the mock client work again - mockClient.err = nil - } - - reactivated := false - resolver.reactivate = func() { - reactivated = true - } - - resolver.ProbeAvailability(context.TODO()) - - if !failed { - t.Errorf("expected that resolving was deactivated") - return - } - - if !resolver.disabled { - t.Errorf("resolver should be Disabled") - return - } - - time.Sleep(time.Millisecond * 200) - - if !reactivated { - t.Errorf("expected that resolving was reactivated") - return - } - - if resolver.disabled { - t.Errorf("should be enabled") - } -} - func TestUpstreamResolver_Failover(t *testing.T) { upstream1 := netip.MustParseAddrPort("192.0.2.1:53") upstream2 := netip.MustParseAddrPort("192.0.2.2:53") @@ -339,9 +287,9 @@ func TestUpstreamResolver_Failover(t *testing.T) { resolver := &upstreamResolverBase{ ctx: ctx, upstreamClient: trackingClient, - upstreamServers: []netip.AddrPort{upstream1, upstream2}, upstreamTimeout: UpstreamTimeout, } + resolver.addRace([]netip.AddrPort{upstream1, upstream2}) var responseMSG *dns.Msg responseWriter := &test.MockResponseWriter{ @@ -421,9 +369,9 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) { resolver := &upstreamResolverBase{ ctx: ctx, upstreamClient: mockClient, - upstreamServers: []netip.AddrPort{upstream}, upstreamTimeout: UpstreamTimeout, } + resolver.addRace([]netip.AddrPort{upstream}) var responseMSG *dns.Msg responseWriter := &test.MockResponseWriter{ @@ -440,6 +388,133 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) { assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL") } +// TestUpstreamResolver_RaceAcrossGroups covers two nameserver groups +// configured for the same domain, with one broken group. The merge+race +// path should answer as fast as the working group and not pay the timeout +// of the broken one on every query. +func TestUpstreamResolver_RaceAcrossGroups(t *testing.T) { + broken := netip.MustParseAddrPort("192.0.2.1:53") + working := netip.MustParseAddrPort("192.0.2.2:53") + successAnswer := "192.0.2.100" + timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")} + + mockClient := &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + broken.String(): {err: timeoutErr}, + working.String(): {msg: buildMockResponse(dns.RcodeSuccess, successAnswer)}, + }, + rtt: time.Millisecond, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := &upstreamResolverBase{ + ctx: ctx, + upstreamClient: mockClient, + upstreamTimeout: 100 * time.Millisecond, + } + resolver.addRace([]netip.AddrPort{broken}) + resolver.addRace([]netip.AddrPort{working}) + + var responseMSG *dns.Msg + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + start := time.Now() + resolver.ServeDNS(responseWriter, inputMSG) + elapsed := time.Since(start) + + require.NotNil(t, responseMSG, "should write a response") + assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode) + require.NotEmpty(t, responseMSG.Answer) + assert.Contains(t, responseMSG.Answer[0].String(), successAnswer) + // Working group answers in a single RTT; the broken group's + // timeout (100ms) must not block the response. + assert.Less(t, elapsed, 100*time.Millisecond, "race must not wait for broken group's timeout") +} + +// TestUpstreamResolver_AllGroupsFail checks that when every group fails the +// resolver returns SERVFAIL rather than leaking a partial response. +func TestUpstreamResolver_AllGroupsFail(t *testing.T) { + a := netip.MustParseAddrPort("192.0.2.1:53") + b := netip.MustParseAddrPort("192.0.2.2:53") + + mockClient := &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + a.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")}, + b.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")}, + }, + rtt: time.Millisecond, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := &upstreamResolverBase{ + ctx: ctx, + upstreamClient: mockClient, + upstreamTimeout: UpstreamTimeout, + } + resolver.addRace([]netip.AddrPort{a}) + resolver.addRace([]netip.AddrPort{b}) + + var responseMSG *dns.Msg + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA)) + require.NotNil(t, responseMSG) + assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode) +} + +// TestUpstreamResolver_HealthTracking verifies that query-path results are +// recorded into per-upstream health, which is what projects back to +// NSGroupState for status reporting. +func TestUpstreamResolver_HealthTracking(t *testing.T) { + ok := netip.MustParseAddrPort("192.0.2.10:53") + bad := netip.MustParseAddrPort("192.0.2.11:53") + + mockClient := &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + ok.String(): {msg: buildMockResponse(dns.RcodeSuccess, "192.0.2.100")}, + bad.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")}, + }, + rtt: time.Millisecond, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := &upstreamResolverBase{ + ctx: ctx, + upstreamClient: mockClient, + upstreamTimeout: UpstreamTimeout, + } + resolver.addRace([]netip.AddrPort{ok, bad}) + + responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }} + resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA)) + + health := resolver.UpstreamHealth() + require.Contains(t, health, ok) + assert.False(t, health[ok].LastOk.IsZero(), "ok upstream should have LastOk set") + assert.Empty(t, health[ok].LastErr) + + // bad upstream was never tried because ok answered first; its health + // should remain unset. + assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers") +} + func TestFormatFailures(t *testing.T) { testCases := []struct { name string diff --git a/client/internal/engine.go b/client/internal/engine.go index 8d7e02bd5..ea93ecede 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -504,16 +504,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) - e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool { - for _, routes := range e.routeManager.GetSelectedClientRoutes() { - for _, r := range routes { - if r.Network.Contains(ip) { - return true - } - } - } - return false - }) + e.dnsServer.SetRouteSources(e.routeManager.GetSelectedClientRoutes, e.routeManager.GetActiveClientRoutes) if err = e.wgInterfaceCreate(); err != nil { log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error()) @@ -1336,9 +1327,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.networkSerial = serial - // Test received (upstream) servers for availability right away instead of upon usage. - // If no server of a server group responds this will disable the respective handler and retry later. - go e.dnsServer.ProbeAvailability() return nil } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 3923e153b..73535a5a8 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -53,6 +53,7 @@ type Manager interface { GetRouteSelector() *routeselector.RouteSelector GetClientRoutes() route.HAMap GetSelectedClientRoutes() route.HAMap + GetActiveClientRoutes() route.HAMap GetClientRoutesWithNetID() map[route.NetID][]*route.Route SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string @@ -477,6 +478,39 @@ func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap { return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes)) } +// GetActiveClientRoutes returns the subset of selected client routes +// that are currently reachable: the route's peer is Connected and is +// the one actively carrying the route (not just an HA sibling). +func (m *DefaultManager) GetActiveClientRoutes() route.HAMap { + m.mux.Lock() + selected := m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes)) + recorder := m.statusRecorder + m.mux.Unlock() + + if recorder == nil { + return selected + } + + out := make(route.HAMap, len(selected)) + for id, routes := range selected { + for _, r := range routes { + st, err := recorder.GetPeer(r.Peer) + if err != nil { + continue + } + if st.ConnStatus != peer.StatusConnected { + continue + } + if _, hasRoute := st.GetRoutes()[r.Network.String()]; !hasRoute { + continue + } + out[id] = routes + break + } + } + return out +} + // GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { m.mux.Lock() diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 66b5e30dd..937314995 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -19,6 +19,7 @@ type MockManager struct { GetRouteSelectorFunc func() *routeselector.RouteSelector GetClientRoutesFunc func() route.HAMap GetSelectedClientRoutesFunc func() route.HAMap + GetActiveClientRoutesFunc func() route.HAMap GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route StopFunc func(manager *statemanager.Manager) } @@ -78,6 +79,14 @@ func (m *MockManager) GetSelectedClientRoutes() route.HAMap { return nil } +// GetActiveClientRoutes mock implementation of GetActiveClientRoutes from the Manager interface +func (m *MockManager) GetActiveClientRoutes() route.HAMap { + if m.GetActiveClientRoutesFunc != nil { + return m.GetActiveClientRoutesFunc() + } + return nil +} + // GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { if m.GetClientRoutesWithNetIDFunc != nil {