mirror of
https://github.com/fosrl/newt.git
synced 2026-03-01 16:26:40 +00:00
Rewriting desitnation works
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
"gvisor.dev/gvisor/pkg/buffer"
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/checksum"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||||
@@ -24,10 +25,15 @@ type PortRange struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SubnetRule represents a subnet with optional port restrictions and source address
|
// SubnetRule represents a subnet with optional port restrictions and source address
|
||||||
|
// When RewriteTo is set, DNAT (Destination Network Address Translation) is performed:
|
||||||
|
// - Incoming packets: destination IP is rewritten to RewriteTo.Addr()
|
||||||
|
// - Outgoing packets: source IP is rewritten back to the original destination
|
||||||
|
//
|
||||||
|
// This allows transparent proxying where traffic appears to come from the rewritten address
|
||||||
type SubnetRule struct {
|
type SubnetRule struct {
|
||||||
SourcePrefix netip.Prefix // Source IP prefix (who is sending)
|
SourcePrefix netip.Prefix // Source IP prefix (who is sending)
|
||||||
DestPrefix netip.Prefix // Destination IP prefix (where it's going)
|
DestPrefix netip.Prefix // Destination IP prefix (where it's going)
|
||||||
RewriteTo netip.Prefix // Optional rewrite address for destination
|
RewriteTo netip.Prefix // Optional rewrite address for DNAT (destination NAT)
|
||||||
PortRanges []PortRange // empty slice means all ports allowed
|
PortRanges []PortRange // empty slice means all ports allowed
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,13 +89,13 @@ func (sl *SubnetLookup) RemoveSubnet(sourcePrefix, destPrefix netip.Prefix) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Match checks if a source IP, destination IP, and port match any subnet rule
|
// Match checks if a source IP, destination IP, and port match any subnet rule
|
||||||
// Returns true if BOTH:
|
// Returns the matched rule if BOTH:
|
||||||
// - The source IP is in the rule's source prefix
|
// - The source IP is in the rule's source prefix
|
||||||
// - The destination IP is in the rule's destination prefix
|
// - The destination IP is in the rule's destination prefix
|
||||||
// - The port is in an allowed range (or no port restrictions exist)
|
// - The port is in an allowed range (or no port restrictions exist)
|
||||||
//
|
//
|
||||||
// This implementation uses O(n) iteration but checks exact prefix matches first for common cases
|
// Returns nil if no rule matches
|
||||||
func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16) bool {
|
func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16) *SubnetRule {
|
||||||
sl.mu.RLock()
|
sl.mu.RLock()
|
||||||
defer sl.mu.RUnlock()
|
defer sl.mu.RUnlock()
|
||||||
|
|
||||||
@@ -107,18 +113,33 @@ func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16) bool {
|
|||||||
// Both IPs match - now check port restrictions
|
// Both IPs match - now check port restrictions
|
||||||
// If no port ranges specified, all ports are allowed
|
// If no port ranges specified, all ports are allowed
|
||||||
if len(rule.PortRanges) == 0 {
|
if len(rule.PortRanges) == 0 {
|
||||||
return true
|
return rule
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if port is in any of the allowed ranges
|
// Check if port is in any of the allowed ranges
|
||||||
for _, pr := range rule.PortRanges {
|
for _, pr := range rule.PortRanges {
|
||||||
if port >= pr.Min && port <= pr.Max {
|
if port >= pr.Min && port <= pr.Max {
|
||||||
return true
|
return rule
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// connKey uniquely identifies a connection for NAT tracking
|
||||||
|
type connKey struct {
|
||||||
|
srcIP string
|
||||||
|
srcPort uint16
|
||||||
|
dstIP string
|
||||||
|
dstPort uint16
|
||||||
|
proto uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
// natState tracks NAT translation state for reverse translation
|
||||||
|
type natState struct {
|
||||||
|
originalDst netip.Addr // Original destination before DNAT
|
||||||
|
rewrittenTo netip.Addr // The address we rewrote to
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProxyHandler handles packet injection and extraction for promiscuous mode
|
// ProxyHandler handles packet injection and extraction for promiscuous mode
|
||||||
@@ -129,6 +150,8 @@ type ProxyHandler struct {
|
|||||||
tcpHandler *TCPHandler
|
tcpHandler *TCPHandler
|
||||||
udpHandler *UDPHandler
|
udpHandler *UDPHandler
|
||||||
subnetLookup *SubnetLookup
|
subnetLookup *SubnetLookup
|
||||||
|
natTable map[connKey]*natState
|
||||||
|
natMu sync.RWMutex
|
||||||
enabled bool
|
enabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,6 +171,7 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
|
|||||||
handler := &ProxyHandler{
|
handler := &ProxyHandler{
|
||||||
enabled: true,
|
enabled: true,
|
||||||
subnetLookup: NewSubnetLookup(),
|
subnetLookup: NewSubnetLookup(),
|
||||||
|
natTable: make(map[connKey]*natState),
|
||||||
proxyEp: channel.New(1024, uint32(options.MTU), ""),
|
proxyEp: channel.New(1024, uint32(options.MTU), ""),
|
||||||
proxyStack: stack.New(stack.Options{
|
proxyStack: stack.New(stack.Options{
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||||
@@ -307,7 +331,48 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if the source IP, destination IP, and port match any subnet rule
|
// Check if the source IP, destination IP, and port match any subnet rule
|
||||||
if p.subnetLookup.Match(srcAddr, dstAddr, dstPort) {
|
matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort)
|
||||||
|
if matchedRule != nil {
|
||||||
|
// Check if we need to perform DNAT
|
||||||
|
if matchedRule.RewriteTo.IsValid() && matchedRule.RewriteTo.Addr().IsValid() {
|
||||||
|
// Perform DNAT - rewrite destination IP
|
||||||
|
originalDst := dstAddr
|
||||||
|
newDst := matchedRule.RewriteTo.Addr()
|
||||||
|
|
||||||
|
// Create connection tracking key
|
||||||
|
var srcPort uint16
|
||||||
|
switch protocol {
|
||||||
|
case header.TCPProtocolNumber:
|
||||||
|
tcpHeader := header.TCP(packet[headerLen:])
|
||||||
|
srcPort = tcpHeader.SourcePort()
|
||||||
|
case header.UDPProtocolNumber:
|
||||||
|
udpHeader := header.UDP(packet[headerLen:])
|
||||||
|
srcPort = udpHeader.SourcePort()
|
||||||
|
}
|
||||||
|
|
||||||
|
key := connKey{
|
||||||
|
srcIP: srcAddr.String(),
|
||||||
|
srcPort: srcPort,
|
||||||
|
dstIP: newDst.String(),
|
||||||
|
dstPort: dstPort,
|
||||||
|
proto: uint8(protocol),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store NAT state for reverse translation
|
||||||
|
p.natMu.Lock()
|
||||||
|
p.natTable[key] = &natState{
|
||||||
|
originalDst: originalDst,
|
||||||
|
rewrittenTo: newDst,
|
||||||
|
}
|
||||||
|
p.natMu.Unlock()
|
||||||
|
|
||||||
|
// Rewrite the packet
|
||||||
|
packet = p.rewritePacketDestination(packet, newDst)
|
||||||
|
if packet == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Inject into proxy stack
|
// Inject into proxy stack
|
||||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||||
Payload: buffer.MakeWithData(packet),
|
Payload: buffer.MakeWithData(packet),
|
||||||
@@ -319,6 +384,118 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rewritePacketDestination rewrites the destination IP in a packet and recalculates checksums
|
||||||
|
func (p *ProxyHandler) rewritePacketDestination(packet []byte, newDst netip.Addr) []byte {
|
||||||
|
if len(packet) < header.IPv4MinimumSize {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make a copy to avoid modifying the original
|
||||||
|
pkt := make([]byte, len(packet))
|
||||||
|
copy(pkt, packet)
|
||||||
|
|
||||||
|
ipv4Header := header.IPv4(pkt)
|
||||||
|
headerLen := int(ipv4Header.HeaderLength())
|
||||||
|
|
||||||
|
// Rewrite destination IP
|
||||||
|
newDstBytes := newDst.As4()
|
||||||
|
newDstAddr := tcpip.AddrFrom4(newDstBytes)
|
||||||
|
ipv4Header.SetDestinationAddress(newDstAddr)
|
||||||
|
|
||||||
|
// Recalculate IP checksum
|
||||||
|
ipv4Header.SetChecksum(0)
|
||||||
|
ipv4Header.SetChecksum(^ipv4Header.CalculateChecksum())
|
||||||
|
|
||||||
|
// Update transport layer checksum if needed
|
||||||
|
protocol := ipv4Header.TransportProtocol()
|
||||||
|
switch protocol {
|
||||||
|
case header.TCPProtocolNumber:
|
||||||
|
if len(pkt) >= headerLen+header.TCPMinimumSize {
|
||||||
|
tcpHeader := header.TCP(pkt[headerLen:])
|
||||||
|
tcpHeader.SetChecksum(0)
|
||||||
|
xsum := header.PseudoHeaderChecksum(
|
||||||
|
header.TCPProtocolNumber,
|
||||||
|
ipv4Header.SourceAddress(),
|
||||||
|
ipv4Header.DestinationAddress(),
|
||||||
|
uint16(len(pkt)-headerLen),
|
||||||
|
)
|
||||||
|
xsum = checksum.Checksum(pkt[headerLen:], xsum)
|
||||||
|
tcpHeader.SetChecksum(^xsum)
|
||||||
|
}
|
||||||
|
case header.UDPProtocolNumber:
|
||||||
|
if len(pkt) >= headerLen+header.UDPMinimumSize {
|
||||||
|
udpHeader := header.UDP(pkt[headerLen:])
|
||||||
|
udpHeader.SetChecksum(0)
|
||||||
|
xsum := header.PseudoHeaderChecksum(
|
||||||
|
header.UDPProtocolNumber,
|
||||||
|
ipv4Header.SourceAddress(),
|
||||||
|
ipv4Header.DestinationAddress(),
|
||||||
|
uint16(len(pkt)-headerLen),
|
||||||
|
)
|
||||||
|
xsum = checksum.Checksum(pkt[headerLen:], xsum)
|
||||||
|
udpHeader.SetChecksum(^xsum)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return pkt
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewritePacketSource rewrites the source IP in a packet and recalculates checksums (for reverse NAT)
|
||||||
|
func (p *ProxyHandler) rewritePacketSource(packet []byte, newSrc netip.Addr) []byte {
|
||||||
|
if len(packet) < header.IPv4MinimumSize {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make a copy to avoid modifying the original
|
||||||
|
pkt := make([]byte, len(packet))
|
||||||
|
copy(pkt, packet)
|
||||||
|
|
||||||
|
ipv4Header := header.IPv4(pkt)
|
||||||
|
headerLen := int(ipv4Header.HeaderLength())
|
||||||
|
|
||||||
|
// Rewrite source IP
|
||||||
|
newSrcBytes := newSrc.As4()
|
||||||
|
newSrcAddr := tcpip.AddrFrom4(newSrcBytes)
|
||||||
|
ipv4Header.SetSourceAddress(newSrcAddr)
|
||||||
|
|
||||||
|
// Recalculate IP checksum
|
||||||
|
ipv4Header.SetChecksum(0)
|
||||||
|
ipv4Header.SetChecksum(^ipv4Header.CalculateChecksum())
|
||||||
|
|
||||||
|
// Update transport layer checksum if needed
|
||||||
|
protocol := ipv4Header.TransportProtocol()
|
||||||
|
switch protocol {
|
||||||
|
case header.TCPProtocolNumber:
|
||||||
|
if len(pkt) >= headerLen+header.TCPMinimumSize {
|
||||||
|
tcpHeader := header.TCP(pkt[headerLen:])
|
||||||
|
tcpHeader.SetChecksum(0)
|
||||||
|
xsum := header.PseudoHeaderChecksum(
|
||||||
|
header.TCPProtocolNumber,
|
||||||
|
ipv4Header.SourceAddress(),
|
||||||
|
ipv4Header.DestinationAddress(),
|
||||||
|
uint16(len(pkt)-headerLen),
|
||||||
|
)
|
||||||
|
xsum = checksum.Checksum(pkt[headerLen:], xsum)
|
||||||
|
tcpHeader.SetChecksum(^xsum)
|
||||||
|
}
|
||||||
|
case header.UDPProtocolNumber:
|
||||||
|
if len(pkt) >= headerLen+header.UDPMinimumSize {
|
||||||
|
udpHeader := header.UDP(pkt[headerLen:])
|
||||||
|
udpHeader.SetChecksum(0)
|
||||||
|
xsum := header.PseudoHeaderChecksum(
|
||||||
|
header.UDPProtocolNumber,
|
||||||
|
ipv4Header.SourceAddress(),
|
||||||
|
ipv4Header.DestinationAddress(),
|
||||||
|
uint16(len(pkt)-headerLen),
|
||||||
|
)
|
||||||
|
xsum = checksum.Checksum(pkt[headerLen:], xsum)
|
||||||
|
udpHeader.SetChecksum(^xsum)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return pkt
|
||||||
|
}
|
||||||
|
|
||||||
// ReadOutgoingPacket reads packets from the proxy stack that need to be
|
// ReadOutgoingPacket reads packets from the proxy stack that need to be
|
||||||
// sent back through the tunnel
|
// sent back through the tunnel
|
||||||
func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View {
|
func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View {
|
||||||
@@ -330,6 +507,55 @@ func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View {
|
|||||||
if pkt != nil {
|
if pkt != nil {
|
||||||
view := pkt.ToView()
|
view := pkt.ToView()
|
||||||
pkt.DecRef()
|
pkt.DecRef()
|
||||||
|
|
||||||
|
// Check if we need to perform reverse NAT
|
||||||
|
packet := view.AsSlice()
|
||||||
|
if len(packet) >= header.IPv4MinimumSize && packet[0]>>4 == 4 {
|
||||||
|
ipv4Header := header.IPv4(packet)
|
||||||
|
srcIP := ipv4Header.SourceAddress()
|
||||||
|
dstIP := ipv4Header.DestinationAddress()
|
||||||
|
protocol := ipv4Header.TransportProtocol()
|
||||||
|
headerLen := int(ipv4Header.HeaderLength())
|
||||||
|
|
||||||
|
// Extract ports
|
||||||
|
var srcPort, dstPort uint16
|
||||||
|
switch protocol {
|
||||||
|
case header.TCPProtocolNumber:
|
||||||
|
if len(packet) >= headerLen+header.TCPMinimumSize {
|
||||||
|
tcpHeader := header.TCP(packet[headerLen:])
|
||||||
|
srcPort = tcpHeader.SourcePort()
|
||||||
|
dstPort = tcpHeader.DestinationPort()
|
||||||
|
}
|
||||||
|
case header.UDPProtocolNumber:
|
||||||
|
if len(packet) >= headerLen+header.UDPMinimumSize {
|
||||||
|
udpHeader := header.UDP(packet[headerLen:])
|
||||||
|
srcPort = udpHeader.SourcePort()
|
||||||
|
dstPort = udpHeader.DestinationPort()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up NAT state (key is based on the request, so dst/src are swapped for replies)
|
||||||
|
key := connKey{
|
||||||
|
srcIP: dstIP.String(),
|
||||||
|
srcPort: dstPort,
|
||||||
|
dstIP: srcIP.String(),
|
||||||
|
dstPort: srcPort,
|
||||||
|
proto: uint8(protocol),
|
||||||
|
}
|
||||||
|
|
||||||
|
p.natMu.RLock()
|
||||||
|
natEntry, exists := p.natTable[key]
|
||||||
|
p.natMu.RUnlock()
|
||||||
|
|
||||||
|
if exists {
|
||||||
|
// Perform reverse NAT - rewrite source to original destination
|
||||||
|
packet = p.rewritePacketSource(packet, natEntry.originalDst)
|
||||||
|
if packet != nil {
|
||||||
|
return buffer.NewViewWithData(packet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return view
|
return view
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user