mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-02 23:26:41 +00:00
Merge branch 'main' into ssh-rewrite
This commit is contained in:
@@ -260,7 +260,7 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -268,7 +268,7 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -880,7 +880,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
@@ -913,7 +913,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
|
||||
@@ -376,7 +376,7 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
@@ -384,7 +384,7 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -1350,7 +1350,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
@@ -1426,7 +1426,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
|
||||
@@ -22,6 +22,8 @@ type BaseConnTrack struct {
|
||||
PacketsRx atomic.Uint64
|
||||
BytesTx atomic.Uint64
|
||||
BytesRx atomic.Uint64
|
||||
|
||||
DNATOrigPort atomic.Uint32
|
||||
}
|
||||
|
||||
// these small methods will be inlined by the compiler
|
||||
|
||||
@@ -157,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
return tracker
|
||||
}
|
||||
|
||||
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
@@ -171,28 +171,30 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
||||
|
||||
if exists {
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
return key, true
|
||||
return key, uint16(conn.DNATOrigPort.Load()), true
|
||||
}
|
||||
|
||||
return key, false
|
||||
return key, 0, false
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound TCP connection
|
||||
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
|
||||
// TrackOutbound records an outbound TCP connection and returns the original port if DNAT reversal is needed
|
||||
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) uint16 {
|
||||
if _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); exists {
|
||||
return origPort
|
||||
}
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size, 0)
|
||||
return 0
|
||||
}
|
||||
|
||||
// TrackInbound processes an inbound TCP packet and updates connection state
|
||||
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
|
||||
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int, dnatOrigPort uint16) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size, dnatOrigPort)
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound connections
|
||||
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
||||
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
|
||||
key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
||||
if exists || flags&TCPSyn == 0 {
|
||||
return
|
||||
}
|
||||
@@ -210,8 +212,13 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
|
||||
|
||||
conn.tombstone.Store(false)
|
||||
conn.state.Store(int32(TCPStateNew))
|
||||
conn.DNATOrigPort.Store(uint32(origPort))
|
||||
|
||||
t.logger.Trace2("New %s TCP connection: %s", direction, key)
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s TCP connection: %s", direction, key)
|
||||
}
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
@@ -449,6 +456,21 @@ func (t *TCPTracker) cleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
// GetConnection safely retrieves a connection state
|
||||
func (t *TCPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*TCPConnTrack, bool) {
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn, exists := t.connections[key]
|
||||
return conn, exists
|
||||
}
|
||||
|
||||
// Close stops the cleanup routine and releases resources
|
||||
func (t *TCPTracker) Close() {
|
||||
t.tickerCancel()
|
||||
|
||||
@@ -603,7 +603,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
|
||||
serverPort := uint16(80)
|
||||
|
||||
// 1. Client sends SYN (we receive it as inbound)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: clientIP,
|
||||
@@ -623,12 +623,12 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||
|
||||
// 3. Client sends ACK to complete handshake
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
|
||||
|
||||
// 4. Test data transfer
|
||||
// Client sends data
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000, 0)
|
||||
|
||||
// Server sends ACK for data
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
||||
@@ -637,7 +637,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
|
||||
|
||||
// Client sends ACK for data
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||
|
||||
// Verify state and counters
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
@@ -58,20 +58,23 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
return tracker
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound UDP connection
|
||||
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
|
||||
// TrackOutbound records an outbound UDP connection and returns the original port if DNAT reversal is needed
|
||||
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) uint16 {
|
||||
_, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size)
|
||||
if exists {
|
||||
return origPort
|
||||
}
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size, 0)
|
||||
return 0
|
||||
}
|
||||
|
||||
// TrackInbound records an inbound UDP connection
|
||||
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
|
||||
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int, dnatOrigPort uint16) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size, dnatOrigPort)
|
||||
}
|
||||
|
||||
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
@@ -86,15 +89,15 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
if exists {
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(direction, size)
|
||||
return key, true
|
||||
return key, uint16(conn.DNATOrigPort.Load()), true
|
||||
}
|
||||
|
||||
return key, false
|
||||
return key, 0, false
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound connections
|
||||
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
|
||||
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
|
||||
key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
|
||||
if exists {
|
||||
return
|
||||
}
|
||||
@@ -109,6 +112,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
||||
SourcePort: srcPort,
|
||||
DestPort: dstPort,
|
||||
}
|
||||
conn.DNATOrigPort.Store(uint32(origPort))
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
@@ -116,7 +120,11 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.logger.Trace2("New %s UDP connection: %s", direction, key)
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s UDP connection: %s", direction, key)
|
||||
}
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||
}
|
||||
|
||||
|
||||
@@ -116,11 +116,9 @@ type Manager struct {
|
||||
dnatMutex sync.RWMutex
|
||||
dnatBiMap *biDNATMap
|
||||
|
||||
// Port-specific DNAT for SSH redirection
|
||||
portDNATEnabled atomic.Bool
|
||||
portDNATMap *portDNATMap
|
||||
portDNATRules []portDNATRule
|
||||
portDNATMutex sync.RWMutex
|
||||
portNATTracker *portNATTracker
|
||||
|
||||
netstackServices map[serviceKey]struct{}
|
||||
netstackServiceMutex sync.RWMutex
|
||||
@@ -137,6 +135,8 @@ type decoder struct {
|
||||
icmp6 layers.ICMPv6
|
||||
decoded []gopacket.LayerType
|
||||
parser *gopacket.DecodingLayerParser
|
||||
|
||||
dnatOrigPort uint16
|
||||
}
|
||||
|
||||
// Create userspace firewall manager constructor
|
||||
@@ -211,8 +211,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||
portDNATMap: &portDNATMap{rules: make([]portDNATRule, 0)},
|
||||
portNATTracker: newPortNATTracker(),
|
||||
portDNATRules: []portDNATRule{},
|
||||
netstackServices: make(map[serviceKey]struct{}),
|
||||
}
|
||||
m.routingEnabled.Store(false)
|
||||
@@ -351,22 +350,18 @@ func (m *Manager) initForwarder() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Init initializes the firewall manager with state manager.
|
||||
func (m *Manager) Init(*statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsServerRouteSupported returns whether server routes are supported.
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// IsStateful returns whether the firewall manager tracks connection state.
|
||||
func (m *Manager) IsStateful() bool {
|
||||
return m.stateful
|
||||
}
|
||||
|
||||
// AddNatRule adds a routing firewall rule for NAT translation.
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddNatRule(pair)
|
||||
@@ -652,9 +647,8 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
m.trackOutbound(d, srcIP, dstIP, size)
|
||||
m.trackOutbound(d, srcIP, dstIP, packetData, size)
|
||||
m.translateOutboundDNAT(packetData, d)
|
||||
m.translateOutboundPortReverse(packetData, d)
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -697,14 +691,26 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
|
||||
return flags
|
||||
}
|
||||
|
||||
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
|
||||
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) {
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
case layers.LayerTypeUDP:
|
||||
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
|
||||
origPort := m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
|
||||
if origPort == 0 {
|
||||
break
|
||||
}
|
||||
if err := m.rewriteUDPPort(packetData, d, origPort, sourcePortOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite UDP port: %v", err)
|
||||
}
|
||||
case layers.LayerTypeTCP:
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
||||
origPort := m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
||||
if origPort == 0 {
|
||||
break
|
||||
}
|
||||
if err := m.rewriteTCPPort(packetData, d, origPort, sourcePortOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite TCP port: %v", err)
|
||||
}
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
|
||||
}
|
||||
@@ -714,13 +720,15 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
case layers.LayerTypeUDP:
|
||||
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size)
|
||||
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size, d.dnatOrigPort)
|
||||
case layers.LayerTypeTCP:
|
||||
flags := getTCPFlags(&d.tcp)
|
||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
|
||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
|
||||
}
|
||||
|
||||
d.dnatOrigPort = 0
|
||||
}
|
||||
|
||||
// udpHooksDrop checks if any UDP hooks should drop the packet
|
||||
@@ -782,10 +790,11 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if translated := m.translateInboundPortDNAT(packetData, d); translated {
|
||||
// TODO: optimize port DNAT by caching matched rules in conntrack
|
||||
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
|
||||
// Re-decode after port DNAT translation to update port information
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
m.logger.Error1("Failed to re-decode packet after port DNAT: %v", err)
|
||||
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
|
||||
return true
|
||||
}
|
||||
srcIP, dstIP = m.extractIPs(d)
|
||||
@@ -794,7 +803,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
if translated := m.translateInboundReverse(packetData, d); translated {
|
||||
// Re-decode after translation to get original addresses
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err)
|
||||
m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
|
||||
return true
|
||||
}
|
||||
srcIP, dstIP = m.extractIPs(d)
|
||||
|
||||
@@ -5,8 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
"slices"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
@@ -21,10 +20,16 @@ var (
|
||||
)
|
||||
|
||||
const (
|
||||
errRewriteTCPDestinationPort = "rewrite TCP destination port: %v"
|
||||
// Port offsets in TCP/UDP headers
|
||||
sourcePortOffset = 0
|
||||
destinationPortOffset = 2
|
||||
|
||||
// IP address offsets in IPv4 header
|
||||
sourceIPOffset = 12
|
||||
destinationIPOffset = 16
|
||||
)
|
||||
|
||||
// ipv4Checksum calculates IPv4 header checksum using optimized parallel processing for performance.
|
||||
// ipv4Checksum calculates IPv4 header checksum.
|
||||
func ipv4Checksum(header []byte) uint16 {
|
||||
if len(header) < 20 {
|
||||
return 0
|
||||
@@ -64,7 +69,7 @@ func ipv4Checksum(header []byte) uint16 {
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
// icmpChecksum calculates ICMP checksum using parallel accumulation for high-performance processing.
|
||||
// icmpChecksum calculates ICMP checksum.
|
||||
func icmpChecksum(data []byte) uint16 {
|
||||
var sum1, sum2, sum3, sum4 uint32
|
||||
i := 0
|
||||
@@ -102,116 +107,21 @@ func icmpChecksum(data []byte) uint16 {
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
// biDNATMap maintains bidirectional DNAT mappings for efficient forward and reverse lookups.
|
||||
// biDNATMap maintains bidirectional DNAT mappings.
|
||||
type biDNATMap struct {
|
||||
forward map[netip.Addr]netip.Addr
|
||||
reverse map[netip.Addr]netip.Addr
|
||||
}
|
||||
|
||||
// portDNATRule represents a port-specific DNAT rule
|
||||
// portDNATRule represents a port-specific DNAT rule.
|
||||
type portDNATRule struct {
|
||||
protocol gopacket.LayerType
|
||||
sourcePort uint16
|
||||
origPort uint16
|
||||
targetPort uint16
|
||||
targetIP netip.Addr
|
||||
}
|
||||
|
||||
// portDNATMap manages port-specific DNAT rules
|
||||
type portDNATMap struct {
|
||||
rules []portDNATRule
|
||||
}
|
||||
|
||||
// ConnKey represents a connection 4-tuple for NAT tracking.
|
||||
type ConnKey struct {
|
||||
SrcIP netip.Addr
|
||||
DstIP netip.Addr
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
}
|
||||
|
||||
// portNATConn tracks port NAT state for a specific connection.
|
||||
type portNATConn struct {
|
||||
rule portDNATRule
|
||||
originalPort uint16
|
||||
translatedAt time.Time
|
||||
}
|
||||
|
||||
// portNATTracker tracks connection-specific port NAT state
|
||||
type portNATTracker struct {
|
||||
connections map[ConnKey]*portNATConn
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// newPortNATTracker creates a new port NAT tracker for stateful connection tracking.
|
||||
func newPortNATTracker() *portNATTracker {
|
||||
return &portNATTracker{
|
||||
connections: make(map[ConnKey]*portNATConn),
|
||||
}
|
||||
}
|
||||
|
||||
// trackConnection tracks a connection that has port NAT applied using translated port as key.
|
||||
func (t *portNATTracker) trackConnection(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, rule portDNATRule) {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: rule.targetPort,
|
||||
}
|
||||
|
||||
t.connections[key] = &portNATConn{
|
||||
rule: rule,
|
||||
originalPort: dstPort,
|
||||
translatedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// getConnectionNAT returns NAT info for a connection if tracked, looking up by connection 4-tuple.
|
||||
func (t *portNATTracker) getConnectionNAT(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) (*portNATConn, bool) {
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
|
||||
conn, exists := t.connections[key]
|
||||
return conn, exists
|
||||
}
|
||||
|
||||
// shouldApplyNAT checks if NAT should be applied to a new connection to prevent bidirectional conflicts.
|
||||
func (t *portNATTracker) shouldApplyNAT(srcIP, dstIP netip.Addr, dstPort uint16) bool {
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
|
||||
for key, conn := range t.connections {
|
||||
if key.SrcIP == dstIP && key.DstIP == srcIP &&
|
||||
conn.rule.sourcePort == dstPort && conn.originalPort == dstPort {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// cleanupConnection removes a NAT connection based on original 4-tuple for connection cleanup.
|
||||
func (t *portNATTracker) cleanupConnection(srcIP, dstIP netip.Addr, srcPort uint16) {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
for key := range t.connections {
|
||||
if key.SrcIP == srcIP && key.DstIP == dstIP && key.SrcPort == srcPort {
|
||||
delete(t.connections, key)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newBiDNATMap creates a new bidirectional DNAT mapping structure for efficient forward/reverse lookups.
|
||||
// newBiDNATMap creates a new bidirectional DNAT mapping structure.
|
||||
func newBiDNATMap() *biDNATMap {
|
||||
return &biDNATMap{
|
||||
forward: make(map[netip.Addr]netip.Addr),
|
||||
@@ -219,7 +129,7 @@ func newBiDNATMap() *biDNATMap {
|
||||
}
|
||||
}
|
||||
|
||||
// set adds a bidirectional DNAT mapping between original and translated addresses for both directions.
|
||||
// set adds a bidirectional DNAT mapping between original and translated addresses.
|
||||
func (b *biDNATMap) set(original, translated netip.Addr) {
|
||||
b.forward[original] = translated
|
||||
b.reverse[translated] = original
|
||||
@@ -233,13 +143,13 @@ func (b *biDNATMap) delete(original netip.Addr) {
|
||||
}
|
||||
}
|
||||
|
||||
// getTranslated returns the translated address for a given original address from forward mapping.
|
||||
// getTranslated returns the translated address for a given original address.
|
||||
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
|
||||
translated, exists := b.forward[original]
|
||||
return translated, exists
|
||||
}
|
||||
|
||||
// getOriginal returns the original address for a given translated address from reverse mapping.
|
||||
// getOriginal returns the original address for a given translated address.
|
||||
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
|
||||
original, exists := b.reverse[translated]
|
||||
return original, exists
|
||||
@@ -261,7 +171,6 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
|
||||
m.dnatMutex.Lock()
|
||||
defer m.dnatMutex.Unlock()
|
||||
|
||||
// Initialize both maps together if either is nil
|
||||
if m.dnatMappings == nil || m.dnatBiMap == nil {
|
||||
m.dnatMappings = make(map[netip.Addr]netip.Addr)
|
||||
m.dnatBiMap = newBiDNATMap()
|
||||
@@ -295,7 +204,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// getDNATTranslation returns the translated address if a mapping exists with fast-path optimization.
|
||||
// getDNATTranslation returns the translated address if a mapping exists.
|
||||
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return addr, false
|
||||
@@ -307,7 +216,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
|
||||
return translated, exists
|
||||
}
|
||||
|
||||
// findReverseDNATMapping finds original address for return traffic using reverse mapping.
|
||||
// findReverseDNATMapping finds original address for return traffic.
|
||||
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return translatedAddr, false
|
||||
@@ -319,16 +228,12 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr,
|
||||
return original, exists
|
||||
}
|
||||
|
||||
// translateOutboundDNAT applies DNAT translation to outbound packets for 1:1 IP mapping.
|
||||
// translateOutboundDNAT applies DNAT translation to outbound packets.
|
||||
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
return false
|
||||
}
|
||||
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
|
||||
translatedIP, exists := m.getDNATTranslation(dstIP)
|
||||
@@ -336,8 +241,8 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
|
||||
m.logger.Error1("rewrite packet destination: %v", err)
|
||||
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite packet destination: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -345,16 +250,12 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// translateInboundReverse applies reverse DNAT to inbound return traffic for 1:1 IP mapping.
|
||||
// translateInboundReverse applies reverse DNAT to inbound return traffic.
|
||||
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
|
||||
originalIP, exists := m.findReverseDNATMapping(srcIP)
|
||||
@@ -362,8 +263,8 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
|
||||
m.logger.Error1("rewrite packet source: %v", err)
|
||||
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite packet source: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -371,17 +272,17 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// rewritePacketDestination replaces destination IP in the packet and updates checksums.
|
||||
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||
// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums.
|
||||
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error {
|
||||
if !newIP.Is4() {
|
||||
return ErrIPv4Only
|
||||
}
|
||||
|
||||
var oldDst [4]byte
|
||||
copy(oldDst[:], packetData[16:20])
|
||||
newDst := newIP.As4()
|
||||
var oldIP [4]byte
|
||||
copy(oldIP[:], packetData[ipOffset:ipOffset+4])
|
||||
newIPBytes := newIP.As4()
|
||||
|
||||
copy(packetData[16:20], newDst[:])
|
||||
copy(packetData[ipOffset:ipOffset+4], newIPBytes[:])
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
@@ -395,9 +296,9 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||
}
|
||||
@@ -406,42 +307,7 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP
|
||||
return nil
|
||||
}
|
||||
|
||||
// rewritePacketSource replaces the source IP address in the packet and updates checksums.
|
||||
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||
return ErrIPv4Only
|
||||
}
|
||||
|
||||
var oldSrc [4]byte
|
||||
copy(oldSrc[:], packetData[12:16])
|
||||
newSrc := newIP.As4()
|
||||
|
||||
copy(packetData[12:16], newSrc[:])
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return errInvalidIPHeaderLength
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||
|
||||
if len(d.decoded) > 1 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
||||
case layers.LayerTypeUDP:
|
||||
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
||||
case layers.LayerTypeICMPv4:
|
||||
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateTCPChecksum updates TCP checksum after IP address change using incremental update per RFC 1624.
|
||||
// updateTCPChecksum updates TCP checksum after IP address change per RFC 1624.
|
||||
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||
tcpStart := ipHeaderLen
|
||||
if len(packetData) < tcpStart+18 {
|
||||
@@ -454,7 +320,7 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
// updateUDPChecksum updates UDP checksum after IP address change using incremental update per RFC 1624.
|
||||
// updateUDPChecksum updates UDP checksum after IP address change per RFC 1624.
|
||||
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||
udpStart := ipHeaderLen
|
||||
if len(packetData) < udpStart+8 {
|
||||
@@ -472,7 +338,7 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
// updateICMPChecksum recalculates ICMP checksum after packet modification using full recalculation.
|
||||
// updateICMPChecksum recalculates ICMP checksum after packet modification.
|
||||
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
||||
icmpStart := ipHeaderLen
|
||||
if len(packetData) < icmpStart+8 {
|
||||
@@ -485,7 +351,7 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
||||
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
|
||||
}
|
||||
|
||||
// incrementalUpdate performs incremental checksum update per RFC 1624 for performance.
|
||||
// incrementalUpdate performs incremental checksum update per RFC 1624.
|
||||
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||
sum := uint32(^oldChecksum)
|
||||
|
||||
@@ -536,25 +402,25 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
// addPortRedirection adds port redirection rule for specified target IP, protocol and ports.
|
||||
// addPortRedirection adds a port redirection rule.
|
||||
func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
|
||||
m.portDNATMutex.Lock()
|
||||
defer m.portDNATMutex.Unlock()
|
||||
|
||||
rule := portDNATRule{
|
||||
protocol: protocol,
|
||||
sourcePort: sourcePort,
|
||||
origPort: sourcePort,
|
||||
targetPort: targetPort,
|
||||
targetIP: targetIP,
|
||||
}
|
||||
|
||||
m.portDNATMap.rules = append(m.portDNATMap.rules, rule)
|
||||
m.portDNATRules = append(m.portDNATRules, rule)
|
||||
m.portDNATEnabled.Store(true)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services on specific ports.
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
var layerType gopacket.LayerType
|
||||
switch protocol {
|
||||
@@ -569,27 +435,23 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
|
||||
return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// removePortRedirection removes port redirection rule for specified target IP, protocol and ports.
|
||||
// removePortRedirection removes a port redirection rule.
|
||||
func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
|
||||
m.portDNATMutex.Lock()
|
||||
defer m.portDNATMutex.Unlock()
|
||||
|
||||
var filteredRules []portDNATRule
|
||||
for _, rule := range m.portDNATMap.rules {
|
||||
if !(rule.protocol == protocol && rule.sourcePort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0) {
|
||||
filteredRules = append(filteredRules, rule)
|
||||
}
|
||||
}
|
||||
m.portDNATMap.rules = filteredRules
|
||||
m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool {
|
||||
return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0
|
||||
})
|
||||
|
||||
if len(m.portDNATMap.rules) == 0 {
|
||||
if len(m.portDNATRules) == 0 {
|
||||
m.portDNATEnabled.Store(false)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes inbound DNAT rule for specified local address and ports.
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
var layerType gopacket.LayerType
|
||||
switch protocol {
|
||||
@@ -604,146 +466,55 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// translateInboundPortDNAT applies stateful port-specific DNAT translation to inbound packets.
|
||||
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder) bool {
|
||||
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
|
||||
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||
if !m.portDNATEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
dstPort := uint16(d.tcp.DstPort)
|
||||
return m.applyPortRule(packetData, d, srcIP, dstIP, dstPort, layers.LayerTypeTCP, m.rewriteTCPPort)
|
||||
case layers.LayerTypeUDP:
|
||||
dstPort := uint16(d.udp.DstPort)
|
||||
return m.applyPortRule(packetData, d, netip.Addr{}, dstIP, dstPort, layers.LayerTypeUDP, m.rewriteUDPPort)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP {
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
srcPort := uint16(d.tcp.SrcPort)
|
||||
dstPort := uint16(d.tcp.DstPort)
|
||||
|
||||
if m.handleReturnTraffic(packetData, d, srcIP, dstIP, srcPort, dstPort) {
|
||||
return true
|
||||
}
|
||||
|
||||
return m.handleNewConnection(packetData, d, srcIP, dstIP, srcPort, dstPort)
|
||||
}
|
||||
|
||||
// handleReturnTraffic processes return traffic for existing NAT connections.
|
||||
func (m *Manager) handleReturnTraffic(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool {
|
||||
if m.isTranslatedPortTraffic(srcIP, srcPort) {
|
||||
return false
|
||||
}
|
||||
type portRewriteFunc func(packetData []byte, d *decoder, newPort uint16, portOffset int) error
|
||||
|
||||
if handled := m.handleExistingNATConnection(packetData, d, srcIP, dstIP, srcPort, dstPort); handled {
|
||||
return true
|
||||
}
|
||||
|
||||
return m.handleForwardTrafficInExistingConnections(packetData, d, srcIP, dstIP, srcPort, dstPort)
|
||||
}
|
||||
|
||||
// isTranslatedPortTraffic checks if traffic is from a translated port that should be handled by outbound reverse.
|
||||
func (m *Manager) isTranslatedPortTraffic(srcIP netip.Addr, srcPort uint16) bool {
|
||||
func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, port uint16, protocol gopacket.LayerType, rewriteFn portRewriteFunc) bool {
|
||||
m.portDNATMutex.RLock()
|
||||
defer m.portDNATMutex.RUnlock()
|
||||
|
||||
for _, rule := range m.portDNATMap.rules {
|
||||
if rule.protocol == layers.LayerTypeTCP && rule.targetPort == srcPort &&
|
||||
rule.targetIP.Unmap().Compare(srcIP.Unmap()) == 0 {
|
||||
return true
|
||||
for _, rule := range m.portDNATRules {
|
||||
if rule.protocol != protocol || rule.targetIP.Compare(dstIP) != 0 {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// handleExistingNATConnection processes return traffic for existing NAT connections.
|
||||
func (m *Manager) handleExistingNATConnection(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool {
|
||||
if natConn, exists := m.portNATTracker.getConnectionNAT(dstIP, srcIP, dstPort, srcPort); exists {
|
||||
if err := m.rewriteTCPDestinationPort(packetData, d, natConn.originalPort); err != nil {
|
||||
m.logger.Error1(errRewriteTCPDestinationPort, err)
|
||||
if rule.targetPort == port && rule.targetIP.Compare(srcIP) == 0 {
|
||||
return false
|
||||
}
|
||||
m.logger.Trace4("Inbound Port DNAT (return): %s:%d -> %s:%d", dstIP, srcPort, dstIP, natConn.originalPort)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// handleForwardTrafficInExistingConnections processes forward traffic in existing connections.
|
||||
func (m *Manager) handleForwardTrafficInExistingConnections(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool {
|
||||
m.portDNATMutex.RLock()
|
||||
defer m.portDNATMutex.RUnlock()
|
||||
|
||||
for _, rule := range m.portDNATMap.rules {
|
||||
if rule.protocol != layers.LayerTypeTCP || rule.sourcePort != dstPort {
|
||||
continue
|
||||
}
|
||||
if rule.targetIP.Unmap().Compare(dstIP.Unmap()) != 0 {
|
||||
if rule.origPort != port {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, exists := m.portNATTracker.getConnectionNAT(srcIP, dstIP, srcPort, rule.targetPort); !exists {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.rewriteTCPDestinationPort(packetData, d, rule.targetPort); err != nil {
|
||||
m.logger.Error1(errRewriteTCPDestinationPort, err)
|
||||
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite port: %v", err)
|
||||
return false
|
||||
}
|
||||
d.dnatOrigPort = rule.origPort
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// handleNewConnection processes new connections that match port DNAT rules.
|
||||
func (m *Manager) handleNewConnection(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool {
|
||||
m.portDNATMutex.RLock()
|
||||
defer m.portDNATMutex.RUnlock()
|
||||
|
||||
for _, rule := range m.portDNATMap.rules {
|
||||
if m.applyPortDNATRule(packetData, d, rule, srcIP, dstIP, srcPort, dstPort) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// applyPortDNATRule applies a specific port DNAT rule if conditions are met.
|
||||
func (m *Manager) applyPortDNATRule(packetData []byte, d *decoder, rule portDNATRule, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool {
|
||||
if rule.protocol != layers.LayerTypeTCP || rule.sourcePort != dstPort {
|
||||
return false
|
||||
}
|
||||
|
||||
if rule.targetIP.Unmap().Compare(dstIP.Unmap()) != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if !m.portNATTracker.shouldApplyNAT(srcIP, dstIP, dstPort) {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewriteTCPDestinationPort(packetData, d, rule.targetPort); err != nil {
|
||||
m.logger.Error1(errRewriteTCPDestinationPort, err)
|
||||
return false
|
||||
}
|
||||
|
||||
m.portNATTracker.trackConnection(srcIP, dstIP, srcPort, dstPort, rule)
|
||||
m.logger.Trace8("Inbound Port DNAT (new): %s:%d -> %s:%d (tracked: %s:%d -> %s:%d)", dstIP, rule.sourcePort, dstIP, rule.targetPort, srcIP, srcPort, dstIP, rule.targetPort)
|
||||
return true
|
||||
}
|
||||
|
||||
// rewriteTCPDestinationPort rewrites the destination port in a TCP packet and updates checksum.
|
||||
func (m *Manager) rewriteTCPDestinationPort(packetData []byte, d *decoder, newPort uint16) error {
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
return ErrIPv4Only
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP {
|
||||
return fmt.Errorf("not a TCP packet")
|
||||
}
|
||||
|
||||
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
|
||||
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return errInvalidIPHeaderLength
|
||||
@@ -754,9 +525,9 @@ func (m *Manager) rewriteTCPDestinationPort(packetData []byte, d *decoder, newPo
|
||||
return fmt.Errorf("packet too short for TCP header")
|
||||
}
|
||||
|
||||
oldPort := binary.BigEndian.Uint16(packetData[tcpStart+2 : tcpStart+4])
|
||||
|
||||
binary.BigEndian.PutUint16(packetData[tcpStart+2:tcpStart+4], newPort)
|
||||
portStart := tcpStart + portOffset
|
||||
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
|
||||
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
|
||||
|
||||
if len(packetData) >= tcpStart+18 {
|
||||
checksumOffset := tcpStart + 16
|
||||
@@ -773,75 +544,34 @@ func (m *Manager) rewriteTCPDestinationPort(packetData []byte, d *decoder, newPo
|
||||
return nil
|
||||
}
|
||||
|
||||
// rewriteTCPSourcePort rewrites the source port in a TCP packet and updates checksum.
|
||||
func (m *Manager) rewriteTCPSourcePort(packetData []byte, d *decoder, newPort uint16) error {
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
return ErrIPv4Only
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP {
|
||||
return fmt.Errorf("not a TCP packet")
|
||||
}
|
||||
|
||||
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
|
||||
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return errInvalidIPHeaderLength
|
||||
}
|
||||
|
||||
tcpStart := ipHeaderLen
|
||||
if len(packetData) < tcpStart+4 {
|
||||
return fmt.Errorf("packet too short for TCP header")
|
||||
udpStart := ipHeaderLen
|
||||
if len(packetData) < udpStart+8 {
|
||||
return fmt.Errorf("packet too short for UDP header")
|
||||
}
|
||||
|
||||
oldPort := binary.BigEndian.Uint16(packetData[tcpStart : tcpStart+2])
|
||||
portStart := udpStart + portOffset
|
||||
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
|
||||
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
|
||||
|
||||
binary.BigEndian.PutUint16(packetData[tcpStart:tcpStart+2], newPort)
|
||||
|
||||
if len(packetData) >= tcpStart+18 {
|
||||
checksumOffset := tcpStart + 16
|
||||
checksumOffset := udpStart + 6
|
||||
if len(packetData) >= udpStart+8 {
|
||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||
if oldChecksum != 0 {
|
||||
var oldPortBytes, newPortBytes [2]byte
|
||||
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
|
||||
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
|
||||
|
||||
var oldPortBytes, newPortBytes [2]byte
|
||||
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
|
||||
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
|
||||
|
||||
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// translateOutboundPortReverse applies stateful reverse port DNAT to outbound return traffic for SSH redirection.
|
||||
func (m *Manager) translateOutboundPortReverse(packetData []byte, d *decoder) bool {
|
||||
if !m.portDNATEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP {
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
srcPort := uint16(d.tcp.SrcPort)
|
||||
dstPort := uint16(d.tcp.DstPort)
|
||||
|
||||
// For outbound reverse, we need to find the connection using the same key as when it was stored
|
||||
// Connection was stored as: srcIP, dstIP, srcPort, translatedPort
|
||||
// So for return traffic (srcIP=server, dstIP=client), we need: dstIP, srcIP, dstPort, srcPort
|
||||
if natConn, exists := m.portNATTracker.getConnectionNAT(dstIP, srcIP, dstPort, srcPort); exists {
|
||||
if err := m.rewriteTCPSourcePort(packetData, d, natConn.rule.sourcePort); err != nil {
|
||||
m.logger.Error1("rewrite TCP source port: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -414,3 +414,127 @@ func BenchmarkChecksumOptimizations(b *testing.B) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkPortDNAT measures the performance of port DNAT operations
|
||||
func BenchmarkPortDNAT(b *testing.B) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
proto layers.IPProtocol
|
||||
setupDNAT bool
|
||||
useMatchPort bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "tcp_inbound_dnat_match",
|
||||
proto: layers.IPProtocolTCP,
|
||||
setupDNAT: true,
|
||||
useMatchPort: true,
|
||||
description: "TCP inbound port DNAT translation (22 → 22022)",
|
||||
},
|
||||
{
|
||||
name: "tcp_inbound_dnat_nomatch",
|
||||
proto: layers.IPProtocolTCP,
|
||||
setupDNAT: true,
|
||||
useMatchPort: false,
|
||||
description: "TCP inbound with DNAT configured but no port match",
|
||||
},
|
||||
{
|
||||
name: "tcp_inbound_no_dnat",
|
||||
proto: layers.IPProtocolTCP,
|
||||
setupDNAT: false,
|
||||
useMatchPort: false,
|
||||
description: "TCP inbound without DNAT (baseline)",
|
||||
},
|
||||
{
|
||||
name: "udp_inbound_dnat_match",
|
||||
proto: layers.IPProtocolUDP,
|
||||
setupDNAT: true,
|
||||
useMatchPort: true,
|
||||
description: "UDP inbound port DNAT translation (5353 → 22054)",
|
||||
},
|
||||
{
|
||||
name: "udp_inbound_dnat_nomatch",
|
||||
proto: layers.IPProtocolUDP,
|
||||
setupDNAT: true,
|
||||
useMatchPort: false,
|
||||
description: "UDP inbound with DNAT configured but no port match",
|
||||
},
|
||||
{
|
||||
name: "udp_inbound_no_dnat",
|
||||
proto: layers.IPProtocolUDP,
|
||||
setupDNAT: false,
|
||||
useMatchPort: false,
|
||||
description: "UDP inbound without DNAT (baseline)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Set logger to error level to reduce noise during benchmarking
|
||||
manager.SetLogLevel(log.ErrorLevel)
|
||||
defer func() {
|
||||
// Restore to info level after benchmark
|
||||
manager.SetLogLevel(log.InfoLevel)
|
||||
}()
|
||||
|
||||
localAddr := netip.MustParseAddr("100.0.2.175")
|
||||
clientIP := netip.MustParseAddr("100.0.169.249")
|
||||
|
||||
var origPort, targetPort, testPort uint16
|
||||
if sc.proto == layers.IPProtocolTCP {
|
||||
origPort, targetPort = 22, 22022
|
||||
} else {
|
||||
origPort, targetPort = 5353, 22054
|
||||
}
|
||||
|
||||
if sc.useMatchPort {
|
||||
testPort = origPort
|
||||
} else {
|
||||
testPort = 443 // Different port
|
||||
}
|
||||
|
||||
// Setup port DNAT mapping if needed
|
||||
if sc.setupDNAT {
|
||||
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(sc.proto), origPort, targetPort)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
// Pre-establish inbound connection for outbound reverse test
|
||||
if sc.setupDNAT && sc.useMatchPort {
|
||||
inboundPacket := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, origPort)
|
||||
manager.filterInbound(inboundPacket, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
// Benchmark inbound DNAT translation
|
||||
b.Run("inbound", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Create fresh packet each time
|
||||
packet := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, testPort)
|
||||
manager.filterInbound(packet, 0)
|
||||
}
|
||||
})
|
||||
|
||||
// Benchmark outbound reverse DNAT translation (only if DNAT is set up and port matches)
|
||||
if sc.setupDNAT && sc.useMatchPort {
|
||||
b.Run("outbound_reverse", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Create fresh return packet (from target port)
|
||||
packet := generateDNATTestPacket(b, localAddr, clientIP, sc.proto, targetPort, 54321)
|
||||
manager.filterOutbound(packet, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
@@ -148,8 +145,7 @@ func TestDNATMappingManagement(t *testing.T) {
|
||||
require.Error(t, err, "Should error when removing non-existent mapping")
|
||||
}
|
||||
|
||||
// TestSSHPortRedirection tests SSH port redirection functionality
|
||||
func TestSSHPortRedirection(t *testing.T) {
|
||||
func TestInboundPortDNAT(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
@@ -158,462 +154,48 @@ func TestSSHPortRedirection(t *testing.T) {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define NetBird network range
|
||||
peerIP := netip.MustParseAddr("100.10.0.50")
|
||||
clientIP := netip.MustParseAddr("100.10.0.100")
|
||||
localAddr := netip.MustParseAddr("100.0.2.175")
|
||||
clientIP := netip.MustParseAddr("100.0.169.249")
|
||||
|
||||
// Add SSH port redirection rule
|
||||
err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
testCases := []struct {
|
||||
name string
|
||||
protocol layers.IPProtocol
|
||||
sourcePort uint16
|
||||
targetPort uint16
|
||||
}{
|
||||
{"TCP SSH", layers.IPProtocolTCP, 22, 22022},
|
||||
{"UDP DNS", layers.IPProtocolUDP, 5353, 22054},
|
||||
}
|
||||
|
||||
// Verify port DNAT is enabled
|
||||
require.True(t, manager.portDNATEnabled.Load(), "Port DNAT should be enabled")
|
||||
require.Len(t, manager.portDNATMap.rules, 1, "Should have one port DNAT rule")
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the rule configuration
|
||||
rule := manager.portDNATMap.rules[0]
|
||||
require.Equal(t, gopacket.LayerType(layers.LayerTypeTCP), rule.protocol)
|
||||
require.Equal(t, uint16(22), rule.sourcePort)
|
||||
require.Equal(t, uint16(22022), rule.targetPort)
|
||||
require.Equal(t, peerIP, rule.targetIP)
|
||||
inboundPacket := generateDNATTestPacket(t, clientIP, localAddr, tc.protocol, 54321, tc.sourcePort)
|
||||
d := parsePacket(t, inboundPacket)
|
||||
|
||||
// Test inbound SSH packet (client -> peer:22, should redirect to peer:22022)
|
||||
inboundPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22)
|
||||
originalInbound := make([]byte, len(inboundPacket))
|
||||
copy(originalInbound, inboundPacket)
|
||||
translated := manager.translateInboundPortDNAT(inboundPacket, d, clientIP, localAddr)
|
||||
require.True(t, translated, "Inbound packet should be translated")
|
||||
|
||||
// Process inbound packet
|
||||
translated := manager.translateInboundPortDNAT(inboundPacket, parsePacket(t, inboundPacket))
|
||||
require.True(t, translated, "Inbound SSH packet should be translated")
|
||||
|
||||
// Verify destination port was changed from 22 to 22022
|
||||
d := parsePacket(t, inboundPacket)
|
||||
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Destination port should be rewritten to 22022")
|
||||
|
||||
// Verify destination IP remains unchanged
|
||||
dstIPAfter := netip.AddrFrom4([4]byte{inboundPacket[16], inboundPacket[17], inboundPacket[18], inboundPacket[19]})
|
||||
require.Equal(t, peerIP, dstIPAfter, "Destination IP should remain unchanged")
|
||||
|
||||
// Test outbound return packet (peer:22022 -> client, should rewrite source port to 22)
|
||||
outboundPacket := generateDNATTestPacket(t, peerIP, clientIP, layers.IPProtocolTCP, 22022, 54321)
|
||||
originalOutbound := make([]byte, len(outboundPacket))
|
||||
copy(originalOutbound, outboundPacket)
|
||||
|
||||
// Process outbound return packet
|
||||
reversed := manager.translateOutboundPortReverse(outboundPacket, parsePacket(t, outboundPacket))
|
||||
require.True(t, reversed, "Outbound return packet should be reverse translated")
|
||||
|
||||
// Verify source port was changed from 22022 to 22
|
||||
d = parsePacket(t, outboundPacket)
|
||||
require.Equal(t, uint16(22), uint16(d.tcp.SrcPort), "Source port should be rewritten to 22")
|
||||
|
||||
// Verify source IP remains unchanged
|
||||
srcIPAfter := netip.AddrFrom4([4]byte{outboundPacket[12], outboundPacket[13], outboundPacket[14], outboundPacket[15]})
|
||||
require.Equal(t, peerIP, srcIPAfter, "Source IP should remain unchanged")
|
||||
|
||||
// Test removal of SSH port redirection
|
||||
err = manager.RemoveInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
require.False(t, manager.portDNATEnabled.Load(), "Port DNAT should be disabled after removal")
|
||||
require.Len(t, manager.portDNATMap.rules, 0, "Should have no port DNAT rules after removal")
|
||||
}
|
||||
|
||||
// TestSSHPortRedirectionNetworkFiltering tests that SSH redirection only applies to specified networks
|
||||
func TestSSHPortRedirectionNetworkFiltering(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define NetBird network range
|
||||
peerInNetwork := netip.MustParseAddr("100.10.0.50")
|
||||
peerOutsideNetwork := netip.MustParseAddr("192.168.1.50")
|
||||
clientIP := netip.MustParseAddr("100.10.0.100")
|
||||
|
||||
// Add SSH port redirection rule for NetBird network only
|
||||
err = manager.AddInboundDNAT(peerInNetwork, firewall.ProtocolTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test SSH packet to peer within NetBird network (should be redirected)
|
||||
inNetworkPacket := generateDNATTestPacket(t, clientIP, peerInNetwork, layers.IPProtocolTCP, 54321, 22)
|
||||
translated := manager.translateInboundPortDNAT(inNetworkPacket, parsePacket(t, inNetworkPacket))
|
||||
require.True(t, translated, "SSH packet to NetBird peer should be translated")
|
||||
|
||||
// Verify port was changed
|
||||
d := parsePacket(t, inNetworkPacket)
|
||||
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be redirected for NetBird peer")
|
||||
|
||||
// Test SSH packet to peer outside NetBird network (should NOT be redirected)
|
||||
outOfNetworkPacket := generateDNATTestPacket(t, clientIP, peerOutsideNetwork, layers.IPProtocolTCP, 54321, 22)
|
||||
originalOutOfNetwork := make([]byte, len(outOfNetworkPacket))
|
||||
copy(originalOutOfNetwork, outOfNetworkPacket)
|
||||
|
||||
notTranslated := manager.translateInboundPortDNAT(outOfNetworkPacket, parsePacket(t, outOfNetworkPacket))
|
||||
require.False(t, notTranslated, "SSH packet to non-NetBird peer should NOT be translated")
|
||||
|
||||
// Verify packet was not modified
|
||||
require.Equal(t, originalOutOfNetwork, outOfNetworkPacket, "Packet to non-NetBird peer should remain unchanged")
|
||||
}
|
||||
|
||||
// TestSSHPortRedirectionNonTCPTraffic tests that only TCP traffic is affected
|
||||
func TestSSHPortRedirectionNonTCPTraffic(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define NetBird network range
|
||||
peerIP := netip.MustParseAddr("100.10.0.50")
|
||||
clientIP := netip.MustParseAddr("100.10.0.100")
|
||||
|
||||
// Add SSH port redirection rule
|
||||
err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test UDP packet on port 22 (should NOT be redirected)
|
||||
udpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolUDP, 54321, 22)
|
||||
originalUDP := make([]byte, len(udpPacket))
|
||||
copy(originalUDP, udpPacket)
|
||||
|
||||
translated := manager.translateInboundPortDNAT(udpPacket, parsePacket(t, udpPacket))
|
||||
require.False(t, translated, "UDP packet should NOT be translated by SSH port redirection")
|
||||
require.Equal(t, originalUDP, udpPacket, "UDP packet should remain unchanged")
|
||||
|
||||
// Test ICMP packet (should NOT be redirected)
|
||||
icmpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolICMPv4, 0, 0)
|
||||
originalICMP := make([]byte, len(icmpPacket))
|
||||
copy(originalICMP, icmpPacket)
|
||||
|
||||
translated = manager.translateInboundPortDNAT(icmpPacket, parsePacket(t, icmpPacket))
|
||||
require.False(t, translated, "ICMP packet should NOT be translated by SSH port redirection")
|
||||
require.Equal(t, originalICMP, icmpPacket, "ICMP packet should remain unchanged")
|
||||
}
|
||||
|
||||
// TestSSHPortRedirectionNonSSHPorts tests that only port 22 is redirected
|
||||
func TestSSHPortRedirectionNonSSHPorts(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define NetBird network range
|
||||
peerIP := netip.MustParseAddr("100.10.0.50")
|
||||
clientIP := netip.MustParseAddr("100.10.0.100")
|
||||
|
||||
// Add SSH port redirection rule
|
||||
err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test TCP packet on port 80 (should NOT be redirected)
|
||||
httpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 80)
|
||||
originalHTTP := make([]byte, len(httpPacket))
|
||||
copy(originalHTTP, httpPacket)
|
||||
|
||||
translated := manager.translateInboundPortDNAT(httpPacket, parsePacket(t, httpPacket))
|
||||
require.False(t, translated, "Non-SSH TCP packet should NOT be translated")
|
||||
require.Equal(t, originalHTTP, httpPacket, "Non-SSH TCP packet should remain unchanged")
|
||||
|
||||
// Test TCP packet on port 443 (should NOT be redirected)
|
||||
httpsPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 443)
|
||||
originalHTTPS := make([]byte, len(httpsPacket))
|
||||
copy(originalHTTPS, httpsPacket)
|
||||
|
||||
translated = manager.translateInboundPortDNAT(httpsPacket, parsePacket(t, httpsPacket))
|
||||
require.False(t, translated, "Non-SSH TCP packet should NOT be translated")
|
||||
require.Equal(t, originalHTTPS, httpsPacket, "Non-SSH TCP packet should remain unchanged")
|
||||
|
||||
// Test TCP packet on port 22 (SHOULD be redirected)
|
||||
sshPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22)
|
||||
translated = manager.translateInboundPortDNAT(sshPacket, parsePacket(t, sshPacket))
|
||||
require.True(t, translated, "SSH TCP packet should be translated")
|
||||
|
||||
// Verify port was changed to 22022
|
||||
d := parsePacket(t, sshPacket)
|
||||
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "SSH port should be redirected to 22022")
|
||||
}
|
||||
|
||||
// TestFlexiblePortRedirection tests the flexible port redirection functionality
|
||||
func TestFlexiblePortRedirection(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define peer and client IPs
|
||||
peerIP := netip.MustParseAddr("10.0.0.50")
|
||||
clientIP := netip.MustParseAddr("10.0.0.100")
|
||||
|
||||
// Add custom port redirection: TCP port 8080 -> 3000 for peer IP
|
||||
err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 8080, 3000)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify port DNAT is enabled
|
||||
require.True(t, manager.portDNATEnabled.Load(), "Port DNAT should be enabled")
|
||||
require.Len(t, manager.portDNATMap.rules, 1, "Should have one port DNAT rule")
|
||||
|
||||
// Verify the rule configuration
|
||||
rule := manager.portDNATMap.rules[0]
|
||||
require.Equal(t, gopacket.LayerType(layers.LayerTypeTCP), rule.protocol)
|
||||
require.Equal(t, uint16(8080), rule.sourcePort)
|
||||
require.Equal(t, uint16(3000), rule.targetPort)
|
||||
require.Equal(t, peerIP, rule.targetIP)
|
||||
|
||||
// Test inbound packet (client -> peer:8080, should redirect to peer:3000)
|
||||
inboundPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 8080)
|
||||
translated := manager.translateInboundPortDNAT(inboundPacket, parsePacket(t, inboundPacket))
|
||||
require.True(t, translated, "Inbound packet should be translated")
|
||||
|
||||
// Verify destination port was changed from 8080 to 3000
|
||||
d := parsePacket(t, inboundPacket)
|
||||
require.Equal(t, uint16(3000), uint16(d.tcp.DstPort), "Destination port should be rewritten to 3000")
|
||||
|
||||
// Test outbound return packet (peer:3000 -> client, should rewrite source port to 8080)
|
||||
outboundPacket := generateDNATTestPacket(t, peerIP, clientIP, layers.IPProtocolTCP, 3000, 54321)
|
||||
reversed := manager.translateOutboundPortReverse(outboundPacket, parsePacket(t, outboundPacket))
|
||||
require.True(t, reversed, "Outbound return packet should be reverse translated")
|
||||
|
||||
// Verify source port was changed from 3000 to 8080
|
||||
d = parsePacket(t, outboundPacket)
|
||||
require.Equal(t, uint16(8080), uint16(d.tcp.SrcPort), "Source port should be rewritten to 8080")
|
||||
|
||||
// Test removal of port redirection
|
||||
err = manager.removePortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 8080, 3000)
|
||||
require.NoError(t, err)
|
||||
require.False(t, manager.portDNATEnabled.Load(), "Port DNAT should be disabled after removal")
|
||||
require.Len(t, manager.portDNATMap.rules, 0, "Should have no port DNAT rules after removal")
|
||||
}
|
||||
|
||||
// TestMultiplePortRedirections tests multiple port redirection rules
|
||||
func TestMultiplePortRedirections(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define peer and client IPs
|
||||
peerIP := netip.MustParseAddr("172.16.0.50")
|
||||
clientIP := netip.MustParseAddr("172.16.0.100")
|
||||
|
||||
// Add multiple port redirections for peer IP
|
||||
err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 22, 22022) // SSH
|
||||
require.NoError(t, err)
|
||||
err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 80, 8080) // HTTP
|
||||
require.NoError(t, err)
|
||||
err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 443, 8443) // HTTPS
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all rules are present
|
||||
require.True(t, manager.portDNATEnabled.Load(), "Port DNAT should be enabled")
|
||||
require.Len(t, manager.portDNATMap.rules, 3, "Should have three port DNAT rules")
|
||||
|
||||
// Test SSH redirection (22 -> 22022)
|
||||
sshPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22)
|
||||
translated := manager.translateInboundPortDNAT(sshPacket, parsePacket(t, sshPacket))
|
||||
require.True(t, translated, "SSH packet should be translated")
|
||||
d := parsePacket(t, sshPacket)
|
||||
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "SSH should redirect to 22022")
|
||||
|
||||
// Test HTTP redirection (80 -> 8080)
|
||||
httpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 80)
|
||||
translated = manager.translateInboundPortDNAT(httpPacket, parsePacket(t, httpPacket))
|
||||
require.True(t, translated, "HTTP packet should be translated")
|
||||
d = parsePacket(t, httpPacket)
|
||||
require.Equal(t, uint16(8080), uint16(d.tcp.DstPort), "HTTP should redirect to 8080")
|
||||
|
||||
// Test HTTPS redirection (443 -> 8443)
|
||||
httpsPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 443)
|
||||
translated = manager.translateInboundPortDNAT(httpsPacket, parsePacket(t, httpsPacket))
|
||||
require.True(t, translated, "HTTPS packet should be translated")
|
||||
d = parsePacket(t, httpsPacket)
|
||||
require.Equal(t, uint16(8443), uint16(d.tcp.DstPort), "HTTPS should redirect to 8443")
|
||||
|
||||
// Test removing one rule (HTTP)
|
||||
err = manager.removePortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 80, 8080)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, manager.portDNATMap.rules, 2, "Should have two rules after removing HTTP rule")
|
||||
|
||||
// Verify HTTP is no longer redirected
|
||||
httpPacket2 := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 80)
|
||||
originalHTTP := make([]byte, len(httpPacket2))
|
||||
copy(originalHTTP, httpPacket2)
|
||||
translated = manager.translateInboundPortDNAT(httpPacket2, parsePacket(t, httpPacket2))
|
||||
require.False(t, translated, "HTTP packet should NOT be translated after rule removal")
|
||||
require.Equal(t, originalHTTP, httpPacket2, "HTTP packet should remain unchanged")
|
||||
|
||||
// Verify SSH and HTTPS still work
|
||||
sshPacket2 := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22)
|
||||
translated = manager.translateInboundPortDNAT(sshPacket2, parsePacket(t, sshPacket2))
|
||||
require.True(t, translated, "SSH should still be translated")
|
||||
|
||||
httpsPacket2 := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 443)
|
||||
translated = manager.translateInboundPortDNAT(httpsPacket2, parsePacket(t, httpsPacket2))
|
||||
require.True(t, translated, "HTTPS should still be translated")
|
||||
}
|
||||
|
||||
// TestSSHPortRedirectionEndToEnd tests actual network delivery through sockets
|
||||
func TestSSHPortRedirectionEndToEnd(t *testing.T) {
|
||||
// Start a mock SSH server on port 22022 (NetBird SSH server)
|
||||
mockSSHServer, err := net.Listen("tcp", "127.0.0.1:22022")
|
||||
require.NoError(t, err, "Should be able to bind to NetBird SSH port")
|
||||
defer func() {
|
||||
require.NoError(t, mockSSHServer.Close())
|
||||
}()
|
||||
|
||||
// Handle connections on the SSH server
|
||||
serverReceivedData := make(chan string, 1)
|
||||
go func() {
|
||||
for {
|
||||
conn, err := mockSSHServer.Accept()
|
||||
if err != nil {
|
||||
return // Server closed
|
||||
d = parsePacket(t, inboundPacket)
|
||||
var dstPort uint16
|
||||
switch tc.protocol {
|
||||
case layers.IPProtocolTCP:
|
||||
dstPort = uint16(d.tcp.DstPort)
|
||||
case layers.IPProtocolUDP:
|
||||
dstPort = uint16(d.udp.DstPort)
|
||||
}
|
||||
go func(conn net.Conn) {
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil && err != io.EOF {
|
||||
t.Logf("Server read error: %v", err)
|
||||
return
|
||||
}
|
||||
require.Equal(t, tc.targetPort, dstPort, "Destination port should be rewritten to target port")
|
||||
|
||||
receivedData := string(buf[:n])
|
||||
serverReceivedData <- receivedData
|
||||
|
||||
// Echo back a response
|
||||
_, err = conn.Write([]byte("SSH-2.0-MockNetBirdSSH\r\n"))
|
||||
if err != nil {
|
||||
t.Logf("Server write error: %v", err)
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
// Give server time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// This test demonstrates what SHOULD happen after port redirection:
|
||||
// 1. Client connects to 127.0.0.1:22 (standard SSH port)
|
||||
// 2. Firewall redirects to 127.0.0.1:22022 (NetBird SSH server)
|
||||
// 3. NetBird SSH server receives the connection
|
||||
|
||||
t.Run("DirectConnectionToNetBirdSSHPort", func(t *testing.T) {
|
||||
// This simulates what should happen AFTER port redirection
|
||||
// Connect directly to 22022 (where NetBird SSH server listens)
|
||||
conn, err := net.DialTimeout("tcp", "127.0.0.1:22022", 5*time.Second)
|
||||
require.NoError(t, err, "Should connect to NetBird SSH server")
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
// Send SSH client identification
|
||||
testData := "SSH-2.0-TestClient\r\n"
|
||||
_, err = conn.Write([]byte(testData))
|
||||
require.NoError(t, err, "Should send data to SSH server")
|
||||
|
||||
// Verify server received the data
|
||||
select {
|
||||
case received := <-serverReceivedData:
|
||||
require.Equal(t, testData, received, "Server should receive client data")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Server did not receive data within timeout")
|
||||
}
|
||||
|
||||
// Read server response
|
||||
buf := make([]byte, 1024)
|
||||
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Logf("failed to set read deadline: %v", err)
|
||||
}
|
||||
n, err := conn.Read(buf)
|
||||
require.NoError(t, err, "Should read server response")
|
||||
|
||||
response := string(buf[:n])
|
||||
require.Equal(t, "SSH-2.0-MockNetBirdSSH\r\n", response, "Should receive SSH server identification")
|
||||
})
|
||||
|
||||
t.Run("PortRedirectionSimulation", func(t *testing.T) {
|
||||
// This test simulates the port redirection process
|
||||
// Note: This doesn't test the actual userspace packet interception,
|
||||
// but demonstrates the expected behavior
|
||||
|
||||
t.Log("NOTE: This test demonstrates expected behavior after implementing")
|
||||
t.Log("full userspace packet interception. Currently, we test packet")
|
||||
t.Log("translation logic separately from actual network delivery.")
|
||||
|
||||
// In a real implementation with userspace packet interception:
|
||||
// 1. Client would connect to 127.0.0.1:22
|
||||
// 2. Userspace firewall would intercept packets
|
||||
// 3. translateInboundPortDNAT would rewrite port 22 -> 22022
|
||||
// 4. Packets would be delivered to 127.0.0.1:22022
|
||||
// 5. NetBird SSH server would receive the connection
|
||||
|
||||
// For now, we verify that the packet translation logic works correctly
|
||||
// (this is tested in other test functions) and that the target server
|
||||
// is reachable (tested above)
|
||||
|
||||
clientIP := netip.MustParseAddr("127.0.0.1")
|
||||
serverIP := netip.MustParseAddr("127.0.0.1")
|
||||
|
||||
// Create manager with SSH port redirection
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Add SSH port redirection for localhost (for testing)
|
||||
err = manager.AddInboundDNAT(netip.MustParseAddr("127.0.0.1"), firewall.ProtocolTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate packet: client connecting to server:22
|
||||
sshPacket := generateDNATTestPacket(t, clientIP, serverIP, layers.IPProtocolTCP, 54321, 22)
|
||||
originalPacket := make([]byte, len(sshPacket))
|
||||
copy(originalPacket, sshPacket)
|
||||
|
||||
// Apply port redirection
|
||||
translated := manager.translateInboundPortDNAT(sshPacket, parsePacket(t, sshPacket))
|
||||
require.True(t, translated, "SSH packet should be translated")
|
||||
|
||||
// Verify port was redirected from 22 to 22022
|
||||
d := parsePacket(t, sshPacket)
|
||||
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be redirected to NetBird SSH server")
|
||||
require.NotEqual(t, originalPacket, sshPacket, "Packet should be modified")
|
||||
|
||||
t.Log("✓ Packet translation verified: port 22 redirected to 22022")
|
||||
t.Log("✓ Target SSH server (port 22022) is reachable and responsive")
|
||||
t.Log("→ Integration complete: SSH port redirection ready for userspace interception")
|
||||
})
|
||||
err = manager.RemoveInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFullSSHRedirectionWorkflow demonstrates the complete SSH redirection workflow
|
||||
func TestFullSSHRedirectionWorkflow(t *testing.T) {
|
||||
t.Log("=== SSH Port Redirection Workflow Test ===")
|
||||
t.Log("This test demonstrates the complete SSH redirection process:")
|
||||
t.Log("1. Client connects to peer:22 (standard SSH)")
|
||||
t.Log("2. Userspace firewall intercepts and redirects to peer:22022")
|
||||
t.Log("3. NetBird SSH server receives connection on port 22022")
|
||||
t.Log("4. Return traffic is reverse-translated (22022 -> 22)")
|
||||
|
||||
// Setup test environment
|
||||
func TestInboundPortDNATNegative(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
@@ -622,47 +204,51 @@ func TestFullSSHRedirectionWorkflow(t *testing.T) {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define NetBird network and peer IPs
|
||||
peerIP := netip.MustParseAddr("100.10.0.50")
|
||||
clientIP := netip.MustParseAddr("100.10.0.100")
|
||||
localAddr := netip.MustParseAddr("100.0.2.175")
|
||||
clientIP := netip.MustParseAddr("100.0.169.249")
|
||||
|
||||
// Step 1: Configure SSH port redirection
|
||||
err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022)
|
||||
err = manager.AddInboundDNAT(localAddr, firewall.ProtocolTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
t.Log("✓ SSH port redirection configured for NetBird network")
|
||||
|
||||
// Step 2: Simulate inbound SSH connection (client -> peer:22)
|
||||
t.Log("→ Simulating: ssh user@100.10.0.50")
|
||||
inboundPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22)
|
||||
testCases := []struct {
|
||||
name string
|
||||
protocol layers.IPProtocol
|
||||
srcIP netip.Addr
|
||||
dstIP netip.Addr
|
||||
srcPort uint16
|
||||
dstPort uint16
|
||||
}{
|
||||
{"Wrong port", layers.IPProtocolTCP, clientIP, localAddr, 54321, 80},
|
||||
{"Wrong IP", layers.IPProtocolTCP, clientIP, netip.MustParseAddr("100.64.0.99"), 54321, 22},
|
||||
{"Wrong protocol", layers.IPProtocolUDP, clientIP, localAddr, 54321, 22},
|
||||
{"ICMP", layers.IPProtocolICMPv4, clientIP, localAddr, 0, 0},
|
||||
}
|
||||
|
||||
// Step 3: Apply inbound port redirection
|
||||
translated := manager.translateInboundPortDNAT(inboundPacket, parsePacket(t, inboundPacket))
|
||||
require.True(t, translated, "Inbound SSH packet should be redirected")
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
packet := generateDNATTestPacket(t, tc.srcIP, tc.dstIP, tc.protocol, tc.srcPort, tc.dstPort)
|
||||
d := parsePacket(t, packet)
|
||||
|
||||
d := parsePacket(t, inboundPacket)
|
||||
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Should redirect to NetBird SSH server port")
|
||||
t.Log("✓ Inbound packet redirected: 100.10.0.50:22 → 100.10.0.50:22022")
|
||||
translated := manager.translateInboundPortDNAT(packet, d, tc.srcIP, tc.dstIP)
|
||||
require.False(t, translated, "Packet should NOT be translated for %s", tc.name)
|
||||
|
||||
// Step 4: Simulate outbound return traffic (peer:22022 -> client)
|
||||
t.Log("→ Simulating return traffic from NetBird SSH server")
|
||||
outboundPacket := generateDNATTestPacket(t, peerIP, clientIP, layers.IPProtocolTCP, 22022, 54321)
|
||||
|
||||
// Step 5: Apply outbound reverse translation
|
||||
reversed := manager.translateOutboundPortReverse(outboundPacket, parsePacket(t, outboundPacket))
|
||||
require.True(t, reversed, "Outbound return packet should be reverse translated")
|
||||
|
||||
d = parsePacket(t, outboundPacket)
|
||||
require.Equal(t, uint16(22), uint16(d.tcp.SrcPort), "Should restore original SSH port")
|
||||
t.Log("✓ Outbound packet reverse translated: 100.10.0.50:22022 → 100.10.0.50:22")
|
||||
|
||||
// Step 6: Verify client sees standard SSH connection
|
||||
srcIPAfter := netip.AddrFrom4([4]byte{outboundPacket[12], outboundPacket[13], outboundPacket[14], outboundPacket[15]})
|
||||
require.Equal(t, peerIP, srcIPAfter, "Client should see traffic from peer IP")
|
||||
t.Log("✓ Client receives traffic from 100.10.0.50:22 (transparent redirection)")
|
||||
|
||||
t.Log("=== SSH Port Redirection Workflow Complete ===")
|
||||
t.Log("Result: Standard SSH clients can connect to NetBird peers using:")
|
||||
t.Log(" ssh user@100.10.0.50")
|
||||
t.Log("Instead of:")
|
||||
t.Log(" ssh user@100.10.0.50 -p 22022")
|
||||
d = parsePacket(t, packet)
|
||||
if tc.protocol == layers.IPProtocolTCP {
|
||||
require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged")
|
||||
} else if tc.protocol == layers.IPProtocolUDP {
|
||||
require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func protocolToFirewall(proto layers.IPProtocol) firewall.Protocol {
|
||||
switch proto {
|
||||
case layers.IPProtocolTCP:
|
||||
return firewall.ProtocolTCP
|
||||
case layers.IPProtocolUDP:
|
||||
return firewall.ProtocolUDP
|
||||
default:
|
||||
return firewall.ProtocolALL
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,25 +16,33 @@ type PacketStage int
|
||||
|
||||
const (
|
||||
StageReceived PacketStage = iota
|
||||
StageInboundPortDNAT
|
||||
StageInbound1to1NAT
|
||||
StageConntrack
|
||||
StagePeerACL
|
||||
StageRouting
|
||||
StageRouteACL
|
||||
StageForwarding
|
||||
StageCompleted
|
||||
StageOutbound1to1NAT
|
||||
StageOutboundPortReverse
|
||||
)
|
||||
|
||||
const msgProcessingCompleted = "Processing completed"
|
||||
|
||||
func (s PacketStage) String() string {
|
||||
return map[PacketStage]string{
|
||||
StageReceived: "Received",
|
||||
StageConntrack: "Connection Tracking",
|
||||
StagePeerACL: "Peer ACL",
|
||||
StageRouting: "Routing",
|
||||
StageRouteACL: "Route ACL",
|
||||
StageForwarding: "Forwarding",
|
||||
StageCompleted: "Completed",
|
||||
StageReceived: "Received",
|
||||
StageInboundPortDNAT: "Inbound Port DNAT",
|
||||
StageInbound1to1NAT: "Inbound 1:1 NAT",
|
||||
StageConntrack: "Connection Tracking",
|
||||
StagePeerACL: "Peer ACL",
|
||||
StageRouting: "Routing",
|
||||
StageRouteACL: "Route ACL",
|
||||
StageForwarding: "Forwarding",
|
||||
StageCompleted: "Completed",
|
||||
StageOutbound1to1NAT: "Outbound 1:1 NAT",
|
||||
StageOutboundPortReverse: "Outbound DNAT Reverse",
|
||||
}[s]
|
||||
}
|
||||
|
||||
@@ -261,6 +269,10 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
||||
}
|
||||
|
||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
|
||||
if m.handleInboundDNAT(trace, packetData, d, &srcIP, &dstIP) {
|
||||
return trace
|
||||
}
|
||||
|
||||
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
||||
return trace
|
||||
}
|
||||
@@ -400,7 +412,16 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
|
||||
}
|
||||
|
||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||
// will create or update the connection state
|
||||
d := m.decoders.Get().(*decoder)
|
||||
defer m.decoders.Put(d)
|
||||
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
|
||||
return trace
|
||||
}
|
||||
|
||||
m.handleOutboundDNAT(trace, packetData, d)
|
||||
|
||||
dropped := m.filterOutbound(packetData, 0)
|
||||
if dropped {
|
||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||
@@ -409,3 +430,199 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
|
||||
}
|
||||
return trace
|
||||
}
|
||||
|
||||
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
|
||||
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
|
||||
if portDNATApplied {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
|
||||
return true
|
||||
}
|
||||
*srcIP, *dstIP = m.extractIPs(d)
|
||||
trace.DestinationPort = m.getDestPort(d)
|
||||
}
|
||||
|
||||
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
|
||||
if nat1to1Applied {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
|
||||
return true
|
||||
}
|
||||
*srcIP, *dstIP = m.extractIPs(d)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) traceInboundPortDNAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
|
||||
if !m.portDNATEnabled.Load() {
|
||||
trace.AddResult(StageInboundPortDNAT, "Port DNAT not enabled", true)
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
trace.AddResult(StageInboundPortDNAT, "Not IPv4, skipping port DNAT", true)
|
||||
return false
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 {
|
||||
trace.AddResult(StageInboundPortDNAT, "No transport layer, skipping port DNAT", true)
|
||||
return false
|
||||
}
|
||||
|
||||
protocol := d.decoded[1]
|
||||
if protocol != layers.LayerTypeTCP && protocol != layers.LayerTypeUDP {
|
||||
trace.AddResult(StageInboundPortDNAT, "Not TCP/UDP, skipping port DNAT", true)
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
var originalPort uint16
|
||||
if protocol == layers.LayerTypeTCP {
|
||||
originalPort = uint16(d.tcp.DstPort)
|
||||
} else {
|
||||
originalPort = uint16(d.udp.DstPort)
|
||||
}
|
||||
|
||||
translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP)
|
||||
if translated {
|
||||
ipHeaderLen := int((packetData[0] & 0x0F) * 4)
|
||||
translatedPort := uint16(packetData[ipHeaderLen+2])<<8 | uint16(packetData[ipHeaderLen+3])
|
||||
|
||||
protoStr := "TCP"
|
||||
if protocol == layers.LayerTypeUDP {
|
||||
protoStr = "UDP"
|
||||
}
|
||||
msg := fmt.Sprintf("%s port DNAT applied: %s:%d -> %s:%d", protoStr, dstIP, originalPort, dstIP, translatedPort)
|
||||
trace.AddResult(StageInboundPortDNAT, msg, true)
|
||||
return true
|
||||
}
|
||||
|
||||
trace.AddResult(StageInboundPortDNAT, "No matching port DNAT rule", true)
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
trace.AddResult(StageInbound1to1NAT, "1:1 NAT not enabled", true)
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
|
||||
translated := m.translateInboundReverse(packetData, d)
|
||||
if translated {
|
||||
m.dnatMutex.RLock()
|
||||
translatedIP, exists := m.dnatBiMap.getOriginal(srcIP)
|
||||
m.dnatMutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
msg := fmt.Sprintf("1:1 NAT reverse applied: %s -> %s", srcIP, translatedIP)
|
||||
trace.AddResult(StageInbound1to1NAT, msg, true)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
trace.AddResult(StageInbound1to1NAT, "No matching 1:1 NAT rule", true)
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) handleOutboundDNAT(trace *PacketTrace, packetData []byte, d *decoder) {
|
||||
m.traceOutbound1to1NAT(trace, packetData, d)
|
||||
m.traceOutboundPortReverse(trace, packetData, d)
|
||||
}
|
||||
|
||||
func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
trace.AddResult(StageOutbound1to1NAT, "1:1 NAT not enabled", true)
|
||||
return false
|
||||
}
|
||||
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
|
||||
translated := m.translateOutboundDNAT(packetData, d)
|
||||
if translated {
|
||||
m.dnatMutex.RLock()
|
||||
translatedIP, exists := m.dnatMappings[dstIP]
|
||||
m.dnatMutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
msg := fmt.Sprintf("1:1 NAT applied: %s -> %s", dstIP, translatedIP)
|
||||
trace.AddResult(StageOutbound1to1NAT, msg, true)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
trace.AddResult(StageOutbound1to1NAT, "No matching 1:1 NAT rule", true)
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) traceOutboundPortReverse(trace *PacketTrace, packetData []byte, d *decoder) bool {
|
||||
if !m.portDNATEnabled.Load() {
|
||||
trace.AddResult(StageOutboundPortReverse, "Port DNAT not enabled", true)
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
trace.AddResult(StageOutboundPortReverse, "Not IPv4, skipping port reverse", true)
|
||||
return false
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 {
|
||||
trace.AddResult(StageOutboundPortReverse, "No transport layer, skipping port reverse", true)
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
|
||||
var origPort uint16
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
case layers.LayerTypeTCP:
|
||||
srcPort := uint16(d.tcp.SrcPort)
|
||||
dstPort := uint16(d.tcp.DstPort)
|
||||
conn, exists := m.tcpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
|
||||
if exists {
|
||||
origPort = uint16(conn.DNATOrigPort.Load())
|
||||
}
|
||||
if origPort != 0 {
|
||||
msg := fmt.Sprintf("TCP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
|
||||
trace.AddResult(StageOutboundPortReverse, msg, true)
|
||||
return true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
srcPort := uint16(d.udp.SrcPort)
|
||||
dstPort := uint16(d.udp.DstPort)
|
||||
conn, exists := m.udpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
|
||||
if exists {
|
||||
origPort = uint16(conn.DNATOrigPort.Load())
|
||||
}
|
||||
if origPort != 0 {
|
||||
msg := fmt.Sprintf("UDP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
|
||||
trace.AddResult(StageOutboundPortReverse, msg, true)
|
||||
return true
|
||||
}
|
||||
default:
|
||||
trace.AddResult(StageOutboundPortReverse, "Not TCP/UDP, skipping port reverse", true)
|
||||
return false
|
||||
}
|
||||
|
||||
trace.AddResult(StageOutboundPortReverse, "No tracked connection for DNAT reverse", true)
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) getDestPort(d *decoder) uint16 {
|
||||
if len(d.decoded) < 2 {
|
||||
return 0
|
||||
}
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
return uint16(d.tcp.DstPort)
|
||||
case layers.LayerTypeUDP:
|
||||
return uint16(d.udp.DstPort)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,6 +104,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -126,6 +128,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -153,6 +157,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -179,6 +185,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -204,6 +212,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
@@ -228,6 +238,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
@@ -246,6 +258,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageRouteACL,
|
||||
@@ -264,6 +278,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StageCompleted,
|
||||
@@ -287,6 +303,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageCompleted,
|
||||
},
|
||||
@@ -301,6 +319,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageOutbound1to1NAT,
|
||||
StageOutboundPortReverse,
|
||||
StageCompleted,
|
||||
},
|
||||
expectedAllow: true,
|
||||
@@ -319,6 +339,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -340,6 +362,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -362,6 +386,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -382,6 +408,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageConntrack,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
@@ -406,6 +434,8 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
expectedStages: []PacketStage{
|
||||
StageReceived,
|
||||
StageInboundPortDNAT,
|
||||
StageInbound1to1NAT,
|
||||
StageRouting,
|
||||
StagePeerACL,
|
||||
StageCompleted,
|
||||
|
||||
Reference in New Issue
Block a user