diff --git a/netstack2/proxy.go b/netstack2/proxy.go index 46dfaa1..388a3d1 100644 --- a/netstack2/proxy.go +++ b/netstack2/proxy.go @@ -57,6 +57,17 @@ type connKey struct { proto uint8 } +// reverseConnKey uniquely identifies a connection for reverse NAT lookup (reply direction) +// Key structure: (rewrittenTo, originalSrcIP, originalSrcPort, originalDstPort, proto) +// This allows O(1) lookup of NAT entries for reply packets +type reverseConnKey struct { + rewrittenTo string // The address we rewrote to (becomes src in replies) + originalSrcIP string // Original source IP (becomes dst in replies) + originalSrcPort uint16 // Original source port (becomes dst port in replies) + originalDstPort uint16 // Original destination port (becomes src port in replies) + proto uint8 +} + // destKey identifies a destination for handler lookups (without source port since it may change) type destKey struct { srcIP string @@ -81,7 +92,8 @@ type ProxyHandler struct { icmpHandler *ICMPHandler subnetLookup *SubnetLookup natTable map[connKey]*natState - destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups + reverseNatTable map[reverseConnKey]*natState // Reverse lookup map for O(1) reply packet NAT + destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups natMu sync.RWMutex enabled bool icmpReplies chan []byte // Channel for ICMP reply packets to be sent back through the tunnel @@ -106,6 +118,7 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) { enabled: true, subnetLookup: NewSubnetLookup(), natTable: make(map[connKey]*natState), + reverseNatTable: make(map[reverseConnKey]*natState), destRewriteTable: make(map[destKey]netip.Addr), icmpReplies: make(chan []byte, 256), // Buffer for ICMP reply packets proxyEp: channel.New(1024, uint32(options.MTU), ""), @@ -408,10 +421,23 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool { // Store NAT state for this connection p.natMu.Lock() - p.natTable[key] = &natState{ + natEntry := &natState{ originalDst: dstAddr, rewrittenTo: newDst, } + p.natTable[key] = natEntry + + // Create reverse lookup key for O(1) reply packet lookups + // Key: (rewrittenTo, originalSrcIP, originalSrcPort, originalDstPort, proto) + reverseKey := reverseConnKey{ + rewrittenTo: newDst.String(), + originalSrcIP: srcAddr.String(), + originalSrcPort: srcPort, + originalDstPort: dstPort, + proto: uint8(protocol), + } + p.reverseNatTable[reverseKey] = natEntry + // Store destination rewrite for handler lookups p.destRewriteTable[dKey] = newDst p.natMu.Unlock() @@ -610,20 +636,22 @@ func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View { return view } - // Look up NAT state for reverse translation - // The key uses the original dst (before rewrite), so for replies we need to - // find the entry where the rewritten address matches the current source + // Look up NAT state for reverse translation using O(1) reverse lookup map + // Key: (rewrittenTo, originalSrcIP, originalSrcPort, originalDstPort, proto) + // For reply packets: + // - reply's srcIP = rewrittenTo (the address we rewrote to) + // - reply's dstIP = originalSrcIP (original source IP) + // - reply's srcPort = originalDstPort (original destination port) + // - reply's dstPort = originalSrcPort (original source port) p.natMu.RLock() - var natEntry *natState - for k, entry := range p.natTable { - // Match: reply's dst should be original src, reply's src should be rewritten dst - if k.srcIP == dstIP.String() && k.srcPort == dstPort && - entry.rewrittenTo.String() == srcIP.String() && k.dstPort == srcPort && - k.proto == uint8(protocol) { - natEntry = entry - break - } + reverseKey := reverseConnKey{ + rewrittenTo: srcIP.String(), // Reply's source is the rewritten address + originalSrcIP: dstIP.String(), // Reply's destination is the original source + originalSrcPort: dstPort, // Reply's destination port is the original source port + originalDstPort: srcPort, // Reply's source port is the original destination port + proto: uint8(protocol), } + natEntry := p.reverseNatTable[reverseKey] p.natMu.RUnlock() if natEntry != nil {