mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-03 15:46:38 +00:00
Harden race fan-out and fix lint
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -44,6 +45,11 @@ const (
|
||||
warningDelayBonusCap = 30 * time.Second
|
||||
)
|
||||
|
||||
// errNoUsableNameservers signals that a merged-domain group has no usable
|
||||
// upstream servers. Callers should skip the group without treating it as a
|
||||
// build failure.
|
||||
var errNoUsableNameservers = errors.New("no usable nameservers")
|
||||
|
||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||
type ReadyListener interface {
|
||||
OnReady()
|
||||
@@ -315,6 +321,19 @@ func (s *DefaultServer) SetRouteSources(selected, active func() route.HAMap) {
|
||||
defer s.mux.Unlock()
|
||||
s.selectedRoutes = selected
|
||||
s.activeRoutes = active
|
||||
|
||||
// Permanent / iOS constructors build the root handler before the
|
||||
// engine wires route sources, so its selectedRoutes callback would
|
||||
// otherwise remain nil and overlay upstreams would be classified
|
||||
// as public. Propagate the new accessors to existing handlers.
|
||||
type routeSettable interface {
|
||||
setSelectedRoutes(func() route.HAMap)
|
||||
}
|
||||
for _, entry := range s.dnsMuxMap {
|
||||
if h, ok := entry.handler.(routeSettable); ok {
|
||||
h.setSelectedRoutes(selected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterHandler registers a handler for the given domains with the given priority.
|
||||
@@ -778,7 +797,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
||||
s.wgInterface,
|
||||
s.statusRecorder,
|
||||
s.hostsDNSHolder,
|
||||
domain.Domain(nbdns.RootZone),
|
||||
nbdns.RootZone,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create upstream resolver for original nameservers: %v", err)
|
||||
@@ -861,11 +880,13 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
|
||||
update, err := s.buildMergedDomainHandler(domainGroup, priority)
|
||||
if err != nil {
|
||||
if errors.Is(err, errNoUsableNameservers) {
|
||||
log.Errorf("no usable nameservers for domain=%s", domainGroup.domain)
|
||||
continue
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if update != nil {
|
||||
muxUpdates = append(muxUpdates, *update)
|
||||
}
|
||||
muxUpdates = append(muxUpdates, *update)
|
||||
}
|
||||
|
||||
return muxUpdates, nil
|
||||
@@ -897,8 +918,7 @@ func (s *DefaultServer) buildMergedDomainHandler(domainGroup nsGroupsByDomain, p
|
||||
|
||||
if len(handler.upstreamServers) == 0 {
|
||||
handler.Stop()
|
||||
log.Errorf("no usable nameservers for domain=%s", domainGroup.domain)
|
||||
return nil, nil
|
||||
return nil, errNoUsableNameservers
|
||||
}
|
||||
|
||||
log.Debugf("creating merged handler for domain=%s with %d group(s) priority=%d", domainGroup.domain, len(handler.upstreamServers), priority)
|
||||
@@ -927,6 +947,27 @@ func (s *DefaultServer) filterNameServers(nameServers []nbdns.NameServer) []neti
|
||||
return out
|
||||
}
|
||||
|
||||
// usableNameServers returns the subset of nameServers the handler would
|
||||
// actually query. Matches filterNameServers without the warning logs, so
|
||||
// it's safe to call on every health-projection tick.
|
||||
func (s *DefaultServer) usableNameServers(nameServers []nbdns.NameServer) []netip.AddrPort {
|
||||
var runtimeIP netip.Addr
|
||||
if s.service != nil {
|
||||
runtimeIP = s.service.RuntimeIP()
|
||||
}
|
||||
var out []netip.AddrPort
|
||||
for _, ns := range nameServers {
|
||||
if ns.NSType != nbdns.UDPNameServerType {
|
||||
continue
|
||||
}
|
||||
if runtimeIP.IsValid() && ns.IP == runtimeIP {
|
||||
continue
|
||||
}
|
||||
out = append(out, ns.AddrPort())
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
|
||||
// this will introduce a short period of time when the server is not able to handle DNS requests
|
||||
for _, existing := range s.dnsMuxMap {
|
||||
@@ -1044,11 +1085,14 @@ func (s *DefaultServer) projectNSGroupHealth(snap nsHealthSnapshot) {
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
delay := s.warningDelay(len(snap.selected))
|
||||
delay := s.warningDelay(haMapRouteCount(snap.selected))
|
||||
states := make([]peer.NSGroupState, 0, len(snap.groups))
|
||||
seen := make(map[nsGroupID]struct{}, len(snap.groups))
|
||||
for _, group := range snap.groups {
|
||||
servers := nameServerAddrPorts(group.NameServers)
|
||||
servers := s.usableNameServers(group.NameServers)
|
||||
if len(servers) == 0 {
|
||||
continue
|
||||
}
|
||||
verdict, groupErr := evaluateNSGroupHealth(snap.merged, servers, now)
|
||||
id := generateGroupKey(group)
|
||||
seen[id] = struct{}{}
|
||||
@@ -1069,7 +1113,10 @@ func (s *DefaultServer) projectNSGroupHealth(snap nsHealthSnapshot) {
|
||||
enabled = s.projectUnhealthy(p, servers, immediate, now, delay)
|
||||
case nsVerdictUndecided:
|
||||
// Stay Available until evidence says otherwise, unless a
|
||||
// warning is already active for this group.
|
||||
// warning is already active for this group. Also clear any
|
||||
// prior Unhealthy streak so a later Unhealthy verdict starts
|
||||
// a fresh grace window rather than inheriting a stale one.
|
||||
p.unhealthySince = time.Time{}
|
||||
enabled = !p.warningActive
|
||||
groupErr = nil
|
||||
}
|
||||
@@ -1142,6 +1189,9 @@ func (s *DefaultServer) projectUnhealthy(p *nsGroupProj, servers []netip.AddrPor
|
||||
// count. Scales gently: +1s per 100 routes, capped by
|
||||
// warningDelayBonusCap. Parallel handshakes mean handshake time grows
|
||||
// much slower than route count, so linear scaling would overcorrect.
|
||||
//
|
||||
// TODO: revisit the scaling curve with real-world data — the current
|
||||
// values are a reasonable starting point, not a measured fit.
|
||||
func (s *DefaultServer) warningDelay(routeCount int) time.Duration {
|
||||
bonus := time.Duration(routeCount/100) * time.Second
|
||||
if bonus > warningDelayBonusCap {
|
||||
@@ -1164,11 +1214,16 @@ func (s *DefaultServer) groupHasImmediateUpstream(servers []netip.AddrPort, snap
|
||||
for _, srv := range servers {
|
||||
addr := srv.Addr().Unmap()
|
||||
overlay := overlayV4.IsValid() && overlayV4.Contains(addr)
|
||||
routed := haMapContains(snap.selected, addr)
|
||||
selMatched, selDynamic := haMapContains(snap.selected, addr)
|
||||
// Treat an unknown (dynamic selected route) as possibly routed:
|
||||
// the upstream might reach through a dynamic route whose Network
|
||||
// hasn't resolved yet, and classifying as public would bypass
|
||||
// the startup grace window.
|
||||
routed := selMatched || selDynamic
|
||||
if !overlay && !routed {
|
||||
return true
|
||||
}
|
||||
if haMapContains(snap.active, addr) {
|
||||
if actMatched, _ := haMapContains(snap.active, addr); actMatched {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -1290,15 +1345,6 @@ func classifyUpstreamHealth(h UpstreamHealth, now time.Time) upstreamClassificat
|
||||
return upstreamStale
|
||||
}
|
||||
|
||||
// nameServerAddrPorts flattens a NameServer list to AddrPorts.
|
||||
func nameServerAddrPorts(ns []nbdns.NameServer) []netip.AddrPort {
|
||||
out := make([]netip.AddrPort, 0, len(ns))
|
||||
for _, n := range ns {
|
||||
out = append(out, n.AddrPort())
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func joinAddrPorts(servers []netip.AddrPort) string {
|
||||
parts := make([]string, 0, len(servers))
|
||||
for _, s := range servers {
|
||||
@@ -1307,12 +1353,18 @@ func joinAddrPorts(servers []netip.AddrPort) string {
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
// generateGroupKey returns a stable identity for an NS group so health
|
||||
// state (everHealthy / warningActive) survives reorderings in the
|
||||
// configured nameserver or domain lists.
|
||||
func generateGroupKey(nsGroup *nbdns.NameServerGroup) nsGroupID {
|
||||
var servers []string
|
||||
servers := make([]string, 0, len(nsGroup.NameServers))
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
servers = append(servers, ns.AddrPort().String())
|
||||
}
|
||||
return nsGroupID(fmt.Sprintf("%v_%v", servers, nsGroup.Domains))
|
||||
slices.Sort(servers)
|
||||
domains := slices.Clone(nsGroup.Domains)
|
||||
slices.Sort(domains)
|
||||
return nsGroupID(fmt.Sprintf("%v_%v", servers, domains))
|
||||
}
|
||||
|
||||
// groupNSGroupsByDomain groups nameserver groups by their match domains
|
||||
|
||||
@@ -1001,7 +1001,6 @@ type mockHandler struct {
|
||||
|
||||
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||
func (m *mockHandler) Stop() {}
|
||||
func (m *mockHandler) ProbeAvailability(context.Context) {}
|
||||
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
|
||||
|
||||
type mockService struct{}
|
||||
|
||||
@@ -163,8 +163,8 @@ func dnsProtocolFromContext(ctx context.Context) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
|
||||
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
|
||||
// contextWithUpstreamProtocolResult stores a mutable result holder in the context.
|
||||
func contextWithUpstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
|
||||
r := &upstreamProtocolResult{}
|
||||
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
|
||||
}
|
||||
@@ -196,16 +196,20 @@ func (u *upstreamResolverBase) String() string {
|
||||
return fmt.Sprintf("Upstream %s", u.flatUpstreams())
|
||||
}
|
||||
|
||||
// ID returns the unique handler ID
|
||||
// ID returns the unique handler ID. Race groupings and within-race
|
||||
// ordering are both part of the identity: [[A,B]] and [[A],[B]] query
|
||||
// the same servers but with different semantics (serial fallback vs
|
||||
// parallel race), so their handlers must not collide.
|
||||
func (u *upstreamResolverBase) ID() types.HandlerID {
|
||||
servers := u.flatUpstreams()
|
||||
slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) })
|
||||
|
||||
hash := sha256.New()
|
||||
hash.Write([]byte(u.domain.PunycodeString() + ":"))
|
||||
for _, s := range servers {
|
||||
hash.Write([]byte(s.String()))
|
||||
hash.Write([]byte("|"))
|
||||
for _, race := range u.upstreamServers {
|
||||
hash.Write([]byte("["))
|
||||
for _, s := range race {
|
||||
hash.Write([]byte(s.String()))
|
||||
hash.Write([]byte("|"))
|
||||
}
|
||||
hash.Write([]byte("]"))
|
||||
}
|
||||
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
||||
}
|
||||
@@ -228,13 +232,11 @@ func (u *upstreamResolverBase) flatUpstreams() []netip.AddrPort {
|
||||
return out
|
||||
}
|
||||
|
||||
// isRouted reports whether ip falls inside any client route the admin
|
||||
// has selected.
|
||||
func (u *upstreamResolverBase) isRouted(ip netip.Addr) bool {
|
||||
if u.selectedRoutes == nil {
|
||||
return false
|
||||
}
|
||||
return haMapContains(u.selectedRoutes(), ip)
|
||||
// setSelectedRoutes swaps the accessor used to classify overlay-routed
|
||||
// upstreams. Called when route sources are wired after the handler was
|
||||
// built (permanent / iOS constructors).
|
||||
func (u *upstreamResolverBase) setSelectedRoutes(selected func() route.HAMap) {
|
||||
u.selectedRoutes = selected
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) addRace(servers []netip.AddrPort) {
|
||||
@@ -313,6 +315,8 @@ func (u *upstreamResolverBase) raceAll(ctx context.Context, w dns.ResponseWriter
|
||||
// after the coordinator has returned.
|
||||
results := make(chan raceResult, len(groups))
|
||||
for _, g := range groups {
|
||||
// tryRace clones the request per attempt, so workers never share
|
||||
// a *dns.Msg and concurrent EDNS0 mutations can't race.
|
||||
go func(g upstreamRace) {
|
||||
results <- u.tryRace(raceCtx, r, g)
|
||||
}(g)
|
||||
@@ -337,7 +341,14 @@ func (u *upstreamResolverBase) raceAll(ctx context.Context, w dns.ResponseWriter
|
||||
func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group upstreamRace) raceResult {
|
||||
timeout := u.upstreamTimeout
|
||||
if len(group) > 1 {
|
||||
// Cap the whole walk at raceMaxTotalTimeout: per-upstream timeouts
|
||||
// still honor raceMinPerUpstreamTimeout as a floor for correctness
|
||||
// on slow links, but the outer context ensures the combined walk
|
||||
// cannot exceed the cap regardless of group size.
|
||||
timeout = max(raceMaxTotalTimeout/time.Duration(len(group)), raceMinPerUpstreamTimeout)
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, raceMaxTotalTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
var failures []upstreamFailure
|
||||
@@ -345,7 +356,11 @@ func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group up
|
||||
if ctx.Err() != nil {
|
||||
return raceResult{failures: failures}
|
||||
}
|
||||
msg, proto, failure := u.queryUpstream(ctx, r, upstream, timeout)
|
||||
// Clone the request per attempt: the exchange path mutates EDNS0
|
||||
// options in-place, so reusing the same *dns.Msg across sequential
|
||||
// upstreams would carry those mutations (e.g. a reduced UDP size)
|
||||
// into the next attempt.
|
||||
msg, proto, failure := u.queryUpstream(ctx, r.Copy(), upstream, timeout)
|
||||
if failure != nil {
|
||||
failures = append(failures, *failure)
|
||||
continue
|
||||
@@ -358,12 +373,19 @@ func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group up
|
||||
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration) (*dns.Msg, string, *upstreamFailure) {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, timeout)
|
||||
defer cancel()
|
||||
ctx, upstreamProto := contextWithupstreamProtocolResult(ctx)
|
||||
ctx, upstreamProto := contextWithUpstreamProtocolResult(ctx)
|
||||
|
||||
startTime := time.Now()
|
||||
rm, _, err := u.upstreamClient.exchange(ctx, upstream.String(), r)
|
||||
|
||||
if err != nil {
|
||||
// A parent cancellation (e.g., another race won and the coordinator
|
||||
// cancelled the losers) is not an upstream failure. Check both the
|
||||
// error chain and the parent context: a transport may surface the
|
||||
// cancellation as a read/deadline error rather than context.Canceled.
|
||||
if errors.Is(err, context.Canceled) || errors.Is(parentCtx.Err(), context.Canceled) {
|
||||
return nil, "", &upstreamFailure{upstream: upstream, reason: "canceled"}
|
||||
}
|
||||
failure := u.handleUpstreamError(err, upstream, startTime)
|
||||
u.markUpstreamFail(upstream, failure.reason)
|
||||
return nil, "", failure
|
||||
@@ -522,13 +544,10 @@ func clientUDPMaxSize(r *dns.Msg) int {
|
||||
// ExchangeWithFallback exchanges a DNS message with the upstream server.
|
||||
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
|
||||
// If the inbound request came over TCP (via context), it skips the UDP attempt.
|
||||
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
|
||||
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
|
||||
// If the request came in over TCP, go straight to TCP upstream.
|
||||
if dnsProtocolFromContext(ctx) == protoTCP {
|
||||
tcpClient := *client
|
||||
tcpClient.Net = protoTCP
|
||||
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
|
||||
rm, t, err := toTCPClient(client).ExchangeContext(ctx, r, upstream)
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
@@ -548,18 +567,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
opt.SetUDPSize(maxUDPPayload)
|
||||
}
|
||||
|
||||
var (
|
||||
rm *dns.Msg
|
||||
t time.Duration
|
||||
err error
|
||||
)
|
||||
|
||||
if ctx == nil {
|
||||
rm, t, err = client.Exchange(r, upstream)
|
||||
} else {
|
||||
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
|
||||
rm, t, err := client.ExchangeContext(ctx, r, upstream)
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with udp: %w", err)
|
||||
}
|
||||
@@ -573,15 +581,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
// data than the client's buffer, we could truncate locally and skip
|
||||
// the TCP retry.
|
||||
|
||||
tcpClient := *client
|
||||
tcpClient.Net = protoTCP
|
||||
|
||||
if ctx == nil {
|
||||
rm, t, err = tcpClient.Exchange(r, upstream)
|
||||
} else {
|
||||
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
|
||||
rm, t, err = toTCPClient(client).ExchangeContext(ctx, r, upstream)
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
@@ -595,6 +595,25 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
// toTCPClient returns a copy of c configured for TCP. If c's Dialer has a
|
||||
// *net.UDPAddr bound as LocalAddr (iOS does this to keep the source IP on
|
||||
// the tunnel interface), it is converted to the equivalent *net.TCPAddr
|
||||
// so net.Dialer doesn't reject the TCP dial with "mismatched local
|
||||
// address type".
|
||||
func toTCPClient(c *dns.Client) *dns.Client {
|
||||
tcp := *c
|
||||
tcp.Net = protoTCP
|
||||
if tcp.Dialer == nil {
|
||||
return &tcp
|
||||
}
|
||||
d := *tcp.Dialer
|
||||
if ua, ok := d.LocalAddr.(*net.UDPAddr); ok {
|
||||
d.LocalAddr = &net.TCPAddr{IP: ua.IP, Port: ua.Port, Zone: ua.Zone}
|
||||
}
|
||||
tcp.Dialer = &d
|
||||
return &tcp
|
||||
}
|
||||
|
||||
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
|
||||
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
|
||||
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
|
||||
@@ -736,22 +755,36 @@ func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
|
||||
return bestMatch
|
||||
}
|
||||
|
||||
// haMapContains reports whether any route in the map contains ip.
|
||||
//
|
||||
// Gap: dynamic (domain-based) routes carry a placeholder Network that
|
||||
// never matches a real address, so an upstream reached via a dynamic
|
||||
// route is classified as "not routed" here. The DNS health path then
|
||||
// emits failure events immediately for such upstreams instead of
|
||||
// applying the startup grace window. Rare (DNS servers are usually
|
||||
// designated by IP, not by domain) but worth revisiting if DoT/DoH-style
|
||||
// upstreams or /etc/hosts-style domain routing to DNS become supported.
|
||||
func haMapContains(hm route.HAMap, ip netip.Addr) bool {
|
||||
// haMapRouteCount returns the total number of routes across all HA
|
||||
// groups in the map. route.HAMap is keyed by HAUniqueID with slices of
|
||||
// routes per key, so len(hm) is the number of HA groups, not routes.
|
||||
func haMapRouteCount(hm route.HAMap) int {
|
||||
total := 0
|
||||
for _, routes := range hm {
|
||||
total += len(routes)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// haMapContains checks whether ip is covered by any concrete prefix in
|
||||
// the HA map. haveDynamic is reported separately: dynamic (domain-based)
|
||||
// routes carry a placeholder Network that can't be prefix-checked, so we
|
||||
// can't know at this point whether ip is reached through one. Callers
|
||||
// decide how to interpret the unknown: health projection treats it as
|
||||
// "possibly routed" to avoid emitting false-positive warnings during
|
||||
// startup, while iOS dial selection requires a concrete match before
|
||||
// binding to the tunnel.
|
||||
func haMapContains(hm route.HAMap, ip netip.Addr) (matched, haveDynamic bool) {
|
||||
for _, routes := range hm {
|
||||
for _, r := range routes {
|
||||
if r.IsDynamic() {
|
||||
haveDynamic = true
|
||||
continue
|
||||
}
|
||||
if r.Network.Contains(ip) {
|
||||
return true
|
||||
return true, haveDynamic
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
return false, haveDynamic
|
||||
}
|
||||
|
||||
@@ -66,7 +66,14 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
} else {
|
||||
upstreamIP = upstreamIP.Unmap()
|
||||
}
|
||||
needsPrivate := u.lNet.Contains(upstreamIP) || u.isRouted(upstreamIP)
|
||||
var routed bool
|
||||
if u.selectedRoutes != nil {
|
||||
// Only a concrete prefix match binds to the tunnel: dialing
|
||||
// through a private client for an upstream we can't prove is
|
||||
// routed would break public resolvers.
|
||||
routed, _ = haMapContains(u.selectedRoutes(), upstreamIP)
|
||||
}
|
||||
needsPrivate := u.lNet.Contains(upstreamIP) || routed
|
||||
if needsPrivate {
|
||||
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
|
||||
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
|
||||
@@ -75,8 +82,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
}
|
||||
}
|
||||
|
||||
// Cannot use client.ExchangeContext because it overwrites our Dialer
|
||||
return ExchangeWithFallback(nil, client, r, upstream)
|
||||
return ExchangeWithFallback(ctx, client, r, upstream)
|
||||
}
|
||||
|
||||
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -132,20 +133,10 @@ func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
type mockUpstreamResolver struct {
|
||||
r *dns.Msg
|
||||
rtt time.Duration
|
||||
err error
|
||||
}
|
||||
|
||||
// exchange mock implementation of exchange from upstreamResolver
|
||||
func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||
return c.r, c.rtt, c.err
|
||||
}
|
||||
|
||||
type mockUpstreamResponse struct {
|
||||
msg *dns.Msg
|
||||
err error
|
||||
msg *dns.Msg
|
||||
err error
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
type mockUpstreamResolverPerServer struct {
|
||||
@@ -153,11 +144,19 @@ type mockUpstreamResolverPerServer struct {
|
||||
rtt time.Duration
|
||||
}
|
||||
|
||||
func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||
if r, ok := c.responses[upstream]; ok {
|
||||
return r.msg, c.rtt, r.err
|
||||
func (c mockUpstreamResolverPerServer) exchange(ctx context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||
r, ok := c.responses[upstream]
|
||||
if !ok {
|
||||
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
|
||||
}
|
||||
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
|
||||
if r.delay > 0 {
|
||||
select {
|
||||
case <-time.After(r.delay):
|
||||
case <-ctx.Done():
|
||||
return nil, c.rtt, ctx.Err()
|
||||
}
|
||||
}
|
||||
return r.msg, c.rtt, r.err
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_Failover(t *testing.T) {
|
||||
@@ -400,7 +399,10 @@ func TestUpstreamResolver_RaceAcrossGroups(t *testing.T) {
|
||||
|
||||
mockClient := &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
broken.String(): {err: timeoutErr},
|
||||
// Force the broken upstream to only unblock via timeout /
|
||||
// cancellation so the assertion below can't pass if races
|
||||
// were run serially.
|
||||
broken.String(): {err: timeoutErr, delay: 500 * time.Millisecond},
|
||||
working.String(): {msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
@@ -412,7 +414,7 @@ func TestUpstreamResolver_RaceAcrossGroups(t *testing.T) {
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: mockClient,
|
||||
upstreamTimeout: 100 * time.Millisecond,
|
||||
upstreamTimeout: 250 * time.Millisecond,
|
||||
}
|
||||
resolver.addRace([]netip.AddrPort{broken})
|
||||
resolver.addRace([]netip.AddrPort{working})
|
||||
@@ -740,10 +742,10 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
|
||||
// Verify that a client EDNS0 larger than our MTU-derived limit gets
|
||||
// capped in the outgoing request so the upstream doesn't send a
|
||||
// response larger than our read buffer.
|
||||
var receivedUDPSize uint16
|
||||
var receivedUDPSize atomic.Uint32
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if opt := r.IsEdns0(); opt != nil {
|
||||
receivedUDPSize = opt.UDPSize()
|
||||
receivedUDPSize.Store(uint32(opt.UDPSize()))
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
@@ -774,7 +776,7 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
|
||||
require.NotNil(t, rm)
|
||||
|
||||
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
|
||||
assert.Equal(t, expectedMax, receivedUDPSize,
|
||||
assert.Equal(t, expectedMax, uint16(receivedUDPSize.Load()),
|
||||
"upstream should see capped EDNS0, not the client's 4096")
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user