Strip OPT from response when client lacked EDNS0; add tests; show EDE in capture

This commit is contained in:
Viktor Liu
2026-05-06 10:31:03 +02:00
parent e394818b9f
commit bea54ef6aa
3 changed files with 145 additions and 4 deletions

View File

@@ -267,8 +267,13 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
// Advertise EDNS0 so the upstream may include Extended DNS Errors // Advertise EDNS0 so the upstream may include Extended DNS Errors
// (RFC 8914) in failure responses; we use those to short-circuit // (RFC 8914) in failure responses; we use those to short-circuit
// failover for definitive answers like DNSSEC validation failures. // failover for definitive answers like DNSSEC validation failures.
if r.IsEdns0() == nil { // Operate on a copy so the inbound request is unchanged: a client that
r.SetEdns0(currentMTU-ipUDPHeaderSize, false) // did not advertise EDNS0 must not see an OPT in the response.
hadEdns := r.IsEdns0() != nil
reqUp := r
if !hadEdns {
reqUp = r.Copy()
reqUp.SetEdns0(upstreamUDPSize(), false)
} }
var startTime time.Time var startTime time.Time
@@ -278,7 +283,7 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
defer cancel() defer cancel()
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx) ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
startTime = time.Now() startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r) rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), reqUp)
}() }()
if err != nil { if err != nil {
@@ -292,16 +297,47 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused { if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
if code, ok := nonRetryableEDE(rm); ok { if code, ok := nonRetryableEDE(rm); ok {
resutil.SetMeta(w, "ede", edeName(code)) resutil.SetMeta(w, "ede", edeName(code))
if !hadEdns {
stripOPT(rm)
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger) u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
return nil return nil
} }
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]} return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
} }
if !hadEdns {
stripOPT(rm)
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger) u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
return nil return nil
} }
// upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams,
// derived from the tunnel MTU and bounded against underflow.
func upstreamUDPSize() uint16 {
if currentMTU > ipUDPHeaderSize {
return currentMTU - ipUDPHeaderSize
}
return dns.MinMsgSize
}
// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
// the response complies with RFC 6891 when the client did not advertise EDNS0.
func stripOPT(rm *dns.Msg) {
if len(rm.Extra) == 0 {
return
}
out := rm.Extra[:0]
for _, rr := range rm.Extra {
if _, ok := rr.(*dns.OPT); ok {
continue
}
out = append(out, rr)
}
rm.Extra = out
}
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure { func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) { if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
return &upstreamFailure{upstream: upstream, reason: err.Error()} return &upstreamFailure{upstream: upstream, reason: err.Error()}

View File

@@ -770,3 +770,81 @@ func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) {
assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records") assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records")
assert.True(t, rm2.Truncated, "response should be truncated for small buffer client") assert.True(t, rm2.Truncated, "response should be truncated for small buffer client")
} }
func msgWithEDE(rcode int, codes ...uint16) *dns.Msg {
m := new(dns.Msg)
m.Response = true
m.Rcode = rcode
if len(codes) == 0 {
return m
}
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
opt.SetUDPSize(dns.MinMsgSize)
for _, c := range codes {
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: c})
}
m.Extra = append(m.Extra, opt)
return m
}
func TestNonRetryableEDE(t *testing.T) {
tests := []struct {
name string
msg *dns.Msg
wantOK bool
wantCode uint16
}{
{name: "no edns0", msg: msgWithEDE(dns.RcodeServerFailure)},
{
name: "opt without ede",
msg: func() *dns.Msg {
m := msgWithEDE(dns.RcodeServerFailure)
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
opt.Option = append(opt.Option, &dns.EDNS0_NSID{Code: dns.EDNS0NSID})
m.Extra = []dns.RR{opt}
return m
}(),
},
{name: "ede dnsbogus", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus), wantOK: true, wantCode: dns.ExtendedErrorCodeDNSBogus},
{name: "ede signature expired", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeSignatureExpired), wantOK: true, wantCode: dns.ExtendedErrorCodeSignatureExpired},
{name: "ede blocked", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeBlocked), wantOK: true, wantCode: dns.ExtendedErrorCodeBlocked},
{name: "ede prohibited", msg: msgWithEDE(dns.RcodeRefused, dns.ExtendedErrorCodeProhibited), wantOK: true, wantCode: dns.ExtendedErrorCodeProhibited},
{name: "ede cached error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeCachedError)},
{name: "ede network error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError)},
{name: "ede not ready retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNotReady)},
{
name: "first non-retryable wins",
msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError, dns.ExtendedErrorCodeDNSBogus),
wantOK: true,
wantCode: dns.ExtendedErrorCodeDNSBogus,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
code, ok := nonRetryableEDE(tc.msg)
assert.Equal(t, tc.wantOK, ok, "ok should match")
if tc.wantOK {
assert.Equal(t, tc.wantCode, code, "code should match")
}
})
}
}
func TestEDEName(t *testing.T) {
assert.Equal(t, "DNSSEC Bogus", edeName(dns.ExtendedErrorCodeDNSBogus))
assert.Equal(t, "Signature Expired", edeName(dns.ExtendedErrorCodeSignatureExpired))
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
}
func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
},
}
stripOPT(rm)
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
_, isOPT := rm.Extra[0].(*dns.OPT)
assert.False(t, isOPT, "remaining record must not be OPT")
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/miekg/dns"
) )
// TextWriter writes human-readable one-line-per-packet summaries. // TextWriter writes human-readable one-line-per-packet summaries.
@@ -594,7 +595,8 @@ func formatDNSResponse(d *layers.DNS, rd string, plen int) string {
arCount := d.ARCount arCount := d.ARCount
if d.ResponseCode != layers.DNSResponseCodeNoErr { if d.ResponseCode != layers.DNSResponseCodeNoErr {
return fmt.Sprintf("%04x %d/%d/%d %s (%d)", d.ID, anCount, nsCount, arCount, d.ResponseCode, plen) ede := formatEDE(d)
return fmt.Sprintf("%04x %d/%d/%d %s%s (%d)", d.ID, anCount, nsCount, arCount, d.ResponseCode, ede, plen)
} }
if anCount > 0 && len(d.Answers) > 0 { if anCount > 0 && len(d.Answers) > 0 {
@@ -607,6 +609,31 @@ func formatDNSResponse(d *layers.DNS, rd string, plen int) string {
return fmt.Sprintf("%04x %d/%d/%d (%d)", d.ID, anCount, nsCount, arCount, plen) return fmt.Sprintf("%04x %d/%d/%d (%d)", d.ID, anCount, nsCount, arCount, plen)
} }
// dnsOPTCodeEDE is the EDNS0 option code for Extended DNS Errors (RFC 8914).
const dnsOPTCodeEDE layers.DNSOptionCode = layers.DNSOptionCode(dns.EDNS0EDE)
// formatEDE returns " EDE=Name" for the first Extended DNS Error option
// found in the response, or empty string if none is present.
func formatEDE(d *layers.DNS) string {
for _, rr := range d.Additionals {
if rr.Type != layers.DNSTypeOPT {
continue
}
for _, opt := range rr.OPT {
if opt.Code != dnsOPTCodeEDE || len(opt.Data) < 2 {
continue
}
info := binary.BigEndian.Uint16(opt.Data[:2])
name, ok := dns.ExtendedErrorCodeToString[info]
if !ok {
name = fmt.Sprintf("%d", info)
}
return " EDE=" + name
}
}
return ""
}
func shortRData(rr *layers.DNSResourceRecord) string { func shortRData(rr *layers.DNSResourceRecord) string {
switch rr.Type { switch rr.Type {
case layers.DNSTypeA, layers.DNSTypeAAAA: case layers.DNSTypeA, layers.DNSTypeAAAA: