diff --git a/netstack2/proxy.go b/netstack2/proxy.go index 4b2e562..35f1a98 100644 --- a/netstack2/proxy.go +++ b/netstack2/proxy.go @@ -151,59 +151,6 @@ type natState struct { rewrittenTo netip.Addr // The address we rewrote to } -// dnsCache entry for caching resolved addresses -type dnsCacheEntry struct { - addr netip.Addr - expiresAt time.Time -} - -// dnsCache provides TTL-based caching for DNS lookups -type dnsCache struct { - mu sync.RWMutex - entries map[string]*dnsCacheEntry - ttl time.Duration -} - -// newDNSCache creates a new DNS cache with the specified TTL -func newDNSCache(ttl time.Duration) *dnsCache { - return &dnsCache{ - entries: make(map[string]*dnsCacheEntry), - ttl: ttl, - } -} - -// get retrieves a cached address if it exists and hasn't expired -func (c *dnsCache) get(domain string) (netip.Addr, bool) { - c.mu.RLock() - entry, exists := c.entries[domain] - c.mu.RUnlock() - - if !exists { - return netip.Addr{}, false - } - - if time.Now().After(entry.expiresAt) { - // Entry expired, remove it - c.mu.Lock() - delete(c.entries, domain) - c.mu.Unlock() - return netip.Addr{}, false - } - - return entry.addr, true -} - -// set stores an address in the cache with the configured TTL -func (c *dnsCache) set(domain string, addr netip.Addr) { - c.mu.Lock() - defer c.mu.Unlock() - - c.entries[domain] = &dnsCacheEntry{ - addr: addr, - expiresAt: time.Now().Add(c.ttl), - } -} - // ProxyHandler handles packet injection and extraction for promiscuous mode type ProxyHandler struct { proxyStack *stack.Stack @@ -214,7 +161,6 @@ type ProxyHandler struct { subnetLookup *SubnetLookup natTable map[connKey]*natState natMu sync.RWMutex - dnsCache *dnsCache enabled bool } @@ -235,7 +181,6 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) { enabled: true, subnetLookup: NewSubnetLookup(), natTable: make(map[connKey]*natState), - dnsCache: newDNSCache(5 * time.Minute), // Cache DNS lookups for 5 minutes proxyEp: channel.New(1024, uint32(options.MTU), ""), proxyStack: stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ @@ -309,9 +254,8 @@ func (p *ProxyHandler) RemoveSubnetRule(sourcePrefix, destPrefix netip.Prefix) { // resolveRewriteAddress resolves a rewrite address which can be either: // - An IP address with CIDR notation (e.g., "192.168.1.1/32") - returns the IP directly // - A plain IP address (e.g., "192.168.1.1") - returns the IP directly -// - A domain name (e.g., "example.com") - performs DNS lookup with caching +// - A domain name (e.g., "example.com") - performs DNS lookup func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, error) { - logger.Debug("Resolving rewrite address: %s", rewriteTo) // First, try to parse as a CIDR prefix (e.g., "192.168.1.1/32") @@ -324,14 +268,7 @@ func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, erro return addr, nil } - // Not an IP address, treat as domain name - // Check cache first - if cachedAddr, found := p.dnsCache.get(rewriteTo); found { - logger.Debug("DNS cache hit for %s: %s", rewriteTo, cachedAddr) - return cachedAddr, nil - } - - // Cache miss, perform DNS lookup + // Not an IP address, treat as domain name - perform DNS lookup ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -348,9 +285,7 @@ func (p *ProxyHandler) resolveRewriteAddress(rewriteTo string) (netip.Addr, erro ip := ips[0] if ip4 := ip.To4(); ip4 != nil { addr := netip.AddrFrom4([4]byte{ip4[0], ip4[1], ip4[2], ip4[3]}) - // Cache the result - p.dnsCache.set(rewriteTo, addr) - logger.Debug("DNS cache miss for %s, resolved to %s", rewriteTo, addr) + logger.Debug("Resolved %s to %s", rewriteTo, addr) return addr, nil } @@ -451,21 +386,8 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool { if matchedRule != nil { // Check if we need to perform DNAT if matchedRule.RewriteTo != "" { - // Resolve the rewrite address (could be IP or domain) - newDst, err := p.resolveRewriteAddress(matchedRule.RewriteTo) - if err != nil { - // Failed to resolve, skip DNAT but still proxy the packet - pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: buffer.MakeWithData(packet), - }) - p.proxyEp.InjectInbound(header.IPv4ProtocolNumber, pkb) - return true - } - - // Perform DNAT - rewrite destination IP - originalDst := dstAddr - - // Create connection tracking key + // Create connection tracking key using original destination + // This allows us to check if we've already resolved for this connection var srcPort uint16 switch protocol { case header.TCPProtocolNumber: @@ -476,21 +398,48 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool { srcPort = udpHeader.SourcePort() } + // Key using original destination to track the connection key := connKey{ srcIP: srcAddr.String(), srcPort: srcPort, - dstIP: newDst.String(), + dstIP: dstAddr.String(), dstPort: dstPort, proto: uint8(protocol), } - // Store NAT state for reverse translation - p.natMu.Lock() - p.natTable[key] = &natState{ - originalDst: originalDst, - rewrittenTo: newDst, + // Check if we already have a NAT entry for this connection + p.natMu.RLock() + existingEntry, exists := p.natTable[key] + p.natMu.RUnlock() + + var newDst netip.Addr + if exists { + // Use the previously resolved address for this connection + newDst = existingEntry.rewrittenTo + logger.Debug("Using existing NAT entry for connection: %s -> %s", dstAddr, newDst) + } else { + // New connection - resolve the rewrite address + var err error + newDst, err = p.resolveRewriteAddress(matchedRule.RewriteTo) + if err != nil { + // Failed to resolve, skip DNAT but still proxy the packet + logger.Debug("Failed to resolve rewrite address: %v", err) + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(packet), + }) + p.proxyEp.InjectInbound(header.IPv4ProtocolNumber, pkb) + return true + } + + // Store NAT state for this connection + p.natMu.Lock() + p.natTable[key] = &natState{ + originalDst: dstAddr, + rewrittenTo: newDst, + } + p.natMu.Unlock() + logger.Debug("New NAT entry for connection: %s -> %s", dstAddr, newDst) } - p.natMu.Unlock() // Rewrite the packet packet = p.rewritePacketDestination(packet, newDst) @@ -660,20 +609,23 @@ func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View { } } - // Look up NAT state (key is based on the request, so dst/src are swapped for replies) - key := connKey{ - srcIP: dstIP.String(), - srcPort: dstPort, - dstIP: srcIP.String(), - dstPort: srcPort, - proto: uint8(protocol), - } - + // Look up NAT state for reverse translation + // The key uses the original dst (before rewrite), so for replies we need to + // find the entry where the rewritten address matches the current source p.natMu.RLock() - natEntry, exists := p.natTable[key] + var natEntry *natState + for k, entry := range p.natTable { + // Match: reply's dst should be original src, reply's src should be rewritten dst + if k.srcIP == dstIP.String() && k.srcPort == dstPort && + entry.rewrittenTo.String() == srcIP.String() && k.dstPort == srcPort && + k.proto == uint8(protocol) { + natEntry = entry + break + } + } p.natMu.RUnlock() - if exists { + if natEntry != nil { // Perform reverse NAT - rewrite source to original destination packet = p.rewritePacketSource(packet, natEntry.originalDst) if packet != nil {