diff --git a/netstack2/proxy.go b/netstack2/proxy.go index 625a8af..7b1a77d 100644 --- a/netstack2/proxy.go +++ b/netstack2/proxy.go @@ -7,6 +7,7 @@ import ( "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/checksum" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -24,10 +25,15 @@ type PortRange struct { } // SubnetRule represents a subnet with optional port restrictions and source address +// When RewriteTo is set, DNAT (Destination Network Address Translation) is performed: +// - Incoming packets: destination IP is rewritten to RewriteTo.Addr() +// - Outgoing packets: source IP is rewritten back to the original destination +// +// This allows transparent proxying where traffic appears to come from the rewritten address type SubnetRule struct { SourcePrefix netip.Prefix // Source IP prefix (who is sending) DestPrefix netip.Prefix // Destination IP prefix (where it's going) - RewriteTo netip.Prefix // Optional rewrite address for destination + RewriteTo netip.Prefix // Optional rewrite address for DNAT (destination NAT) PortRanges []PortRange // empty slice means all ports allowed } @@ -83,13 +89,13 @@ func (sl *SubnetLookup) RemoveSubnet(sourcePrefix, destPrefix netip.Prefix) { } // Match checks if a source IP, destination IP, and port match any subnet rule -// Returns true if BOTH: +// Returns the matched rule if BOTH: // - The source IP is in the rule's source prefix // - The destination IP is in the rule's destination prefix // - The port is in an allowed range (or no port restrictions exist) // -// This implementation uses O(n) iteration but checks exact prefix matches first for common cases -func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16) bool { +// Returns nil if no rule matches +func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16) *SubnetRule { sl.mu.RLock() defer sl.mu.RUnlock() @@ -107,18 +113,33 @@ func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16) bool { // Both IPs match - now check port restrictions // If no port ranges specified, all ports are allowed if len(rule.PortRanges) == 0 { - return true + return rule } // Check if port is in any of the allowed ranges for _, pr := range rule.PortRanges { if port >= pr.Min && port <= pr.Max { - return true + return rule } } } - return false + return nil +} + +// connKey uniquely identifies a connection for NAT tracking +type connKey struct { + srcIP string + srcPort uint16 + dstIP string + dstPort uint16 + proto uint8 +} + +// natState tracks NAT translation state for reverse translation +type natState struct { + originalDst netip.Addr // Original destination before DNAT + rewrittenTo netip.Addr // The address we rewrote to } // ProxyHandler handles packet injection and extraction for promiscuous mode @@ -129,6 +150,8 @@ type ProxyHandler struct { tcpHandler *TCPHandler udpHandler *UDPHandler subnetLookup *SubnetLookup + natTable map[connKey]*natState + natMu sync.RWMutex enabled bool } @@ -148,6 +171,7 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) { handler := &ProxyHandler{ enabled: true, subnetLookup: NewSubnetLookup(), + natTable: make(map[connKey]*natState), proxyEp: channel.New(1024, uint32(options.MTU), ""), proxyStack: stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ @@ -307,7 +331,48 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool { } // Check if the source IP, destination IP, and port match any subnet rule - if p.subnetLookup.Match(srcAddr, dstAddr, dstPort) { + matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort) + if matchedRule != nil { + // Check if we need to perform DNAT + if matchedRule.RewriteTo.IsValid() && matchedRule.RewriteTo.Addr().IsValid() { + // Perform DNAT - rewrite destination IP + originalDst := dstAddr + newDst := matchedRule.RewriteTo.Addr() + + // Create connection tracking key + var srcPort uint16 + switch protocol { + case header.TCPProtocolNumber: + tcpHeader := header.TCP(packet[headerLen:]) + srcPort = tcpHeader.SourcePort() + case header.UDPProtocolNumber: + udpHeader := header.UDP(packet[headerLen:]) + srcPort = udpHeader.SourcePort() + } + + key := connKey{ + srcIP: srcAddr.String(), + srcPort: srcPort, + dstIP: newDst.String(), + dstPort: dstPort, + proto: uint8(protocol), + } + + // Store NAT state for reverse translation + p.natMu.Lock() + p.natTable[key] = &natState{ + originalDst: originalDst, + rewrittenTo: newDst, + } + p.natMu.Unlock() + + // Rewrite the packet + packet = p.rewritePacketDestination(packet, newDst) + if packet == nil { + return false + } + } + // Inject into proxy stack pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{ Payload: buffer.MakeWithData(packet), @@ -319,6 +384,118 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool { return false } +// rewritePacketDestination rewrites the destination IP in a packet and recalculates checksums +func (p *ProxyHandler) rewritePacketDestination(packet []byte, newDst netip.Addr) []byte { + if len(packet) < header.IPv4MinimumSize { + return nil + } + + // Make a copy to avoid modifying the original + pkt := make([]byte, len(packet)) + copy(pkt, packet) + + ipv4Header := header.IPv4(pkt) + headerLen := int(ipv4Header.HeaderLength()) + + // Rewrite destination IP + newDstBytes := newDst.As4() + newDstAddr := tcpip.AddrFrom4(newDstBytes) + ipv4Header.SetDestinationAddress(newDstAddr) + + // Recalculate IP checksum + ipv4Header.SetChecksum(0) + ipv4Header.SetChecksum(^ipv4Header.CalculateChecksum()) + + // Update transport layer checksum if needed + protocol := ipv4Header.TransportProtocol() + switch protocol { + case header.TCPProtocolNumber: + if len(pkt) >= headerLen+header.TCPMinimumSize { + tcpHeader := header.TCP(pkt[headerLen:]) + tcpHeader.SetChecksum(0) + xsum := header.PseudoHeaderChecksum( + header.TCPProtocolNumber, + ipv4Header.SourceAddress(), + ipv4Header.DestinationAddress(), + uint16(len(pkt)-headerLen), + ) + xsum = checksum.Checksum(pkt[headerLen:], xsum) + tcpHeader.SetChecksum(^xsum) + } + case header.UDPProtocolNumber: + if len(pkt) >= headerLen+header.UDPMinimumSize { + udpHeader := header.UDP(pkt[headerLen:]) + udpHeader.SetChecksum(0) + xsum := header.PseudoHeaderChecksum( + header.UDPProtocolNumber, + ipv4Header.SourceAddress(), + ipv4Header.DestinationAddress(), + uint16(len(pkt)-headerLen), + ) + xsum = checksum.Checksum(pkt[headerLen:], xsum) + udpHeader.SetChecksum(^xsum) + } + } + + return pkt +} + +// rewritePacketSource rewrites the source IP in a packet and recalculates checksums (for reverse NAT) +func (p *ProxyHandler) rewritePacketSource(packet []byte, newSrc netip.Addr) []byte { + if len(packet) < header.IPv4MinimumSize { + return nil + } + + // Make a copy to avoid modifying the original + pkt := make([]byte, len(packet)) + copy(pkt, packet) + + ipv4Header := header.IPv4(pkt) + headerLen := int(ipv4Header.HeaderLength()) + + // Rewrite source IP + newSrcBytes := newSrc.As4() + newSrcAddr := tcpip.AddrFrom4(newSrcBytes) + ipv4Header.SetSourceAddress(newSrcAddr) + + // Recalculate IP checksum + ipv4Header.SetChecksum(0) + ipv4Header.SetChecksum(^ipv4Header.CalculateChecksum()) + + // Update transport layer checksum if needed + protocol := ipv4Header.TransportProtocol() + switch protocol { + case header.TCPProtocolNumber: + if len(pkt) >= headerLen+header.TCPMinimumSize { + tcpHeader := header.TCP(pkt[headerLen:]) + tcpHeader.SetChecksum(0) + xsum := header.PseudoHeaderChecksum( + header.TCPProtocolNumber, + ipv4Header.SourceAddress(), + ipv4Header.DestinationAddress(), + uint16(len(pkt)-headerLen), + ) + xsum = checksum.Checksum(pkt[headerLen:], xsum) + tcpHeader.SetChecksum(^xsum) + } + case header.UDPProtocolNumber: + if len(pkt) >= headerLen+header.UDPMinimumSize { + udpHeader := header.UDP(pkt[headerLen:]) + udpHeader.SetChecksum(0) + xsum := header.PseudoHeaderChecksum( + header.UDPProtocolNumber, + ipv4Header.SourceAddress(), + ipv4Header.DestinationAddress(), + uint16(len(pkt)-headerLen), + ) + xsum = checksum.Checksum(pkt[headerLen:], xsum) + udpHeader.SetChecksum(^xsum) + } + } + + return pkt +} + // ReadOutgoingPacket reads packets from the proxy stack that need to be // sent back through the tunnel func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View { @@ -330,6 +507,55 @@ func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View { if pkt != nil { view := pkt.ToView() pkt.DecRef() + + // Check if we need to perform reverse NAT + packet := view.AsSlice() + if len(packet) >= header.IPv4MinimumSize && packet[0]>>4 == 4 { + ipv4Header := header.IPv4(packet) + srcIP := ipv4Header.SourceAddress() + dstIP := ipv4Header.DestinationAddress() + protocol := ipv4Header.TransportProtocol() + headerLen := int(ipv4Header.HeaderLength()) + + // Extract ports + var srcPort, dstPort uint16 + switch protocol { + case header.TCPProtocolNumber: + if len(packet) >= headerLen+header.TCPMinimumSize { + tcpHeader := header.TCP(packet[headerLen:]) + srcPort = tcpHeader.SourcePort() + dstPort = tcpHeader.DestinationPort() + } + case header.UDPProtocolNumber: + if len(packet) >= headerLen+header.UDPMinimumSize { + udpHeader := header.UDP(packet[headerLen:]) + srcPort = udpHeader.SourcePort() + dstPort = udpHeader.DestinationPort() + } + } + + // Look up NAT state (key is based on the request, so dst/src are swapped for replies) + key := connKey{ + srcIP: dstIP.String(), + srcPort: dstPort, + dstIP: srcIP.String(), + dstPort: srcPort, + proto: uint8(protocol), + } + + p.natMu.RLock() + natEntry, exists := p.natTable[key] + p.natMu.RUnlock() + + if exists { + // Perform reverse NAT - rewrite source to original destination + packet = p.rewritePacketSource(packet, natEntry.originalDst) + if packet != nil { + return buffer.NewViewWithData(packet) + } + } + } + return view }