Change DNS lookup to conntrack

This commit is contained in:
Owen
2025-12-04 20:13:48 -05:00
parent d8b4fb4acb
commit 4dbf200cca

View File

@@ -151,59 +151,6 @@ type natState struct {
rewrittenTo netip.Addr // The address we rewrote to
}
// dnsCache entry for caching resolved addresses
type dnsCacheEntry struct {
addr netip.Addr
expiresAt time.Time
}
// dnsCache provides TTL-based caching for DNS lookups
type dnsCache struct {
mu sync.RWMutex
entries map[string]*dnsCacheEntry
ttl time.Duration
}
// newDNSCache creates a new DNS cache with the specified TTL
func newDNSCache(ttl time.Duration) *dnsCache {
return &dnsCache{
entries: make(map[string]*dnsCacheEntry),
ttl: ttl,
}
}
// get retrieves a cached address if it exists and hasn't expired
func (c *dnsCache) get(domain string) (netip.Addr, bool) {
c.mu.RLock()
entry, exists := c.entries[domain]
c.mu.RUnlock()
if !exists {
return netip.Addr{}, false
}
if time.Now().After(entry.expiresAt) {
// Entry expired, remove it
c.mu.Lock()
delete(c.entries, domain)
c.mu.Unlock()
return netip.Addr{}, false
}
return entry.addr, true
}
// set stores an address in the cache with the configured TTL
func (c *dnsCache) set(domain string, addr netip.Addr) {
c.mu.Lock()
defer c.mu.Unlock()
c.entries[domain] = &dnsCacheEntry{
addr: addr,
expiresAt: time.Now().Add(c.ttl),
}
}
// ProxyHandler handles packet injection and extraction for promiscuous mode
type ProxyHandler struct {
proxyStack *stack.Stack
@@ -214,7 +161,6 @@ type ProxyHandler struct {
subnetLookup *SubnetLookup
natTable map[connKey]*natState
natMu sync.RWMutex
dnsCache *dnsCache
enabled bool
}
@@ -235,7 +181,6 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
enabled: true,
subnetLookup: NewSubnetLookup(),
natTable: make(map[connKey]*natState),
dnsCache: newDNSCache(5 * time.Minute), // Cache DNS lookups for 5 minutes
proxyEp: channel.New(1024, uint32(options.MTU), ""),
proxyStack: stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
@@ -309,9 +254,8 @@ func (p *ProxyHandler) RemoveSubnetRule(sourcePrefix, destPrefix netip.Prefix) {
// resolveRewriteAddress resolves a rewrite address which can be either:
// - An IP address with CIDR notation (e.g., "192.168.1.1/32") - returns the IP directly
// - A plain IP address (e.g., "192.168.1.1") - returns the IP directly
// - A domain name (e.g., "example.com") - performs DNS lookup with caching
// - A domain name (e.g., "example.com") - performs DNS lookup
func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, error) {
logger.Debug("Resolving rewrite address: %s", rewriteTo)
// First, try to parse as a CIDR prefix (e.g., "192.168.1.1/32")
@@ -324,14 +268,7 @@ func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, erro
return addr, nil
}
// Not an IP address, treat as domain name
// Check cache first
if cachedAddr, found := p.dnsCache.get(rewriteTo); found {
logger.Debug("DNS cache hit for %s: %s", rewriteTo, cachedAddr)
return cachedAddr, nil
}
// Cache miss, perform DNS lookup
// Not an IP address, treat as domain name - perform DNS lookup
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
@@ -348,9 +285,7 @@ func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, erro
ip := ips[0]
if ip4 := ip.To4(); ip4 != nil {
addr := netip.AddrFrom4([4]byte{ip4[0], ip4[1], ip4[2], ip4[3]})
// Cache the result
p.dnsCache.set(rewriteTo, addr)
logger.Debug("DNS cache miss for %s, resolved to %s", rewriteTo, addr)
logger.Debug("Resolved %s to %s", rewriteTo, addr)
return addr, nil
}
@@ -451,21 +386,8 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
if matchedRule != nil {
// Check if we need to perform DNAT
if matchedRule.RewriteTo != "" {
// Resolve the rewrite address (could be IP or domain)
newDst, err := p.resolveRewriteAddress(matchedRule.RewriteTo)
if err != nil {
// Failed to resolve, skip DNAT but still proxy the packet
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
p.proxyEp.InjectInbound(header.IPv4ProtocolNumber, pkb)
return true
}
// Perform DNAT - rewrite destination IP
originalDst := dstAddr
// Create connection tracking key
// Create connection tracking key using original destination
// This allows us to check if we've already resolved for this connection
var srcPort uint16
switch protocol {
case header.TCPProtocolNumber:
@@ -476,21 +398,48 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
srcPort = udpHeader.SourcePort()
}
// Key using original destination to track the connection
key := connKey{
srcIP: srcAddr.String(),
srcPort: srcPort,
dstIP: newDst.String(),
dstIP: dstAddr.String(),
dstPort: dstPort,
proto: uint8(protocol),
}
// Store NAT state for reverse translation
p.natMu.Lock()
p.natTable[key] = &natState{
originalDst: originalDst,
rewrittenTo: newDst,
// Check if we already have a NAT entry for this connection
p.natMu.RLock()
existingEntry, exists := p.natTable[key]
p.natMu.RUnlock()
var newDst netip.Addr
if exists {
// Use the previously resolved address for this connection
newDst = existingEntry.rewrittenTo
logger.Debug("Using existing NAT entry for connection: %s -> %s", dstAddr, newDst)
} else {
// New connection - resolve the rewrite address
var err error
newDst, err = p.resolveRewriteAddress(matchedRule.RewriteTo)
if err != nil {
// Failed to resolve, skip DNAT but still proxy the packet
logger.Debug("Failed to resolve rewrite address: %v", err)
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
p.proxyEp.InjectInbound(header.IPv4ProtocolNumber, pkb)
return true
}
// Store NAT state for this connection
p.natMu.Lock()
p.natTable[key] = &natState{
originalDst: dstAddr,
rewrittenTo: newDst,
}
p.natMu.Unlock()
logger.Debug("New NAT entry for connection: %s -> %s", dstAddr, newDst)
}
p.natMu.Unlock()
// Rewrite the packet
packet = p.rewritePacketDestination(packet, newDst)
@@ -660,20 +609,23 @@ func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View {
}
}
// Look up NAT state (key is based on the request, so dst/src are swapped for replies)
key := connKey{
srcIP: dstIP.String(),
srcPort: dstPort,
dstIP: srcIP.String(),
dstPort: srcPort,
proto: uint8(protocol),
}
// Look up NAT state for reverse translation
// The key uses the original dst (before rewrite), so for replies we need to
// find the entry where the rewritten address matches the current source
p.natMu.RLock()
natEntry, exists := p.natTable[key]
var natEntry *natState
for k, entry := range p.natTable {
// Match: reply's dst should be original src, reply's src should be rewritten dst
if k.srcIP == dstIP.String() && k.srcPort == dstPort &&
entry.rewrittenTo.String() == srcIP.String() && k.dstPort == srcPort &&
k.proto == uint8(protocol) {
natEntry = entry
break
}
}
p.natMu.RUnlock()
if exists {
if natEntry != nil {
// Perform reverse NAT - rewrite source to original destination
packet = p.rewritePacketSource(packet, natEntry.originalDst)
if packet != nil {