mirror of
https://github.com/fosrl/newt.git
synced 2026-03-10 04:36:40 +00:00
Add caching to the dns requests - is this good enough?
This commit is contained in:
@@ -694,7 +694,7 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
|
|||||||
|
|
||||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges)
|
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges)
|
||||||
|
|
||||||
logger.Info("Added target subnet from %s to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.PortRange)
|
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -1085,7 +1085,7 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
|
|||||||
|
|
||||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges)
|
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges)
|
||||||
|
|
||||||
logger.Info("Added target subnet from %s to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.PortRange)
|
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1201,7 +1201,7 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges)
|
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges)
|
||||||
logger.Info("Added target subnet from %s to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.PortRange)
|
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ type Manager struct {
|
|||||||
sendHolepunchInterval time.Duration
|
sendHolepunchInterval time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
const sendHolepunchIntervalMax = 3 * time.Second
|
const sendHolepunchIntervalMax = 60 * time.Second
|
||||||
const sendHolepunchIntervalMin = 1 * time.Second
|
const sendHolepunchIntervalMin = 1 * time.Second
|
||||||
|
|
||||||
// NewManager creates a new hole punch manager
|
// NewManager creates a new hole punch manager
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
"gvisor.dev/gvisor/pkg/buffer"
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/checksum"
|
"gvisor.dev/gvisor/pkg/tcpip/checksum"
|
||||||
@@ -150,6 +151,59 @@ type natState struct {
|
|||||||
rewrittenTo netip.Addr // The address we rewrote to
|
rewrittenTo netip.Addr // The address we rewrote to
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// dnsCache entry for caching resolved addresses
|
||||||
|
type dnsCacheEntry struct {
|
||||||
|
addr netip.Addr
|
||||||
|
expiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// dnsCache provides TTL-based caching for DNS lookups
|
||||||
|
type dnsCache struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
entries map[string]*dnsCacheEntry
|
||||||
|
ttl time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// newDNSCache creates a new DNS cache with the specified TTL
|
||||||
|
func newDNSCache(ttl time.Duration) *dnsCache {
|
||||||
|
return &dnsCache{
|
||||||
|
entries: make(map[string]*dnsCacheEntry),
|
||||||
|
ttl: ttl,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// get retrieves a cached address if it exists and hasn't expired
|
||||||
|
func (c *dnsCache) get(domain string) (netip.Addr, bool) {
|
||||||
|
c.mu.RLock()
|
||||||
|
entry, exists := c.entries[domain]
|
||||||
|
c.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return netip.Addr{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().After(entry.expiresAt) {
|
||||||
|
// Entry expired, remove it
|
||||||
|
c.mu.Lock()
|
||||||
|
delete(c.entries, domain)
|
||||||
|
c.mu.Unlock()
|
||||||
|
return netip.Addr{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return entry.addr, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// set stores an address in the cache with the configured TTL
|
||||||
|
func (c *dnsCache) set(domain string, addr netip.Addr) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
c.entries[domain] = &dnsCacheEntry{
|
||||||
|
addr: addr,
|
||||||
|
expiresAt: time.Now().Add(c.ttl),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ProxyHandler handles packet injection and extraction for promiscuous mode
|
// ProxyHandler handles packet injection and extraction for promiscuous mode
|
||||||
type ProxyHandler struct {
|
type ProxyHandler struct {
|
||||||
proxyStack *stack.Stack
|
proxyStack *stack.Stack
|
||||||
@@ -160,6 +214,7 @@ type ProxyHandler struct {
|
|||||||
subnetLookup *SubnetLookup
|
subnetLookup *SubnetLookup
|
||||||
natTable map[connKey]*natState
|
natTable map[connKey]*natState
|
||||||
natMu sync.RWMutex
|
natMu sync.RWMutex
|
||||||
|
dnsCache *dnsCache
|
||||||
enabled bool
|
enabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -180,6 +235,7 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
|
|||||||
enabled: true,
|
enabled: true,
|
||||||
subnetLookup: NewSubnetLookup(),
|
subnetLookup: NewSubnetLookup(),
|
||||||
natTable: make(map[connKey]*natState),
|
natTable: make(map[connKey]*natState),
|
||||||
|
dnsCache: newDNSCache(5 * time.Minute), // Cache DNS lookups for 5 minutes
|
||||||
proxyEp: channel.New(1024, uint32(options.MTU), ""),
|
proxyEp: channel.New(1024, uint32(options.MTU), ""),
|
||||||
proxyStack: stack.New(stack.Options{
|
proxyStack: stack.New(stack.Options{
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||||
@@ -253,8 +309,11 @@ func (p *ProxyHandler) RemoveSubnetRule(sourcePrefix, destPrefix netip.Prefix) {
|
|||||||
// resolveRewriteAddress resolves a rewrite address which can be either:
|
// resolveRewriteAddress resolves a rewrite address which can be either:
|
||||||
// - An IP address with CIDR notation (e.g., "192.168.1.1/32") - returns the IP directly
|
// - An IP address with CIDR notation (e.g., "192.168.1.1/32") - returns the IP directly
|
||||||
// - A plain IP address (e.g., "192.168.1.1") - returns the IP directly
|
// - A plain IP address (e.g., "192.168.1.1") - returns the IP directly
|
||||||
// - A domain name (e.g., "example.com") - performs DNS lookup at request time
|
// - A domain name (e.g., "example.com") - performs DNS lookup with caching
|
||||||
func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, error) {
|
func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, error) {
|
||||||
|
|
||||||
|
logger.Debug("Resolving rewrite address: %s", rewriteTo)
|
||||||
|
|
||||||
// First, try to parse as a CIDR prefix (e.g., "192.168.1.1/32")
|
// First, try to parse as a CIDR prefix (e.g., "192.168.1.1/32")
|
||||||
if prefix, err := netip.ParsePrefix(rewriteTo); err == nil {
|
if prefix, err := netip.ParsePrefix(rewriteTo); err == nil {
|
||||||
return prefix.Addr(), nil
|
return prefix.Addr(), nil
|
||||||
@@ -265,7 +324,14 @@ func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, erro
|
|||||||
return addr, nil
|
return addr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Not an IP address, treat as domain name and perform DNS lookup
|
// Not an IP address, treat as domain name
|
||||||
|
// Check cache first
|
||||||
|
if cachedAddr, found := p.dnsCache.get(rewriteTo); found {
|
||||||
|
logger.Debug("DNS cache hit for %s: %s", rewriteTo, cachedAddr)
|
||||||
|
return cachedAddr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cache miss, perform DNS lookup
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -281,7 +347,11 @@ func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, erro
|
|||||||
// Use the first resolved IP address
|
// Use the first resolved IP address
|
||||||
ip := ips[0]
|
ip := ips[0]
|
||||||
if ip4 := ip.To4(); ip4 != nil {
|
if ip4 := ip.To4(); ip4 != nil {
|
||||||
return netip.AddrFrom4([4]byte{ip4[0], ip4[1], ip4[2], ip4[3]}), nil
|
addr := netip.AddrFrom4([4]byte{ip4[0], ip4[1], ip4[2], ip4[3]})
|
||||||
|
// Cache the result
|
||||||
|
p.dnsCache.set(rewriteTo, addr)
|
||||||
|
logger.Debug("DNS cache miss for %s, resolved to %s", rewriteTo, addr)
|
||||||
|
return addr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return netip.Addr{}, fmt.Errorf("no IPv4 address found for domain %s", rewriteTo)
|
return netip.Addr{}, fmt.Errorf("no IPv4 address found for domain %s", rewriteTo)
|
||||||
|
|||||||
Reference in New Issue
Block a user