mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-15 05:09:55 +00:00
[client] Skip DNS upstream failover on definitive EDE (#6089)
This commit is contained in:
@@ -30,6 +30,27 @@ import (
|
||||
|
||||
var currentMTU uint16 = iface.DefaultMTU
|
||||
|
||||
// nonRetryableEDECodes lists EDE info codes (RFC 8914) for which a SERVFAIL
|
||||
// from one upstream means another upstream would return the same answer:
|
||||
// DNSSEC validation outcomes and policy-based blocks. Transient errors
|
||||
// (network, cached, not ready) are not included.
|
||||
var nonRetryableEDECodes = map[uint16]struct{}{
|
||||
dns.ExtendedErrorCodeUnsupportedDNSKEYAlgorithm: {},
|
||||
dns.ExtendedErrorCodeUnsupportedDSDigestType: {},
|
||||
dns.ExtendedErrorCodeDNSSECIndeterminate: {},
|
||||
dns.ExtendedErrorCodeDNSBogus: {},
|
||||
dns.ExtendedErrorCodeSignatureExpired: {},
|
||||
dns.ExtendedErrorCodeSignatureNotYetValid: {},
|
||||
dns.ExtendedErrorCodeDNSKEYMissing: {},
|
||||
dns.ExtendedErrorCodeRRSIGsMissing: {},
|
||||
dns.ExtendedErrorCodeNoZoneKeyBitSet: {},
|
||||
dns.ExtendedErrorCodeNSECMissing: {},
|
||||
dns.ExtendedErrorCodeBlocked: {},
|
||||
dns.ExtendedErrorCodeCensored: {},
|
||||
dns.ExtendedErrorCodeFiltered: {},
|
||||
dns.ExtendedErrorCodeProhibited: {},
|
||||
}
|
||||
|
||||
// privateClientIface is the subset of the WireGuard interface needed by GetClientPrivate.
|
||||
type privateClientIface interface {
|
||||
Name() string
|
||||
@@ -250,6 +271,18 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
|
||||
var t time.Duration
|
||||
var err error
|
||||
|
||||
// Advertise EDNS0 so the upstream may include Extended DNS Errors
|
||||
// (RFC 8914) in failure responses; we use those to short-circuit
|
||||
// failover for definitive answers like DNSSEC validation failures.
|
||||
// Operate on a copy so the inbound request is unchanged: a client that
|
||||
// did not advertise EDNS0 must not see an OPT in the response.
|
||||
hadEdns := r.IsEdns0() != nil
|
||||
reqUp := r
|
||||
if !hadEdns {
|
||||
reqUp = r.Copy()
|
||||
reqUp.SetEdns0(upstreamUDPSize(), false)
|
||||
}
|
||||
|
||||
var startTime time.Time
|
||||
var upstreamProto *upstreamProtocolResult
|
||||
func() {
|
||||
@@ -257,7 +290,7 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
|
||||
defer cancel()
|
||||
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
|
||||
startTime = time.Now()
|
||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
|
||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), reqUp)
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
@@ -269,13 +302,49 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
|
||||
}
|
||||
|
||||
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
||||
if code, ok := nonRetryableEDE(rm); ok {
|
||||
resutil.SetMeta(w, "ede", edeName(code))
|
||||
if !hadEdns {
|
||||
stripOPT(rm)
|
||||
}
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
||||
return nil
|
||||
}
|
||||
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
|
||||
}
|
||||
|
||||
if !hadEdns {
|
||||
stripOPT(rm)
|
||||
}
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
||||
return nil
|
||||
}
|
||||
|
||||
// upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams,
|
||||
// derived from the tunnel MTU and bounded against underflow.
|
||||
func upstreamUDPSize() uint16 {
|
||||
if currentMTU > ipUDPHeaderSize {
|
||||
return currentMTU - ipUDPHeaderSize
|
||||
}
|
||||
return dns.MinMsgSize
|
||||
}
|
||||
|
||||
// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
|
||||
// the response complies with RFC 6891 when the client did not advertise EDNS0.
|
||||
func stripOPT(rm *dns.Msg) {
|
||||
if len(rm.Extra) == 0 {
|
||||
return
|
||||
}
|
||||
out := rm.Extra[:0]
|
||||
for _, rr := range rm.Extra {
|
||||
if _, ok := rr.(*dns.OPT); ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, rr)
|
||||
}
|
||||
rm.Extra = out
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
|
||||
return &upstreamFailure{upstream: upstream, reason: err.Error()}
|
||||
@@ -337,6 +406,34 @@ func formatFailures(failures []upstreamFailure) string {
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
// nonRetryableEDE returns the first non-retryable EDE code carried in the
|
||||
// response, if any.
|
||||
func nonRetryableEDE(rm *dns.Msg) (uint16, bool) {
|
||||
opt := rm.IsEdns0()
|
||||
if opt == nil {
|
||||
return 0, false
|
||||
}
|
||||
for _, o := range opt.Option {
|
||||
ede, ok := o.(*dns.EDNS0_EDE)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := nonRetryableEDECodes[ede.InfoCode]; ok {
|
||||
return ede.InfoCode, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// edeName returns a human-readable name for an EDE code, falling back to
|
||||
// the numeric code when unknown.
|
||||
func edeName(code uint16) string {
|
||||
if name, ok := dns.ExtendedErrorCodeToString[code]; ok {
|
||||
return name
|
||||
}
|
||||
return fmt.Sprintf("EDE %d", code)
|
||||
}
|
||||
|
||||
// ProbeAvailability tests all upstream servers simultaneously and
|
||||
// disables the resolver if none work
|
||||
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
|
||||
|
||||
@@ -770,3 +770,132 @@ func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) {
|
||||
assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records")
|
||||
assert.True(t, rm2.Truncated, "response should be truncated for small buffer client")
|
||||
}
|
||||
|
||||
func msgWithEDE(rcode int, codes ...uint16) *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.Response = true
|
||||
m.Rcode = rcode
|
||||
if len(codes) == 0 {
|
||||
return m
|
||||
}
|
||||
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||
opt.SetUDPSize(dns.MinMsgSize)
|
||||
for _, c := range codes {
|
||||
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: c})
|
||||
}
|
||||
m.Extra = append(m.Extra, opt)
|
||||
return m
|
||||
}
|
||||
|
||||
func TestNonRetryableEDE(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
wantOK bool
|
||||
wantCode uint16
|
||||
}{
|
||||
{name: "no edns0", msg: msgWithEDE(dns.RcodeServerFailure)},
|
||||
{
|
||||
name: "opt without ede",
|
||||
msg: func() *dns.Msg {
|
||||
m := msgWithEDE(dns.RcodeServerFailure)
|
||||
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||
opt.Option = append(opt.Option, &dns.EDNS0_NSID{Code: dns.EDNS0NSID})
|
||||
m.Extra = []dns.RR{opt}
|
||||
return m
|
||||
}(),
|
||||
},
|
||||
{name: "ede dnsbogus", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus), wantOK: true, wantCode: dns.ExtendedErrorCodeDNSBogus},
|
||||
{name: "ede signature expired", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeSignatureExpired), wantOK: true, wantCode: dns.ExtendedErrorCodeSignatureExpired},
|
||||
{name: "ede blocked", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeBlocked), wantOK: true, wantCode: dns.ExtendedErrorCodeBlocked},
|
||||
{name: "ede prohibited", msg: msgWithEDE(dns.RcodeRefused, dns.ExtendedErrorCodeProhibited), wantOK: true, wantCode: dns.ExtendedErrorCodeProhibited},
|
||||
{name: "ede cached error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeCachedError)},
|
||||
{name: "ede network error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError)},
|
||||
{name: "ede not ready retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNotReady)},
|
||||
{
|
||||
name: "first non-retryable wins",
|
||||
msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError, dns.ExtendedErrorCodeDNSBogus),
|
||||
wantOK: true,
|
||||
wantCode: dns.ExtendedErrorCodeDNSBogus,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
code, ok := nonRetryableEDE(tc.msg)
|
||||
assert.Equal(t, tc.wantOK, ok, "ok should match")
|
||||
if tc.wantOK {
|
||||
assert.Equal(t, tc.wantCode, code, "code should match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEDEName(t *testing.T) {
|
||||
assert.Equal(t, "DNSSEC Bogus", edeName(dns.ExtendedErrorCodeDNSBogus))
|
||||
assert.Equal(t, "Signature Expired", edeName(dns.ExtendedErrorCodeSignatureExpired))
|
||||
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
|
||||
}
|
||||
|
||||
func TestStripOPT(t *testing.T) {
|
||||
rm := &dns.Msg{
|
||||
Extra: []dns.RR{
|
||||
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
|
||||
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
|
||||
},
|
||||
}
|
||||
stripOPT(rm)
|
||||
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
|
||||
_, isOPT := rm.Extra[0].(*dns.OPT)
|
||||
assert.False(t, isOPT, "remaining record must not be OPT")
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
|
||||
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
|
||||
|
||||
servfailWithEDE := msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus)
|
||||
successResp := buildMockResponse(dns.RcodeSuccess, "192.0.2.100")
|
||||
|
||||
var queried []string
|
||||
tracking := &trackingMockClient{
|
||||
inner: &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
upstream1.String(): {msg: servfailWithEDE},
|
||||
upstream2.String(): {msg: successResp},
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
},
|
||||
queriedUpstreams: &queried,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: tracking,
|
||||
upstreamServers: []netip.AddrPort{upstream1, upstream2},
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
|
||||
var written *dns.Msg
|
||||
w := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
written = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Client query without EDNS0 must not see an OPT in the response.
|
||||
q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
resolver.ServeDNS(w, q)
|
||||
|
||||
require.NotNil(t, written, "response must be written")
|
||||
assert.Equal(t, dns.RcodeServerFailure, written.Rcode, "SERVFAIL must propagate")
|
||||
assert.Len(t, queried, 1, "only first upstream should be queried")
|
||||
assert.Equal(t, upstream1.String(), queried[0])
|
||||
for _, rr := range written.Extra {
|
||||
_, isOPT := rr.(*dns.OPT)
|
||||
assert.False(t, isOPT, "synthetic OPT must not leak to a non-EDNS0 client")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user