diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 02934ab51..0b6dd4389 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -267,8 +267,13 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re // Advertise EDNS0 so the upstream may include Extended DNS Errors // (RFC 8914) in failure responses; we use those to short-circuit // failover for definitive answers like DNSSEC validation failures. - if r.IsEdns0() == nil { - r.SetEdns0(currentMTU-ipUDPHeaderSize, false) + // Operate on a copy so the inbound request is unchanged: a client that + // did not advertise EDNS0 must not see an OPT in the response. + hadEdns := r.IsEdns0() != nil + reqUp := r + if !hadEdns { + reqUp = r.Copy() + reqUp.SetEdns0(upstreamUDPSize(), false) } var startTime time.Time @@ -278,7 +283,7 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re defer cancel() ctx, upstreamProto = contextWithupstreamProtocolResult(ctx) startTime = time.Now() - rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r) + rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), reqUp) }() if err != nil { @@ -292,16 +297,47 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused { if code, ok := nonRetryableEDE(rm); ok { resutil.SetMeta(w, "ede", edeName(code)) + if !hadEdns { + stripOPT(rm) + } u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger) return nil } return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]} } + if !hadEdns { + stripOPT(rm) + } u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger) return nil } +// upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams, +// derived from the tunnel MTU and bounded against underflow. +func upstreamUDPSize() uint16 { + if currentMTU > ipUDPHeaderSize { + return currentMTU - ipUDPHeaderSize + } + return dns.MinMsgSize +} + +// stripOPT removes any OPT pseudo-RRs from the response's Extra section so +// the response complies with RFC 6891 when the client did not advertise EDNS0. +func stripOPT(rm *dns.Msg) { + if len(rm.Extra) == 0 { + return + } + out := rm.Extra[:0] + for _, rr := range rm.Extra { + if _, ok := rr.(*dns.OPT); ok { + continue + } + out = append(out, rr) + } + rm.Extra = out +} + func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure { if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) { return &upstreamFailure{upstream: upstream, reason: err.Error()} diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 1797fdad8..dad55eb8d 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -770,3 +770,81 @@ func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) { assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records") assert.True(t, rm2.Truncated, "response should be truncated for small buffer client") } + +func msgWithEDE(rcode int, codes ...uint16) *dns.Msg { + m := new(dns.Msg) + m.Response = true + m.Rcode = rcode + if len(codes) == 0 { + return m + } + opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}} + opt.SetUDPSize(dns.MinMsgSize) + for _, c := range codes { + opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: c}) + } + m.Extra = append(m.Extra, opt) + return m +} + +func TestNonRetryableEDE(t *testing.T) { + tests := []struct { + name string + msg *dns.Msg + wantOK bool + wantCode uint16 + }{ + {name: "no edns0", msg: msgWithEDE(dns.RcodeServerFailure)}, + { + name: "opt without ede", + msg: func() *dns.Msg { + m := msgWithEDE(dns.RcodeServerFailure) + opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}} + opt.Option = append(opt.Option, &dns.EDNS0_NSID{Code: dns.EDNS0NSID}) + m.Extra = []dns.RR{opt} + return m + }(), + }, + {name: "ede dnsbogus", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus), wantOK: true, wantCode: dns.ExtendedErrorCodeDNSBogus}, + {name: "ede signature expired", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeSignatureExpired), wantOK: true, wantCode: dns.ExtendedErrorCodeSignatureExpired}, + {name: "ede blocked", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeBlocked), wantOK: true, wantCode: dns.ExtendedErrorCodeBlocked}, + {name: "ede prohibited", msg: msgWithEDE(dns.RcodeRefused, dns.ExtendedErrorCodeProhibited), wantOK: true, wantCode: dns.ExtendedErrorCodeProhibited}, + {name: "ede cached error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeCachedError)}, + {name: "ede network error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError)}, + {name: "ede not ready retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNotReady)}, + { + name: "first non-retryable wins", + msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError, dns.ExtendedErrorCodeDNSBogus), + wantOK: true, + wantCode: dns.ExtendedErrorCodeDNSBogus, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + code, ok := nonRetryableEDE(tc.msg) + assert.Equal(t, tc.wantOK, ok, "ok should match") + if tc.wantOK { + assert.Equal(t, tc.wantCode, code, "code should match") + } + }) + } +} + +func TestEDEName(t *testing.T) { + assert.Equal(t, "DNSSEC Bogus", edeName(dns.ExtendedErrorCodeDNSBogus)) + assert.Equal(t, "Signature Expired", edeName(dns.ExtendedErrorCodeSignatureExpired)) + assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric") +} + +func TestStripOPT(t *testing.T) { + rm := &dns.Msg{ + Extra: []dns.RR{ + &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}, + &dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)}, + }, + } + stripOPT(rm) + assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept") + _, isOPT := rm.Extra[0].(*dns.OPT) + assert.False(t, isOPT, "remaining record must not be OPT") +} diff --git a/util/capture/text.go b/util/capture/text.go index b44bd0cad..5b8a9357a 100644 --- a/util/capture/text.go +++ b/util/capture/text.go @@ -10,6 +10,7 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" + "github.com/miekg/dns" ) // TextWriter writes human-readable one-line-per-packet summaries. @@ -594,7 +595,8 @@ func formatDNSResponse(d *layers.DNS, rd string, plen int) string { arCount := d.ARCount if d.ResponseCode != layers.DNSResponseCodeNoErr { - return fmt.Sprintf("%04x %d/%d/%d %s (%d)", d.ID, anCount, nsCount, arCount, d.ResponseCode, plen) + ede := formatEDE(d) + return fmt.Sprintf("%04x %d/%d/%d %s%s (%d)", d.ID, anCount, nsCount, arCount, d.ResponseCode, ede, plen) } if anCount > 0 && len(d.Answers) > 0 { @@ -607,6 +609,31 @@ func formatDNSResponse(d *layers.DNS, rd string, plen int) string { return fmt.Sprintf("%04x %d/%d/%d (%d)", d.ID, anCount, nsCount, arCount, plen) } +// dnsOPTCodeEDE is the EDNS0 option code for Extended DNS Errors (RFC 8914). +const dnsOPTCodeEDE layers.DNSOptionCode = layers.DNSOptionCode(dns.EDNS0EDE) + +// formatEDE returns " EDE=Name" for the first Extended DNS Error option +// found in the response, or empty string if none is present. +func formatEDE(d *layers.DNS) string { + for _, rr := range d.Additionals { + if rr.Type != layers.DNSTypeOPT { + continue + } + for _, opt := range rr.OPT { + if opt.Code != dnsOPTCodeEDE || len(opt.Data) < 2 { + continue + } + info := binary.BigEndian.Uint16(opt.Data[:2]) + name, ok := dns.ExtendedErrorCodeToString[info] + if !ok { + name = fmt.Sprintf("%d", info) + } + return " EDE=" + name + } + } + return "" +} + func shortRData(rr *layers.DNSResourceRecord) string { switch rr.Type { case layers.DNSTypeA, layers.DNSTypeAAAA: