diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 46e07f98d..a5f5eb77a 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -6,6 +6,7 @@ import ( "fmt" "net/netip" "net/url" + "slices" "strings" "sync" "time" @@ -44,6 +45,11 @@ const ( warningDelayBonusCap = 30 * time.Second ) +// errNoUsableNameservers signals that a merged-domain group has no usable +// upstream servers. Callers should skip the group without treating it as a +// build failure. +var errNoUsableNameservers = errors.New("no usable nameservers") + // ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes type ReadyListener interface { OnReady() @@ -315,6 +321,19 @@ func (s *DefaultServer) SetRouteSources(selected, active func() route.HAMap) { defer s.mux.Unlock() s.selectedRoutes = selected s.activeRoutes = active + + // Permanent / iOS constructors build the root handler before the + // engine wires route sources, so its selectedRoutes callback would + // otherwise remain nil and overlay upstreams would be classified + // as public. Propagate the new accessors to existing handlers. + type routeSettable interface { + setSelectedRoutes(func() route.HAMap) + } + for _, entry := range s.dnsMuxMap { + if h, ok := entry.handler.(routeSettable); ok { + h.setSelectedRoutes(selected) + } + } } // RegisterHandler registers a handler for the given domains with the given priority. @@ -778,7 +797,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) { s.wgInterface, s.statusRecorder, s.hostsDNSHolder, - domain.Domain(nbdns.RootZone), + nbdns.RootZone, ) if err != nil { log.Errorf("failed to create upstream resolver for original nameservers: %v", err) @@ -861,11 +880,13 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam update, err := s.buildMergedDomainHandler(domainGroup, priority) if err != nil { + if errors.Is(err, errNoUsableNameservers) { + log.Errorf("no usable nameservers for domain=%s", domainGroup.domain) + continue + } return nil, err } - if update != nil { - muxUpdates = append(muxUpdates, *update) - } + muxUpdates = append(muxUpdates, *update) } return muxUpdates, nil @@ -897,8 +918,7 @@ func (s *DefaultServer) buildMergedDomainHandler(domainGroup nsGroupsByDomain, p if len(handler.upstreamServers) == 0 { handler.Stop() - log.Errorf("no usable nameservers for domain=%s", domainGroup.domain) - return nil, nil + return nil, errNoUsableNameservers } log.Debugf("creating merged handler for domain=%s with %d group(s) priority=%d", domainGroup.domain, len(handler.upstreamServers), priority) @@ -927,6 +947,27 @@ func (s *DefaultServer) filterNameServers(nameServers []nbdns.NameServer) []neti return out } +// usableNameServers returns the subset of nameServers the handler would +// actually query. Matches filterNameServers without the warning logs, so +// it's safe to call on every health-projection tick. +func (s *DefaultServer) usableNameServers(nameServers []nbdns.NameServer) []netip.AddrPort { + var runtimeIP netip.Addr + if s.service != nil { + runtimeIP = s.service.RuntimeIP() + } + var out []netip.AddrPort + for _, ns := range nameServers { + if ns.NSType != nbdns.UDPNameServerType { + continue + } + if runtimeIP.IsValid() && ns.IP == runtimeIP { + continue + } + out = append(out, ns.AddrPort()) + } + return out +} + func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { // this will introduce a short period of time when the server is not able to handle DNS requests for _, existing := range s.dnsMuxMap { @@ -1044,11 +1085,14 @@ func (s *DefaultServer) projectNSGroupHealth(snap nsHealthSnapshot) { } now := time.Now() - delay := s.warningDelay(len(snap.selected)) + delay := s.warningDelay(haMapRouteCount(snap.selected)) states := make([]peer.NSGroupState, 0, len(snap.groups)) seen := make(map[nsGroupID]struct{}, len(snap.groups)) for _, group := range snap.groups { - servers := nameServerAddrPorts(group.NameServers) + servers := s.usableNameServers(group.NameServers) + if len(servers) == 0 { + continue + } verdict, groupErr := evaluateNSGroupHealth(snap.merged, servers, now) id := generateGroupKey(group) seen[id] = struct{}{} @@ -1069,7 +1113,10 @@ func (s *DefaultServer) projectNSGroupHealth(snap nsHealthSnapshot) { enabled = s.projectUnhealthy(p, servers, immediate, now, delay) case nsVerdictUndecided: // Stay Available until evidence says otherwise, unless a - // warning is already active for this group. + // warning is already active for this group. Also clear any + // prior Unhealthy streak so a later Unhealthy verdict starts + // a fresh grace window rather than inheriting a stale one. + p.unhealthySince = time.Time{} enabled = !p.warningActive groupErr = nil } @@ -1142,6 +1189,9 @@ func (s *DefaultServer) projectUnhealthy(p *nsGroupProj, servers []netip.AddrPor // count. Scales gently: +1s per 100 routes, capped by // warningDelayBonusCap. Parallel handshakes mean handshake time grows // much slower than route count, so linear scaling would overcorrect. +// +// TODO: revisit the scaling curve with real-world data — the current +// values are a reasonable starting point, not a measured fit. func (s *DefaultServer) warningDelay(routeCount int) time.Duration { bonus := time.Duration(routeCount/100) * time.Second if bonus > warningDelayBonusCap { @@ -1164,11 +1214,16 @@ func (s *DefaultServer) groupHasImmediateUpstream(servers []netip.AddrPort, snap for _, srv := range servers { addr := srv.Addr().Unmap() overlay := overlayV4.IsValid() && overlayV4.Contains(addr) - routed := haMapContains(snap.selected, addr) + selMatched, selDynamic := haMapContains(snap.selected, addr) + // Treat an unknown (dynamic selected route) as possibly routed: + // the upstream might reach through a dynamic route whose Network + // hasn't resolved yet, and classifying as public would bypass + // the startup grace window. + routed := selMatched || selDynamic if !overlay && !routed { return true } - if haMapContains(snap.active, addr) { + if actMatched, _ := haMapContains(snap.active, addr); actMatched { return true } } @@ -1290,15 +1345,6 @@ func classifyUpstreamHealth(h UpstreamHealth, now time.Time) upstreamClassificat return upstreamStale } -// nameServerAddrPorts flattens a NameServer list to AddrPorts. -func nameServerAddrPorts(ns []nbdns.NameServer) []netip.AddrPort { - out := make([]netip.AddrPort, 0, len(ns)) - for _, n := range ns { - out = append(out, n.AddrPort()) - } - return out -} - func joinAddrPorts(servers []netip.AddrPort) string { parts := make([]string, 0, len(servers)) for _, s := range servers { @@ -1307,12 +1353,18 @@ func joinAddrPorts(servers []netip.AddrPort) string { return strings.Join(parts, ", ") } +// generateGroupKey returns a stable identity for an NS group so health +// state (everHealthy / warningActive) survives reorderings in the +// configured nameserver or domain lists. func generateGroupKey(nsGroup *nbdns.NameServerGroup) nsGroupID { - var servers []string + servers := make([]string, 0, len(nsGroup.NameServers)) for _, ns := range nsGroup.NameServers { servers = append(servers, ns.AddrPort().String()) } - return nsGroupID(fmt.Sprintf("%v_%v", servers, nsGroup.Domains)) + slices.Sort(servers) + domains := slices.Clone(nsGroup.Domains) + slices.Sort(domains) + return nsGroupID(fmt.Sprintf("%v_%v", servers, domains)) } // groupNSGroupsByDomain groups nameserver groups by their match domains diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 7b596d3fd..a42a60164 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -1001,7 +1001,6 @@ type mockHandler struct { func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {} func (m *mockHandler) Stop() {} -func (m *mockHandler) ProbeAvailability(context.Context) {} func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) } type mockService struct{} diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 3df69517a..2cb9ad199 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -163,8 +163,8 @@ func dnsProtocolFromContext(ctx context.Context) string { return "" } -// contextWithupstreamProtocolResult stores a mutable result holder in the context. -func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) { +// contextWithUpstreamProtocolResult stores a mutable result holder in the context. +func contextWithUpstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) { r := &upstreamProtocolResult{} return context.WithValue(ctx, upstreamProtocolKey{}, r), r } @@ -196,16 +196,20 @@ func (u *upstreamResolverBase) String() string { return fmt.Sprintf("Upstream %s", u.flatUpstreams()) } -// ID returns the unique handler ID +// ID returns the unique handler ID. Race groupings and within-race +// ordering are both part of the identity: [[A,B]] and [[A],[B]] query +// the same servers but with different semantics (serial fallback vs +// parallel race), so their handlers must not collide. func (u *upstreamResolverBase) ID() types.HandlerID { - servers := u.flatUpstreams() - slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) }) - hash := sha256.New() hash.Write([]byte(u.domain.PunycodeString() + ":")) - for _, s := range servers { - hash.Write([]byte(s.String())) - hash.Write([]byte("|")) + for _, race := range u.upstreamServers { + hash.Write([]byte("[")) + for _, s := range race { + hash.Write([]byte(s.String())) + hash.Write([]byte("|")) + } + hash.Write([]byte("]")) } return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8])) } @@ -228,13 +232,11 @@ func (u *upstreamResolverBase) flatUpstreams() []netip.AddrPort { return out } -// isRouted reports whether ip falls inside any client route the admin -// has selected. -func (u *upstreamResolverBase) isRouted(ip netip.Addr) bool { - if u.selectedRoutes == nil { - return false - } - return haMapContains(u.selectedRoutes(), ip) +// setSelectedRoutes swaps the accessor used to classify overlay-routed +// upstreams. Called when route sources are wired after the handler was +// built (permanent / iOS constructors). +func (u *upstreamResolverBase) setSelectedRoutes(selected func() route.HAMap) { + u.selectedRoutes = selected } func (u *upstreamResolverBase) addRace(servers []netip.AddrPort) { @@ -313,6 +315,8 @@ func (u *upstreamResolverBase) raceAll(ctx context.Context, w dns.ResponseWriter // after the coordinator has returned. results := make(chan raceResult, len(groups)) for _, g := range groups { + // tryRace clones the request per attempt, so workers never share + // a *dns.Msg and concurrent EDNS0 mutations can't race. go func(g upstreamRace) { results <- u.tryRace(raceCtx, r, g) }(g) @@ -337,7 +341,14 @@ func (u *upstreamResolverBase) raceAll(ctx context.Context, w dns.ResponseWriter func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group upstreamRace) raceResult { timeout := u.upstreamTimeout if len(group) > 1 { + // Cap the whole walk at raceMaxTotalTimeout: per-upstream timeouts + // still honor raceMinPerUpstreamTimeout as a floor for correctness + // on slow links, but the outer context ensures the combined walk + // cannot exceed the cap regardless of group size. timeout = max(raceMaxTotalTimeout/time.Duration(len(group)), raceMinPerUpstreamTimeout) + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, raceMaxTotalTimeout) + defer cancel() } var failures []upstreamFailure @@ -345,7 +356,11 @@ func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group up if ctx.Err() != nil { return raceResult{failures: failures} } - msg, proto, failure := u.queryUpstream(ctx, r, upstream, timeout) + // Clone the request per attempt: the exchange path mutates EDNS0 + // options in-place, so reusing the same *dns.Msg across sequential + // upstreams would carry those mutations (e.g. a reduced UDP size) + // into the next attempt. + msg, proto, failure := u.queryUpstream(ctx, r.Copy(), upstream, timeout) if failure != nil { failures = append(failures, *failure) continue @@ -358,12 +373,19 @@ func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group up func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration) (*dns.Msg, string, *upstreamFailure) { ctx, cancel := context.WithTimeout(parentCtx, timeout) defer cancel() - ctx, upstreamProto := contextWithupstreamProtocolResult(ctx) + ctx, upstreamProto := contextWithUpstreamProtocolResult(ctx) startTime := time.Now() rm, _, err := u.upstreamClient.exchange(ctx, upstream.String(), r) if err != nil { + // A parent cancellation (e.g., another race won and the coordinator + // cancelled the losers) is not an upstream failure. Check both the + // error chain and the parent context: a transport may surface the + // cancellation as a read/deadline error rather than context.Canceled. + if errors.Is(err, context.Canceled) || errors.Is(parentCtx.Err(), context.Canceled) { + return nil, "", &upstreamFailure{upstream: upstream, reason: "canceled"} + } failure := u.handleUpstreamError(err, upstream, startTime) u.markUpstreamFail(upstream, failure.reason) return nil, "", failure @@ -522,13 +544,10 @@ func clientUDPMaxSize(r *dns.Msg) int { // ExchangeWithFallback exchanges a DNS message with the upstream server. // It first tries to use UDP, and if it is truncated, it falls back to TCP. // If the inbound request came over TCP (via context), it skips the UDP attempt. -// If the passed context is nil, this will use Exchange instead of ExchangeContext. func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) { // If the request came in over TCP, go straight to TCP upstream. if dnsProtocolFromContext(ctx) == protoTCP { - tcpClient := *client - tcpClient.Net = protoTCP - rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream) + rm, t, err := toTCPClient(client).ExchangeContext(ctx, r, upstream) if err != nil { return nil, t, fmt.Errorf("with tcp: %w", err) } @@ -548,18 +567,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u opt.SetUDPSize(maxUDPPayload) } - var ( - rm *dns.Msg - t time.Duration - err error - ) - - if ctx == nil { - rm, t, err = client.Exchange(r, upstream) - } else { - rm, t, err = client.ExchangeContext(ctx, r, upstream) - } - + rm, t, err := client.ExchangeContext(ctx, r, upstream) if err != nil { return nil, t, fmt.Errorf("with udp: %w", err) } @@ -573,15 +581,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u // data than the client's buffer, we could truncate locally and skip // the TCP retry. - tcpClient := *client - tcpClient.Net = protoTCP - - if ctx == nil { - rm, t, err = tcpClient.Exchange(r, upstream) - } else { - rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream) - } - + rm, t, err = toTCPClient(client).ExchangeContext(ctx, r, upstream) if err != nil { return nil, t, fmt.Errorf("with tcp: %w", err) } @@ -595,6 +595,25 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u return rm, t, nil } +// toTCPClient returns a copy of c configured for TCP. If c's Dialer has a +// *net.UDPAddr bound as LocalAddr (iOS does this to keep the source IP on +// the tunnel interface), it is converted to the equivalent *net.TCPAddr +// so net.Dialer doesn't reject the TCP dial with "mismatched local +// address type". +func toTCPClient(c *dns.Client) *dns.Client { + tcp := *c + tcp.Net = protoTCP + if tcp.Dialer == nil { + return &tcp + } + d := *tcp.Dialer + if ua, ok := d.LocalAddr.(*net.UDPAddr); ok { + d.LocalAddr = &net.TCPAddr{IP: ua.IP, Port: ua.Port, Zone: ua.Zone} + } + tcp.Dialer = &d + return &tcp +} + // ExchangeWithNetstack performs a DNS exchange using netstack for dialing. // This is needed when netstack is enabled to reach peer IPs through the tunnel. func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) { @@ -736,22 +755,36 @@ func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State { return bestMatch } -// haMapContains reports whether any route in the map contains ip. -// -// Gap: dynamic (domain-based) routes carry a placeholder Network that -// never matches a real address, so an upstream reached via a dynamic -// route is classified as "not routed" here. The DNS health path then -// emits failure events immediately for such upstreams instead of -// applying the startup grace window. Rare (DNS servers are usually -// designated by IP, not by domain) but worth revisiting if DoT/DoH-style -// upstreams or /etc/hosts-style domain routing to DNS become supported. -func haMapContains(hm route.HAMap, ip netip.Addr) bool { +// haMapRouteCount returns the total number of routes across all HA +// groups in the map. route.HAMap is keyed by HAUniqueID with slices of +// routes per key, so len(hm) is the number of HA groups, not routes. +func haMapRouteCount(hm route.HAMap) int { + total := 0 + for _, routes := range hm { + total += len(routes) + } + return total +} + +// haMapContains checks whether ip is covered by any concrete prefix in +// the HA map. haveDynamic is reported separately: dynamic (domain-based) +// routes carry a placeholder Network that can't be prefix-checked, so we +// can't know at this point whether ip is reached through one. Callers +// decide how to interpret the unknown: health projection treats it as +// "possibly routed" to avoid emitting false-positive warnings during +// startup, while iOS dial selection requires a concrete match before +// binding to the tunnel. +func haMapContains(hm route.HAMap, ip netip.Addr) (matched, haveDynamic bool) { for _, routes := range hm { for _, r := range routes { + if r.IsDynamic() { + haveDynamic = true + continue + } if r.Network.Contains(ip) { - return true + return true, haveDynamic } } } - return false + return false, haveDynamic } diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 450152b2e..65dc0bc50 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -66,7 +66,14 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * } else { upstreamIP = upstreamIP.Unmap() } - needsPrivate := u.lNet.Contains(upstreamIP) || u.isRouted(upstreamIP) + var routed bool + if u.selectedRoutes != nil { + // Only a concrete prefix match binds to the tunnel: dialing + // through a private client for an upstream we can't prove is + // routed would break public resolvers. + routed, _ = haMapContains(u.selectedRoutes(), upstreamIP) + } + needsPrivate := u.lNet.Contains(upstreamIP) || routed if needsPrivate { log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream) client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout) @@ -75,8 +82,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r * } } - // Cannot use client.ExchangeContext because it overwrites our Dialer - return ExchangeWithFallback(nil, client, r, upstream) + return ExchangeWithFallback(ctx, client, r, upstream) } // GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index b0c510b18..60195b5ac 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -6,6 +6,7 @@ import ( "net" "net/netip" "strings" + "sync/atomic" "testing" "time" @@ -132,20 +133,10 @@ func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) { return "", nil } -type mockUpstreamResolver struct { - r *dns.Msg - rtt time.Duration - err error -} - -// exchange mock implementation of exchange from upstreamResolver -func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) (*dns.Msg, time.Duration, error) { - return c.r, c.rtt, c.err -} - type mockUpstreamResponse struct { - msg *dns.Msg - err error + msg *dns.Msg + err error + delay time.Duration } type mockUpstreamResolverPerServer struct { @@ -153,11 +144,19 @@ type mockUpstreamResolverPerServer struct { rtt time.Duration } -func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) { - if r, ok := c.responses[upstream]; ok { - return r.msg, c.rtt, r.err +func (c mockUpstreamResolverPerServer) exchange(ctx context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) { + r, ok := c.responses[upstream] + if !ok { + return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream) } - return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream) + if r.delay > 0 { + select { + case <-time.After(r.delay): + case <-ctx.Done(): + return nil, c.rtt, ctx.Err() + } + } + return r.msg, c.rtt, r.err } func TestUpstreamResolver_Failover(t *testing.T) { @@ -400,7 +399,10 @@ func TestUpstreamResolver_RaceAcrossGroups(t *testing.T) { mockClient := &mockUpstreamResolverPerServer{ responses: map[string]mockUpstreamResponse{ - broken.String(): {err: timeoutErr}, + // Force the broken upstream to only unblock via timeout / + // cancellation so the assertion below can't pass if races + // were run serially. + broken.String(): {err: timeoutErr, delay: 500 * time.Millisecond}, working.String(): {msg: buildMockResponse(dns.RcodeSuccess, successAnswer)}, }, rtt: time.Millisecond, @@ -412,7 +414,7 @@ func TestUpstreamResolver_RaceAcrossGroups(t *testing.T) { resolver := &upstreamResolverBase{ ctx: ctx, upstreamClient: mockClient, - upstreamTimeout: 100 * time.Millisecond, + upstreamTimeout: 250 * time.Millisecond, } resolver.addRace([]netip.AddrPort{broken}) resolver.addRace([]netip.AddrPort{working}) @@ -740,10 +742,10 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) { // Verify that a client EDNS0 larger than our MTU-derived limit gets // capped in the outgoing request so the upstream doesn't send a // response larger than our read buffer. - var receivedUDPSize uint16 + var receivedUDPSize atomic.Uint32 udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) { if opt := r.IsEdns0(); opt != nil { - receivedUDPSize = opt.UDPSize() + receivedUDPSize.Store(uint32(opt.UDPSize())) } m := new(dns.Msg) m.SetReply(r) @@ -774,7 +776,7 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) { require.NotNil(t, rm) expectedMax := uint16(currentMTU - ipUDPHeaderSize) - assert.Equal(t, expectedMax, receivedUDPSize, + assert.Equal(t, expectedMax, uint16(receivedUDPSize.Load()), "upstream should see capped EDNS0, not the client's 4096") }