Merge branch 'main' into ssh-rewrite

This commit is contained in:
Viktor Liu
2025-10-28 16:50:23 +01:00
67 changed files with 1606 additions and 1208 deletions

View File

@@ -260,7 +260,7 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return m.router.UpdateSet(set, prefixes)
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -268,7 +268,7 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveInboundDNAT removes inbound DNAT rule
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()

View File

@@ -880,7 +880,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nberrors.FormatErrorOrNil(merr)
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
@@ -913,7 +913,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
return nil
}
// RemoveInboundDNAT removes inbound DNAT rule
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)

View File

@@ -376,7 +376,7 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return m.router.UpdateSet(set, prefixes)
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -384,7 +384,7 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveInboundDNAT removes inbound DNAT rule
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()

View File

@@ -1350,7 +1350,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nil
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
@@ -1426,7 +1426,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol
return nil
}
// RemoveInboundDNAT removes inbound DNAT rule
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)

View File

@@ -22,6 +22,8 @@ type BaseConnTrack struct {
PacketsRx atomic.Uint64
BytesTx atomic.Uint64
BytesRx atomic.Uint64
DNATOrigPort atomic.Uint32
}
// these small methods will be inlined by the compiler

View File

@@ -157,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker
}
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
@@ -171,28 +171,30 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
if exists {
t.updateState(key, conn, flags, direction, size)
return key, true
return key, uint16(conn.DNATOrigPort.Load()), true
}
return key, false
return key, 0, false
}
// TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
// TrackOutbound records an outbound TCP connection and returns the original port if DNAT reversal is needed
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) uint16 {
if _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); exists {
return origPort
}
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size, 0)
return 0
}
// TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int, dnatOrigPort uint16) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size, dnatOrigPort)
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists || flags&TCPSyn == 0 {
return
}
@@ -210,8 +212,13 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
conn.tombstone.Store(false)
conn.state.Store(int32(TCPStateNew))
conn.DNATOrigPort.Store(uint32(origPort))
t.logger.Trace2("New %s TCP connection: %s", direction, key)
if origPort != 0 {
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s TCP connection: %s", direction, key)
}
t.updateState(key, conn, flags, direction, size)
t.mutex.Lock()
@@ -449,6 +456,21 @@ func (t *TCPTracker) cleanup() {
}
}
// GetConnection safely retrieves a connection state
func (t *TCPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*TCPConnTrack, bool) {
t.mutex.RLock()
defer t.mutex.RUnlock()
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := t.connections[key]
return conn, exists
}
// Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() {
t.tickerCancel()

View File

@@ -603,7 +603,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
serverPort := uint16(80)
// 1. Client sends SYN (we receive it as inbound)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
key := ConnKey{
SrcIP: clientIP,
@@ -623,12 +623,12 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
// 3. Client sends ACK to complete handshake
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
// 4. Test data transfer
// Client sends data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000, 0)
// Server sends ACK for data
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
@@ -637,7 +637,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
// Client sends ACK for data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
// Verify state and counters
require.Equal(t, TCPStateEstablished, conn.GetState())

View File

@@ -58,20 +58,23 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker
}
// TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
// TrackOutbound records an outbound UDP connection and returns the original port if DNAT reversal is needed
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) uint16 {
_, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size)
if exists {
return origPort
}
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size, 0)
return 0
}
// TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int, dnatOrigPort uint16) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size, dnatOrigPort)
}
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
@@ -86,15 +89,15 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
return key, uint16(conn.DNATOrigPort.Load()), true
}
return key, false
return key, 0, false
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
if exists {
return
}
@@ -109,6 +112,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
SourcePort: srcPort,
DestPort: dstPort,
}
conn.DNATOrigPort.Store(uint32(origPort))
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
@@ -116,7 +120,11 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace2("New %s UDP connection: %s", direction, key)
if origPort != 0 {
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s UDP connection: %s", direction, key)
}
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}

View File

@@ -116,11 +116,9 @@ type Manager struct {
dnatMutex sync.RWMutex
dnatBiMap *biDNATMap
// Port-specific DNAT for SSH redirection
portDNATEnabled atomic.Bool
portDNATMap *portDNATMap
portDNATRules []portDNATRule
portDNATMutex sync.RWMutex
portNATTracker *portNATTracker
netstackServices map[serviceKey]struct{}
netstackServiceMutex sync.RWMutex
@@ -137,6 +135,8 @@ type decoder struct {
icmp6 layers.ICMPv6
decoded []gopacket.LayerType
parser *gopacket.DecodingLayerParser
dnatOrigPort uint16
}
// Create userspace firewall manager constructor
@@ -211,8 +211,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATMap: &portDNATMap{rules: make([]portDNATRule, 0)},
portNATTracker: newPortNATTracker(),
portDNATRules: []portDNATRule{},
netstackServices: make(map[serviceKey]struct{}),
}
m.routingEnabled.Store(false)
@@ -351,22 +350,18 @@ func (m *Manager) initForwarder() error {
return nil
}
// Init initializes the firewall manager with state manager.
func (m *Manager) Init(*statemanager.Manager) error {
return nil
}
// IsServerRouteSupported returns whether server routes are supported.
func (m *Manager) IsServerRouteSupported() bool {
return true
}
// IsStateful returns whether the firewall manager tracks connection state.
func (m *Manager) IsStateful() bool {
return m.stateful
}
// AddNatRule adds a routing firewall rule for NAT translation.
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddNatRule(pair)
@@ -652,9 +647,8 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return true
}
m.trackOutbound(d, srcIP, dstIP, size)
m.trackOutbound(d, srcIP, dstIP, packetData, size)
m.translateOutboundDNAT(packetData, d)
m.translateOutboundPortReverse(packetData, d)
return false
}
@@ -697,14 +691,26 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags
}
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) {
transport := d.decoded[1]
switch transport {
case layers.LayerTypeUDP:
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
origPort := m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
if origPort == 0 {
break
}
if err := m.rewriteUDPPort(packetData, d, origPort, sourcePortOffset); err != nil {
m.logger.Error1("failed to rewrite UDP port: %v", err)
}
case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
origPort := m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
if origPort == 0 {
break
}
if err := m.rewriteTCPPort(packetData, d, origPort, sourcePortOffset); err != nil {
m.logger.Error1("failed to rewrite TCP port: %v", err)
}
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
}
@@ -714,13 +720,15 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
transport := d.decoded[1]
switch transport {
case layers.LayerTypeUDP:
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size)
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size, d.dnatOrigPort)
case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
}
d.dnatOrigPort = 0
}
// udpHooksDrop checks if any UDP hooks should drop the packet
@@ -782,10 +790,11 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
return false
}
if translated := m.translateInboundPortDNAT(packetData, d); translated {
// TODO: optimize port DNAT by caching matched rules in conntrack
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
// Re-decode after port DNAT translation to update port information
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error1("Failed to re-decode packet after port DNAT: %v", err)
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
return true
}
srcIP, dstIP = m.extractIPs(d)
@@ -794,7 +803,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
if translated := m.translateInboundReverse(packetData, d); translated {
// Re-decode after translation to get original addresses
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err)
m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
return true
}
srcIP, dstIP = m.extractIPs(d)

View File

@@ -5,8 +5,7 @@ import (
"errors"
"fmt"
"net/netip"
"sync"
"time"
"slices"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
@@ -21,10 +20,16 @@ var (
)
const (
errRewriteTCPDestinationPort = "rewrite TCP destination port: %v"
// Port offsets in TCP/UDP headers
sourcePortOffset = 0
destinationPortOffset = 2
// IP address offsets in IPv4 header
sourceIPOffset = 12
destinationIPOffset = 16
)
// ipv4Checksum calculates IPv4 header checksum using optimized parallel processing for performance.
// ipv4Checksum calculates IPv4 header checksum.
func ipv4Checksum(header []byte) uint16 {
if len(header) < 20 {
return 0
@@ -64,7 +69,7 @@ func ipv4Checksum(header []byte) uint16 {
return ^uint16(sum)
}
// icmpChecksum calculates ICMP checksum using parallel accumulation for high-performance processing.
// icmpChecksum calculates ICMP checksum.
func icmpChecksum(data []byte) uint16 {
var sum1, sum2, sum3, sum4 uint32
i := 0
@@ -102,116 +107,21 @@ func icmpChecksum(data []byte) uint16 {
return ^uint16(sum)
}
// biDNATMap maintains bidirectional DNAT mappings for efficient forward and reverse lookups.
// biDNATMap maintains bidirectional DNAT mappings.
type biDNATMap struct {
forward map[netip.Addr]netip.Addr
reverse map[netip.Addr]netip.Addr
}
// portDNATRule represents a port-specific DNAT rule
// portDNATRule represents a port-specific DNAT rule.
type portDNATRule struct {
protocol gopacket.LayerType
sourcePort uint16
origPort uint16
targetPort uint16
targetIP netip.Addr
}
// portDNATMap manages port-specific DNAT rules
type portDNATMap struct {
rules []portDNATRule
}
// ConnKey represents a connection 4-tuple for NAT tracking.
type ConnKey struct {
SrcIP netip.Addr
DstIP netip.Addr
SrcPort uint16
DstPort uint16
}
// portNATConn tracks port NAT state for a specific connection.
type portNATConn struct {
rule portDNATRule
originalPort uint16
translatedAt time.Time
}
// portNATTracker tracks connection-specific port NAT state
type portNATTracker struct {
connections map[ConnKey]*portNATConn
mutex sync.RWMutex
}
// newPortNATTracker creates a new port NAT tracker for stateful connection tracking.
func newPortNATTracker() *portNATTracker {
return &portNATTracker{
connections: make(map[ConnKey]*portNATConn),
}
}
// trackConnection tracks a connection that has port NAT applied using translated port as key.
func (t *portNATTracker) trackConnection(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, rule portDNATRule) {
t.mutex.Lock()
defer t.mutex.Unlock()
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: rule.targetPort,
}
t.connections[key] = &portNATConn{
rule: rule,
originalPort: dstPort,
translatedAt: time.Now(),
}
}
// getConnectionNAT returns NAT info for a connection if tracked, looking up by connection 4-tuple.
func (t *portNATTracker) getConnectionNAT(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) (*portNATConn, bool) {
t.mutex.RLock()
defer t.mutex.RUnlock()
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := t.connections[key]
return conn, exists
}
// shouldApplyNAT checks if NAT should be applied to a new connection to prevent bidirectional conflicts.
func (t *portNATTracker) shouldApplyNAT(srcIP, dstIP netip.Addr, dstPort uint16) bool {
t.mutex.RLock()
defer t.mutex.RUnlock()
for key, conn := range t.connections {
if key.SrcIP == dstIP && key.DstIP == srcIP &&
conn.rule.sourcePort == dstPort && conn.originalPort == dstPort {
return false
}
}
return true
}
// cleanupConnection removes a NAT connection based on original 4-tuple for connection cleanup.
func (t *portNATTracker) cleanupConnection(srcIP, dstIP netip.Addr, srcPort uint16) {
t.mutex.Lock()
defer t.mutex.Unlock()
for key := range t.connections {
if key.SrcIP == srcIP && key.DstIP == dstIP && key.SrcPort == srcPort {
delete(t.connections, key)
return
}
}
}
// newBiDNATMap creates a new bidirectional DNAT mapping structure for efficient forward/reverse lookups.
// newBiDNATMap creates a new bidirectional DNAT mapping structure.
func newBiDNATMap() *biDNATMap {
return &biDNATMap{
forward: make(map[netip.Addr]netip.Addr),
@@ -219,7 +129,7 @@ func newBiDNATMap() *biDNATMap {
}
}
// set adds a bidirectional DNAT mapping between original and translated addresses for both directions.
// set adds a bidirectional DNAT mapping between original and translated addresses.
func (b *biDNATMap) set(original, translated netip.Addr) {
b.forward[original] = translated
b.reverse[translated] = original
@@ -233,13 +143,13 @@ func (b *biDNATMap) delete(original netip.Addr) {
}
}
// getTranslated returns the translated address for a given original address from forward mapping.
// getTranslated returns the translated address for a given original address.
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
translated, exists := b.forward[original]
return translated, exists
}
// getOriginal returns the original address for a given translated address from reverse mapping.
// getOriginal returns the original address for a given translated address.
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
original, exists := b.reverse[translated]
return original, exists
@@ -261,7 +171,6 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
// Initialize both maps together if either is nil
if m.dnatMappings == nil || m.dnatBiMap == nil {
m.dnatMappings = make(map[netip.Addr]netip.Addr)
m.dnatBiMap = newBiDNATMap()
@@ -295,7 +204,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
return nil
}
// getDNATTranslation returns the translated address if a mapping exists with fast-path optimization.
// getDNATTranslation returns the translated address if a mapping exists.
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return addr, false
@@ -307,7 +216,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
return translated, exists
}
// findReverseDNATMapping finds original address for return traffic using reverse mapping.
// findReverseDNATMapping finds original address for return traffic.
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return translatedAddr, false
@@ -319,16 +228,12 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr,
return original, exists
}
// translateOutboundDNAT applies DNAT translation to outbound packets for 1:1 IP mapping.
// translateOutboundDNAT applies DNAT translation to outbound packets.
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translatedIP, exists := m.getDNATTranslation(dstIP)
@@ -336,8 +241,8 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return false
}
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
m.logger.Error1("rewrite packet destination: %v", err)
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
m.logger.Error1("failed to rewrite packet destination: %v", err)
return false
}
@@ -345,16 +250,12 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return true
}
// translateInboundReverse applies reverse DNAT to inbound return traffic for 1:1 IP mapping.
// translateInboundReverse applies reverse DNAT to inbound return traffic.
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
originalIP, exists := m.findReverseDNATMapping(srcIP)
@@ -362,8 +263,8 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return false
}
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
m.logger.Error1("rewrite packet source: %v", err)
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
m.logger.Error1("failed to rewrite packet source: %v", err)
return false
}
@@ -371,17 +272,17 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return true
}
// rewritePacketDestination replaces destination IP in the packet and updates checksums.
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums.
func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error {
if !newIP.Is4() {
return ErrIPv4Only
}
var oldDst [4]byte
copy(oldDst[:], packetData[16:20])
newDst := newIP.As4()
var oldIP [4]byte
copy(oldIP[:], packetData[ipOffset:ipOffset+4])
newIPBytes := newIP.As4()
copy(packetData[16:20], newDst[:])
copy(packetData[ipOffset:ipOffset+4], newIPBytes[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
@@ -395,9 +296,9 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
@@ -406,42 +307,7 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP
return nil
}
// rewritePacketSource replaces the source IP address in the packet and updates checksums.
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
}
var oldSrc [4]byte
copy(oldSrc[:], packetData[12:16])
newSrc := newIP.As4()
copy(packetData[12:16], newSrc[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
// updateTCPChecksum updates TCP checksum after IP address change using incremental update per RFC 1624.
// updateTCPChecksum updates TCP checksum after IP address change per RFC 1624.
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+18 {
@@ -454,7 +320,7 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
// updateUDPChecksum updates UDP checksum after IP address change using incremental update per RFC 1624.
// updateUDPChecksum updates UDP checksum after IP address change per RFC 1624.
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
udpStart := ipHeaderLen
if len(packetData) < udpStart+8 {
@@ -472,7 +338,7 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
// updateICMPChecksum recalculates ICMP checksum after packet modification using full recalculation.
// updateICMPChecksum recalculates ICMP checksum after packet modification.
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
icmpStart := ipHeaderLen
if len(packetData) < icmpStart+8 {
@@ -485,7 +351,7 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
}
// incrementalUpdate performs incremental checksum update per RFC 1624 for performance.
// incrementalUpdate performs incremental checksum update per RFC 1624.
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
sum := uint32(^oldChecksum)
@@ -536,25 +402,25 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
return m.nativeFirewall.DeleteDNATRule(rule)
}
// addPortRedirection adds port redirection rule for specified target IP, protocol and ports.
// addPortRedirection adds a port redirection rule.
func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock()
rule := portDNATRule{
protocol: protocol,
sourcePort: sourcePort,
origPort: sourcePort,
targetPort: targetPort,
targetIP: targetIP,
}
m.portDNATMap.rules = append(m.portDNATMap.rules, rule)
m.portDNATRules = append(m.portDNATRules, rule)
m.portDNATEnabled.Store(true)
return nil
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services on specific ports.
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
@@ -569,27 +435,23 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// removePortRedirection removes port redirection rule for specified target IP, protocol and ports.
// removePortRedirection removes a port redirection rule.
func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock()
var filteredRules []portDNATRule
for _, rule := range m.portDNATMap.rules {
if !(rule.protocol == protocol && rule.sourcePort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0) {
filteredRules = append(filteredRules, rule)
}
}
m.portDNATMap.rules = filteredRules
m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool {
return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0
})
if len(m.portDNATMap.rules) == 0 {
if len(m.portDNATRules) == 0 {
m.portDNATEnabled.Store(false)
}
return nil
}
// RemoveInboundDNAT removes inbound DNAT rule for specified local address and ports.
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
@@ -604,146 +466,55 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// translateInboundPortDNAT applies stateful port-specific DNAT translation to inbound packets.
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder) bool {
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
if !m.portDNATEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
dstPort := uint16(d.tcp.DstPort)
return m.applyPortRule(packetData, d, srcIP, dstIP, dstPort, layers.LayerTypeTCP, m.rewriteTCPPort)
case layers.LayerTypeUDP:
dstPort := uint16(d.udp.DstPort)
return m.applyPortRule(packetData, d, netip.Addr{}, dstIP, dstPort, layers.LayerTypeUDP, m.rewriteUDPPort)
default:
return false
}
if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP {
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
srcPort := uint16(d.tcp.SrcPort)
dstPort := uint16(d.tcp.DstPort)
if m.handleReturnTraffic(packetData, d, srcIP, dstIP, srcPort, dstPort) {
return true
}
return m.handleNewConnection(packetData, d, srcIP, dstIP, srcPort, dstPort)
}
// handleReturnTraffic processes return traffic for existing NAT connections.
func (m *Manager) handleReturnTraffic(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool {
if m.isTranslatedPortTraffic(srcIP, srcPort) {
return false
}
type portRewriteFunc func(packetData []byte, d *decoder, newPort uint16, portOffset int) error
if handled := m.handleExistingNATConnection(packetData, d, srcIP, dstIP, srcPort, dstPort); handled {
return true
}
return m.handleForwardTrafficInExistingConnections(packetData, d, srcIP, dstIP, srcPort, dstPort)
}
// isTranslatedPortTraffic checks if traffic is from a translated port that should be handled by outbound reverse.
func (m *Manager) isTranslatedPortTraffic(srcIP netip.Addr, srcPort uint16) bool {
func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, port uint16, protocol gopacket.LayerType, rewriteFn portRewriteFunc) bool {
m.portDNATMutex.RLock()
defer m.portDNATMutex.RUnlock()
for _, rule := range m.portDNATMap.rules {
if rule.protocol == layers.LayerTypeTCP && rule.targetPort == srcPort &&
rule.targetIP.Unmap().Compare(srcIP.Unmap()) == 0 {
return true
for _, rule := range m.portDNATRules {
if rule.protocol != protocol || rule.targetIP.Compare(dstIP) != 0 {
continue
}
}
return false
}
// handleExistingNATConnection processes return traffic for existing NAT connections.
func (m *Manager) handleExistingNATConnection(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool {
if natConn, exists := m.portNATTracker.getConnectionNAT(dstIP, srcIP, dstPort, srcPort); exists {
if err := m.rewriteTCPDestinationPort(packetData, d, natConn.originalPort); err != nil {
m.logger.Error1(errRewriteTCPDestinationPort, err)
if rule.targetPort == port && rule.targetIP.Compare(srcIP) == 0 {
return false
}
m.logger.Trace4("Inbound Port DNAT (return): %s:%d -> %s:%d", dstIP, srcPort, dstIP, natConn.originalPort)
return true
}
return false
}
// handleForwardTrafficInExistingConnections processes forward traffic in existing connections.
func (m *Manager) handleForwardTrafficInExistingConnections(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool {
m.portDNATMutex.RLock()
defer m.portDNATMutex.RUnlock()
for _, rule := range m.portDNATMap.rules {
if rule.protocol != layers.LayerTypeTCP || rule.sourcePort != dstPort {
continue
}
if rule.targetIP.Unmap().Compare(dstIP.Unmap()) != 0 {
if rule.origPort != port {
continue
}
if _, exists := m.portNATTracker.getConnectionNAT(srcIP, dstIP, srcPort, rule.targetPort); !exists {
continue
}
if err := m.rewriteTCPDestinationPort(packetData, d, rule.targetPort); err != nil {
m.logger.Error1(errRewriteTCPDestinationPort, err)
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
m.logger.Error1("failed to rewrite port: %v", err)
return false
}
d.dnatOrigPort = rule.origPort
return true
}
return false
}
// handleNewConnection processes new connections that match port DNAT rules.
func (m *Manager) handleNewConnection(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool {
m.portDNATMutex.RLock()
defer m.portDNATMutex.RUnlock()
for _, rule := range m.portDNATMap.rules {
if m.applyPortDNATRule(packetData, d, rule, srcIP, dstIP, srcPort, dstPort) {
return true
}
}
return false
}
// applyPortDNATRule applies a specific port DNAT rule if conditions are met.
func (m *Manager) applyPortDNATRule(packetData []byte, d *decoder, rule portDNATRule, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool {
if rule.protocol != layers.LayerTypeTCP || rule.sourcePort != dstPort {
return false
}
if rule.targetIP.Unmap().Compare(dstIP.Unmap()) != 0 {
return false
}
if !m.portNATTracker.shouldApplyNAT(srcIP, dstIP, dstPort) {
return false
}
if err := m.rewriteTCPDestinationPort(packetData, d, rule.targetPort); err != nil {
m.logger.Error1(errRewriteTCPDestinationPort, err)
return false
}
m.portNATTracker.trackConnection(srcIP, dstIP, srcPort, dstPort, rule)
m.logger.Trace8("Inbound Port DNAT (new): %s:%d -> %s:%d (tracked: %s:%d -> %s:%d)", dstIP, rule.sourcePort, dstIP, rule.targetPort, srcIP, srcPort, dstIP, rule.targetPort)
return true
}
// rewriteTCPDestinationPort rewrites the destination port in a TCP packet and updates checksum.
func (m *Manager) rewriteTCPDestinationPort(packetData []byte, d *decoder, newPort uint16) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return ErrIPv4Only
}
if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP {
return fmt.Errorf("not a TCP packet")
}
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
@@ -754,9 +525,9 @@ func (m *Manager) rewriteTCPDestinationPort(packetData []byte, d *decoder, newPo
return fmt.Errorf("packet too short for TCP header")
}
oldPort := binary.BigEndian.Uint16(packetData[tcpStart+2 : tcpStart+4])
binary.BigEndian.PutUint16(packetData[tcpStart+2:tcpStart+4], newPort)
portStart := tcpStart + portOffset
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
if len(packetData) >= tcpStart+18 {
checksumOffset := tcpStart + 16
@@ -773,75 +544,34 @@ func (m *Manager) rewriteTCPDestinationPort(packetData []byte, d *decoder, newPo
return nil
}
// rewriteTCPSourcePort rewrites the source port in a TCP packet and updates checksum.
func (m *Manager) rewriteTCPSourcePort(packetData []byte, d *decoder, newPort uint16) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return ErrIPv4Only
}
if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP {
return fmt.Errorf("not a TCP packet")
}
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+4 {
return fmt.Errorf("packet too short for TCP header")
udpStart := ipHeaderLen
if len(packetData) < udpStart+8 {
return fmt.Errorf("packet too short for UDP header")
}
oldPort := binary.BigEndian.Uint16(packetData[tcpStart : tcpStart+2])
portStart := udpStart + portOffset
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
binary.BigEndian.PutUint16(packetData[tcpStart:tcpStart+2], newPort)
if len(packetData) >= tcpStart+18 {
checksumOffset := tcpStart + 16
checksumOffset := udpStart + 6
if len(packetData) >= udpStart+8 {
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
if oldChecksum != 0 {
var oldPortBytes, newPortBytes [2]byte
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
var oldPortBytes, newPortBytes [2]byte
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
}
return nil
}
// translateOutboundPortReverse applies stateful reverse port DNAT to outbound return traffic for SSH redirection.
func (m *Manager) translateOutboundPortReverse(packetData []byte, d *decoder) bool {
if !m.portDNATEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP {
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
srcPort := uint16(d.tcp.SrcPort)
dstPort := uint16(d.tcp.DstPort)
// For outbound reverse, we need to find the connection using the same key as when it was stored
// Connection was stored as: srcIP, dstIP, srcPort, translatedPort
// So for return traffic (srcIP=server, dstIP=client), we need: dstIP, srcIP, dstPort, srcPort
if natConn, exists := m.portNATTracker.getConnectionNAT(dstIP, srcIP, dstPort, srcPort); exists {
if err := m.rewriteTCPSourcePort(packetData, d, natConn.rule.sourcePort); err != nil {
m.logger.Error1("rewrite TCP source port: %v", err)
return false
}
return true
}
return false
}

View File

@@ -414,3 +414,127 @@ func BenchmarkChecksumOptimizations(b *testing.B) {
}
})
}
// BenchmarkPortDNAT measures the performance of port DNAT operations
func BenchmarkPortDNAT(b *testing.B) {
scenarios := []struct {
name string
proto layers.IPProtocol
setupDNAT bool
useMatchPort bool
description string
}{
{
name: "tcp_inbound_dnat_match",
proto: layers.IPProtocolTCP,
setupDNAT: true,
useMatchPort: true,
description: "TCP inbound port DNAT translation (22 → 22022)",
},
{
name: "tcp_inbound_dnat_nomatch",
proto: layers.IPProtocolTCP,
setupDNAT: true,
useMatchPort: false,
description: "TCP inbound with DNAT configured but no port match",
},
{
name: "tcp_inbound_no_dnat",
proto: layers.IPProtocolTCP,
setupDNAT: false,
useMatchPort: false,
description: "TCP inbound without DNAT (baseline)",
},
{
name: "udp_inbound_dnat_match",
proto: layers.IPProtocolUDP,
setupDNAT: true,
useMatchPort: true,
description: "UDP inbound port DNAT translation (5353 → 22054)",
},
{
name: "udp_inbound_dnat_nomatch",
proto: layers.IPProtocolUDP,
setupDNAT: true,
useMatchPort: false,
description: "UDP inbound with DNAT configured but no port match",
},
{
name: "udp_inbound_no_dnat",
proto: layers.IPProtocolUDP,
setupDNAT: false,
useMatchPort: false,
description: "UDP inbound without DNAT (baseline)",
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
var origPort, targetPort, testPort uint16
if sc.proto == layers.IPProtocolTCP {
origPort, targetPort = 22, 22022
} else {
origPort, targetPort = 5353, 22054
}
if sc.useMatchPort {
testPort = origPort
} else {
testPort = 443 // Different port
}
// Setup port DNAT mapping if needed
if sc.setupDNAT {
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(sc.proto), origPort, targetPort)
require.NoError(b, err)
}
// Pre-establish inbound connection for outbound reverse test
if sc.setupDNAT && sc.useMatchPort {
inboundPacket := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, origPort)
manager.filterInbound(inboundPacket, 0)
}
b.ResetTimer()
b.ReportAllocs()
// Benchmark inbound DNAT translation
b.Run("inbound", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh packet each time
packet := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, testPort)
manager.filterInbound(packet, 0)
}
})
// Benchmark outbound reverse DNAT translation (only if DNAT is set up and port matches)
if sc.setupDNAT && sc.useMatchPort {
b.Run("outbound_reverse", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh return packet (from target port)
packet := generateDNATTestPacket(b, localAddr, clientIP, sc.proto, targetPort, 54321)
manager.filterOutbound(packet, 0)
}
})
}
})
}
}

View File

@@ -1,11 +1,8 @@
package uspfilter
import (
"io"
"net"
"net/netip"
"testing"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
@@ -148,8 +145,7 @@ func TestDNATMappingManagement(t *testing.T) {
require.Error(t, err, "Should error when removing non-existent mapping")
}
// TestSSHPortRedirection tests SSH port redirection functionality
func TestSSHPortRedirection(t *testing.T) {
func TestInboundPortDNAT(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
@@ -158,462 +154,48 @@ func TestSSHPortRedirection(t *testing.T) {
require.NoError(t, manager.Close(nil))
}()
// Define NetBird network range
peerIP := netip.MustParseAddr("100.10.0.50")
clientIP := netip.MustParseAddr("100.10.0.100")
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
// Add SSH port redirection rule
err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022)
require.NoError(t, err)
testCases := []struct {
name string
protocol layers.IPProtocol
sourcePort uint16
targetPort uint16
}{
{"TCP SSH", layers.IPProtocolTCP, 22, 22022},
{"UDP DNS", layers.IPProtocolUDP, 5353, 22054},
}
// Verify port DNAT is enabled
require.True(t, manager.portDNATEnabled.Load(), "Port DNAT should be enabled")
require.Len(t, manager.portDNATMap.rules, 1, "Should have one port DNAT rule")
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
require.NoError(t, err)
// Verify the rule configuration
rule := manager.portDNATMap.rules[0]
require.Equal(t, gopacket.LayerType(layers.LayerTypeTCP), rule.protocol)
require.Equal(t, uint16(22), rule.sourcePort)
require.Equal(t, uint16(22022), rule.targetPort)
require.Equal(t, peerIP, rule.targetIP)
inboundPacket := generateDNATTestPacket(t, clientIP, localAddr, tc.protocol, 54321, tc.sourcePort)
d := parsePacket(t, inboundPacket)
// Test inbound SSH packet (client -> peer:22, should redirect to peer:22022)
inboundPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22)
originalInbound := make([]byte, len(inboundPacket))
copy(originalInbound, inboundPacket)
translated := manager.translateInboundPortDNAT(inboundPacket, d, clientIP, localAddr)
require.True(t, translated, "Inbound packet should be translated")
// Process inbound packet
translated := manager.translateInboundPortDNAT(inboundPacket, parsePacket(t, inboundPacket))
require.True(t, translated, "Inbound SSH packet should be translated")
// Verify destination port was changed from 22 to 22022
d := parsePacket(t, inboundPacket)
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Destination port should be rewritten to 22022")
// Verify destination IP remains unchanged
dstIPAfter := netip.AddrFrom4([4]byte{inboundPacket[16], inboundPacket[17], inboundPacket[18], inboundPacket[19]})
require.Equal(t, peerIP, dstIPAfter, "Destination IP should remain unchanged")
// Test outbound return packet (peer:22022 -> client, should rewrite source port to 22)
outboundPacket := generateDNATTestPacket(t, peerIP, clientIP, layers.IPProtocolTCP, 22022, 54321)
originalOutbound := make([]byte, len(outboundPacket))
copy(originalOutbound, outboundPacket)
// Process outbound return packet
reversed := manager.translateOutboundPortReverse(outboundPacket, parsePacket(t, outboundPacket))
require.True(t, reversed, "Outbound return packet should be reverse translated")
// Verify source port was changed from 22022 to 22
d = parsePacket(t, outboundPacket)
require.Equal(t, uint16(22), uint16(d.tcp.SrcPort), "Source port should be rewritten to 22")
// Verify source IP remains unchanged
srcIPAfter := netip.AddrFrom4([4]byte{outboundPacket[12], outboundPacket[13], outboundPacket[14], outboundPacket[15]})
require.Equal(t, peerIP, srcIPAfter, "Source IP should remain unchanged")
// Test removal of SSH port redirection
err = manager.RemoveInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022)
require.NoError(t, err)
require.False(t, manager.portDNATEnabled.Load(), "Port DNAT should be disabled after removal")
require.Len(t, manager.portDNATMap.rules, 0, "Should have no port DNAT rules after removal")
}
// TestSSHPortRedirectionNetworkFiltering tests that SSH redirection only applies to specified networks
func TestSSHPortRedirectionNetworkFiltering(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Define NetBird network range
peerInNetwork := netip.MustParseAddr("100.10.0.50")
peerOutsideNetwork := netip.MustParseAddr("192.168.1.50")
clientIP := netip.MustParseAddr("100.10.0.100")
// Add SSH port redirection rule for NetBird network only
err = manager.AddInboundDNAT(peerInNetwork, firewall.ProtocolTCP, 22, 22022)
require.NoError(t, err)
// Test SSH packet to peer within NetBird network (should be redirected)
inNetworkPacket := generateDNATTestPacket(t, clientIP, peerInNetwork, layers.IPProtocolTCP, 54321, 22)
translated := manager.translateInboundPortDNAT(inNetworkPacket, parsePacket(t, inNetworkPacket))
require.True(t, translated, "SSH packet to NetBird peer should be translated")
// Verify port was changed
d := parsePacket(t, inNetworkPacket)
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be redirected for NetBird peer")
// Test SSH packet to peer outside NetBird network (should NOT be redirected)
outOfNetworkPacket := generateDNATTestPacket(t, clientIP, peerOutsideNetwork, layers.IPProtocolTCP, 54321, 22)
originalOutOfNetwork := make([]byte, len(outOfNetworkPacket))
copy(originalOutOfNetwork, outOfNetworkPacket)
notTranslated := manager.translateInboundPortDNAT(outOfNetworkPacket, parsePacket(t, outOfNetworkPacket))
require.False(t, notTranslated, "SSH packet to non-NetBird peer should NOT be translated")
// Verify packet was not modified
require.Equal(t, originalOutOfNetwork, outOfNetworkPacket, "Packet to non-NetBird peer should remain unchanged")
}
// TestSSHPortRedirectionNonTCPTraffic tests that only TCP traffic is affected
func TestSSHPortRedirectionNonTCPTraffic(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Define NetBird network range
peerIP := netip.MustParseAddr("100.10.0.50")
clientIP := netip.MustParseAddr("100.10.0.100")
// Add SSH port redirection rule
err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022)
require.NoError(t, err)
// Test UDP packet on port 22 (should NOT be redirected)
udpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolUDP, 54321, 22)
originalUDP := make([]byte, len(udpPacket))
copy(originalUDP, udpPacket)
translated := manager.translateInboundPortDNAT(udpPacket, parsePacket(t, udpPacket))
require.False(t, translated, "UDP packet should NOT be translated by SSH port redirection")
require.Equal(t, originalUDP, udpPacket, "UDP packet should remain unchanged")
// Test ICMP packet (should NOT be redirected)
icmpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolICMPv4, 0, 0)
originalICMP := make([]byte, len(icmpPacket))
copy(originalICMP, icmpPacket)
translated = manager.translateInboundPortDNAT(icmpPacket, parsePacket(t, icmpPacket))
require.False(t, translated, "ICMP packet should NOT be translated by SSH port redirection")
require.Equal(t, originalICMP, icmpPacket, "ICMP packet should remain unchanged")
}
// TestSSHPortRedirectionNonSSHPorts tests that only port 22 is redirected
func TestSSHPortRedirectionNonSSHPorts(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Define NetBird network range
peerIP := netip.MustParseAddr("100.10.0.50")
clientIP := netip.MustParseAddr("100.10.0.100")
// Add SSH port redirection rule
err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022)
require.NoError(t, err)
// Test TCP packet on port 80 (should NOT be redirected)
httpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 80)
originalHTTP := make([]byte, len(httpPacket))
copy(originalHTTP, httpPacket)
translated := manager.translateInboundPortDNAT(httpPacket, parsePacket(t, httpPacket))
require.False(t, translated, "Non-SSH TCP packet should NOT be translated")
require.Equal(t, originalHTTP, httpPacket, "Non-SSH TCP packet should remain unchanged")
// Test TCP packet on port 443 (should NOT be redirected)
httpsPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 443)
originalHTTPS := make([]byte, len(httpsPacket))
copy(originalHTTPS, httpsPacket)
translated = manager.translateInboundPortDNAT(httpsPacket, parsePacket(t, httpsPacket))
require.False(t, translated, "Non-SSH TCP packet should NOT be translated")
require.Equal(t, originalHTTPS, httpsPacket, "Non-SSH TCP packet should remain unchanged")
// Test TCP packet on port 22 (SHOULD be redirected)
sshPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22)
translated = manager.translateInboundPortDNAT(sshPacket, parsePacket(t, sshPacket))
require.True(t, translated, "SSH TCP packet should be translated")
// Verify port was changed to 22022
d := parsePacket(t, sshPacket)
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "SSH port should be redirected to 22022")
}
// TestFlexiblePortRedirection tests the flexible port redirection functionality
func TestFlexiblePortRedirection(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Define peer and client IPs
peerIP := netip.MustParseAddr("10.0.0.50")
clientIP := netip.MustParseAddr("10.0.0.100")
// Add custom port redirection: TCP port 8080 -> 3000 for peer IP
err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 8080, 3000)
require.NoError(t, err)
// Verify port DNAT is enabled
require.True(t, manager.portDNATEnabled.Load(), "Port DNAT should be enabled")
require.Len(t, manager.portDNATMap.rules, 1, "Should have one port DNAT rule")
// Verify the rule configuration
rule := manager.portDNATMap.rules[0]
require.Equal(t, gopacket.LayerType(layers.LayerTypeTCP), rule.protocol)
require.Equal(t, uint16(8080), rule.sourcePort)
require.Equal(t, uint16(3000), rule.targetPort)
require.Equal(t, peerIP, rule.targetIP)
// Test inbound packet (client -> peer:8080, should redirect to peer:3000)
inboundPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 8080)
translated := manager.translateInboundPortDNAT(inboundPacket, parsePacket(t, inboundPacket))
require.True(t, translated, "Inbound packet should be translated")
// Verify destination port was changed from 8080 to 3000
d := parsePacket(t, inboundPacket)
require.Equal(t, uint16(3000), uint16(d.tcp.DstPort), "Destination port should be rewritten to 3000")
// Test outbound return packet (peer:3000 -> client, should rewrite source port to 8080)
outboundPacket := generateDNATTestPacket(t, peerIP, clientIP, layers.IPProtocolTCP, 3000, 54321)
reversed := manager.translateOutboundPortReverse(outboundPacket, parsePacket(t, outboundPacket))
require.True(t, reversed, "Outbound return packet should be reverse translated")
// Verify source port was changed from 3000 to 8080
d = parsePacket(t, outboundPacket)
require.Equal(t, uint16(8080), uint16(d.tcp.SrcPort), "Source port should be rewritten to 8080")
// Test removal of port redirection
err = manager.removePortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 8080, 3000)
require.NoError(t, err)
require.False(t, manager.portDNATEnabled.Load(), "Port DNAT should be disabled after removal")
require.Len(t, manager.portDNATMap.rules, 0, "Should have no port DNAT rules after removal")
}
// TestMultiplePortRedirections tests multiple port redirection rules
func TestMultiplePortRedirections(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Define peer and client IPs
peerIP := netip.MustParseAddr("172.16.0.50")
clientIP := netip.MustParseAddr("172.16.0.100")
// Add multiple port redirections for peer IP
err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 22, 22022) // SSH
require.NoError(t, err)
err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 80, 8080) // HTTP
require.NoError(t, err)
err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 443, 8443) // HTTPS
require.NoError(t, err)
// Verify all rules are present
require.True(t, manager.portDNATEnabled.Load(), "Port DNAT should be enabled")
require.Len(t, manager.portDNATMap.rules, 3, "Should have three port DNAT rules")
// Test SSH redirection (22 -> 22022)
sshPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22)
translated := manager.translateInboundPortDNAT(sshPacket, parsePacket(t, sshPacket))
require.True(t, translated, "SSH packet should be translated")
d := parsePacket(t, sshPacket)
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "SSH should redirect to 22022")
// Test HTTP redirection (80 -> 8080)
httpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 80)
translated = manager.translateInboundPortDNAT(httpPacket, parsePacket(t, httpPacket))
require.True(t, translated, "HTTP packet should be translated")
d = parsePacket(t, httpPacket)
require.Equal(t, uint16(8080), uint16(d.tcp.DstPort), "HTTP should redirect to 8080")
// Test HTTPS redirection (443 -> 8443)
httpsPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 443)
translated = manager.translateInboundPortDNAT(httpsPacket, parsePacket(t, httpsPacket))
require.True(t, translated, "HTTPS packet should be translated")
d = parsePacket(t, httpsPacket)
require.Equal(t, uint16(8443), uint16(d.tcp.DstPort), "HTTPS should redirect to 8443")
// Test removing one rule (HTTP)
err = manager.removePortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 80, 8080)
require.NoError(t, err)
require.Len(t, manager.portDNATMap.rules, 2, "Should have two rules after removing HTTP rule")
// Verify HTTP is no longer redirected
httpPacket2 := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 80)
originalHTTP := make([]byte, len(httpPacket2))
copy(originalHTTP, httpPacket2)
translated = manager.translateInboundPortDNAT(httpPacket2, parsePacket(t, httpPacket2))
require.False(t, translated, "HTTP packet should NOT be translated after rule removal")
require.Equal(t, originalHTTP, httpPacket2, "HTTP packet should remain unchanged")
// Verify SSH and HTTPS still work
sshPacket2 := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22)
translated = manager.translateInboundPortDNAT(sshPacket2, parsePacket(t, sshPacket2))
require.True(t, translated, "SSH should still be translated")
httpsPacket2 := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 443)
translated = manager.translateInboundPortDNAT(httpsPacket2, parsePacket(t, httpsPacket2))
require.True(t, translated, "HTTPS should still be translated")
}
// TestSSHPortRedirectionEndToEnd tests actual network delivery through sockets
func TestSSHPortRedirectionEndToEnd(t *testing.T) {
// Start a mock SSH server on port 22022 (NetBird SSH server)
mockSSHServer, err := net.Listen("tcp", "127.0.0.1:22022")
require.NoError(t, err, "Should be able to bind to NetBird SSH port")
defer func() {
require.NoError(t, mockSSHServer.Close())
}()
// Handle connections on the SSH server
serverReceivedData := make(chan string, 1)
go func() {
for {
conn, err := mockSSHServer.Accept()
if err != nil {
return // Server closed
d = parsePacket(t, inboundPacket)
var dstPort uint16
switch tc.protocol {
case layers.IPProtocolTCP:
dstPort = uint16(d.tcp.DstPort)
case layers.IPProtocolUDP:
dstPort = uint16(d.udp.DstPort)
}
go func(conn net.Conn) {
defer func() {
require.NoError(t, conn.Close())
}()
buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil && err != io.EOF {
t.Logf("Server read error: %v", err)
return
}
require.Equal(t, tc.targetPort, dstPort, "Destination port should be rewritten to target port")
receivedData := string(buf[:n])
serverReceivedData <- receivedData
// Echo back a response
_, err = conn.Write([]byte("SSH-2.0-MockNetBirdSSH\r\n"))
if err != nil {
t.Logf("Server write error: %v", err)
}
}(conn)
}
}()
// Give server time to start
time.Sleep(100 * time.Millisecond)
// This test demonstrates what SHOULD happen after port redirection:
// 1. Client connects to 127.0.0.1:22 (standard SSH port)
// 2. Firewall redirects to 127.0.0.1:22022 (NetBird SSH server)
// 3. NetBird SSH server receives the connection
t.Run("DirectConnectionToNetBirdSSHPort", func(t *testing.T) {
// This simulates what should happen AFTER port redirection
// Connect directly to 22022 (where NetBird SSH server listens)
conn, err := net.DialTimeout("tcp", "127.0.0.1:22022", 5*time.Second)
require.NoError(t, err, "Should connect to NetBird SSH server")
defer func() {
require.NoError(t, conn.Close())
}()
// Send SSH client identification
testData := "SSH-2.0-TestClient\r\n"
_, err = conn.Write([]byte(testData))
require.NoError(t, err, "Should send data to SSH server")
// Verify server received the data
select {
case received := <-serverReceivedData:
require.Equal(t, testData, received, "Server should receive client data")
case <-time.After(2 * time.Second):
t.Fatal("Server did not receive data within timeout")
}
// Read server response
buf := make([]byte, 1024)
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Logf("failed to set read deadline: %v", err)
}
n, err := conn.Read(buf)
require.NoError(t, err, "Should read server response")
response := string(buf[:n])
require.Equal(t, "SSH-2.0-MockNetBirdSSH\r\n", response, "Should receive SSH server identification")
})
t.Run("PortRedirectionSimulation", func(t *testing.T) {
// This test simulates the port redirection process
// Note: This doesn't test the actual userspace packet interception,
// but demonstrates the expected behavior
t.Log("NOTE: This test demonstrates expected behavior after implementing")
t.Log("full userspace packet interception. Currently, we test packet")
t.Log("translation logic separately from actual network delivery.")
// In a real implementation with userspace packet interception:
// 1. Client would connect to 127.0.0.1:22
// 2. Userspace firewall would intercept packets
// 3. translateInboundPortDNAT would rewrite port 22 -> 22022
// 4. Packets would be delivered to 127.0.0.1:22022
// 5. NetBird SSH server would receive the connection
// For now, we verify that the packet translation logic works correctly
// (this is tested in other test functions) and that the target server
// is reachable (tested above)
clientIP := netip.MustParseAddr("127.0.0.1")
serverIP := netip.MustParseAddr("127.0.0.1")
// Create manager with SSH port redirection
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Add SSH port redirection for localhost (for testing)
err = manager.AddInboundDNAT(netip.MustParseAddr("127.0.0.1"), firewall.ProtocolTCP, 22, 22022)
require.NoError(t, err)
// Generate packet: client connecting to server:22
sshPacket := generateDNATTestPacket(t, clientIP, serverIP, layers.IPProtocolTCP, 54321, 22)
originalPacket := make([]byte, len(sshPacket))
copy(originalPacket, sshPacket)
// Apply port redirection
translated := manager.translateInboundPortDNAT(sshPacket, parsePacket(t, sshPacket))
require.True(t, translated, "SSH packet should be translated")
// Verify port was redirected from 22 to 22022
d := parsePacket(t, sshPacket)
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be redirected to NetBird SSH server")
require.NotEqual(t, originalPacket, sshPacket, "Packet should be modified")
t.Log("✓ Packet translation verified: port 22 redirected to 22022")
t.Log("✓ Target SSH server (port 22022) is reachable and responsive")
t.Log("→ Integration complete: SSH port redirection ready for userspace interception")
})
err = manager.RemoveInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
require.NoError(t, err)
})
}
}
// TestFullSSHRedirectionWorkflow demonstrates the complete SSH redirection workflow
func TestFullSSHRedirectionWorkflow(t *testing.T) {
t.Log("=== SSH Port Redirection Workflow Test ===")
t.Log("This test demonstrates the complete SSH redirection process:")
t.Log("1. Client connects to peer:22 (standard SSH)")
t.Log("2. Userspace firewall intercepts and redirects to peer:22022")
t.Log("3. NetBird SSH server receives connection on port 22022")
t.Log("4. Return traffic is reverse-translated (22022 -> 22)")
// Setup test environment
func TestInboundPortDNATNegative(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
@@ -622,47 +204,51 @@ func TestFullSSHRedirectionWorkflow(t *testing.T) {
require.NoError(t, manager.Close(nil))
}()
// Define NetBird network and peer IPs
peerIP := netip.MustParseAddr("100.10.0.50")
clientIP := netip.MustParseAddr("100.10.0.100")
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
// Step 1: Configure SSH port redirection
err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022)
err = manager.AddInboundDNAT(localAddr, firewall.ProtocolTCP, 22, 22022)
require.NoError(t, err)
t.Log("✓ SSH port redirection configured for NetBird network")
// Step 2: Simulate inbound SSH connection (client -> peer:22)
t.Log("→ Simulating: ssh user@100.10.0.50")
inboundPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22)
testCases := []struct {
name string
protocol layers.IPProtocol
srcIP netip.Addr
dstIP netip.Addr
srcPort uint16
dstPort uint16
}{
{"Wrong port", layers.IPProtocolTCP, clientIP, localAddr, 54321, 80},
{"Wrong IP", layers.IPProtocolTCP, clientIP, netip.MustParseAddr("100.64.0.99"), 54321, 22},
{"Wrong protocol", layers.IPProtocolUDP, clientIP, localAddr, 54321, 22},
{"ICMP", layers.IPProtocolICMPv4, clientIP, localAddr, 0, 0},
}
// Step 3: Apply inbound port redirection
translated := manager.translateInboundPortDNAT(inboundPacket, parsePacket(t, inboundPacket))
require.True(t, translated, "Inbound SSH packet should be redirected")
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
packet := generateDNATTestPacket(t, tc.srcIP, tc.dstIP, tc.protocol, tc.srcPort, tc.dstPort)
d := parsePacket(t, packet)
d := parsePacket(t, inboundPacket)
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Should redirect to NetBird SSH server port")
t.Log("✓ Inbound packet redirected: 100.10.0.50:22 → 100.10.0.50:22022")
translated := manager.translateInboundPortDNAT(packet, d, tc.srcIP, tc.dstIP)
require.False(t, translated, "Packet should NOT be translated for %s", tc.name)
// Step 4: Simulate outbound return traffic (peer:22022 -> client)
t.Log("→ Simulating return traffic from NetBird SSH server")
outboundPacket := generateDNATTestPacket(t, peerIP, clientIP, layers.IPProtocolTCP, 22022, 54321)
// Step 5: Apply outbound reverse translation
reversed := manager.translateOutboundPortReverse(outboundPacket, parsePacket(t, outboundPacket))
require.True(t, reversed, "Outbound return packet should be reverse translated")
d = parsePacket(t, outboundPacket)
require.Equal(t, uint16(22), uint16(d.tcp.SrcPort), "Should restore original SSH port")
t.Log("✓ Outbound packet reverse translated: 100.10.0.50:22022 → 100.10.0.50:22")
// Step 6: Verify client sees standard SSH connection
srcIPAfter := netip.AddrFrom4([4]byte{outboundPacket[12], outboundPacket[13], outboundPacket[14], outboundPacket[15]})
require.Equal(t, peerIP, srcIPAfter, "Client should see traffic from peer IP")
t.Log("✓ Client receives traffic from 100.10.0.50:22 (transparent redirection)")
t.Log("=== SSH Port Redirection Workflow Complete ===")
t.Log("Result: Standard SSH clients can connect to NetBird peers using:")
t.Log(" ssh user@100.10.0.50")
t.Log("Instead of:")
t.Log(" ssh user@100.10.0.50 -p 22022")
d = parsePacket(t, packet)
if tc.protocol == layers.IPProtocolTCP {
require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged")
} else if tc.protocol == layers.IPProtocolUDP {
require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged")
}
})
}
}
func protocolToFirewall(proto layers.IPProtocol) firewall.Protocol {
switch proto {
case layers.IPProtocolTCP:
return firewall.ProtocolTCP
case layers.IPProtocolUDP:
return firewall.ProtocolUDP
default:
return firewall.ProtocolALL
}
}

View File

@@ -16,25 +16,33 @@ type PacketStage int
const (
StageReceived PacketStage = iota
StageInboundPortDNAT
StageInbound1to1NAT
StageConntrack
StagePeerACL
StageRouting
StageRouteACL
StageForwarding
StageCompleted
StageOutbound1to1NAT
StageOutboundPortReverse
)
const msgProcessingCompleted = "Processing completed"
func (s PacketStage) String() string {
return map[PacketStage]string{
StageReceived: "Received",
StageConntrack: "Connection Tracking",
StagePeerACL: "Peer ACL",
StageRouting: "Routing",
StageRouteACL: "Route ACL",
StageForwarding: "Forwarding",
StageCompleted: "Completed",
StageReceived: "Received",
StageInboundPortDNAT: "Inbound Port DNAT",
StageInbound1to1NAT: "Inbound 1:1 NAT",
StageConntrack: "Connection Tracking",
StagePeerACL: "Peer ACL",
StageRouting: "Routing",
StageRouteACL: "Route ACL",
StageForwarding: "Forwarding",
StageCompleted: "Completed",
StageOutbound1to1NAT: "Outbound 1:1 NAT",
StageOutboundPortReverse: "Outbound DNAT Reverse",
}[s]
}
@@ -261,6 +269,10 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
}
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
if m.handleInboundDNAT(trace, packetData, d, &srcIP, &dstIP) {
return trace
}
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
return trace
}
@@ -400,7 +412,16 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
}
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
return trace
}
m.handleOutboundDNAT(trace, packetData, d)
dropped := m.filterOutbound(packetData, 0)
if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
@@ -409,3 +430,199 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
}
return trace
}
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
if portDNATApplied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
return true
}
*srcIP, *dstIP = m.extractIPs(d)
trace.DestinationPort = m.getDestPort(d)
}
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
if nat1to1Applied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
return true
}
*srcIP, *dstIP = m.extractIPs(d)
}
return false
}
func (m *Manager) traceInboundPortDNAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.portDNATEnabled.Load() {
trace.AddResult(StageInboundPortDNAT, "Port DNAT not enabled", true)
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
trace.AddResult(StageInboundPortDNAT, "Not IPv4, skipping port DNAT", true)
return false
}
if len(d.decoded) < 2 {
trace.AddResult(StageInboundPortDNAT, "No transport layer, skipping port DNAT", true)
return false
}
protocol := d.decoded[1]
if protocol != layers.LayerTypeTCP && protocol != layers.LayerTypeUDP {
trace.AddResult(StageInboundPortDNAT, "Not TCP/UDP, skipping port DNAT", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
var originalPort uint16
if protocol == layers.LayerTypeTCP {
originalPort = uint16(d.tcp.DstPort)
} else {
originalPort = uint16(d.udp.DstPort)
}
translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP)
if translated {
ipHeaderLen := int((packetData[0] & 0x0F) * 4)
translatedPort := uint16(packetData[ipHeaderLen+2])<<8 | uint16(packetData[ipHeaderLen+3])
protoStr := "TCP"
if protocol == layers.LayerTypeUDP {
protoStr = "UDP"
}
msg := fmt.Sprintf("%s port DNAT applied: %s:%d -> %s:%d", protoStr, dstIP, originalPort, dstIP, translatedPort)
trace.AddResult(StageInboundPortDNAT, msg, true)
return true
}
trace.AddResult(StageInboundPortDNAT, "No matching port DNAT rule", true)
return false
}
func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
trace.AddResult(StageInbound1to1NAT, "1:1 NAT not enabled", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
translated := m.translateInboundReverse(packetData, d)
if translated {
m.dnatMutex.RLock()
translatedIP, exists := m.dnatBiMap.getOriginal(srcIP)
m.dnatMutex.RUnlock()
if exists {
msg := fmt.Sprintf("1:1 NAT reverse applied: %s -> %s", srcIP, translatedIP)
trace.AddResult(StageInbound1to1NAT, msg, true)
return true
}
}
trace.AddResult(StageInbound1to1NAT, "No matching 1:1 NAT rule", true)
return false
}
func (m *Manager) handleOutboundDNAT(trace *PacketTrace, packetData []byte, d *decoder) {
m.traceOutbound1to1NAT(trace, packetData, d)
m.traceOutboundPortReverse(trace, packetData, d)
}
func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
trace.AddResult(StageOutbound1to1NAT, "1:1 NAT not enabled", true)
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translated := m.translateOutboundDNAT(packetData, d)
if translated {
m.dnatMutex.RLock()
translatedIP, exists := m.dnatMappings[dstIP]
m.dnatMutex.RUnlock()
if exists {
msg := fmt.Sprintf("1:1 NAT applied: %s -> %s", dstIP, translatedIP)
trace.AddResult(StageOutbound1to1NAT, msg, true)
return true
}
}
trace.AddResult(StageOutbound1to1NAT, "No matching 1:1 NAT rule", true)
return false
}
func (m *Manager) traceOutboundPortReverse(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.portDNATEnabled.Load() {
trace.AddResult(StageOutboundPortReverse, "Port DNAT not enabled", true)
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
trace.AddResult(StageOutboundPortReverse, "Not IPv4, skipping port reverse", true)
return false
}
if len(d.decoded) < 2 {
trace.AddResult(StageOutboundPortReverse, "No transport layer, skipping port reverse", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
var origPort uint16
transport := d.decoded[1]
switch transport {
case layers.LayerTypeTCP:
srcPort := uint16(d.tcp.SrcPort)
dstPort := uint16(d.tcp.DstPort)
conn, exists := m.tcpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
if exists {
origPort = uint16(conn.DNATOrigPort.Load())
}
if origPort != 0 {
msg := fmt.Sprintf("TCP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
trace.AddResult(StageOutboundPortReverse, msg, true)
return true
}
case layers.LayerTypeUDP:
srcPort := uint16(d.udp.SrcPort)
dstPort := uint16(d.udp.DstPort)
conn, exists := m.udpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
if exists {
origPort = uint16(conn.DNATOrigPort.Load())
}
if origPort != 0 {
msg := fmt.Sprintf("UDP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
trace.AddResult(StageOutboundPortReverse, msg, true)
return true
}
default:
trace.AddResult(StageOutboundPortReverse, "Not TCP/UDP, skipping port reverse", true)
return false
}
trace.AddResult(StageOutboundPortReverse, "No tracked connection for DNAT reverse", true)
return false
}
func (m *Manager) getDestPort(d *decoder) uint16 {
if len(d.decoded) < 2 {
return 0
}
switch d.decoded[1] {
case layers.LayerTypeTCP:
return uint16(d.tcp.DstPort)
case layers.LayerTypeUDP:
return uint16(d.udp.DstPort)
default:
return 0
}
}

View File

@@ -104,6 +104,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -126,6 +128,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -153,6 +157,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -179,6 +185,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -204,6 +212,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StageRouteACL,
@@ -228,6 +238,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StageRouteACL,
@@ -246,6 +258,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StageRouteACL,
@@ -264,6 +278,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StageCompleted,
@@ -287,6 +303,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageCompleted,
},
@@ -301,6 +319,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageOutbound1to1NAT,
StageOutboundPortReverse,
StageCompleted,
},
expectedAllow: true,
@@ -319,6 +339,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -340,6 +362,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -362,6 +386,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -382,6 +408,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
@@ -406,6 +434,8 @@ func TestTracePacket(t *testing.T) {
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageRouting,
StagePeerACL,
StageCompleted,