[client] Don't deactivate upstream resolvers on failure (#4128)

This commit is contained in:
Viktor Liu
2025-08-29 17:40:05 +02:00
committed by GitHub
parent dbefa8bd9f
commit d4c067f0af
19 changed files with 1598 additions and 167 deletions

View File

@@ -33,9 +33,11 @@ func SetCurrentMTU(mtu uint16) {
}
const (
UpstreamTimeout = 15 * time.Second
UpstreamTimeout = 4 * time.Second
// ClientTimeout is the timeout for the dns.Client.
// Set longer than UpstreamTimeout to ensure context timeout takes precedence
ClientTimeout = 5 * time.Second
failsTillDeact = int32(5)
reactivatePeriod = 30 * time.Second
probeTimeout = 2 * time.Second
)
@@ -58,9 +60,7 @@ type upstreamResolverBase struct {
upstreamServers []netip.AddrPort
domain string
disabled bool
failsCount atomic.Int32
successCount atomic.Int32
failsTillDeact int32
mutex sync.Mutex
reactivatePeriod time.Duration
upstreamTimeout time.Duration
@@ -79,14 +79,13 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
domain: domain,
upstreamTimeout: UpstreamTimeout,
reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
statusRecorder: statusRecorder,
}
}
// String returns a string representation of the upstream resolver
func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("upstream %s", u.upstreamServers)
return fmt.Sprintf("Upstream %s", u.upstreamServers)
}
// ID returns the unique handler ID
@@ -116,58 +115,102 @@ func (u *upstreamResolverBase) Stop() {
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID)
var err error
defer func() {
u.checkUpstreamFails(err)
}()
logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
u.prepareRequest(r)
if u.ctx.Err() != nil {
logger.Tracef("%s has been stopped", u)
return
}
if u.tryUpstreamServers(w, r, logger) {
return
}
u.writeErrorResponse(w, r, logger)
}
func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true
}
}
select {
case <-u.ctx.Done():
logger.Tracef("%s has been stopped", u)
return
default:
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool {
timeout := u.upstreamTimeout
if len(u.upstreamServers) > 1 {
maxTotal := 5 * time.Second
minPerUpstream := 2 * time.Second
scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers))
if scaledTimeout > minPerUpstream {
timeout = scaledTimeout
} else {
timeout = minPerUpstream
}
}
for _, upstream := range u.upstreamServers {
var rm *dns.Msg
var t time.Duration
func() {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
defer cancel()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}()
if err != nil {
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
continue
}
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
continue
if u.queryUpstream(w, r, upstream, timeout, logger) {
return true
}
}
return false
}
if rm == nil || !rm.Response {
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
continue
}
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) bool {
var rm *dns.Msg
var t time.Duration
var err error
u.successCount.Add(1)
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
var startTime time.Time
func() {
ctx, cancel := context.WithTimeout(u.ctx, timeout)
defer cancel()
startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}()
if err = w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
}
// count the fails only if they happen sequentially
u.failsCount.Store(0)
if err != nil {
u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger)
return false
}
if rm == nil || !rm.Response {
logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
return false
}
return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
}
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) {
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err)
return
}
u.failsCount.Add(1)
elapsed := time.Since(startTime)
timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout)
if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" {
timeoutMsg += " " + peerInfo
}
timeoutMsg += fmt.Sprintf(" - error: %v", err)
logger.Warnf(timeoutMsg)
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
u.successCount.Add(1)
logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, domain)
if err := w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
}
return true
}
func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) {
logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
m := new(dns.Msg)
@@ -177,41 +220,6 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
}
// checkUpstreamFails counts fails and disables or enables upstream resolving
//
// If fails count is greater that failsTillDeact, upstream resolving
// will be disabled for reactivatePeriod, after that time period fails counter
// will be reset and upstream will be reactivated.
func (u *upstreamResolverBase) checkUpstreamFails(err error) {
u.mutex.Lock()
defer u.mutex.Unlock()
if u.failsCount.Load() < u.failsTillDeact || u.disabled {
return
}
select {
case <-u.ctx.Done():
return
default:
}
u.disable(err)
if u.statusRecorder == nil {
return
}
u.statusRecorder.PublishEvent(
proto.SystemEvent_WARNING,
proto.SystemEvent_DNS,
"All upstream servers failed (fail count exceeded)",
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
map[string]string{"upstreams": u.upstreamServersString()},
// TODO add domain meta
)
}
// ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability() {
@@ -224,8 +232,8 @@ func (u *upstreamResolverBase) ProbeAvailability() {
default:
}
// avoid probe if upstreams could resolve at least one query and fails count is less than failsTillDeact
if u.successCount.Load() > 0 && u.failsCount.Load() < u.failsTillDeact {
// avoid probe if upstreams could resolve at least one query
if u.successCount.Load() > 0 {
return
}
@@ -312,7 +320,6 @@ func (u *upstreamResolverBase) waitUntilResponse() {
}
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
u.failsCount.Store(0)
u.successCount.Add(1)
u.reactivate()
u.disabled = false
@@ -416,3 +423,80 @@ func GenerateRequestID() string {
}
return hex.EncodeToString(bytes)
}
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
func FormatPeerStatus(peerState *peer.State) string {
isConnected := peerState.ConnStatus == peer.StatusConnected
hasRecentHandshake := !peerState.LastWireguardHandshake.IsZero() &&
time.Since(peerState.LastWireguardHandshake) < 3*time.Minute
statusInfo := fmt.Sprintf("%s:%s", peerState.FQDN, peerState.IP)
switch {
case !isConnected:
statusInfo += " DISCONNECTED"
case !hasRecentHandshake:
statusInfo += " NO_RECENT_HANDSHAKE"
default:
statusInfo += " connected"
}
if !peerState.LastWireguardHandshake.IsZero() {
timeSinceHandshake := time.Since(peerState.LastWireguardHandshake)
statusInfo += fmt.Sprintf(" last_handshake=%v_ago", timeSinceHandshake.Truncate(time.Second))
} else {
statusInfo += " no_handshake"
}
if peerState.Relayed {
statusInfo += " via_relay"
}
if peerState.Latency > 0 {
statusInfo += fmt.Sprintf(" latency=%v", peerState.Latency)
}
return statusInfo
}
// findPeerForIP finds which peer handles the given IP address
func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
if statusRecorder == nil {
return nil
}
fullStatus := statusRecorder.GetFullStatus()
var bestMatch *peer.State
var bestPrefixLen int
for _, peerState := range fullStatus.Peers {
routes := peerState.GetRoutes()
for route := range routes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
continue
}
if prefix.Contains(ip) && prefix.Bits() > bestPrefixLen {
peerStateCopy := peerState
bestMatch = &peerStateCopy
bestPrefixLen = prefix.Bits()
}
}
}
return bestMatch
}
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
if u.statusRecorder == nil {
return ""
}
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
if peerInfo == nil {
return ""
}
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
}