mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 17:56:39 +00:00
Compare commits
2 Commits
fix/manage
...
mgmt-cache
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e6c62410ea | ||
|
|
7486738d0a |
@@ -1,7 +1,10 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -192,6 +195,12 @@ func (c *HandlerChain) logHandlers() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
c.dispatch(w, r, math.MaxInt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// dispatch routes a DNS request through the chain, skipping handlers with
|
||||||
|
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
|
||||||
|
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -216,6 +225,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
// Try handlers in priority order
|
// Try handlers in priority order
|
||||||
for _, entry := range handlers {
|
for _, entry := range handlers {
|
||||||
|
if entry.Priority > maxPriority {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if !c.isHandlerMatch(qname, entry) {
|
if !c.isHandlerMatch(qname, entry) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -273,6 +285,55 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
|
|||||||
cw.response.Len(), meta, time.Since(startTime))
|
cw.response.Len(), meta, time.Since(startTime))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResolveInternal runs an in-process DNS query against the chain, skipping any
|
||||||
|
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
|
||||||
|
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
|
||||||
|
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
|
||||||
|
// (bounded by the invoked handler's internal timeout).
|
||||||
|
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
||||||
|
if len(r.Question) == 0 {
|
||||||
|
return nil, fmt.Errorf("empty question")
|
||||||
|
}
|
||||||
|
|
||||||
|
base := &internalResponseWriter{}
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
c.dispatch(base, r, maxPriority)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-ctx.Done():
|
||||||
|
// Prefer a completed response if dispatch finished concurrently with cancellation.
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
|
||||||
|
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
|
||||||
|
strings.ToLower(r.Question[0].Name), maxPriority)
|
||||||
|
}
|
||||||
|
return base.response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
|
||||||
|
// priority ≤ maxPriority.
|
||||||
|
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
for _, h := range c.handlers {
|
||||||
|
if h.Pattern == "." && h.Priority <= maxPriority {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||||
switch {
|
switch {
|
||||||
case entry.Pattern == ".":
|
case entry.Pattern == ".":
|
||||||
@@ -291,3 +352,36 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// internalResponseWriter captures a dns.Msg for in-process chain queries.
|
||||||
|
type internalResponseWriter struct {
|
||||||
|
response *dns.Msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
|
||||||
|
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
|
||||||
|
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||||
|
|
||||||
|
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
|
||||||
|
// still surface their answer to ResolveInternal.
|
||||||
|
func (w *internalResponseWriter) Write(p []byte) (int, error) {
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
if err := msg.Unpack(p); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
w.response = msg
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *internalResponseWriter) Close() error { return nil }
|
||||||
|
func (w *internalResponseWriter) TsigStatus() error { return nil }
|
||||||
|
|
||||||
|
// TsigTimersOnly is part of dns.ResponseWriter.
|
||||||
|
func (w *internalResponseWriter) TsigTimersOnly(bool) {
|
||||||
|
// no-op: in-process queries carry no TSIG state.
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hijack is part of dns.ResponseWriter.
|
||||||
|
func (w *internalResponseWriter) Hijack() {
|
||||||
|
// no-op: in-process queries have no underlying connection to hand off.
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
package dns_test
|
package dns_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
@@ -1042,3 +1046,163 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// answeringHandler writes a fixed A record to ack the query. Used to verify
|
||||||
|
// which handler ResolveInternal dispatches to.
|
||||||
|
type answeringHandler struct {
|
||||||
|
name string
|
||||||
|
ip string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
resp := &dns.Msg{}
|
||||||
|
resp.SetReply(r)
|
||||||
|
resp.Answer = []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||||
|
A: net.ParseIP(h.ip).To4(),
|
||||||
|
}}
|
||||||
|
_ = w.WriteMsg(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *answeringHandler) String() string { return h.name }
|
||||||
|
|
||||||
|
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
|
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
||||||
|
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
|
||||||
|
|
||||||
|
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
||||||
|
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
|
||||||
|
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, resp)
|
||||||
|
assert.Equal(t, 1, len(resp.Answer))
|
||||||
|
a, ok := resp.Answer[0].(*dns.A)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
||||||
|
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
||||||
|
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||||
|
assert.Error(t, err, "no handler at or below maxPriority should error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
|
||||||
|
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
|
||||||
|
type rawWriteHandler struct {
|
||||||
|
ip string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
resp := &dns.Msg{}
|
||||||
|
resp.SetReply(r)
|
||||||
|
resp.Answer = []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||||
|
A: net.ParseIP(h.ip).To4(),
|
||||||
|
}}
|
||||||
|
packed, err := resp.Pack()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _ = w.Write(packed)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
|
||||||
|
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
require.Len(t, resp.Answer, 1)
|
||||||
|
a, ok := resp.Answer[0].(*dns.A)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
|
||||||
|
type hangingHandler struct {
|
||||||
|
block chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
<-h.block
|
||||||
|
resp := &dns.Msg{}
|
||||||
|
resp.SetReply(r)
|
||||||
|
_ = w.WriteMsg(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *hangingHandler) String() string { return "hangingHandler" }
|
||||||
|
|
||||||
|
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
h := &hangingHandler{block: make(chan struct{})}
|
||||||
|
defer close(h.block)
|
||||||
|
|
||||||
|
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
||||||
|
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||||
|
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
|
||||||
|
|
||||||
|
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
|
||||||
|
|
||||||
|
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
||||||
|
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
|
||||||
|
|
||||||
|
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
|
||||||
|
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
|
||||||
|
|
||||||
|
chain.AddHandler(".", h, nbdns.PriorityDefault)
|
||||||
|
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
|
||||||
|
|
||||||
|
chain.RemoveHandler(".", nbdns.PriorityDefault)
|
||||||
|
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
||||||
|
|
||||||
|
// Primary nsgroup case: root handler lands at PriorityUpstream.
|
||||||
|
chain.AddHandler(".", h, nbdns.PriorityUpstream)
|
||||||
|
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
|
||||||
|
chain.RemoveHandler(".", nbdns.PriorityUpstream)
|
||||||
|
|
||||||
|
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
|
||||||
|
chain.AddHandler(".", h, nbdns.PriorityFallback)
|
||||||
|
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
|
||||||
|
chain.RemoveHandler(".", nbdns.PriorityFallback)
|
||||||
|
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,40 +2,79 @@ package mgmt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
|
|
||||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
const dnsTimeout = 5 * time.Second
|
const (
|
||||||
|
dnsTimeout = 5 * time.Second
|
||||||
|
defaultTTL = 300 * time.Second
|
||||||
|
refreshBackoff = 30 * time.Second
|
||||||
|
|
||||||
// Resolver caches critical NetBird infrastructure domains
|
// envMgmtCacheTTL overrides defaultTTL for integration/dev testing.
|
||||||
|
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChainResolver lets the cache refresh stale entries through the DNS handler
|
||||||
|
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
|
||||||
|
// system resolver.
|
||||||
|
type ChainResolver interface {
|
||||||
|
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
|
||||||
|
HasRootHandlerAtOrBelow(maxPriority int) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
|
||||||
|
// records and cachedAt are set at construction and treated as immutable;
|
||||||
|
// lastFailedRefresh and consecFailures are mutable and must be accessed under
|
||||||
|
// Resolver.mutex.
|
||||||
|
type cachedRecord struct {
|
||||||
|
records []dns.RR
|
||||||
|
cachedAt time.Time
|
||||||
|
lastFailedRefresh time.Time
|
||||||
|
consecFailures int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolver caches critical NetBird infrastructure domains.
|
||||||
|
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
||||||
type Resolver struct {
|
type Resolver struct {
|
||||||
records map[dns.Question][]dns.RR
|
records map[dns.Question]*cachedRecord
|
||||||
mgmtDomain *domain.Domain
|
mgmtDomain *domain.Domain
|
||||||
serverDomains *dnsconfig.ServerDomains
|
serverDomains *dnsconfig.ServerDomains
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
|
||||||
|
|
||||||
type ipsResponse struct {
|
chain ChainResolver
|
||||||
ips []netip.Addr
|
chainMaxPriority int
|
||||||
err error
|
refreshGroup singleflight.Group
|
||||||
|
|
||||||
|
// refreshing tracks questions whose refresh is running via the OS
|
||||||
|
// fallback path. A ServeDNS hit for a question in this map indicates
|
||||||
|
// the OS resolver routed the recursive query back to us (loop). Only
|
||||||
|
// the OS path arms this so chain-path refreshes don't produce false
|
||||||
|
// positives. The atomic bool is CAS-flipped once per refresh to
|
||||||
|
// throttle the warning log.
|
||||||
|
refreshing map[dns.Question]*atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResolver creates a new management domains cache resolver.
|
// NewResolver creates a new management domains cache resolver.
|
||||||
func NewResolver() *Resolver {
|
func NewResolver() *Resolver {
|
||||||
return &Resolver{
|
return &Resolver{
|
||||||
records: make(map[dns.Question][]dns.RR),
|
records: make(map[dns.Question]*cachedRecord),
|
||||||
|
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,7 +83,19 @@ func (m *Resolver) String() string {
|
|||||||
return "MgmtCacheResolver"
|
return "MgmtCacheResolver"
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeDNS implements dns.Handler interface.
|
// SetChainResolver wires the handler chain used to refresh stale cache entries.
|
||||||
|
// maxPriority caps which handlers may answer refresh queries (typically
|
||||||
|
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
|
||||||
|
// mgmt/route/local handlers are skipped).
|
||||||
|
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
m.chain = chain
|
||||||
|
m.chainMaxPriority = maxPriority
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeDNS serves cached A/AAAA records. Stale entries are returned
|
||||||
|
// immediately and refreshed asynchronously (stale-while-revalidate).
|
||||||
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
m.continueToNext(w, r)
|
m.continueToNext(w, r)
|
||||||
@@ -60,7 +111,14 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
records, found := m.records[question]
|
cached, found := m.records[question]
|
||||||
|
inflight := m.refreshing[question]
|
||||||
|
var shouldRefresh bool
|
||||||
|
if found {
|
||||||
|
stale := time.Since(cached.cachedAt) > cacheTTL()
|
||||||
|
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
|
||||||
|
shouldRefresh = stale && !inBackoff
|
||||||
|
}
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
@@ -68,12 +126,23 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if inflight != nil && inflight.CompareAndSwap(false, true) {
|
||||||
|
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
|
||||||
|
question.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip scheduling a refresh goroutine if one is already inflight for
|
||||||
|
// this question; singleflight would dedup anyway but skipping avoids
|
||||||
|
// a parked goroutine per stale hit under bursty load.
|
||||||
|
if shouldRefresh && inflight == nil {
|
||||||
|
m.scheduleRefresh(question)
|
||||||
|
}
|
||||||
|
|
||||||
resp := &dns.Msg{}
|
resp := &dns.Msg{}
|
||||||
resp.SetReply(r)
|
resp.SetReply(r)
|
||||||
resp.Authoritative = false
|
resp.Authoritative = false
|
||||||
resp.RecursionAvailable = true
|
resp.RecursionAvailable = true
|
||||||
|
resp.Answer = append(resp.Answer, cached.records...)
|
||||||
resp.Answer = append(resp.Answer, records...)
|
|
||||||
|
|
||||||
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
|
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
|
||||||
|
|
||||||
@@ -98,101 +167,177 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddDomain manually adds a domain to cache by resolving it.
|
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
||||||
|
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
||||||
|
// entry for that qtype.
|
||||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
ips, err := lookupIPWithExtraTimeout(ctx, d)
|
aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName)
|
||||||
if err != nil {
|
|
||||||
return err
|
if errA != nil && errAAAA != nil {
|
||||||
|
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
||||||
}
|
}
|
||||||
|
|
||||||
var aRecords, aaaaRecords []dns.RR
|
// Dual NODATA: don't wipe existing entries, let the caller retry.
|
||||||
for _, ip := range ips {
|
if errA == nil && errAAAA == nil && len(aRecords) == 0 && len(aaaaRecords) == 0 {
|
||||||
if ip.Is4() {
|
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
|
||||||
rr := &dns.A{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: dnsName,
|
|
||||||
Rrtype: dns.TypeA,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: 300,
|
|
||||||
},
|
|
||||||
A: ip.AsSlice(),
|
|
||||||
}
|
|
||||||
aRecords = append(aRecords, rr)
|
|
||||||
} else if ip.Is6() {
|
|
||||||
rr := &dns.AAAA{
|
|
||||||
Hdr: dns.RR_Header{
|
|
||||||
Name: dnsName,
|
|
||||||
Rrtype: dns.TypeAAAA,
|
|
||||||
Class: dns.ClassINET,
|
|
||||||
Ttl: 300,
|
|
||||||
},
|
|
||||||
AAAA: ip.AsSlice(),
|
|
||||||
}
|
|
||||||
aaaaRecords = append(aaaaRecords, rr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if len(aRecords) > 0 {
|
m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now)
|
||||||
aQuestion := dns.Question{
|
m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now)
|
||||||
Name: dnsName,
|
|
||||||
Qtype: dns.TypeA,
|
|
||||||
Qclass: dns.ClassINET,
|
|
||||||
}
|
|
||||||
m.records[aQuestion] = aRecords
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(aaaaRecords) > 0 {
|
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
||||||
aaaaQuestion := dns.Question{
|
|
||||||
Name: dnsName,
|
|
||||||
Qtype: dns.TypeAAAA,
|
|
||||||
Qclass: dns.ClassINET,
|
|
||||||
}
|
|
||||||
m.records[aaaaQuestion] = aaaaRecords
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mutex.Unlock()
|
|
||||||
|
|
||||||
log.Debugf("added domain=%s with %d A records and %d AAAA records",
|
|
||||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
|
// applyFamilyRecords writes records, evicts on NODATA, leaves the cache
|
||||||
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
|
// untouched on error. Caller holds m.mutex.
|
||||||
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
|
func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) {
|
||||||
resultChan := make(chan *ipsResponse, 1)
|
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||||
|
switch {
|
||||||
go func() {
|
case len(records) > 0:
|
||||||
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
|
m.records[q] = &cachedRecord{records: records, cachedAt: now}
|
||||||
resultChan <- &ipsResponse{
|
case err == nil:
|
||||||
err: err,
|
delete(m.records, q)
|
||||||
ips: ips,
|
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
|
|
||||||
var resp *ipsResponse
|
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
|
||||||
|
// unique in-flight key; bursty stale hits share its channel.
|
||||||
|
func (m *Resolver) scheduleRefresh(question dns.Question) {
|
||||||
|
key := question.Name + "|" + dns.TypeToString[question.Qtype]
|
||||||
|
_ = m.refreshGroup.DoChan(key, func() (any, error) {
|
||||||
|
return nil, m.refreshQuestion(question)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
// refreshQuestion replaces the cached records on success, or marks the entry
|
||||||
case <-time.After(dnsTimeout + time.Millisecond*500):
|
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
|
||||||
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
|
// a resolver loop by spotting a query for this same question arriving on us.
|
||||||
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
|
func (m *Resolver) refreshQuestion(question dns.Question) error {
|
||||||
case <-ctx.Done():
|
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||||
return nil, ctx.Err()
|
defer cancel()
|
||||||
case resp = <-resultChan:
|
|
||||||
|
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
|
||||||
|
if err != nil {
|
||||||
|
m.markRefreshFailed(question)
|
||||||
|
return fmt.Errorf("parse domain: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.err != nil {
|
records, err := m.lookupRecords(ctx, d, question)
|
||||||
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
|
if err != nil {
|
||||||
|
fails := m.markRefreshFailed(question)
|
||||||
|
logf := log.Warnf
|
||||||
|
if fails > 1 {
|
||||||
|
logf = log.Debugf
|
||||||
}
|
}
|
||||||
return resp.ips, nil
|
logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)",
|
||||||
|
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
|
||||||
|
if len(records) == 0 {
|
||||||
|
m.mutex.Lock()
|
||||||
|
delete(m.records, question)
|
||||||
|
m.mutex.Unlock()
|
||||||
|
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
|
||||||
|
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
m.mutex.Lock()
|
||||||
|
if _, stillCached := m.records[question]; !stillCached {
|
||||||
|
m.mutex.Unlock()
|
||||||
|
log.Debugf("skipping refresh write for domain=%s type=%s: entry was removed during refresh",
|
||||||
|
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.records[question] = &cachedRecord{records: records, cachedAt: now}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
|
||||||
|
log.Infof("refreshed mgmt cache domain=%s type=%s",
|
||||||
|
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Resolver) markRefreshing(question dns.Question) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
m.refreshing[question] = &atomic.Bool{}
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Resolver) clearRefreshing(question dns.Question) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
delete(m.refreshing, question)
|
||||||
|
m.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// markRefreshFailed arms the backoff and returns the new consecutive-failure
|
||||||
|
// count so callers can downgrade subsequent failure logs to debug.
|
||||||
|
func (m *Resolver) markRefreshFailed(question dns.Question) int {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
c, ok := m.records[question]
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
c.lastFailedRefresh = time.Now()
|
||||||
|
c.consecFailures++
|
||||||
|
return c.consecFailures
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
|
||||||
|
// callers tell records, NODATA (nil err, no records), and failure apart.
|
||||||
|
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
|
||||||
|
m.mutex.RLock()
|
||||||
|
chain := m.chain
|
||||||
|
maxPriority := m.chainMaxPriority
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
||||||
|
aRecords, errA = lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
|
||||||
|
aaaaRecords, errAAAA = lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: drop once every supported OS registers a fallback resolver. Safe
|
||||||
|
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
|
||||||
|
// not the system resolver, so net.DefaultResolver will not loop back.
|
||||||
|
aRecords, errA = osLookup(ctx, d, dnsName, dns.TypeA)
|
||||||
|
aaaaRecords, errAAAA = osLookup(ctx, d, dnsName, dns.TypeAAAA)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupRecords resolves a single record type via chain or OS. The OS branch
|
||||||
|
// arms the loop detector for the duration of its call so that ServeDNS can
|
||||||
|
// spot the OS resolver routing the recursive query back to us.
|
||||||
|
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
|
||||||
|
m.mutex.RLock()
|
||||||
|
chain := m.chain
|
||||||
|
maxPriority := m.chainMaxPriority
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
||||||
|
return lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: drop once every supported OS registers a fallback resolver.
|
||||||
|
m.markRefreshing(q)
|
||||||
|
defer m.clearRefreshing(q)
|
||||||
|
|
||||||
|
return osLookup(ctx, d, q.Name, q.Qtype)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PopulateFromConfig extracts and caches domains from the client configuration.
|
// PopulateFromConfig extracts and caches domains from the client configuration.
|
||||||
@@ -224,19 +369,8 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
aQuestion := dns.Question{
|
delete(m.records, dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET})
|
||||||
Name: dnsName,
|
delete(m.records, dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET})
|
||||||
Qtype: dns.TypeA,
|
|
||||||
Qclass: dns.ClassINET,
|
|
||||||
}
|
|
||||||
delete(m.records, aQuestion)
|
|
||||||
|
|
||||||
aaaaQuestion := dns.Question{
|
|
||||||
Name: dnsName,
|
|
||||||
Qtype: dns.TypeAAAA,
|
|
||||||
Qclass: dns.ClassINET,
|
|
||||||
}
|
|
||||||
delete(m.records, aaaaQuestion)
|
|
||||||
|
|
||||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||||
return nil
|
return nil
|
||||||
@@ -394,3 +528,82 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
|
|||||||
|
|
||||||
return domains
|
return domains
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cacheTTL() time.Duration {
|
||||||
|
if v := os.Getenv(envMgmtCacheTTL); v != "" {
|
||||||
|
if d, err := time.ParseDuration(v); err == nil && d > 0 {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
// lookupViaChain resolves via the handler chain and rewrites each RR to use
|
||||||
|
// dnsName as owner and cacheTTL() as TTL, so CNAME-backed domains don't cache
|
||||||
|
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
|
||||||
|
func lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
|
||||||
|
msg := &dns.Msg{}
|
||||||
|
msg.SetQuestion(dnsName, qtype)
|
||||||
|
msg.RecursionDesired = true
|
||||||
|
|
||||||
|
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("chain resolve: %w", err)
|
||||||
|
}
|
||||||
|
if resp == nil {
|
||||||
|
return nil, fmt.Errorf("chain resolve returned nil response")
|
||||||
|
}
|
||||||
|
if resp.Rcode != dns.RcodeSuccess {
|
||||||
|
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
|
||||||
|
}
|
||||||
|
|
||||||
|
ttl := uint32(cacheTTL().Seconds())
|
||||||
|
var filtered []dns.RR
|
||||||
|
for _, rr := range resp.Answer {
|
||||||
|
if rr.Header().Class != dns.ClassINET {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch r := rr.(type) {
|
||||||
|
case *dns.A:
|
||||||
|
if qtype != dns.TypeA {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered = append(filtered, &dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: dnsName, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl},
|
||||||
|
A: append(net.IP(nil), r.A...),
|
||||||
|
})
|
||||||
|
case *dns.AAAA:
|
||||||
|
if qtype != dns.TypeAAAA {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered = append(filtered, &dns.AAAA{
|
||||||
|
Hdr: dns.RR_Header{Name: dnsName, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: ttl},
|
||||||
|
AAAA: append(net.IP(nil), r.AAAA...),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// osLookup resolves a single family via net.DefaultResolver using resutil,
|
||||||
|
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
|
||||||
|
// returns (nil, nil).
|
||||||
|
func osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
|
||||||
|
network := resutil.NetworkForQtype(qtype)
|
||||||
|
if network == "" {
|
||||||
|
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
||||||
|
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
||||||
|
|
||||||
|
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
|
||||||
|
if result.Rcode == dns.RcodeSuccess {
|
||||||
|
return resutil.IPsToRRs(dnsName, result.IPs, uint32(cacheTTL().Seconds())), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Err != nil {
|
||||||
|
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
|
||||||
|
}
|
||||||
|
|||||||
359
client/internal/dns/mgmt/mgmt_refresh_test.go
Normal file
359
client/internal/dns/mgmt/mgmt_refresh_test.go
Normal file
@@ -0,0 +1,359 @@
|
|||||||
|
package mgmt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeChain struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
calls map[string]int
|
||||||
|
answers map[string][]dns.RR
|
||||||
|
err error
|
||||||
|
hasRoot bool
|
||||||
|
onLookup func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFakeChain() *fakeChain {
|
||||||
|
return &fakeChain{
|
||||||
|
calls: map[string]int{},
|
||||||
|
answers: map[string][]dns.RR{},
|
||||||
|
hasRoot: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
return f.hasRoot
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
||||||
|
f.mu.Lock()
|
||||||
|
q := msg.Question[0]
|
||||||
|
key := q.Name + "|" + dns.TypeToString[q.Qtype]
|
||||||
|
f.calls[key]++
|
||||||
|
answers := f.answers[key]
|
||||||
|
err := f.err
|
||||||
|
onLookup := f.onLookup
|
||||||
|
f.mu.Unlock()
|
||||||
|
|
||||||
|
if onLookup != nil {
|
||||||
|
onLookup()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resp := &dns.Msg{}
|
||||||
|
resp.SetReply(msg)
|
||||||
|
resp.Answer = answers
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
key := name + "|" + dns.TypeToString[qtype]
|
||||||
|
hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60}
|
||||||
|
switch qtype {
|
||||||
|
case dns.TypeA:
|
||||||
|
f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}}
|
||||||
|
case dns.TypeAAAA:
|
||||||
|
f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
return f.calls[name+"|"+dns.TypeToString[qtype]]
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitFor polls the predicate until it returns true or the deadline passes.
|
||||||
|
func waitFor(t *testing.T, d time.Duration, fn func() bool) {
|
||||||
|
t.Helper()
|
||||||
|
deadline := time.Now().Add(d)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if fn() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
}
|
||||||
|
t.Fatalf("condition not met within %s", d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryA(t *testing.T, r *Resolver, name string) *dns.Msg {
|
||||||
|
t.Helper()
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
msg.SetQuestion(name, dns.TypeA)
|
||||||
|
w := &test.MockResponseWriter{}
|
||||||
|
r.ServeDNS(w, msg)
|
||||||
|
return w.GetLastResponse()
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstA(t *testing.T, resp *dns.Msg) string {
|
||||||
|
t.Helper()
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
require.Greater(t, len(resp.Answer), 0, "expected at least one answer")
|
||||||
|
a, ok := resp.Answer[0].(*dns.A)
|
||||||
|
require.True(t, ok, "expected A record")
|
||||||
|
return a.A.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{
|
||||||
|
records: []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||||
|
A: net.ParseIP("10.0.0.1").To4(),
|
||||||
|
}},
|
||||||
|
cachedAt: time.Now(), // fresh
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := queryA(t, r, "mgmt.example.com.")
|
||||||
|
assert.Equal(t, "10.0.0.1", firstA(t, resp))
|
||||||
|
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||||
|
r.records[q] = &cachedRecord{
|
||||||
|
records: []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||||
|
A: net.ParseIP("10.0.0.1").To4(),
|
||||||
|
}},
|
||||||
|
cachedAt: time.Now().Add(-2 * defaultTTL), // stale
|
||||||
|
}
|
||||||
|
|
||||||
|
// First query: serves stale immediately.
|
||||||
|
resp := queryA(t, r, "mgmt.example.com.")
|
||||||
|
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
|
||||||
|
|
||||||
|
waitFor(t, time.Second, func() bool {
|
||||||
|
return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1
|
||||||
|
})
|
||||||
|
|
||||||
|
// Next query should now return the refreshed IP.
|
||||||
|
waitFor(t, time.Second, func() bool {
|
||||||
|
resp := queryA(t, r, "mgmt.example.com.")
|
||||||
|
return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2"
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||||
|
|
||||||
|
var inflight atomic.Int32
|
||||||
|
var maxInflight atomic.Int32
|
||||||
|
chain.onLookup = func() {
|
||||||
|
cur := inflight.Add(1)
|
||||||
|
defer inflight.Add(-1)
|
||||||
|
for {
|
||||||
|
prev := maxInflight.Load()
|
||||||
|
if cur <= prev || maxInflight.CompareAndSwap(prev, cur) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide
|
||||||
|
}
|
||||||
|
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||||
|
r.records[q] = &cachedRecord{
|
||||||
|
records: []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||||
|
A: net.ParseIP("10.0.0.1").To4(),
|
||||||
|
}},
|
||||||
|
cachedAt: time.Now().Add(-2 * defaultTTL),
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
queryA(t, r, "mgmt.example.com.")
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
waitFor(t, 2*time.Second, func() bool {
|
||||||
|
return inflight.Load() == 0
|
||||||
|
})
|
||||||
|
|
||||||
|
calls := chain.callCount("mgmt.example.com.", dns.TypeA)
|
||||||
|
assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls)
|
||||||
|
assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.err = errors.New("boom")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||||
|
r.records[q] = &cachedRecord{
|
||||||
|
records: []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||||
|
A: net.ParseIP("10.0.0.1").To4(),
|
||||||
|
}},
|
||||||
|
cachedAt: time.Now().Add(-2 * defaultTTL),
|
||||||
|
}
|
||||||
|
|
||||||
|
// First stale hit triggers a refresh attempt that fails.
|
||||||
|
resp := queryA(t, r, "mgmt.example.com.")
|
||||||
|
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails")
|
||||||
|
|
||||||
|
waitFor(t, time.Second, func() bool {
|
||||||
|
return chain.callCount("mgmt.example.com.", dns.TypeA) == 1
|
||||||
|
})
|
||||||
|
waitFor(t, time.Second, func() bool {
|
||||||
|
r.mutex.RLock()
|
||||||
|
defer r.mutex.RUnlock()
|
||||||
|
c, ok := r.records[q]
|
||||||
|
return ok && !c.lastFailedRefresh.IsZero()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Subsequent stale hits within backoff window should not schedule more refreshes.
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
queryA(t, r, "mgmt.example.com.")
|
||||||
|
}
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.hasRoot = false
|
||||||
|
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
// With hasRoot=false the chain must not be consulted. Use a short
|
||||||
|
// deadline so the OS fallback returns quickly without waiting on a
|
||||||
|
// real network call in CI.
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
_, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.")
|
||||||
|
|
||||||
|
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA),
|
||||||
|
"chain must not be used when no root handler is registered at the bound priority")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
|
||||||
|
// ServeDNS being invoked for a question while a refresh for that question
|
||||||
|
// is inflight indicates a resolver loop (OS resolver sent the recursive
|
||||||
|
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
|
||||||
|
r := NewResolver()
|
||||||
|
|
||||||
|
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||||
|
r.records[q] = &cachedRecord{
|
||||||
|
records: []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||||
|
A: net.ParseIP("10.0.0.1").To4(),
|
||||||
|
}},
|
||||||
|
cachedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate an inflight refresh.
|
||||||
|
r.markRefreshing(q)
|
||||||
|
defer r.clearRefreshing(q)
|
||||||
|
|
||||||
|
resp := queryA(t, r, "mgmt.example.com.")
|
||||||
|
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries")
|
||||||
|
|
||||||
|
r.mutex.RLock()
|
||||||
|
inflight := r.refreshing[q]
|
||||||
|
r.mutex.RUnlock()
|
||||||
|
require.NotNil(t, inflight)
|
||||||
|
assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
|
||||||
|
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||||
|
r.records[q] = &cachedRecord{
|
||||||
|
records: []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||||
|
A: net.ParseIP("10.0.0.1").To4(),
|
||||||
|
}},
|
||||||
|
cachedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
r.markRefreshing(q)
|
||||||
|
defer r.clearRefreshing(q)
|
||||||
|
|
||||||
|
// Multiple ServeDNS calls during the same refresh must not re-set the flag
|
||||||
|
// (CompareAndSwap from false -> true returns true only on the first call).
|
||||||
|
for range 5 {
|
||||||
|
queryA(t, r, "mgmt.example.com.")
|
||||||
|
}
|
||||||
|
|
||||||
|
r.mutex.RLock()
|
||||||
|
inflight := r.refreshing[q]
|
||||||
|
r.mutex.RUnlock()
|
||||||
|
assert.True(t, inflight.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
|
||||||
|
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||||
|
r.records[q] = &cachedRecord{
|
||||||
|
records: []dns.RR{&dns.A{
|
||||||
|
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||||
|
A: net.ParseIP("10.0.0.1").To4(),
|
||||||
|
}},
|
||||||
|
cachedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
queryA(t, r, "mgmt.example.com.")
|
||||||
|
|
||||||
|
r.mutex.RLock()
|
||||||
|
_, ok := r.refreshing[q]
|
||||||
|
r.mutex.RUnlock()
|
||||||
|
assert.False(t, ok, "no refresh inflight means no loop tracking")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
|
||||||
|
r := NewResolver()
|
||||||
|
chain := newFakeChain()
|
||||||
|
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||||
|
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
|
||||||
|
r.SetChainResolver(chain, 50)
|
||||||
|
|
||||||
|
require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com")))
|
||||||
|
|
||||||
|
resp := queryA(t, r, "mgmt.example.com.")
|
||||||
|
assert.Equal(t, "10.0.0.2", firstA(t, resp))
|
||||||
|
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA))
|
||||||
|
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA))
|
||||||
|
}
|
||||||
@@ -212,6 +212,7 @@ func newDefaultServer(
|
|||||||
ctx, stop := context.WithCancel(ctx)
|
ctx, stop := context.WithCancel(ctx)
|
||||||
|
|
||||||
mgmtCacheResolver := mgmt.NewResolver()
|
mgmtCacheResolver := mgmt.NewResolver()
|
||||||
|
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
|
||||||
|
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
|||||||
Reference in New Issue
Block a user