[client] Fallback to TCP if a truncated UDP response is received from upstream DNS (#3632)

This commit is contained in:
Viktor Liu
2025-04-08 13:41:13 +02:00
committed by GitHub
parent 192c97aa63
commit 03f600b576
5 changed files with 53 additions and 9 deletions

View File

@@ -18,6 +18,7 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
)
@@ -107,9 +108,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}()
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
if r.Extra == nil {
r.SetEdns0(4096, false)
r.MsgHdr.AuthenticatedData = true
}
@@ -337,3 +336,51 @@ func (u *upstreamResolverBase) testNameserver(server string, timeout time.Durati
_, _, err := u.upstreamClient.exchange(ctx, server, r)
return err
}
// ExchangeWithFallback exchanges a DNS message with the upstream server.
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
// MTU - ip + udp headers
// Note: this could be sent out on an interface that is not ours, but our MTU should always be lower.
client.UDPSize = iface.DefaultMTU - (60 + 8)
var (
rm *dns.Msg
t time.Duration
err error
)
if ctx == nil {
rm, t, err = client.Exchange(r, upstream)
} else {
rm, t, err = client.ExchangeContext(ctx, r, upstream)
}
if err != nil {
return nil, t, fmt.Errorf("with udp: %w", err)
}
if rm == nil || !rm.MsgHdr.Truncated {
return rm, t, nil
}
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
client.Net = "tcp"
if ctx == nil {
rm, t, err = client.Exchange(r, upstream)
} else {
rm, t, err = client.ExchangeContext(ctx, r, upstream)
}
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
// TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP
return rm, t, nil
}