diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 136d3741b..33aba0bf5 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -109,6 +109,7 @@ type Manager struct { dnatEnabled atomic.Bool dnatMappings map[netip.Addr]netip.Addr dnatMutex sync.RWMutex + dnatBiMap *biDNATMap } // decoder for packages diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go index 8c9343995..3d5fd603d 100644 --- a/client/firewall/uspfilter/nat.go +++ b/client/firewall/uspfilter/nat.go @@ -15,8 +15,24 @@ func ipv4Checksum(header []byte) uint16 { return 0 } - var sum uint32 - for i := 0; i < len(header)-1; i += 2 { + var sum1, sum2 uint32 + + // Parallel processing - unroll and compute two sums simultaneously + sum1 += uint32(binary.BigEndian.Uint16(header[0:2])) + sum2 += uint32(binary.BigEndian.Uint16(header[2:4])) + sum1 += uint32(binary.BigEndian.Uint16(header[4:6])) + sum2 += uint32(binary.BigEndian.Uint16(header[6:8])) + sum1 += uint32(binary.BigEndian.Uint16(header[8:10])) + // Skip checksum field at [10:12] + sum2 += uint32(binary.BigEndian.Uint16(header[12:14])) + sum1 += uint32(binary.BigEndian.Uint16(header[14:16])) + sum2 += uint32(binary.BigEndian.Uint16(header[16:18])) + sum1 += uint32(binary.BigEndian.Uint16(header[18:20])) + + sum := sum1 + sum2 + + // Handle remaining bytes for headers > 20 bytes + for i := 20; i < len(header)-1; i += 2 { sum += uint32(binary.BigEndian.Uint16(header[i : i+2])) } @@ -24,30 +40,90 @@ func ipv4Checksum(header []byte) uint16 { sum += uint32(header[len(header)-1]) << 8 } - for (sum >> 16) > 0 { - sum = (sum & 0xFFFF) + (sum >> 16) + // Optimized carry fold - single iteration handles most cases + sum = (sum & 0xFFFF) + (sum >> 16) + if sum > 0xFFFF { + sum++ } return ^uint16(sum) } func icmpChecksum(data []byte) uint16 { - var sum uint32 - for i := 0; i < len(data)-1; i += 2 { + var sum1, sum2, sum3, sum4 uint32 + i := 0 + + // Process 16 bytes at once with 4 parallel accumulators + for i <= len(data)-16 { + sum1 += uint32(binary.BigEndian.Uint16(data[i : i+2])) + sum2 += uint32(binary.BigEndian.Uint16(data[i+2 : i+4])) + sum3 += uint32(binary.BigEndian.Uint16(data[i+4 : i+6])) + sum4 += uint32(binary.BigEndian.Uint16(data[i+6 : i+8])) + sum1 += uint32(binary.BigEndian.Uint16(data[i+8 : i+10])) + sum2 += uint32(binary.BigEndian.Uint16(data[i+10 : i+12])) + sum3 += uint32(binary.BigEndian.Uint16(data[i+12 : i+14])) + sum4 += uint32(binary.BigEndian.Uint16(data[i+14 : i+16])) + i += 16 + } + + sum := sum1 + sum2 + sum3 + sum4 + + // Handle remaining bytes + for i < len(data)-1 { sum += uint32(binary.BigEndian.Uint16(data[i : i+2])) + i += 2 } if len(data)%2 == 1 { sum += uint32(data[len(data)-1]) << 8 } - for (sum >> 16) > 0 { - sum = (sum & 0xFFFF) + (sum >> 16) + sum = (sum & 0xFFFF) + (sum >> 16) + if sum > 0xFFFF { + sum++ } return ^uint16(sum) } +type biDNATMap struct { + forward map[netip.Addr]netip.Addr + reverse map[netip.Addr]netip.Addr +} + +func newBiDNATMap() *biDNATMap { + return &biDNATMap{ + forward: make(map[netip.Addr]netip.Addr), + reverse: make(map[netip.Addr]netip.Addr), + } +} + +func (b *biDNATMap) set(original, translated netip.Addr) { + b.forward[original] = translated + b.reverse[translated] = original +} + +func (b *biDNATMap) delete(original netip.Addr) { + if translated, exists := b.forward[original]; exists { + delete(b.forward, original) + delete(b.reverse, translated) + } +} + +func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) { + translated, exists := b.forward[original] + return translated, exists +} + +func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) { + original, exists := b.reverse[translated] + return original, exists +} + +func (b *biDNATMap) len() int { + return len(b.forward) +} + func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error { if !originalAddr.IsValid() || !translatedAddr.IsValid() { return fmt.Errorf("invalid IP addresses") @@ -58,11 +134,20 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr } m.dnatMutex.Lock() + defer m.dnatMutex.Unlock() + + // Initialize both maps together if either is nil + if m.dnatMappings == nil || m.dnatBiMap == nil { + m.dnatMappings = make(map[netip.Addr]netip.Addr) + m.dnatBiMap = newBiDNATMap() + } + m.dnatMappings[originalAddr] = translatedAddr + m.dnatBiMap.set(originalAddr, translatedAddr) + if len(m.dnatMappings) == 1 { m.dnatEnabled.Store(true) } - m.dnatMutex.Unlock() return nil } @@ -77,6 +162,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error { } delete(m.dnatMappings, originalAddr) + m.dnatBiMap.delete(originalAddr) if len(m.dnatMappings) == 0 { m.dnatEnabled.Store(false) } @@ -91,7 +177,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { } m.dnatMutex.RLock() - translated, exists := m.dnatMappings[addr] + translated, exists := m.dnatBiMap.getTranslated(addr) m.dnatMutex.RUnlock() return translated, exists } @@ -103,15 +189,9 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, } m.dnatMutex.RLock() - defer m.dnatMutex.RUnlock() - - for original, translated := range m.dnatMappings { - if translated == translatedAddr { - return original, true - } - } - - return translatedAddr, false + original, exists := m.dnatBiMap.getOriginal(translatedAddr) + m.dnatMutex.RUnlock() + return original, exists } // translateOutboundDNAT applies DNAT translation to outbound packets @@ -120,22 +200,27 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { return false } - _, dstIP := m.extractIPs(d) - if !dstIP.IsValid() || !dstIP.Is4() { + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { return false } + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + translatedIP, exists := m.getDNATTranslation(dstIP) if !exists { return false } if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil { - m.logger.Error("Failed to rewrite packet destination: %v", err) + if m.logger != nil { + m.logger.Error("Failed to rewrite packet destination: %v", err) + } return false } - m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP) + if m.logger != nil { + m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP) + } return true } @@ -145,28 +230,33 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return false } - srcIP, _ := m.extractIPs(d) - if !srcIP.IsValid() || !srcIP.Is4() { + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { return false } + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + originalIP, exists := m.findReverseDNATMapping(srcIP) if !exists { return false } if err := m.rewritePacketSource(packetData, d, originalIP); err != nil { - m.logger.Error("Failed to rewrite packet source: %v", err) + if m.logger != nil { + m.logger.Error("Failed to rewrite packet source: %v", err) + } return false } - m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP) + if m.logger != nil { + m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP) + } return true } // rewritePacketDestination replaces destination IP in the packet func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error { - if d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { return fmt.Errorf("only IPv4 supported") } @@ -177,6 +267,10 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP copy(packetData[16:20], newDst[:]) ipHeaderLen := int(d.ip4.IHL) * 4 + if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { + return fmt.Errorf("invalid IP header length") + } + binary.BigEndian.PutUint16(packetData[10:12], 0) ipChecksum := ipv4Checksum(packetData[:ipHeaderLen]) binary.BigEndian.PutUint16(packetData[10:12], ipChecksum) @@ -197,7 +291,7 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP // rewritePacketSource replaces the source IP address in the packet func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error { - if d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { return fmt.Errorf("only IPv4 supported") } @@ -208,6 +302,10 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip copy(packetData[12:16], newSrc[:]) ipHeaderLen := int(d.ip4.IHL) * 4 + if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { + return fmt.Errorf("invalid IP header length") + } + binary.BigEndian.PutUint16(packetData[10:12], 0) ipChecksum := ipv4Checksum(packetData[:ipHeaderLen]) binary.BigEndian.PutUint16(packetData[10:12], ipChecksum) @@ -271,22 +369,32 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { sum := uint32(^oldChecksum) - for i := 0; i < len(oldBytes)-1; i += 2 { - sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2])) - } - if len(oldBytes)%2 == 1 { - sum += uint32(^oldBytes[len(oldBytes)-1]) << 8 + // Fast path for IPv4 addresses (4 bytes) - most common case + if len(oldBytes) == 4 && len(newBytes) == 4 { + sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2])) + sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4])) + sum += uint32(binary.BigEndian.Uint16(newBytes[0:2])) + sum += uint32(binary.BigEndian.Uint16(newBytes[2:4])) + } else { + // Fallback for other lengths + for i := 0; i < len(oldBytes)-1; i += 2 { + sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2])) + } + if len(oldBytes)%2 == 1 { + sum += uint32(^oldBytes[len(oldBytes)-1]) << 8 + } + + for i := 0; i < len(newBytes)-1; i += 2 { + sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2])) + } + if len(newBytes)%2 == 1 { + sum += uint32(newBytes[len(newBytes)-1]) << 8 + } } - for i := 0; i < len(newBytes)-1; i += 2 { - sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2])) - } - if len(newBytes)%2 == 1 { - sum += uint32(newBytes[len(newBytes)-1]) << 8 - } - - for (sum >> 16) > 0 { - sum = (sum & 0xffff) + (sum >> 16) + sum = (sum & 0xFFFF) + (sum >> 16) + if sum > 0xFFFF { + sum++ } return ^uint16(sum)