Harden race fan-out and fix lint

This commit is contained in:
Viktor Liu
2026-04-23 15:09:11 +02:00
parent c102592735
commit d0f9d80c3a
5 changed files with 195 additions and 103 deletions

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net/netip"
"net/url"
"slices"
"strings"
"sync"
"time"
@@ -44,6 +45,11 @@ const (
warningDelayBonusCap = 30 * time.Second
)
// errNoUsableNameservers signals that a merged-domain group has no usable
// upstream servers. Callers should skip the group without treating it as a
// build failure.
var errNoUsableNameservers = errors.New("no usable nameservers")
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
type ReadyListener interface {
OnReady()
@@ -315,6 +321,19 @@ func (s *DefaultServer) SetRouteSources(selected, active func() route.HAMap) {
defer s.mux.Unlock()
s.selectedRoutes = selected
s.activeRoutes = active
// Permanent / iOS constructors build the root handler before the
// engine wires route sources, so its selectedRoutes callback would
// otherwise remain nil and overlay upstreams would be classified
// as public. Propagate the new accessors to existing handlers.
type routeSettable interface {
setSelectedRoutes(func() route.HAMap)
}
for _, entry := range s.dnsMuxMap {
if h, ok := entry.handler.(routeSettable); ok {
h.setSelectedRoutes(selected)
}
}
}
// RegisterHandler registers a handler for the given domains with the given priority.
@@ -778,7 +797,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
s.wgInterface,
s.statusRecorder,
s.hostsDNSHolder,
domain.Domain(nbdns.RootZone),
nbdns.RootZone,
)
if err != nil {
log.Errorf("failed to create upstream resolver for original nameservers: %v", err)
@@ -861,11 +880,13 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
update, err := s.buildMergedDomainHandler(domainGroup, priority)
if err != nil {
if errors.Is(err, errNoUsableNameservers) {
log.Errorf("no usable nameservers for domain=%s", domainGroup.domain)
continue
}
return nil, err
}
if update != nil {
muxUpdates = append(muxUpdates, *update)
}
muxUpdates = append(muxUpdates, *update)
}
return muxUpdates, nil
@@ -897,8 +918,7 @@ func (s *DefaultServer) buildMergedDomainHandler(domainGroup nsGroupsByDomain, p
if len(handler.upstreamServers) == 0 {
handler.Stop()
log.Errorf("no usable nameservers for domain=%s", domainGroup.domain)
return nil, nil
return nil, errNoUsableNameservers
}
log.Debugf("creating merged handler for domain=%s with %d group(s) priority=%d", domainGroup.domain, len(handler.upstreamServers), priority)
@@ -927,6 +947,27 @@ func (s *DefaultServer) filterNameServers(nameServers []nbdns.NameServer) []neti
return out
}
// usableNameServers returns the subset of nameServers the handler would
// actually query. Matches filterNameServers without the warning logs, so
// it's safe to call on every health-projection tick.
func (s *DefaultServer) usableNameServers(nameServers []nbdns.NameServer) []netip.AddrPort {
var runtimeIP netip.Addr
if s.service != nil {
runtimeIP = s.service.RuntimeIP()
}
var out []netip.AddrPort
for _, ns := range nameServers {
if ns.NSType != nbdns.UDPNameServerType {
continue
}
if runtimeIP.IsValid() && ns.IP == runtimeIP {
continue
}
out = append(out, ns.AddrPort())
}
return out
}
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
// this will introduce a short period of time when the server is not able to handle DNS requests
for _, existing := range s.dnsMuxMap {
@@ -1044,11 +1085,14 @@ func (s *DefaultServer) projectNSGroupHealth(snap nsHealthSnapshot) {
}
now := time.Now()
delay := s.warningDelay(len(snap.selected))
delay := s.warningDelay(haMapRouteCount(snap.selected))
states := make([]peer.NSGroupState, 0, len(snap.groups))
seen := make(map[nsGroupID]struct{}, len(snap.groups))
for _, group := range snap.groups {
servers := nameServerAddrPorts(group.NameServers)
servers := s.usableNameServers(group.NameServers)
if len(servers) == 0 {
continue
}
verdict, groupErr := evaluateNSGroupHealth(snap.merged, servers, now)
id := generateGroupKey(group)
seen[id] = struct{}{}
@@ -1069,7 +1113,10 @@ func (s *DefaultServer) projectNSGroupHealth(snap nsHealthSnapshot) {
enabled = s.projectUnhealthy(p, servers, immediate, now, delay)
case nsVerdictUndecided:
// Stay Available until evidence says otherwise, unless a
// warning is already active for this group.
// warning is already active for this group. Also clear any
// prior Unhealthy streak so a later Unhealthy verdict starts
// a fresh grace window rather than inheriting a stale one.
p.unhealthySince = time.Time{}
enabled = !p.warningActive
groupErr = nil
}
@@ -1142,6 +1189,9 @@ func (s *DefaultServer) projectUnhealthy(p *nsGroupProj, servers []netip.AddrPor
// count. Scales gently: +1s per 100 routes, capped by
// warningDelayBonusCap. Parallel handshakes mean handshake time grows
// much slower than route count, so linear scaling would overcorrect.
//
// TODO: revisit the scaling curve with real-world data — the current
// values are a reasonable starting point, not a measured fit.
func (s *DefaultServer) warningDelay(routeCount int) time.Duration {
bonus := time.Duration(routeCount/100) * time.Second
if bonus > warningDelayBonusCap {
@@ -1164,11 +1214,16 @@ func (s *DefaultServer) groupHasImmediateUpstream(servers []netip.AddrPort, snap
for _, srv := range servers {
addr := srv.Addr().Unmap()
overlay := overlayV4.IsValid() && overlayV4.Contains(addr)
routed := haMapContains(snap.selected, addr)
selMatched, selDynamic := haMapContains(snap.selected, addr)
// Treat an unknown (dynamic selected route) as possibly routed:
// the upstream might reach through a dynamic route whose Network
// hasn't resolved yet, and classifying as public would bypass
// the startup grace window.
routed := selMatched || selDynamic
if !overlay && !routed {
return true
}
if haMapContains(snap.active, addr) {
if actMatched, _ := haMapContains(snap.active, addr); actMatched {
return true
}
}
@@ -1290,15 +1345,6 @@ func classifyUpstreamHealth(h UpstreamHealth, now time.Time) upstreamClassificat
return upstreamStale
}
// nameServerAddrPorts flattens a NameServer list to AddrPorts.
func nameServerAddrPorts(ns []nbdns.NameServer) []netip.AddrPort {
out := make([]netip.AddrPort, 0, len(ns))
for _, n := range ns {
out = append(out, n.AddrPort())
}
return out
}
func joinAddrPorts(servers []netip.AddrPort) string {
parts := make([]string, 0, len(servers))
for _, s := range servers {
@@ -1307,12 +1353,18 @@ func joinAddrPorts(servers []netip.AddrPort) string {
return strings.Join(parts, ", ")
}
// generateGroupKey returns a stable identity for an NS group so health
// state (everHealthy / warningActive) survives reorderings in the
// configured nameserver or domain lists.
func generateGroupKey(nsGroup *nbdns.NameServerGroup) nsGroupID {
var servers []string
servers := make([]string, 0, len(nsGroup.NameServers))
for _, ns := range nsGroup.NameServers {
servers = append(servers, ns.AddrPort().String())
}
return nsGroupID(fmt.Sprintf("%v_%v", servers, nsGroup.Domains))
slices.Sort(servers)
domains := slices.Clone(nsGroup.Domains)
slices.Sort(domains)
return nsGroupID(fmt.Sprintf("%v_%v", servers, domains))
}
// groupNSGroupsByDomain groups nameserver groups by their match domains

View File

@@ -1001,7 +1001,6 @@ type mockHandler struct {
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
func (m *mockHandler) Stop() {}
func (m *mockHandler) ProbeAvailability(context.Context) {}
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
type mockService struct{}

View File

@@ -163,8 +163,8 @@ func dnsProtocolFromContext(ctx context.Context) string {
return ""
}
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
// contextWithUpstreamProtocolResult stores a mutable result holder in the context.
func contextWithUpstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
r := &upstreamProtocolResult{}
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
}
@@ -196,16 +196,20 @@ func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("Upstream %s", u.flatUpstreams())
}
// ID returns the unique handler ID
// ID returns the unique handler ID. Race groupings and within-race
// ordering are both part of the identity: [[A,B]] and [[A],[B]] query
// the same servers but with different semantics (serial fallback vs
// parallel race), so their handlers must not collide.
func (u *upstreamResolverBase) ID() types.HandlerID {
servers := u.flatUpstreams()
slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) })
hash := sha256.New()
hash.Write([]byte(u.domain.PunycodeString() + ":"))
for _, s := range servers {
hash.Write([]byte(s.String()))
hash.Write([]byte("|"))
for _, race := range u.upstreamServers {
hash.Write([]byte("["))
for _, s := range race {
hash.Write([]byte(s.String()))
hash.Write([]byte("|"))
}
hash.Write([]byte("]"))
}
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
}
@@ -228,13 +232,11 @@ func (u *upstreamResolverBase) flatUpstreams() []netip.AddrPort {
return out
}
// isRouted reports whether ip falls inside any client route the admin
// has selected.
func (u *upstreamResolverBase) isRouted(ip netip.Addr) bool {
if u.selectedRoutes == nil {
return false
}
return haMapContains(u.selectedRoutes(), ip)
// setSelectedRoutes swaps the accessor used to classify overlay-routed
// upstreams. Called when route sources are wired after the handler was
// built (permanent / iOS constructors).
func (u *upstreamResolverBase) setSelectedRoutes(selected func() route.HAMap) {
u.selectedRoutes = selected
}
func (u *upstreamResolverBase) addRace(servers []netip.AddrPort) {
@@ -313,6 +315,8 @@ func (u *upstreamResolverBase) raceAll(ctx context.Context, w dns.ResponseWriter
// after the coordinator has returned.
results := make(chan raceResult, len(groups))
for _, g := range groups {
// tryRace clones the request per attempt, so workers never share
// a *dns.Msg and concurrent EDNS0 mutations can't race.
go func(g upstreamRace) {
results <- u.tryRace(raceCtx, r, g)
}(g)
@@ -337,7 +341,14 @@ func (u *upstreamResolverBase) raceAll(ctx context.Context, w dns.ResponseWriter
func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group upstreamRace) raceResult {
timeout := u.upstreamTimeout
if len(group) > 1 {
// Cap the whole walk at raceMaxTotalTimeout: per-upstream timeouts
// still honor raceMinPerUpstreamTimeout as a floor for correctness
// on slow links, but the outer context ensures the combined walk
// cannot exceed the cap regardless of group size.
timeout = max(raceMaxTotalTimeout/time.Duration(len(group)), raceMinPerUpstreamTimeout)
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, raceMaxTotalTimeout)
defer cancel()
}
var failures []upstreamFailure
@@ -345,7 +356,11 @@ func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group up
if ctx.Err() != nil {
return raceResult{failures: failures}
}
msg, proto, failure := u.queryUpstream(ctx, r, upstream, timeout)
// Clone the request per attempt: the exchange path mutates EDNS0
// options in-place, so reusing the same *dns.Msg across sequential
// upstreams would carry those mutations (e.g. a reduced UDP size)
// into the next attempt.
msg, proto, failure := u.queryUpstream(ctx, r.Copy(), upstream, timeout)
if failure != nil {
failures = append(failures, *failure)
continue
@@ -358,12 +373,19 @@ func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group up
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration) (*dns.Msg, string, *upstreamFailure) {
ctx, cancel := context.WithTimeout(parentCtx, timeout)
defer cancel()
ctx, upstreamProto := contextWithupstreamProtocolResult(ctx)
ctx, upstreamProto := contextWithUpstreamProtocolResult(ctx)
startTime := time.Now()
rm, _, err := u.upstreamClient.exchange(ctx, upstream.String(), r)
if err != nil {
// A parent cancellation (e.g., another race won and the coordinator
// cancelled the losers) is not an upstream failure. Check both the
// error chain and the parent context: a transport may surface the
// cancellation as a read/deadline error rather than context.Canceled.
if errors.Is(err, context.Canceled) || errors.Is(parentCtx.Err(), context.Canceled) {
return nil, "", &upstreamFailure{upstream: upstream, reason: "canceled"}
}
failure := u.handleUpstreamError(err, upstream, startTime)
u.markUpstreamFail(upstream, failure.reason)
return nil, "", failure
@@ -522,13 +544,10 @@ func clientUDPMaxSize(r *dns.Msg) int {
// ExchangeWithFallback exchanges a DNS message with the upstream server.
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
// If the inbound request came over TCP (via context), it skips the UDP attempt.
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
// If the request came in over TCP, go straight to TCP upstream.
if dnsProtocolFromContext(ctx) == protoTCP {
tcpClient := *client
tcpClient.Net = protoTCP
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
rm, t, err := toTCPClient(client).ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
@@ -548,18 +567,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
opt.SetUDPSize(maxUDPPayload)
}
var (
rm *dns.Msg
t time.Duration
err error
)
if ctx == nil {
rm, t, err = client.Exchange(r, upstream)
} else {
rm, t, err = client.ExchangeContext(ctx, r, upstream)
}
rm, t, err := client.ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with udp: %w", err)
}
@@ -573,15 +581,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
// data than the client's buffer, we could truncate locally and skip
// the TCP retry.
tcpClient := *client
tcpClient.Net = protoTCP
if ctx == nil {
rm, t, err = tcpClient.Exchange(r, upstream)
} else {
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
}
rm, t, err = toTCPClient(client).ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
@@ -595,6 +595,25 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil
}
// toTCPClient returns a copy of c configured for TCP. If c's Dialer has a
// *net.UDPAddr bound as LocalAddr (iOS does this to keep the source IP on
// the tunnel interface), it is converted to the equivalent *net.TCPAddr
// so net.Dialer doesn't reject the TCP dial with "mismatched local
// address type".
func toTCPClient(c *dns.Client) *dns.Client {
tcp := *c
tcp.Net = protoTCP
if tcp.Dialer == nil {
return &tcp
}
d := *tcp.Dialer
if ua, ok := d.LocalAddr.(*net.UDPAddr); ok {
d.LocalAddr = &net.TCPAddr{IP: ua.IP, Port: ua.Port, Zone: ua.Zone}
}
tcp.Dialer = &d
return &tcp
}
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
@@ -736,22 +755,36 @@ func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
return bestMatch
}
// haMapContains reports whether any route in the map contains ip.
//
// Gap: dynamic (domain-based) routes carry a placeholder Network that
// never matches a real address, so an upstream reached via a dynamic
// route is classified as "not routed" here. The DNS health path then
// emits failure events immediately for such upstreams instead of
// applying the startup grace window. Rare (DNS servers are usually
// designated by IP, not by domain) but worth revisiting if DoT/DoH-style
// upstreams or /etc/hosts-style domain routing to DNS become supported.
func haMapContains(hm route.HAMap, ip netip.Addr) bool {
// haMapRouteCount returns the total number of routes across all HA
// groups in the map. route.HAMap is keyed by HAUniqueID with slices of
// routes per key, so len(hm) is the number of HA groups, not routes.
func haMapRouteCount(hm route.HAMap) int {
total := 0
for _, routes := range hm {
total += len(routes)
}
return total
}
// haMapContains checks whether ip is covered by any concrete prefix in
// the HA map. haveDynamic is reported separately: dynamic (domain-based)
// routes carry a placeholder Network that can't be prefix-checked, so we
// can't know at this point whether ip is reached through one. Callers
// decide how to interpret the unknown: health projection treats it as
// "possibly routed" to avoid emitting false-positive warnings during
// startup, while iOS dial selection requires a concrete match before
// binding to the tunnel.
func haMapContains(hm route.HAMap, ip netip.Addr) (matched, haveDynamic bool) {
for _, routes := range hm {
for _, r := range routes {
if r.IsDynamic() {
haveDynamic = true
continue
}
if r.Network.Contains(ip) {
return true
return true, haveDynamic
}
}
}
return false
return false, haveDynamic
}

View File

@@ -66,7 +66,14 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
} else {
upstreamIP = upstreamIP.Unmap()
}
needsPrivate := u.lNet.Contains(upstreamIP) || u.isRouted(upstreamIP)
var routed bool
if u.selectedRoutes != nil {
// Only a concrete prefix match binds to the tunnel: dialing
// through a private client for an upstream we can't prove is
// routed would break public resolvers.
routed, _ = haMapContains(u.selectedRoutes(), upstreamIP)
}
needsPrivate := u.lNet.Contains(upstreamIP) || routed
if needsPrivate {
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
@@ -75,8 +82,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
}
}
// Cannot use client.ExchangeContext because it overwrites our Dialer
return ExchangeWithFallback(nil, client, r, upstream)
return ExchangeWithFallback(ctx, client, r, upstream)
}
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface

View File

@@ -6,6 +6,7 @@ import (
"net"
"net/netip"
"strings"
"sync/atomic"
"testing"
"time"
@@ -132,20 +133,10 @@ func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
return "", nil
}
type mockUpstreamResolver struct {
r *dns.Msg
rtt time.Duration
err error
}
// exchange mock implementation of exchange from upstreamResolver
func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
return c.r, c.rtt, c.err
}
type mockUpstreamResponse struct {
msg *dns.Msg
err error
msg *dns.Msg
err error
delay time.Duration
}
type mockUpstreamResolverPerServer struct {
@@ -153,11 +144,19 @@ type mockUpstreamResolverPerServer struct {
rtt time.Duration
}
func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
if r, ok := c.responses[upstream]; ok {
return r.msg, c.rtt, r.err
func (c mockUpstreamResolverPerServer) exchange(ctx context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
r, ok := c.responses[upstream]
if !ok {
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
}
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
if r.delay > 0 {
select {
case <-time.After(r.delay):
case <-ctx.Done():
return nil, c.rtt, ctx.Err()
}
}
return r.msg, c.rtt, r.err
}
func TestUpstreamResolver_Failover(t *testing.T) {
@@ -400,7 +399,10 @@ func TestUpstreamResolver_RaceAcrossGroups(t *testing.T) {
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
broken.String(): {err: timeoutErr},
// Force the broken upstream to only unblock via timeout /
// cancellation so the assertion below can't pass if races
// were run serially.
broken.String(): {err: timeoutErr, delay: 500 * time.Millisecond},
working.String(): {msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
},
rtt: time.Millisecond,
@@ -412,7 +414,7 @@ func TestUpstreamResolver_RaceAcrossGroups(t *testing.T) {
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: 100 * time.Millisecond,
upstreamTimeout: 250 * time.Millisecond,
}
resolver.addRace([]netip.AddrPort{broken})
resolver.addRace([]netip.AddrPort{working})
@@ -740,10 +742,10 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
// Verify that a client EDNS0 larger than our MTU-derived limit gets
// capped in the outgoing request so the upstream doesn't send a
// response larger than our read buffer.
var receivedUDPSize uint16
var receivedUDPSize atomic.Uint32
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
if opt := r.IsEdns0(); opt != nil {
receivedUDPSize = opt.UDPSize()
receivedUDPSize.Store(uint32(opt.UDPSize()))
}
m := new(dns.Msg)
m.SetReply(r)
@@ -774,7 +776,7 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
require.NotNil(t, rm)
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
assert.Equal(t, expectedMax, receivedUDPSize,
assert.Equal(t, expectedMax, uint16(receivedUDPSize.Load()),
"upstream should see capped EDNS0, not the client's 4096")
}