From 5dd5a56379f8eda74e7218d18c444d289f659293 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 3 Dec 2025 22:00:23 -0500 Subject: [PATCH] Add caching to the dns requests - is this good enough? --- clients/clients.go | 6 ++-- holepunch/holepunch.go | 2 +- netstack2/proxy.go | 76 ++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 77 insertions(+), 7 deletions(-) diff --git a/clients/clients.go b/clients/clients.go index 7d22e45..4ce1a83 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -694,7 +694,7 @@ func (s *WireGuardService) ensureTargets(targets []Target) error { s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges) - logger.Info("Added target subnet from %s to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.PortRange) + logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) } return nil @@ -1085,7 +1085,7 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) { s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges) - logger.Info("Added target subnet from %s to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.PortRange) + logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) } } @@ -1201,7 +1201,7 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) { } s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges) - logger.Info("Added target subnet from %s to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.PortRange) + logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) } } diff --git a/holepunch/holepunch.go b/holepunch/holepunch.go index 2447ea4..379bddd 100644 --- a/holepunch/holepunch.go +++ b/holepunch/holepunch.go @@ -38,7 +38,7 @@ type Manager struct { sendHolepunchInterval time.Duration } -const sendHolepunchIntervalMax = 3 * time.Second +const sendHolepunchIntervalMax = 60 * time.Second const sendHolepunchIntervalMin = 1 * time.Second // NewManager creates a new hole punch manager diff --git a/netstack2/proxy.go b/netstack2/proxy.go index 8e9c5e3..4b2e562 100644 --- a/netstack2/proxy.go +++ b/netstack2/proxy.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/fosrl/newt/logger" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/checksum" @@ -150,6 +151,59 @@ type natState struct { rewrittenTo netip.Addr // The address we rewrote to } +// dnsCache entry for caching resolved addresses +type dnsCacheEntry struct { + addr netip.Addr + expiresAt time.Time +} + +// dnsCache provides TTL-based caching for DNS lookups +type dnsCache struct { + mu sync.RWMutex + entries map[string]*dnsCacheEntry + ttl time.Duration +} + +// newDNSCache creates a new DNS cache with the specified TTL +func newDNSCache(ttl time.Duration) *dnsCache { + return &dnsCache{ + entries: make(map[string]*dnsCacheEntry), + ttl: ttl, + } +} + +// get retrieves a cached address if it exists and hasn't expired +func (c *dnsCache) get(domain string) (netip.Addr, bool) { + c.mu.RLock() + entry, exists := c.entries[domain] + c.mu.RUnlock() + + if !exists { + return netip.Addr{}, false + } + + if time.Now().After(entry.expiresAt) { + // Entry expired, remove it + c.mu.Lock() + delete(c.entries, domain) + c.mu.Unlock() + return netip.Addr{}, false + } + + return entry.addr, true +} + +// set stores an address in the cache with the configured TTL +func (c *dnsCache) set(domain string, addr netip.Addr) { + c.mu.Lock() + defer c.mu.Unlock() + + c.entries[domain] = &dnsCacheEntry{ + addr: addr, + expiresAt: time.Now().Add(c.ttl), + } +} + // ProxyHandler handles packet injection and extraction for promiscuous mode type ProxyHandler struct { proxyStack *stack.Stack @@ -160,6 +214,7 @@ type ProxyHandler struct { subnetLookup *SubnetLookup natTable map[connKey]*natState natMu sync.RWMutex + dnsCache *dnsCache enabled bool } @@ -180,6 +235,7 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) { enabled: true, subnetLookup: NewSubnetLookup(), natTable: make(map[connKey]*natState), + dnsCache: newDNSCache(5 * time.Minute), // Cache DNS lookups for 5 minutes proxyEp: channel.New(1024, uint32(options.MTU), ""), proxyStack: stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ @@ -253,8 +309,11 @@ func (p *ProxyHandler) RemoveSubnetRule(sourcePrefix, destPrefix netip.Prefix) { // resolveRewriteAddress resolves a rewrite address which can be either: // - An IP address with CIDR notation (e.g., "192.168.1.1/32") - returns the IP directly // - A plain IP address (e.g., "192.168.1.1") - returns the IP directly -// - A domain name (e.g., "example.com") - performs DNS lookup at request time +// - A domain name (e.g., "example.com") - performs DNS lookup with caching func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, error) { + + logger.Debug("Resolving rewrite address: %s", rewriteTo) + // First, try to parse as a CIDR prefix (e.g., "192.168.1.1/32") if prefix, err := netip.ParsePrefix(rewriteTo); err == nil { return prefix.Addr(), nil @@ -265,7 +324,14 @@ func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, erro return addr, nil } - // Not an IP address, treat as domain name and perform DNS lookup + // Not an IP address, treat as domain name + // Check cache first + if cachedAddr, found := p.dnsCache.get(rewriteTo); found { + logger.Debug("DNS cache hit for %s: %s", rewriteTo, cachedAddr) + return cachedAddr, nil + } + + // Cache miss, perform DNS lookup ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -281,7 +347,11 @@ func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, erro // Use the first resolved IP address ip := ips[0] if ip4 := ip.To4(); ip4 != nil { - return netip.AddrFrom4([4]byte{ip4[0], ip4[1], ip4[2], ip4[3]}), nil + addr := netip.AddrFrom4([4]byte{ip4[0], ip4[1], ip4[2], ip4[3]}) + // Cache the result + p.dnsCache.set(rewriteTo, addr) + logger.Debug("DNS cache miss for %s, resolved to %s", rewriteTo, addr) + return addr, nil } return netip.Addr{}, fmt.Errorf("no IPv4 address found for domain %s", rewriteTo)