[client] Support dns upstream failover for nameserver groups with same match domain (#3178)

This commit is contained in:
Viktor Liu
2025-02-10 18:13:34 +01:00
committed by GitHub
parent 5953b43ead
commit 488b697479
11 changed files with 747 additions and 186 deletions

View File

@@ -2,9 +2,13 @@ package dns
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"net"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
@@ -40,6 +44,7 @@ type upstreamResolverBase struct {
cancel context.CancelFunc
upstreamClient upstreamClient
upstreamServers []string
domain string
disabled bool
failsCount atomic.Int32
successCount atomic.Int32
@@ -53,12 +58,13 @@ type upstreamResolverBase struct {
statusRecorder *peer.Status
}
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *upstreamResolverBase {
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
ctx, cancel := context.WithCancel(ctx)
return &upstreamResolverBase{
ctx: ctx,
cancel: cancel,
domain: domain,
upstreamTimeout: upstreamTimeout,
reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
@@ -71,6 +77,17 @@ func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("upstream %v", u.upstreamServers)
}
// ID returns the unique handler ID
func (u *upstreamResolverBase) id() handlerID {
servers := slices.Clone(u.upstreamServers)
slices.Sort(servers)
hash := sha256.New()
hash.Write([]byte(u.domain + ":"))
hash.Write([]byte(strings.Join(servers, ",")))
return handlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
}
func (u *upstreamResolverBase) MatchSubdomains() bool {
return true
}
@@ -87,7 +104,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
u.checkUpstreamFails(err)
}()
log.WithField("question", r.Question[0]).Trace("received an upstream question")
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
if r.Extra == nil {
r.SetEdns0(4096, false)
@@ -96,6 +113,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
select {
case <-u.ctx.Done():
log.Tracef("%s has been stopped", u)
return
default:
}
@@ -112,41 +130,36 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if err != nil {
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
log.WithError(err).WithField("upstream", upstream).
Warn("got an error while connecting to upstream")
log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
continue
}
u.failsCount.Add(1)
log.WithError(err).WithField("upstream", upstream).
Error("got other error while querying the upstream")
return
log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
continue
}
if rm == nil {
log.WithError(err).WithField("upstream", upstream).
Warn("no response from upstream")
return
}
// those checks need to be independent of each other due to memory address issues
if !rm.Response {
log.WithError(err).WithField("upstream", upstream).
Warn("no response from upstream")
return
if rm == nil || !rm.Response {
log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
continue
}
u.successCount.Add(1)
log.Tracef("took %s to query the upstream %s", t, upstream)
log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
err = w.WriteMsg(rm)
if err != nil {
log.WithError(err).Error("got an error while writing the upstream resolver response")
if err = w.WriteMsg(rm); err != nil {
log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
}
// count the fails only if they happen sequentially
u.failsCount.Store(0)
return
}
u.failsCount.Add(1)
log.Error("all queries to the upstream nameservers failed with timeout")
log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
m := new(dns.Msg)
m.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(m); err != nil {
log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
}
}
// checkUpstreamFails counts fails and disables or enables upstream resolving