This commit is contained in:
Viktor Liu
2025-07-02 20:10:59 +02:00
parent 279b77dee0
commit 4bbca28eb6
19 changed files with 71 additions and 417 deletions

View File

@@ -1267,15 +1267,6 @@ func (m *Manager) UnregisterNetstackService(protocol nftypes.Protocol, port uint
m.logger.Debug("Unregistered netstack service on protocol %s port %d", protocol, port)
}
// isNetstackService checks if a service is registered as listening on netstack for the given protocol and port
func (m *Manager) isNetstackService(layerType gopacket.LayerType, port uint16) bool {
m.netstackServiceMutex.RLock()
defer m.netstackServiceMutex.RUnlock()
key := serviceKey{protocol: layerType, port: port}
_, exists := m.netstackServices[key]
return exists
}
// protocolToLayerType converts nftypes.Protocol to gopacket.LayerType for internal use
func (m *Manager) protocolToLayerType(protocol nftypes.Protocol) gopacket.LayerType {
switch protocol {

View File

@@ -16,8 +16,11 @@ import (
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
var (
errInvalidIPHeaderLength = errors.New("invalid IP header length")
)
const (
invalidIPHeaderLengthMsg = "invalid IP header length"
errRewriteTCPDestinationPort = "rewrite TCP destination port: %v"
)
@@ -175,21 +178,6 @@ func (t *portNATTracker) getConnectionNAT(srcIP, dstIP netip.Addr, srcPort, dstP
return conn, exists
}
// removeConnection removes a tracked connection from the NAT tracking table.
func (t *portNATTracker) removeConnection(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
t.mutex.Lock()
defer t.mutex.Unlock()
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
delete(t.connections, key)
}
// shouldApplyNAT checks if NAT should be applied to a new connection to prevent bidirectional conflicts.
func (t *portNATTracker) shouldApplyNAT(srcIP, dstIP netip.Addr, dstPort uint16) bool {
t.mutex.RLock()
@@ -390,7 +378,7 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf(invalidIPHeaderLengthMsg)
return errInvalidIPHeaderLength
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
@@ -425,7 +413,7 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf(invalidIPHeaderLengthMsg)
return errInvalidIPHeaderLength
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
@@ -560,11 +548,12 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services on specific ports.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
if protocol == firewall.ProtocolTCP {
switch protocol {
case firewall.ProtocolTCP:
layerType = layers.LayerTypeTCP
} else if protocol == firewall.ProtocolUDP {
case firewall.ProtocolUDP:
layerType = layers.LayerTypeUDP
} else {
default:
return fmt.Errorf("unsupported protocol: %s", protocol)
}
@@ -594,11 +583,12 @@ func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.L
// RemoveInboundDNAT removes inbound DNAT rule for specified local address and ports.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
if protocol == firewall.ProtocolTCP {
switch protocol {
case firewall.ProtocolTCP:
layerType = layers.LayerTypeTCP
} else if protocol == firewall.ProtocolUDP {
case firewall.ProtocolUDP:
layerType = layers.LayerTypeUDP
} else {
default:
return fmt.Errorf("unsupported protocol: %s", protocol)
}
@@ -747,7 +737,7 @@ func (m *Manager) rewriteTCPDestinationPort(packetData []byte, d *decoder, newPo
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf(invalidIPHeaderLengthMsg)
return errInvalidIPHeaderLength
}
tcpStart := ipHeaderLen
@@ -786,7 +776,7 @@ func (m *Manager) rewriteTCPSourcePort(packetData []byte, d *decoder, newPort ui
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf(invalidIPHeaderLengthMsg)
return errInvalidIPHeaderLength
}
tcpStart := ipHeaderLen

View File

@@ -538,7 +538,9 @@ func TestSSHPortRedirectionEndToEnd(t *testing.T) {
// Read server response
buf := make([]byte, 1024)
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Logf("failed to set read deadline: %v", err)
}
n, err := conn.Read(buf)
require.NoError(t, err, "Should read server response")