mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[client] Implement DNS query caching in DNSForwarder (#4574)
implements DNS query caching in the DNSForwarder to improve performance and provide fallback responses when upstream DNS servers fail. The cache stores successful DNS query results and serves them when upstream resolution fails. - Added a new cache component to store DNS query results by domain and query type - Integrated cache storage after successful DNS resolutions - Enhanced error handling to serve cached responses as fallback when upstream DNS fails
This commit is contained in:
78
client/internal/dnsfwd/cache.go
Normal file
78
client/internal/dnsfwd/cache.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package dnsfwd
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type cache struct {
|
||||
mu sync.RWMutex
|
||||
records map[string]*cacheEntry
|
||||
}
|
||||
|
||||
type cacheEntry struct {
|
||||
ip4Addrs []netip.Addr
|
||||
ip6Addrs []netip.Addr
|
||||
}
|
||||
|
||||
func newCache() *cache {
|
||||
return &cache{
|
||||
records: make(map[string]*cacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
entry, exists := c.records[normalizeDomain(domain)]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
switch reqType {
|
||||
case dns.TypeA:
|
||||
return slices.Clone(entry.ip4Addrs), true
|
||||
case dns.TypeAAAA:
|
||||
return slices.Clone(entry.ip6Addrs), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
norm := normalizeDomain(domain)
|
||||
entry, exists := c.records[norm]
|
||||
if !exists {
|
||||
entry = &cacheEntry{}
|
||||
c.records[norm] = entry
|
||||
}
|
||||
|
||||
switch reqType {
|
||||
case dns.TypeA:
|
||||
entry.ip4Addrs = slices.Clone(addrs)
|
||||
case dns.TypeAAAA:
|
||||
entry.ip6Addrs = slices.Clone(addrs)
|
||||
}
|
||||
}
|
||||
|
||||
// unset removes cached entries for the given domain and request type.
|
||||
func (c *cache) unset(domain string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.records, normalizeDomain(domain))
|
||||
}
|
||||
|
||||
// normalizeDomain converts an input domain into a canonical form used as cache key:
|
||||
// lowercase and fully-qualified (with trailing dot).
|
||||
func normalizeDomain(domain string) string {
|
||||
// dns.Fqdn ensures trailing dot; ToLower for consistent casing
|
||||
return dns.Fqdn(strings.ToLower(domain))
|
||||
}
|
||||
86
client/internal/dnsfwd/cache_test.go
Normal file
86
client/internal/dnsfwd/cache_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package dnsfwd
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func mustAddr(t *testing.T, s string) netip.Addr {
|
||||
t.Helper()
|
||||
a, err := netip.ParseAddr(s)
|
||||
if err != nil {
|
||||
t.Fatalf("parse addr %s: %v", s, err)
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func TestCacheNormalization(t *testing.T) {
|
||||
c := newCache()
|
||||
|
||||
// Mixed case, without trailing dot
|
||||
domainInput := "ExAmPlE.CoM"
|
||||
ipv4 := []netip.Addr{mustAddr(t, "1.2.3.4")}
|
||||
c.set(domainInput, 1 /* dns.TypeA */, ipv4)
|
||||
|
||||
// Lookup with lower, with trailing dot
|
||||
if got, ok := c.get("example.com.", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
|
||||
t.Fatalf("expected cached IPv4 result via normalized key, got=%v ok=%v", got, ok)
|
||||
}
|
||||
|
||||
// Lookup with different casing again
|
||||
if got, ok := c.get("EXAMPLE.COM", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
|
||||
t.Fatalf("expected cached IPv4 result via different casing, got=%v ok=%v", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSeparateTypes(t *testing.T) {
|
||||
c := newCache()
|
||||
|
||||
domain := "test.local"
|
||||
ipv4 := []netip.Addr{mustAddr(t, "10.0.0.1")}
|
||||
ipv6 := []netip.Addr{mustAddr(t, "2001:db8::1")}
|
||||
|
||||
c.set(domain, 1 /* A */, ipv4)
|
||||
c.set(domain, 28 /* AAAA */, ipv6)
|
||||
|
||||
got4, ok4 := c.get(domain, 1)
|
||||
if !ok4 || len(got4) != 1 || got4[0] != ipv4[0] {
|
||||
t.Fatalf("expected A record from cache, got=%v ok=%v", got4, ok4)
|
||||
}
|
||||
|
||||
got6, ok6 := c.get(domain, 28)
|
||||
if !ok6 || len(got6) != 1 || got6[0] != ipv6[0] {
|
||||
t.Fatalf("expected AAAA record from cache, got=%v ok=%v", got6, ok6)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheCloneOnGetAndSet(t *testing.T) {
|
||||
c := newCache()
|
||||
domain := "clone.test"
|
||||
|
||||
src := []netip.Addr{mustAddr(t, "8.8.8.8")}
|
||||
c.set(domain, 1, src)
|
||||
|
||||
// Mutate source slice; cache should be unaffected
|
||||
src[0] = mustAddr(t, "9.9.9.9")
|
||||
|
||||
got, ok := c.get(domain, 1)
|
||||
if !ok || len(got) != 1 || got[0].String() != "8.8.8.8" {
|
||||
t.Fatalf("expected cached value to be independent of source slice, got=%v ok=%v", got, ok)
|
||||
}
|
||||
|
||||
// Mutate returned slice; internal cache should remain unchanged
|
||||
got[0] = mustAddr(t, "4.4.4.4")
|
||||
got2, ok2 := c.get(domain, 1)
|
||||
if !ok2 || len(got2) != 1 || got2[0].String() != "8.8.8.8" {
|
||||
t.Fatalf("expected returned slice to be a clone, got=%v ok=%v", got2, ok2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMiss(t *testing.T) {
|
||||
c := newCache()
|
||||
if got, ok := c.get("missing.example", 1); ok || got != nil {
|
||||
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,6 +46,7 @@ type DNSForwarder struct {
|
||||
fwdEntries []*ForwarderEntry
|
||||
firewall firewaller
|
||||
resolver resolver
|
||||
cache *cache
|
||||
}
|
||||
|
||||
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
|
||||
@@ -56,6 +57,7 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
|
||||
firewall: firewall,
|
||||
statusRecorder: statusRecorder,
|
||||
resolver: net.DefaultResolver,
|
||||
cache: newCache(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,10 +105,39 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
// remove cache entries for domains that no longer appear
|
||||
f.removeStaleCacheEntries(f.fwdEntries, entries)
|
||||
|
||||
f.fwdEntries = entries
|
||||
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
|
||||
}
|
||||
|
||||
// removeStaleCacheEntries unsets cache items for domains that were present
|
||||
// in the old list but not present in the new list.
|
||||
func (f *DNSForwarder) removeStaleCacheEntries(oldEntries, newEntries []*ForwarderEntry) {
|
||||
if f.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
newSet := make(map[string]struct{}, len(newEntries))
|
||||
for _, e := range newEntries {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
newSet[e.Domain.PunycodeString()] = struct{}{}
|
||||
}
|
||||
|
||||
for _, e := range oldEntries {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
pattern := e.Domain.PunycodeString()
|
||||
if _, ok := newSet[pattern]; !ok {
|
||||
f.cache.unset(pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||
var result *multierror.Error
|
||||
|
||||
@@ -171,6 +202,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
||||
|
||||
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
f.cache.set(domain, question.Qtype, ips)
|
||||
|
||||
return resp
|
||||
}
|
||||
@@ -282,29 +314,69 @@ func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
}
|
||||
|
||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
||||
func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) {
|
||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response.
|
||||
func (f *DNSForwarder) handleDNSError(
|
||||
ctx context.Context,
|
||||
w dns.ResponseWriter,
|
||||
question dns.Question,
|
||||
resp *dns.Msg,
|
||||
domain string,
|
||||
err error,
|
||||
) {
|
||||
// Default to SERVFAIL; override below when appropriate.
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
|
||||
qType := question.Qtype
|
||||
qTypeName := dns.TypeToString[qType]
|
||||
|
||||
// Prefer typed DNS errors; fall back to generic logging otherwise.
|
||||
var dnsErr *net.DNSError
|
||||
|
||||
switch {
|
||||
case errors.As(err, &dnsErr):
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
if dnsErr.IsNotFound {
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
|
||||
if !errors.As(err, &dnsErr) {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if dnsErr.Server != "" {
|
||||
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err)
|
||||
} else {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
// NotFound: set NXDOMAIN / appropriate code via helper.
|
||||
if dnsErr.IsNotFound {
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
default:
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
f.cache.set(domain, question.Qtype, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Upstream failed but we might have a cached answer—serve it if present.
|
||||
if ips, ok := f.cache.get(domain, qType); ok {
|
||||
if len(ips) > 0 {
|
||||
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write cached DNS response: %v", writeErr)
|
||||
}
|
||||
} else { // send NXDOMAIN / appropriate code if cache is empty
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// No cache. Log with or without the server field for more context.
|
||||
if dnsErr.Server != "" {
|
||||
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
|
||||
} else {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", err)
|
||||
// Write final failure response.
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -648,6 +648,95 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
|
||||
}
|
||||
|
||||
// Ensures that when the first query succeeds and populates the cache,
|
||||
// a subsequent upstream failure still returns a successful response from cache.
|
||||
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("example.com")
|
||||
require.NoError(t, err)
|
||||
entries := []*ForwarderEntry{{Domain: d, ResID: "res-cache"}}
|
||||
forwarder.UpdateDomains(entries)
|
||||
|
||||
ip := netip.MustParseAddr("1.2.3.4")
|
||||
|
||||
// First call resolves successfully and populates cache
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
|
||||
Return([]netip.Addr{ip}, nil).Once()
|
||||
|
||||
// Second call fails upstream; forwarder should serve from cache
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
|
||||
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
|
||||
|
||||
// First query: populate cache
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
resp1 := forwarder.handleDNSQuery(w1, q1)
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
|
||||
// Second query: serve from cache after upstream failure
|
||||
q2 := &dns.Msg{}
|
||||
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(w2, q2)
|
||||
|
||||
require.NotNil(t, writtenResp, "expected response to be written")
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
require.Len(t, writtenResp.Answer, 1)
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Verifies that cache normalization works across casing and trailing dot variations.
|
||||
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("ExAmPlE.CoM")
|
||||
require.NoError(t, err)
|
||||
entries := []*ForwarderEntry{{Domain: d, ResID: "res-norm"}}
|
||||
forwarder.UpdateDomains(entries)
|
||||
|
||||
ip := netip.MustParseAddr("9.8.7.6")
|
||||
|
||||
// Initial resolution with mixed case to populate cache
|
||||
mixedQuery := "ExAmPlE.CoM"
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(strings.ToLower(mixedQuery))).
|
||||
Return([]netip.Addr{ip}, nil).Once()
|
||||
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
resp1 := forwarder.handleDNSQuery(w1, q1)
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
|
||||
// Subsequent query without dot and upper case should hit cache even if upstream fails
|
||||
// Forwarder lowercases and uses the question name as-is (no trailing dot here)
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", strings.ToLower("EXAMPLE.COM")).
|
||||
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
|
||||
|
||||
q2 := &dns.Msg{}
|
||||
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(w2, q2)
|
||||
|
||||
require.NotNil(t, writtenResp)
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
require.Len(t, writtenResp.Answer, 1)
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||
// Test complex overlapping pattern scenarios
|
||||
mockFirewall := &MockFirewall{}
|
||||
|
||||
Reference in New Issue
Block a user