Improve nat perf

This commit is contained in:
Viktor Liu
2025-06-17 13:55:35 +02:00
parent 8684981b57
commit 7cd44a9a3c
2 changed files with 152 additions and 43 deletions

View File

@@ -109,6 +109,7 @@ type Manager struct {
dnatEnabled atomic.Bool
dnatMappings map[netip.Addr]netip.Addr
dnatMutex sync.RWMutex
dnatBiMap *biDNATMap
}
// decoder for packages

View File

@@ -15,8 +15,24 @@ func ipv4Checksum(header []byte) uint16 {
return 0
}
var sum uint32
for i := 0; i < len(header)-1; i += 2 {
var sum1, sum2 uint32
// Parallel processing - unroll and compute two sums simultaneously
sum1 += uint32(binary.BigEndian.Uint16(header[0:2]))
sum2 += uint32(binary.BigEndian.Uint16(header[2:4]))
sum1 += uint32(binary.BigEndian.Uint16(header[4:6]))
sum2 += uint32(binary.BigEndian.Uint16(header[6:8]))
sum1 += uint32(binary.BigEndian.Uint16(header[8:10]))
// Skip checksum field at [10:12]
sum2 += uint32(binary.BigEndian.Uint16(header[12:14]))
sum1 += uint32(binary.BigEndian.Uint16(header[14:16]))
sum2 += uint32(binary.BigEndian.Uint16(header[16:18]))
sum1 += uint32(binary.BigEndian.Uint16(header[18:20]))
sum := sum1 + sum2
// Handle remaining bytes for headers > 20 bytes
for i := 20; i < len(header)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
}
@@ -24,30 +40,90 @@ func ipv4Checksum(header []byte) uint16 {
sum += uint32(header[len(header)-1]) << 8
}
for (sum >> 16) > 0 {
sum = (sum & 0xFFFF) + (sum >> 16)
// Optimized carry fold - single iteration handles most cases
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
}
return ^uint16(sum)
}
func icmpChecksum(data []byte) uint16 {
var sum uint32
for i := 0; i < len(data)-1; i += 2 {
var sum1, sum2, sum3, sum4 uint32
i := 0
// Process 16 bytes at once with 4 parallel accumulators
for i <= len(data)-16 {
sum1 += uint32(binary.BigEndian.Uint16(data[i : i+2]))
sum2 += uint32(binary.BigEndian.Uint16(data[i+2 : i+4]))
sum3 += uint32(binary.BigEndian.Uint16(data[i+4 : i+6]))
sum4 += uint32(binary.BigEndian.Uint16(data[i+6 : i+8]))
sum1 += uint32(binary.BigEndian.Uint16(data[i+8 : i+10]))
sum2 += uint32(binary.BigEndian.Uint16(data[i+10 : i+12]))
sum3 += uint32(binary.BigEndian.Uint16(data[i+12 : i+14]))
sum4 += uint32(binary.BigEndian.Uint16(data[i+14 : i+16]))
i += 16
}
sum := sum1 + sum2 + sum3 + sum4
// Handle remaining bytes
for i < len(data)-1 {
sum += uint32(binary.BigEndian.Uint16(data[i : i+2]))
i += 2
}
if len(data)%2 == 1 {
sum += uint32(data[len(data)-1]) << 8
}
for (sum >> 16) > 0 {
sum = (sum & 0xFFFF) + (sum >> 16)
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
}
return ^uint16(sum)
}
type biDNATMap struct {
forward map[netip.Addr]netip.Addr
reverse map[netip.Addr]netip.Addr
}
func newBiDNATMap() *biDNATMap {
return &biDNATMap{
forward: make(map[netip.Addr]netip.Addr),
reverse: make(map[netip.Addr]netip.Addr),
}
}
func (b *biDNATMap) set(original, translated netip.Addr) {
b.forward[original] = translated
b.reverse[translated] = original
}
func (b *biDNATMap) delete(original netip.Addr) {
if translated, exists := b.forward[original]; exists {
delete(b.forward, original)
delete(b.reverse, translated)
}
}
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
translated, exists := b.forward[original]
return translated, exists
}
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
original, exists := b.reverse[translated]
return original, exists
}
func (b *biDNATMap) len() int {
return len(b.forward)
}
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
return fmt.Errorf("invalid IP addresses")
@@ -58,11 +134,20 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
}
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
// Initialize both maps together if either is nil
if m.dnatMappings == nil || m.dnatBiMap == nil {
m.dnatMappings = make(map[netip.Addr]netip.Addr)
m.dnatBiMap = newBiDNATMap()
}
m.dnatMappings[originalAddr] = translatedAddr
m.dnatBiMap.set(originalAddr, translatedAddr)
if len(m.dnatMappings) == 1 {
m.dnatEnabled.Store(true)
}
m.dnatMutex.Unlock()
return nil
}
@@ -77,6 +162,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
}
delete(m.dnatMappings, originalAddr)
m.dnatBiMap.delete(originalAddr)
if len(m.dnatMappings) == 0 {
m.dnatEnabled.Store(false)
}
@@ -91,7 +177,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
}
m.dnatMutex.RLock()
translated, exists := m.dnatMappings[addr]
translated, exists := m.dnatBiMap.getTranslated(addr)
m.dnatMutex.RUnlock()
return translated, exists
}
@@ -103,15 +189,9 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr,
}
m.dnatMutex.RLock()
defer m.dnatMutex.RUnlock()
for original, translated := range m.dnatMappings {
if translated == translatedAddr {
return original, true
}
}
return translatedAddr, false
original, exists := m.dnatBiMap.getOriginal(translatedAddr)
m.dnatMutex.RUnlock()
return original, exists
}
// translateOutboundDNAT applies DNAT translation to outbound packets
@@ -120,22 +200,27 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return false
}
_, dstIP := m.extractIPs(d)
if !dstIP.IsValid() || !dstIP.Is4() {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translatedIP, exists := m.getDNATTranslation(dstIP)
if !exists {
return false
}
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
m.logger.Error("Failed to rewrite packet destination: %v", err)
if m.logger != nil {
m.logger.Error("Failed to rewrite packet destination: %v", err)
}
return false
}
m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP)
if m.logger != nil {
m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP)
}
return true
}
@@ -145,28 +230,33 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return false
}
srcIP, _ := m.extractIPs(d)
if !srcIP.IsValid() || !srcIP.Is4() {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
originalIP, exists := m.findReverseDNATMapping(srcIP)
if !exists {
return false
}
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
m.logger.Error("Failed to rewrite packet source: %v", err)
if m.logger != nil {
m.logger.Error("Failed to rewrite packet source: %v", err)
}
return false
}
m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP)
if m.logger != nil {
m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP)
}
return true
}
// rewritePacketDestination replaces destination IP in the packet
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
if d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return fmt.Errorf("only IPv4 supported")
}
@@ -177,6 +267,10 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP
copy(packetData[16:20], newDst[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
@@ -197,7 +291,7 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP
// rewritePacketSource replaces the source IP address in the packet
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
if d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return fmt.Errorf("only IPv4 supported")
}
@@ -208,6 +302,10 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip
copy(packetData[12:16], newSrc[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
@@ -271,22 +369,32 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
sum := uint32(^oldChecksum)
for i := 0; i < len(oldBytes)-1; i += 2 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2]))
}
if len(oldBytes)%2 == 1 {
sum += uint32(^oldBytes[len(oldBytes)-1]) << 8
// Fast path for IPv4 addresses (4 bytes) - most common case
if len(oldBytes) == 4 && len(newBytes) == 4 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
} else {
// Fallback for other lengths
for i := 0; i < len(oldBytes)-1; i += 2 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2]))
}
if len(oldBytes)%2 == 1 {
sum += uint32(^oldBytes[len(oldBytes)-1]) << 8
}
for i := 0; i < len(newBytes)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2]))
}
if len(newBytes)%2 == 1 {
sum += uint32(newBytes[len(newBytes)-1]) << 8
}
}
for i := 0; i < len(newBytes)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2]))
}
if len(newBytes)%2 == 1 {
sum += uint32(newBytes[len(newBytes)-1]) << 8
}
for (sum >> 16) > 0 {
sum = (sum & 0xffff) + (sum >> 16)
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
}
return ^uint16(sum)