From 269d5d1cbab7f67b289b885a3ed9d3493631ef16 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 23 Jan 2026 18:59:52 +0800 Subject: [PATCH] [client] Try next DNS upstream on SERVFAIL/REFUSED responses (#5163) --- client/internal/dns/local/local.go | 5 +- client/internal/dns/upstream.go | 84 +++++--- client/internal/dns/upstream_test.go | 284 +++++++++++++++++++++++++++ 3 files changed, 346 insertions(+), 27 deletions(-) diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go index cbdc64997..b374bcc6a 100644 --- a/client/internal/dns/local/local.go +++ b/client/internal/dns/local/local.go @@ -81,7 +81,10 @@ func (d *Resolver) ProbeAvailability() {} // ServeDNS handles a DNS request func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - logger := log.WithField("request_id", resutil.GetRequestID(w)) + logger := log.WithFields(log.Fields{ + "request_id": resutil.GetRequestID(w), + "dns_id": fmt.Sprintf("%04x", r.Id), + }) if len(r.Question) == 0 { logger.Debug("received local resolver request with no question") diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 654d280ef..0fbd32771 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -71,6 +71,11 @@ type upstreamResolverBase struct { statusRecorder *peer.Status } +type upstreamFailure struct { + upstream netip.AddrPort + reason string +} + func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase { ctx, cancel := context.WithCancel(ctx) @@ -114,7 +119,10 @@ func (u *upstreamResolverBase) Stop() { // ServeDNS handles a DNS request func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - logger := log.WithField("request_id", resutil.GetRequestID(w)) + logger := log.WithFields(log.Fields{ + "request_id": resutil.GetRequestID(w), + "dns_id": fmt.Sprintf("%04x", r.Id), + }) u.prepareRequest(r) @@ -123,11 +131,13 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } - if u.tryUpstreamServers(w, r, logger) { - return + ok, failures := u.tryUpstreamServers(w, r, logger) + if len(failures) > 0 { + u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger) + } + if !ok { + u.writeErrorResponse(w, r, logger) } - - u.writeErrorResponse(w, r, logger) } func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) { @@ -136,7 +146,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) { } } -func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) bool { +func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) { timeout := u.upstreamTimeout if len(u.upstreamServers) > 1 { maxTotal := 5 * time.Second @@ -149,15 +159,19 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M } } + var failures []upstreamFailure for _, upstream := range u.upstreamServers { - if u.queryUpstream(w, r, upstream, timeout, logger) { - return true + if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil { + failures = append(failures, *failure) + } else { + return true, failures } } - return false + return false, failures } -func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) bool { +// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream. +func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure { var rm *dns.Msg var t time.Duration var err error @@ -171,31 +185,32 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u }() if err != nil { - u.handleUpstreamError(err, upstream, r.Question[0].Name, startTime, timeout, logger) - return false + return u.handleUpstreamError(err, upstream, startTime) } if rm == nil || !rm.Response { - logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) - return false + return &upstreamFailure{upstream: upstream, reason: "no response"} } - return u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger) + if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused { + return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]} + } + + u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger) + return nil } -func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, domain string, startTime time.Time, timeout time.Duration, logger *log.Entry) { +func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure { if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) { - logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, domain, err) - return + return &upstreamFailure{upstream: upstream, reason: err.Error()} } elapsed := time.Since(startTime) - timeoutMsg := fmt.Sprintf("upstream %s timed out for question domain=%s after %v (timeout=%v)", upstream, domain, elapsed.Truncate(time.Millisecond), timeout) + reason := fmt.Sprintf("timeout after %v", elapsed.Truncate(time.Millisecond)) if peerInfo := u.debugUpstreamTimeout(upstream); peerInfo != "" { - timeoutMsg += " " + peerInfo + reason += " " + peerInfo } - timeoutMsg += fmt.Sprintf(" - error: %v", err) - logger.Warn(timeoutMsg) + return &upstreamFailure{upstream: upstream, reason: reason} } func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool { @@ -215,16 +230,34 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn return true } -func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) { - logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) +func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) { + totalUpstreams := len(u.upstreamServers) + failedCount := len(failures) + failureSummary := formatFailures(failures) + if succeeded { + logger.Warnf("%d/%d upstreams failed for domain=%s: %s", failedCount, totalUpstreams, domain, failureSummary) + } else { + logger.Errorf("%d/%d upstreams failed for domain=%s: %s", failedCount, totalUpstreams, domain, failureSummary) + } +} + +func (u *upstreamResolverBase) writeErrorResponse(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) { m := new(dns.Msg) m.SetRcode(r, dns.RcodeServerFailure) if err := w.WriteMsg(m); err != nil { - logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err) + logger.Errorf("write error response for domain=%s: %s", r.Question[0].Name, err) } } +func formatFailures(failures []upstreamFailure) string { + parts := make([]string, 0, len(failures)) + for _, f := range failures { + parts = append(parts, fmt.Sprintf("%s=%s", f.upstream, f.reason)) + } + return strings.Join(parts, ", ") +} + // ProbeAvailability tests all upstream servers simultaneously and // disables the resolver if none work func (u *upstreamResolverBase) ProbeAvailability() { @@ -468,7 +501,6 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst return reply, nil } - // FormatPeerStatus formats peer connection status information for debugging DNS timeouts func FormatPeerStatus(peerState *peer.State) string { isConnected := peerState.ConnStatus == peer.StatusConnected diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 2852f4775..8b06e4475 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -2,6 +2,7 @@ package dns import ( "context" + "fmt" "net" "net/netip" "strings" @@ -9,6 +10,8 @@ import ( "time" "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/iface/device" @@ -140,6 +143,23 @@ func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) return c.r, c.rtt, c.err } +type mockUpstreamResponse struct { + msg *dns.Msg + err error +} + +type mockUpstreamResolverPerServer struct { + responses map[string]mockUpstreamResponse + rtt time.Duration +} + +func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) { + if r, ok := c.responses[upstream]; ok { + return r.msg, c.rtt, r.err + } + return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream) +} + func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { mockClient := &mockUpstreamResolver{ err: dns.ErrTime, @@ -191,3 +211,267 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) { t.Errorf("should be enabled") } } + +func TestUpstreamResolver_Failover(t *testing.T) { + upstream1 := netip.MustParseAddrPort("192.0.2.1:53") + upstream2 := netip.MustParseAddrPort("192.0.2.2:53") + + successAnswer := "192.0.2.100" + timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")} + + testCases := []struct { + name string + upstream1 mockUpstreamResponse + upstream2 mockUpstreamResponse + expectedRcode int + expectAnswer bool + expectTrySecond bool + }{ + { + name: "success on first upstream", + upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)}, + upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)}, + expectedRcode: dns.RcodeSuccess, + expectAnswer: true, + expectTrySecond: false, + }, + { + name: "SERVFAIL from first should try second", + upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")}, + upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)}, + expectedRcode: dns.RcodeSuccess, + expectAnswer: true, + expectTrySecond: true, + }, + { + name: "REFUSED from first should try second", + upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")}, + upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)}, + expectedRcode: dns.RcodeSuccess, + expectAnswer: true, + expectTrySecond: true, + }, + { + name: "NXDOMAIN from first should NOT try second", + upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeNameError, "")}, + upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)}, + expectedRcode: dns.RcodeNameError, + expectAnswer: false, + expectTrySecond: false, + }, + { + name: "timeout from first should try second", + upstream1: mockUpstreamResponse{err: timeoutErr}, + upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)}, + expectedRcode: dns.RcodeSuccess, + expectAnswer: true, + expectTrySecond: true, + }, + { + name: "no response from first should try second", + upstream1: mockUpstreamResponse{msg: nil}, + upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeSuccess, successAnswer)}, + expectedRcode: dns.RcodeSuccess, + expectAnswer: true, + expectTrySecond: true, + }, + { + name: "both upstreams return SERVFAIL", + upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")}, + upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")}, + expectedRcode: dns.RcodeServerFailure, + expectAnswer: false, + expectTrySecond: true, + }, + { + name: "both upstreams timeout", + upstream1: mockUpstreamResponse{err: timeoutErr}, + upstream2: mockUpstreamResponse{err: timeoutErr}, + expectedRcode: dns.RcodeServerFailure, + expectAnswer: false, + expectTrySecond: true, + }, + { + name: "first SERVFAIL then timeout", + upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")}, + upstream2: mockUpstreamResponse{err: timeoutErr}, + expectedRcode: dns.RcodeServerFailure, + expectAnswer: false, + expectTrySecond: true, + }, + { + name: "first timeout then SERVFAIL", + upstream1: mockUpstreamResponse{err: timeoutErr}, + upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")}, + expectedRcode: dns.RcodeServerFailure, + expectAnswer: false, + expectTrySecond: true, + }, + { + name: "first REFUSED then SERVFAIL", + upstream1: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")}, + upstream2: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")}, + expectedRcode: dns.RcodeServerFailure, + expectAnswer: false, + expectTrySecond: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var queriedUpstreams []string + mockClient := &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + upstream1.String(): tc.upstream1, + upstream2.String(): tc.upstream2, + }, + rtt: time.Millisecond, + } + + trackingClient := &trackingMockClient{ + inner: mockClient, + queriedUpstreams: &queriedUpstreams, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := &upstreamResolverBase{ + ctx: ctx, + upstreamClient: trackingClient, + upstreamServers: []netip.AddrPort{upstream1, upstream2}, + upstreamTimeout: UpstreamTimeout, + } + + var responseMSG *dns.Msg + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + resolver.ServeDNS(responseWriter, inputMSG) + + require.NotNil(t, responseMSG, "should write a response") + assert.Equal(t, tc.expectedRcode, responseMSG.Rcode, "unexpected rcode") + + if tc.expectAnswer { + require.NotEmpty(t, responseMSG.Answer, "expected answer records") + assert.Contains(t, responseMSG.Answer[0].String(), successAnswer) + } + + if tc.expectTrySecond { + assert.Len(t, queriedUpstreams, 2, "should have tried both upstreams") + assert.Equal(t, upstream1.String(), queriedUpstreams[0]) + assert.Equal(t, upstream2.String(), queriedUpstreams[1]) + } else { + assert.Len(t, queriedUpstreams, 1, "should have only tried first upstream") + assert.Equal(t, upstream1.String(), queriedUpstreams[0]) + } + }) + } +} + +type trackingMockClient struct { + inner *mockUpstreamResolverPerServer + queriedUpstreams *[]string +} + +func (t *trackingMockClient) exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) { + *t.queriedUpstreams = append(*t.queriedUpstreams, upstream) + return t.inner.exchange(ctx, upstream, r) +} + +func buildMockResponse(rcode int, answer string) *dns.Msg { + m := new(dns.Msg) + m.Response = true + m.Rcode = rcode + + if rcode == dns.RcodeSuccess && answer != "" { + m.Answer = []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{ + Name: "example.com.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + }, + A: net.ParseIP(answer), + }, + } + } + return m +} + +func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) { + upstream := netip.MustParseAddrPort("192.0.2.1:53") + + mockClient := &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + upstream.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")}, + }, + rtt: time.Millisecond, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resolver := &upstreamResolverBase{ + ctx: ctx, + upstreamClient: mockClient, + upstreamServers: []netip.AddrPort{upstream}, + upstreamTimeout: UpstreamTimeout, + } + + var responseMSG *dns.Msg + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA) + resolver.ServeDNS(responseWriter, inputMSG) + + require.NotNil(t, responseMSG, "should write a response") + assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL") +} + +func TestFormatFailures(t *testing.T) { + testCases := []struct { + name string + failures []upstreamFailure + expected string + }{ + { + name: "empty slice", + failures: []upstreamFailure{}, + expected: "", + }, + { + name: "single failure", + failures: []upstreamFailure{ + {upstream: netip.MustParseAddrPort("8.8.8.8:53"), reason: "SERVFAIL"}, + }, + expected: "8.8.8.8:53=SERVFAIL", + }, + { + name: "multiple failures", + failures: []upstreamFailure{ + {upstream: netip.MustParseAddrPort("8.8.8.8:53"), reason: "SERVFAIL"}, + {upstream: netip.MustParseAddrPort("8.8.4.4:53"), reason: "timeout after 2s"}, + }, + expected: "8.8.8.8:53=SERVFAIL, 8.8.4.4:53=timeout after 2s", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := formatFailures(tc.failures) + assert.Equal(t, tc.expected, result) + }) + } +}