Complete overhaul

This commit is contained in:
Viktor Liu
2025-06-24 12:19:53 +02:00
parent f56075ca15
commit 9d1554f9f7
74 changed files with 16626 additions and 4524 deletions

View File

@@ -260,6 +260,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return m.router.UpdateSet(set, prefixes)
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveInboundDNAT removes inbound DNAT rule
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@@ -880,6 +880,54 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nberrors.FormatErrorOrNil(merr)
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
dnatRule := []string{
"-i", r.wgIface.Name(),
"-p", strings.ToLower(string(protocol)),
"--dport", strconv.Itoa(int(sourcePort)),
"-d", localAddr.String(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
}
ruleInfo := ruleInfo{
table: tableNat,
chain: chainRTRDR,
rule: dnatRule,
}
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = ruleInfo.rule
r.updateState()
return nil
}
// RemoveInboundDNAT removes inbound DNAT rule
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
return fmt.Errorf("delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}
func applyPort(flag string, port *firewall.Port) []string {
if port == nil {
return nil

View File

@@ -151,14 +151,20 @@ type Manager interface {
DisableRouting() error
// AddDNATRule adds a DNAT rule
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to the NetBird network.
AddDNATRule(ForwardRule) (Rule, error)
// DeleteDNATRule deletes a DNAT rule
// DeleteDNATRule deletes the outbound DNAT rule.
DeleteDNATRule(Rule) error
// UpdateSet updates the set with the given prefixes
UpdateSet(hash Set, prefixes []netip.Prefix) error
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
}
func GenKey(format string, pair RouterPair) string {

View File

@@ -376,6 +376,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return m.router.UpdateSet(set, prefixes)
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveInboundDNAT removes inbound DNAT rule
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {

View File

@@ -1350,6 +1350,103 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nil
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
protoNum, err := protoToInt(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 3,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 3,
Data: binaryutil.BigEndian.PutUint16(sourcePort),
},
}
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(targetPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: 2,
RegProtoMax: 0,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveInboundDNAT removes inbound DNAT rule
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if rule, exists := r.rules[ruleID]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
return nil
}
// applyNetwork generates nftables expressions for networks (CIDR) or sets
func (r *router) applyNetwork(
network firewall.Network,

View File

@@ -29,6 +29,12 @@ import (
const layerTypeAll = 0
// serviceKey represents a protocol/port combination for netstack service registry
type serviceKey struct {
protocol gopacket.LayerType
port uint16
}
const (
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
@@ -110,6 +116,15 @@ type Manager struct {
dnatMappings map[netip.Addr]netip.Addr
dnatMutex sync.RWMutex
dnatBiMap *biDNATMap
// Port-specific DNAT for SSH redirection
portDNATEnabled atomic.Bool
portDNATMap *portDNATMap
portDNATMutex sync.RWMutex
portNATTracker *portNATTracker
netstackServices map[serviceKey]struct{}
netstackServiceMutex sync.RWMutex
}
// decoder for packages
@@ -196,6 +211,9 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATMap: &portDNATMap{rules: make([]portDNATRule, 0)},
portNATTracker: newPortNATTracker(),
netstackServices: make(map[serviceKey]struct{}),
}
m.routingEnabled.Store(false)
@@ -333,18 +351,22 @@ func (m *Manager) initForwarder() error {
return nil
}
// Init initializes the firewall manager with state manager.
func (m *Manager) Init(*statemanager.Manager) error {
return nil
}
// IsServerRouteSupported returns whether server routes are supported.
func (m *Manager) IsServerRouteSupported() bool {
return true
}
// IsStateful returns whether the firewall manager tracks connection state.
func (m *Manager) IsStateful() bool {
return m.stateful
}
// AddNatRule adds a routing firewall rule for NAT translation.
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddNatRule(pair)
@@ -611,6 +633,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
m.trackOutbound(d, srcIP, dstIP, size)
m.translateOutboundDNAT(packetData, d)
m.translateOutboundPortReverse(packetData, d)
return false
}
@@ -738,6 +761,15 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
return false
}
if translated := m.translateInboundPortDNAT(packetData, d); translated {
// Re-decode after port DNAT translation to update port information
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error("Failed to re-decode packet after port DNAT: %v", err)
return true
}
srcIP, dstIP = m.extractIPs(d)
}
if translated := m.translateInboundReverse(packetData, d); translated {
// Re-decode after translation to get original addresses
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
@@ -786,9 +818,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
return true
}
// If requested we pass local traffic to internal interfaces to the forwarder.
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
if m.shouldForward(d, dstIP) {
return m.handleForwardedLocalTraffic(packetData)
}
@@ -1215,3 +1245,95 @@ func (m *Manager) DisableRouting() error {
return nil
}
// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port
func (m *Manager) RegisterNetstackService(protocol nftypes.Protocol, port uint16) {
m.netstackServiceMutex.Lock()
defer m.netstackServiceMutex.Unlock()
layerType := m.protocolToLayerType(protocol)
key := serviceKey{protocol: layerType, port: port}
m.netstackServices[key] = struct{}{}
m.logger.Debug("RegisterNetstackService: registered %s:%d (layerType=%s)", protocol, port, layerType)
m.logger.Debug("RegisterNetstackService: current registry size: %d", len(m.netstackServices))
}
// UnregisterNetstackService removes a service from the netstack registry
func (m *Manager) UnregisterNetstackService(protocol nftypes.Protocol, port uint16) {
m.netstackServiceMutex.Lock()
defer m.netstackServiceMutex.Unlock()
layerType := m.protocolToLayerType(protocol)
key := serviceKey{protocol: layerType, port: port}
delete(m.netstackServices, key)
m.logger.Debug("Unregistered netstack service on protocol %s port %d", protocol, port)
}
// isNetstackService checks if a service is registered as listening on netstack for the given protocol and port
func (m *Manager) isNetstackService(layerType gopacket.LayerType, port uint16) bool {
m.netstackServiceMutex.RLock()
defer m.netstackServiceMutex.RUnlock()
key := serviceKey{protocol: layerType, port: port}
_, exists := m.netstackServices[key]
return exists
}
// protocolToLayerType converts nftypes.Protocol to gopacket.LayerType for internal use
func (m *Manager) protocolToLayerType(protocol nftypes.Protocol) gopacket.LayerType {
switch protocol {
case nftypes.TCP:
return layers.LayerTypeTCP
case nftypes.UDP:
return layers.LayerTypeUDP
case nftypes.ICMP:
return layers.LayerTypeICMPv4
default:
return gopacket.LayerType(0) // Invalid/unknown
}
}
// shouldForward determines if a packet should be forwarded to the forwarder.
// The forwarder handles routing packets to the native OS network stack.
// Returns true if packet should go to the forwarder, false if it should go to netstack listeners or the native stack directly.
func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool {
// not enabled, never forward
if !m.localForwarding {
return false
}
// netstack always needs to forward because it's lacking a native interface
// exception for registered netstack services, those should go to netstack listeners
if m.netstack {
return !m.hasMatchingNetstackService(d)
}
// traffic to our other local interfaces (not NetBird IP) - always forward
if dstIP != m.wgIface.Address().IP {
return true
}
// traffic to our NetBird IP, not netstack mode - send to netstack listeners
return false
}
// hasMatchingNetstackService checks if there's a registered netstack service for this packet
func (m *Manager) hasMatchingNetstackService(d *decoder) bool {
if len(d.decoded) < 2 {
return false
}
var dstPort uint16
switch d.decoded[1] {
case layers.LayerTypeTCP:
dstPort = uint16(d.tcp.DstPort)
case layers.LayerTypeUDP:
dstPort = uint16(d.udp.DstPort)
default:
return false
}
key := serviceKey{protocol: d.decoded[1], port: dstPort}
m.netstackServiceMutex.RLock()
_, exists := m.netstackServices[key]
m.netstackServiceMutex.RUnlock()
return exists
}

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/management/domain"
)
@@ -896,3 +897,138 @@ func TestUpdateSetDeduplication(t *testing.T) {
require.Equal(t, tc.expected, isAllowed, tc.desc)
}
}
func TestShouldForward(t *testing.T) {
// Set up test addresses
wgIP := netip.MustParseAddr("100.10.0.1")
otherIP := netip.MustParseAddr("100.10.0.2")
// Create test manager with mock interface
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
// Set the mock to return our test WG IP
ifaceMock.AddressFunc = func() wgaddr.Address {
return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)}
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Helper to create decoder with TCP packet
createTCPDecoder := func(dstPort uint16) *decoder {
ipv4 := &layers.IPv4{
Version: 4,
Protocol: layers.IPProtocolTCP,
SrcIP: net.ParseIP("192.168.1.100"),
DstIP: wgIP.AsSlice(),
}
tcp := &layers.TCP{
SrcPort: 54321,
DstPort: layers.TCPPort(dstPort),
}
err := tcp.SetNetworkLayerForChecksum(ipv4)
require.NoError(t, err)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err = gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))
require.NoError(t, err)
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded)
require.NoError(t, err)
return d
}
tests := []struct {
name string
localForwarding bool
netstack bool
dstIP netip.Addr
serviceRegistered bool
servicePort uint16
expected bool
description string
}{
{
name: "no local forwarding",
localForwarding: false,
netstack: true,
dstIP: wgIP,
expected: false,
description: "should never forward when local forwarding disabled",
},
{
name: "traffic to other local interface",
localForwarding: true,
netstack: false,
dstIP: otherIP,
expected: true,
description: "should forward traffic to our other local interfaces (not NetBird IP)",
},
{
name: "traffic to NetBird IP, no netstack",
localForwarding: true,
netstack: false,
dstIP: wgIP,
expected: false,
description: "should send to netstack listeners (final return false path)",
},
{
name: "traffic to our IP, netstack mode, no service",
localForwarding: true,
netstack: true,
dstIP: wgIP,
expected: true,
description: "should forward when in netstack mode with no matching service",
},
{
name: "traffic to our IP, netstack mode, with service",
localForwarding: true,
netstack: true,
dstIP: wgIP,
serviceRegistered: true,
servicePort: 22,
expected: false,
description: "should send to netstack listeners when service is registered",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Configure manager
manager.localForwarding = tt.localForwarding
manager.netstack = tt.netstack
// Register service if needed
if tt.serviceRegistered {
manager.RegisterNetstackService(nftypes.TCP, tt.servicePort)
defer manager.UnregisterNetstackService(nftypes.TCP, tt.servicePort)
}
// Create decoder for the test
decoder := createTCPDecoder(tt.servicePort)
if !tt.serviceRegistered {
decoder = createTCPDecoder(8080) // Use non-registered port
}
// Test the method
result := manager.shouldForward(decoder, tt.dstIP)
require.Equal(t, tt.expected, result, tt.description)
})
}
}

View File

@@ -5,7 +5,10 @@ import (
"errors"
"fmt"
"net/netip"
"sync"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -13,6 +16,12 @@ import (
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
const (
invalidIPHeaderLengthMsg = "invalid IP header length"
errRewriteTCPDestinationPort = "rewrite TCP destination port: %v"
)
// ipv4Checksum calculates IPv4 header checksum using optimized parallel processing for performance.
func ipv4Checksum(header []byte) uint16 {
if len(header) < 20 {
return 0
@@ -20,13 +29,11 @@ func ipv4Checksum(header []byte) uint16 {
var sum1, sum2 uint32
// Parallel processing - unroll and compute two sums simultaneously
sum1 += uint32(binary.BigEndian.Uint16(header[0:2]))
sum2 += uint32(binary.BigEndian.Uint16(header[2:4]))
sum1 += uint32(binary.BigEndian.Uint16(header[4:6]))
sum2 += uint32(binary.BigEndian.Uint16(header[6:8]))
sum1 += uint32(binary.BigEndian.Uint16(header[8:10]))
// Skip checksum field at [10:12]
sum2 += uint32(binary.BigEndian.Uint16(header[12:14]))
sum1 += uint32(binary.BigEndian.Uint16(header[14:16]))
sum2 += uint32(binary.BigEndian.Uint16(header[16:18]))
@@ -34,7 +41,6 @@ func ipv4Checksum(header []byte) uint16 {
sum := sum1 + sum2
// Handle remaining bytes for headers > 20 bytes
for i := 20; i < len(header)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
}
@@ -43,7 +49,6 @@ func ipv4Checksum(header []byte) uint16 {
sum += uint32(header[len(header)-1]) << 8
}
// Optimized carry fold - single iteration handles most cases
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
@@ -52,11 +57,11 @@ func ipv4Checksum(header []byte) uint16 {
return ^uint16(sum)
}
// icmpChecksum calculates ICMP checksum using parallel accumulation for high-performance processing.
func icmpChecksum(data []byte) uint16 {
var sum1, sum2, sum3, sum4 uint32
i := 0
// Process 16 bytes at once with 4 parallel accumulators
for i <= len(data)-16 {
sum1 += uint32(binary.BigEndian.Uint16(data[i : i+2]))
sum2 += uint32(binary.BigEndian.Uint16(data[i+2 : i+4]))
@@ -71,7 +76,6 @@ func icmpChecksum(data []byte) uint16 {
sum := sum1 + sum2 + sum3 + sum4
// Handle remaining bytes
for i < len(data)-1 {
sum += uint32(binary.BigEndian.Uint16(data[i : i+2]))
i += 2
@@ -89,11 +93,131 @@ func icmpChecksum(data []byte) uint16 {
return ^uint16(sum)
}
// biDNATMap maintains bidirectional DNAT mappings for efficient forward and reverse lookups.
type biDNATMap struct {
forward map[netip.Addr]netip.Addr
reverse map[netip.Addr]netip.Addr
}
// portDNATRule represents a port-specific DNAT rule
type portDNATRule struct {
protocol gopacket.LayerType
sourcePort uint16
targetPort uint16
targetIP netip.Addr
}
// portDNATMap manages port-specific DNAT rules
type portDNATMap struct {
rules []portDNATRule
}
// ConnKey represents a connection 4-tuple for NAT tracking.
type ConnKey struct {
SrcIP netip.Addr
DstIP netip.Addr
SrcPort uint16
DstPort uint16
}
// portNATConn tracks port NAT state for a specific connection.
type portNATConn struct {
rule portDNATRule
originalPort uint16
translatedAt time.Time
}
// portNATTracker tracks connection-specific port NAT state
type portNATTracker struct {
connections map[ConnKey]*portNATConn
mutex sync.RWMutex
}
// newPortNATTracker creates a new port NAT tracker for stateful connection tracking.
func newPortNATTracker() *portNATTracker {
return &portNATTracker{
connections: make(map[ConnKey]*portNATConn),
}
}
// trackConnection tracks a connection that has port NAT applied using translated port as key.
func (t *portNATTracker) trackConnection(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, rule portDNATRule) {
t.mutex.Lock()
defer t.mutex.Unlock()
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: rule.targetPort,
}
t.connections[key] = &portNATConn{
rule: rule,
originalPort: dstPort,
translatedAt: time.Now(),
}
}
// getConnectionNAT returns NAT info for a connection if tracked, looking up by connection 4-tuple.
func (t *portNATTracker) getConnectionNAT(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) (*portNATConn, bool) {
t.mutex.RLock()
defer t.mutex.RUnlock()
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := t.connections[key]
return conn, exists
}
// removeConnection removes a tracked connection from the NAT tracking table.
func (t *portNATTracker) removeConnection(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
t.mutex.Lock()
defer t.mutex.Unlock()
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
delete(t.connections, key)
}
// shouldApplyNAT checks if NAT should be applied to a new connection to prevent bidirectional conflicts.
func (t *portNATTracker) shouldApplyNAT(srcIP, dstIP netip.Addr, dstPort uint16) bool {
t.mutex.RLock()
defer t.mutex.RUnlock()
for key, conn := range t.connections {
if key.SrcIP == dstIP && key.DstIP == srcIP &&
conn.rule.sourcePort == dstPort && conn.originalPort == dstPort {
return false
}
}
return true
}
// cleanupConnection removes a NAT connection based on original 4-tuple for connection cleanup.
func (t *portNATTracker) cleanupConnection(srcIP, dstIP netip.Addr, srcPort uint16) {
t.mutex.Lock()
defer t.mutex.Unlock()
for key := range t.connections {
if key.SrcIP == srcIP && key.DstIP == dstIP && key.SrcPort == srcPort {
delete(t.connections, key)
return
}
}
}
// newBiDNATMap creates a new bidirectional DNAT mapping structure for efficient forward/reverse lookups.
func newBiDNATMap() *biDNATMap {
return &biDNATMap{
forward: make(map[netip.Addr]netip.Addr),
@@ -101,11 +225,13 @@ func newBiDNATMap() *biDNATMap {
}
}
// set adds a bidirectional DNAT mapping between original and translated addresses for both directions.
func (b *biDNATMap) set(original, translated netip.Addr) {
b.forward[original] = translated
b.reverse[translated] = original
}
// delete removes a bidirectional DNAT mapping for the given original address.
func (b *biDNATMap) delete(original netip.Addr) {
if translated, exists := b.forward[original]; exists {
delete(b.forward, original)
@@ -113,19 +239,25 @@ func (b *biDNATMap) delete(original netip.Addr) {
}
}
// getTranslated returns the translated address for a given original address from forward mapping.
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
translated, exists := b.forward[original]
return translated, exists
}
// getOriginal returns the original address for a given translated address from reverse mapping.
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
original, exists := b.reverse[translated]
return original, exists
}
// AddInternalDNATMapping adds a 1:1 IP address mapping for internal DNAT translation.
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
return fmt.Errorf("invalid IP addresses")
if !originalAddr.IsValid() {
return fmt.Errorf("invalid original IP address")
}
if !translatedAddr.IsValid() {
return fmt.Errorf("invalid translated IP address")
}
if m.localipmanager.IsLocalIP(translatedAddr) {
@@ -135,7 +267,6 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
// Initialize both maps together if either is nil
if m.dnatMappings == nil || m.dnatBiMap == nil {
m.dnatMappings = make(map[netip.Addr]netip.Addr)
m.dnatBiMap = newBiDNATMap()
@@ -151,7 +282,7 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
return nil
}
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
// RemoveInternalDNATMapping removes a 1:1 IP address mapping.
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
@@ -169,7 +300,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
return nil
}
// getDNATTranslation returns the translated address if a mapping exists
// getDNATTranslation returns the translated address if a mapping exists with fast-path optimization.
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return addr, false
@@ -181,7 +312,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
return translated, exists
}
// findReverseDNATMapping finds original address for return traffic
// findReverseDNATMapping finds original address for return traffic using reverse mapping.
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return translatedAddr, false
@@ -193,7 +324,7 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr,
return original, exists
}
// translateOutboundDNAT applies DNAT translation to outbound packets
// translateOutboundDNAT applies DNAT translation to outbound packets for 1:1 IP mapping.
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
@@ -211,7 +342,7 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
}
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
m.logger.Error("Failed to rewrite packet destination: %v", err)
m.logger.Error("rewrite packet destination: %v", err)
return false
}
@@ -219,7 +350,7 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return true
}
// translateInboundReverse applies reverse DNAT to inbound return traffic
// translateInboundReverse applies reverse DNAT to inbound return traffic for 1:1 IP mapping.
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
@@ -237,7 +368,7 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
}
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
m.logger.Error("Failed to rewrite packet source: %v", err)
m.logger.Error("rewrite packet source: %v", err)
return false
}
@@ -245,7 +376,7 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return true
}
// rewritePacketDestination replaces destination IP in the packet
// rewritePacketDestination replaces destination IP in the packet and updates checksums.
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
@@ -259,7 +390,7 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
return fmt.Errorf(invalidIPHeaderLengthMsg)
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
@@ -280,7 +411,7 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP
return nil
}
// rewritePacketSource replaces the source IP address in the packet
// rewritePacketSource replaces the source IP address in the packet and updates checksums.
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
@@ -294,7 +425,7 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
return fmt.Errorf(invalidIPHeaderLengthMsg)
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
@@ -315,6 +446,7 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip
return nil
}
// updateTCPChecksum updates TCP checksum after IP address change using incremental update per RFC 1624.
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+18 {
@@ -327,6 +459,7 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
// updateUDPChecksum updates UDP checksum after IP address change using incremental update per RFC 1624.
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
udpStart := ipHeaderLen
if len(packetData) < udpStart+8 {
@@ -344,6 +477,7 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
// updateICMPChecksum recalculates ICMP checksum after packet modification using full recalculation.
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
icmpStart := ipHeaderLen
if len(packetData) < icmpStart+8 {
@@ -356,18 +490,16 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
}
// incrementalUpdate performs incremental checksum update per RFC 1624
// incrementalUpdate performs incremental checksum update per RFC 1624 for performance.
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
sum := uint32(^oldChecksum)
// Fast path for IPv4 addresses (4 bytes) - most common case
if len(oldBytes) == 4 && len(newBytes) == 4 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
} else {
// Fallback for other lengths
for i := 0; i < len(oldBytes)-1; i += 2 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2]))
}
@@ -391,7 +523,7 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
return ^uint16(sum)
}
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network.
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
@@ -399,10 +531,318 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
// DeleteDNATRule deletes outbound DNAT rule.
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}
// addPortRedirection adds port redirection rule for specified target IP, protocol and ports.
func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock()
rule := portDNATRule{
protocol: protocol,
sourcePort: sourcePort,
targetPort: targetPort,
targetIP: targetIP,
}
m.portDNATMap.rules = append(m.portDNATMap.rules, rule)
m.portDNATEnabled.Store(true)
return nil
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services on specific ports.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
if protocol == firewall.ProtocolTCP {
layerType = layers.LayerTypeTCP
} else if protocol == firewall.ProtocolUDP {
layerType = layers.LayerTypeUDP
} else {
return fmt.Errorf("unsupported protocol: %s", protocol)
}
return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// removePortRedirection removes port redirection rule for specified target IP, protocol and ports.
func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock()
var filteredRules []portDNATRule
for _, rule := range m.portDNATMap.rules {
if !(rule.protocol == protocol && rule.sourcePort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0) {
filteredRules = append(filteredRules, rule)
}
}
m.portDNATMap.rules = filteredRules
if len(m.portDNATMap.rules) == 0 {
m.portDNATEnabled.Store(false)
}
return nil
}
// RemoveInboundDNAT removes inbound DNAT rule for specified local address and ports.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
if protocol == firewall.ProtocolTCP {
layerType = layers.LayerTypeTCP
} else if protocol == firewall.ProtocolUDP {
layerType = layers.LayerTypeUDP
} else {
return fmt.Errorf("unsupported protocol: %s", protocol)
}
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// translateInboundPortDNAT applies stateful port-specific DNAT translation to inbound packets.
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder) bool {
if !m.portDNATEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP {
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
srcPort := uint16(d.tcp.SrcPort)
dstPort := uint16(d.tcp.DstPort)
if m.handleReturnTraffic(packetData, d, srcIP, dstIP, srcPort, dstPort) {
return true
}
return m.handleNewConnection(packetData, d, srcIP, dstIP, srcPort, dstPort)
}
// handleReturnTraffic processes return traffic for existing NAT connections.
func (m *Manager) handleReturnTraffic(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool {
if m.isTranslatedPortTraffic(srcIP, srcPort) {
return false
}
if handled := m.handleExistingNATConnection(packetData, d, srcIP, dstIP, srcPort, dstPort); handled {
return true
}
return m.handleForwardTrafficInExistingConnections(packetData, d, srcIP, dstIP, srcPort, dstPort)
}
// isTranslatedPortTraffic checks if traffic is from a translated port that should be handled by outbound reverse.
func (m *Manager) isTranslatedPortTraffic(srcIP netip.Addr, srcPort uint16) bool {
m.portDNATMutex.RLock()
defer m.portDNATMutex.RUnlock()
for _, rule := range m.portDNATMap.rules {
if rule.protocol == layers.LayerTypeTCP && rule.targetPort == srcPort &&
rule.targetIP.Unmap().Compare(srcIP.Unmap()) == 0 {
return true
}
}
return false
}
// handleExistingNATConnection processes return traffic for existing NAT connections.
func (m *Manager) handleExistingNATConnection(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool {
if natConn, exists := m.portNATTracker.getConnectionNAT(dstIP, srcIP, dstPort, srcPort); exists {
if err := m.rewriteTCPDestinationPort(packetData, d, natConn.originalPort); err != nil {
m.logger.Error(errRewriteTCPDestinationPort, err)
return false
}
m.logger.Trace("Inbound Port DNAT (return): %s:%d -> %s:%d", dstIP, srcPort, dstIP, natConn.originalPort)
return true
}
return false
}
// handleForwardTrafficInExistingConnections processes forward traffic in existing connections.
func (m *Manager) handleForwardTrafficInExistingConnections(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool {
m.portDNATMutex.RLock()
defer m.portDNATMutex.RUnlock()
for _, rule := range m.portDNATMap.rules {
if rule.protocol != layers.LayerTypeTCP || rule.sourcePort != dstPort {
continue
}
if rule.targetIP.Unmap().Compare(dstIP.Unmap()) != 0 {
continue
}
if _, exists := m.portNATTracker.getConnectionNAT(srcIP, dstIP, srcPort, rule.targetPort); !exists {
continue
}
if err := m.rewriteTCPDestinationPort(packetData, d, rule.targetPort); err != nil {
m.logger.Error(errRewriteTCPDestinationPort, err)
return false
}
return true
}
return false
}
// handleNewConnection processes new connections that match port DNAT rules.
func (m *Manager) handleNewConnection(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool {
m.portDNATMutex.RLock()
defer m.portDNATMutex.RUnlock()
for _, rule := range m.portDNATMap.rules {
if m.applyPortDNATRule(packetData, d, rule, srcIP, dstIP, srcPort, dstPort) {
return true
}
}
return false
}
// applyPortDNATRule applies a specific port DNAT rule if conditions are met.
func (m *Manager) applyPortDNATRule(packetData []byte, d *decoder, rule portDNATRule, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool {
if rule.protocol != layers.LayerTypeTCP || rule.sourcePort != dstPort {
return false
}
if rule.targetIP.Unmap().Compare(dstIP.Unmap()) != 0 {
return false
}
if !m.portNATTracker.shouldApplyNAT(srcIP, dstIP, dstPort) {
return false
}
if err := m.rewriteTCPDestinationPort(packetData, d, rule.targetPort); err != nil {
m.logger.Error(errRewriteTCPDestinationPort, err)
return false
}
m.portNATTracker.trackConnection(srcIP, dstIP, srcPort, dstPort, rule)
m.logger.Trace("Inbound Port DNAT (new): %s:%d -> %s:%d (tracked: %s:%d -> %s:%d)", dstIP, rule.sourcePort, dstIP, rule.targetPort, srcIP, srcPort, dstIP, rule.targetPort)
return true
}
// rewriteTCPDestinationPort rewrites the destination port in a TCP packet and updates checksum.
func (m *Manager) rewriteTCPDestinationPort(packetData []byte, d *decoder, newPort uint16) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return ErrIPv4Only
}
if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP {
return fmt.Errorf("not a TCP packet")
}
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf(invalidIPHeaderLengthMsg)
}
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+4 {
return fmt.Errorf("packet too short for TCP header")
}
oldPort := binary.BigEndian.Uint16(packetData[tcpStart+2 : tcpStart+4])
binary.BigEndian.PutUint16(packetData[tcpStart+2:tcpStart+4], newPort)
if len(packetData) >= tcpStart+18 {
checksumOffset := tcpStart + 16
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
var oldPortBytes, newPortBytes [2]byte
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
return nil
}
// rewriteTCPSourcePort rewrites the source port in a TCP packet and updates checksum.
func (m *Manager) rewriteTCPSourcePort(packetData []byte, d *decoder, newPort uint16) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return ErrIPv4Only
}
if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP {
return fmt.Errorf("not a TCP packet")
}
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf(invalidIPHeaderLengthMsg)
}
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+4 {
return fmt.Errorf("packet too short for TCP header")
}
oldPort := binary.BigEndian.Uint16(packetData[tcpStart : tcpStart+2])
binary.BigEndian.PutUint16(packetData[tcpStart:tcpStart+2], newPort)
if len(packetData) >= tcpStart+18 {
checksumOffset := tcpStart + 16
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
var oldPortBytes, newPortBytes [2]byte
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
return nil
}
// translateOutboundPortReverse applies stateful reverse port DNAT to outbound return traffic for SSH redirection.
func (m *Manager) translateOutboundPortReverse(packetData []byte, d *decoder) bool {
if !m.portDNATEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP {
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
srcPort := uint16(d.tcp.SrcPort)
dstPort := uint16(d.tcp.DstPort)
// For outbound reverse, we need to find the connection using the same key as when it was stored
// Connection was stored as: srcIP, dstIP, srcPort, translatedPort
// So for return traffic (srcIP=server, dstIP=client), we need: dstIP, srcIP, dstPort, srcPort
if natConn, exists := m.portNATTracker.getConnectionNAT(dstIP, srcIP, dstPort, srcPort); exists {
if err := m.rewriteTCPSourcePort(packetData, d, natConn.rule.sourcePort); err != nil {
m.logger.Error("rewrite TCP source port: %v", err)
return false
}
return true
}
return false
}

View File

@@ -0,0 +1,111 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/device"
)
// TestStatefulNATBidirectionalSSH tests that stateful NAT prevents interference
// when two peers try to SSH to each other simultaneously
func TestStatefulNATBidirectionalSSH(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Define peer IPs
peerA := netip.MustParseAddr("100.10.0.50")
peerB := netip.MustParseAddr("100.10.0.51")
// Add SSH port redirection rule for peer B (the target)
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
require.NoError(t, err)
// Scenario: Peer A connects to Peer B on port 22 (should get NAT)
// This simulates: ssh user@100.10.0.51
packetAtoB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
translatedAtoB := manager.translateInboundPortDNAT(packetAtoB, parsePacket(t, packetAtoB))
require.True(t, translatedAtoB, "Peer A to Peer B should be translated (NAT applied)")
// Verify port was translated to 22022
d := parsePacket(t, packetAtoB)
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be rewritten to 22022")
// Verify NAT connection is tracked (with translated port as key)
natConn, exists := manager.portNATTracker.getConnectionNAT(peerA, peerB, 54321, 22022)
require.True(t, exists, "NAT connection should be tracked")
require.Equal(t, uint16(22), natConn.originalPort, "Original port should be stored")
// Scenario: Peer B tries to connect to Peer A on port 22 (should NOT get NAT)
// This simulates the reverse direction to prevent interference
packetBtoA := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 54322, 22)
translatedBtoA := manager.translateInboundPortDNAT(packetBtoA, parsePacket(t, packetBtoA))
require.False(t, translatedBtoA, "Peer B to Peer A should NOT be translated (prevent interference)")
// Verify port was NOT translated
d2 := parsePacket(t, packetBtoA)
require.Equal(t, uint16(22), uint16(d2.tcp.DstPort), "Port should remain 22 (no translation)")
// Verify no reverse NAT connection is tracked
_, reverseExists := manager.portNATTracker.getConnectionNAT(peerB, peerA, 54322, 22)
require.False(t, reverseExists, "Reverse NAT connection should NOT be tracked")
// Scenario: Return traffic from Peer B (SSH server) to Peer A (should be reverse translated)
returnPacket := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 22022, 54321)
translatedReturn := manager.translateOutboundPortReverse(returnPacket, parsePacket(t, returnPacket))
require.True(t, translatedReturn, "Return traffic should be reverse translated")
// Verify return traffic port was translated back to 22
d3 := parsePacket(t, returnPacket)
require.Equal(t, uint16(22), uint16(d3.tcp.SrcPort), "Return traffic source port should be 22")
}
// TestStatefulNATConnectionCleanup tests connection cleanup functionality
func TestStatefulNATConnectionCleanup(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Define peer IPs
peerA := netip.MustParseAddr("100.10.0.50")
peerB := netip.MustParseAddr("100.10.0.51")
// Add SSH port redirection rules for both peers
err = manager.addPortRedirection(peerA, layers.LayerTypeTCP, 22, 22022)
require.NoError(t, err)
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
require.NoError(t, err)
// Establish connection with NAT
packet := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
translated := manager.translateInboundPortDNAT(packet, parsePacket(t, packet))
require.True(t, translated, "Initial connection should be translated")
// Verify connection is tracked (using translated port as key)
_, exists := manager.portNATTracker.getConnectionNAT(peerA, peerB, 54321, 22022)
require.True(t, exists, "Connection should be tracked")
// Clean up connection
manager.portNATTracker.cleanupConnection(peerA, peerB, 54321)
// Verify connection is no longer tracked (using translated port as key)
_, stillExists := manager.portNATTracker.getConnectionNAT(peerA, peerB, 54321, 22022)
require.False(t, stillExists, "Connection should be cleaned up")
// Verify new connection from opposite direction now works
reversePacket := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 54322, 22)
reverseTranslated := manager.translateInboundPortDNAT(reversePacket, parsePacket(t, reversePacket))
require.True(t, reverseTranslated, "Reverse connection should now work after cleanup")
}

View File

@@ -1,13 +1,17 @@
package uspfilter
import (
"io"
"net"
"net/netip"
"testing"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/device"
)
@@ -143,3 +147,520 @@ func TestDNATMappingManagement(t *testing.T) {
err = manager.RemoveInternalDNATMapping(originalIP)
require.Error(t, err, "Should error when removing non-existent mapping")
}
// TestSSHPortRedirection tests SSH port redirection functionality
func TestSSHPortRedirection(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Define NetBird network range
peerIP := netip.MustParseAddr("100.10.0.50")
clientIP := netip.MustParseAddr("100.10.0.100")
// Add SSH port redirection rule
err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022)
require.NoError(t, err)
// Verify port DNAT is enabled
require.True(t, manager.portDNATEnabled.Load(), "Port DNAT should be enabled")
require.Len(t, manager.portDNATMap.rules, 1, "Should have one port DNAT rule")
// Verify the rule configuration
rule := manager.portDNATMap.rules[0]
require.Equal(t, gopacket.LayerType(layers.LayerTypeTCP), rule.protocol)
require.Equal(t, uint16(22), rule.sourcePort)
require.Equal(t, uint16(22022), rule.targetPort)
require.Equal(t, peerIP, rule.targetIP)
// Test inbound SSH packet (client -> peer:22, should redirect to peer:22022)
inboundPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22)
originalInbound := make([]byte, len(inboundPacket))
copy(originalInbound, inboundPacket)
// Process inbound packet
translated := manager.translateInboundPortDNAT(inboundPacket, parsePacket(t, inboundPacket))
require.True(t, translated, "Inbound SSH packet should be translated")
// Verify destination port was changed from 22 to 22022
d := parsePacket(t, inboundPacket)
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Destination port should be rewritten to 22022")
// Verify destination IP remains unchanged
dstIPAfter := netip.AddrFrom4([4]byte{inboundPacket[16], inboundPacket[17], inboundPacket[18], inboundPacket[19]})
require.Equal(t, peerIP, dstIPAfter, "Destination IP should remain unchanged")
// Test outbound return packet (peer:22022 -> client, should rewrite source port to 22)
outboundPacket := generateDNATTestPacket(t, peerIP, clientIP, layers.IPProtocolTCP, 22022, 54321)
originalOutbound := make([]byte, len(outboundPacket))
copy(originalOutbound, outboundPacket)
// Process outbound return packet
reversed := manager.translateOutboundPortReverse(outboundPacket, parsePacket(t, outboundPacket))
require.True(t, reversed, "Outbound return packet should be reverse translated")
// Verify source port was changed from 22022 to 22
d = parsePacket(t, outboundPacket)
require.Equal(t, uint16(22), uint16(d.tcp.SrcPort), "Source port should be rewritten to 22")
// Verify source IP remains unchanged
srcIPAfter := netip.AddrFrom4([4]byte{outboundPacket[12], outboundPacket[13], outboundPacket[14], outboundPacket[15]})
require.Equal(t, peerIP, srcIPAfter, "Source IP should remain unchanged")
// Test removal of SSH port redirection
err = manager.RemoveInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022)
require.NoError(t, err)
require.False(t, manager.portDNATEnabled.Load(), "Port DNAT should be disabled after removal")
require.Len(t, manager.portDNATMap.rules, 0, "Should have no port DNAT rules after removal")
}
// TestSSHPortRedirectionNetworkFiltering tests that SSH redirection only applies to specified networks
func TestSSHPortRedirectionNetworkFiltering(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Define NetBird network range
peerInNetwork := netip.MustParseAddr("100.10.0.50")
peerOutsideNetwork := netip.MustParseAddr("192.168.1.50")
clientIP := netip.MustParseAddr("100.10.0.100")
// Add SSH port redirection rule for NetBird network only
err = manager.AddInboundDNAT(peerInNetwork, firewall.ProtocolTCP, 22, 22022)
require.NoError(t, err)
// Test SSH packet to peer within NetBird network (should be redirected)
inNetworkPacket := generateDNATTestPacket(t, clientIP, peerInNetwork, layers.IPProtocolTCP, 54321, 22)
translated := manager.translateInboundPortDNAT(inNetworkPacket, parsePacket(t, inNetworkPacket))
require.True(t, translated, "SSH packet to NetBird peer should be translated")
// Verify port was changed
d := parsePacket(t, inNetworkPacket)
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be redirected for NetBird peer")
// Test SSH packet to peer outside NetBird network (should NOT be redirected)
outOfNetworkPacket := generateDNATTestPacket(t, clientIP, peerOutsideNetwork, layers.IPProtocolTCP, 54321, 22)
originalOutOfNetwork := make([]byte, len(outOfNetworkPacket))
copy(originalOutOfNetwork, outOfNetworkPacket)
notTranslated := manager.translateInboundPortDNAT(outOfNetworkPacket, parsePacket(t, outOfNetworkPacket))
require.False(t, notTranslated, "SSH packet to non-NetBird peer should NOT be translated")
// Verify packet was not modified
require.Equal(t, originalOutOfNetwork, outOfNetworkPacket, "Packet to non-NetBird peer should remain unchanged")
}
// TestSSHPortRedirectionNonTCPTraffic tests that only TCP traffic is affected
func TestSSHPortRedirectionNonTCPTraffic(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Define NetBird network range
peerIP := netip.MustParseAddr("100.10.0.50")
clientIP := netip.MustParseAddr("100.10.0.100")
// Add SSH port redirection rule
err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022)
require.NoError(t, err)
// Test UDP packet on port 22 (should NOT be redirected)
udpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolUDP, 54321, 22)
originalUDP := make([]byte, len(udpPacket))
copy(originalUDP, udpPacket)
translated := manager.translateInboundPortDNAT(udpPacket, parsePacket(t, udpPacket))
require.False(t, translated, "UDP packet should NOT be translated by SSH port redirection")
require.Equal(t, originalUDP, udpPacket, "UDP packet should remain unchanged")
// Test ICMP packet (should NOT be redirected)
icmpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolICMPv4, 0, 0)
originalICMP := make([]byte, len(icmpPacket))
copy(originalICMP, icmpPacket)
translated = manager.translateInboundPortDNAT(icmpPacket, parsePacket(t, icmpPacket))
require.False(t, translated, "ICMP packet should NOT be translated by SSH port redirection")
require.Equal(t, originalICMP, icmpPacket, "ICMP packet should remain unchanged")
}
// TestSSHPortRedirectionNonSSHPorts tests that only port 22 is redirected
func TestSSHPortRedirectionNonSSHPorts(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Define NetBird network range
peerIP := netip.MustParseAddr("100.10.0.50")
clientIP := netip.MustParseAddr("100.10.0.100")
// Add SSH port redirection rule
err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022)
require.NoError(t, err)
// Test TCP packet on port 80 (should NOT be redirected)
httpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 80)
originalHTTP := make([]byte, len(httpPacket))
copy(originalHTTP, httpPacket)
translated := manager.translateInboundPortDNAT(httpPacket, parsePacket(t, httpPacket))
require.False(t, translated, "Non-SSH TCP packet should NOT be translated")
require.Equal(t, originalHTTP, httpPacket, "Non-SSH TCP packet should remain unchanged")
// Test TCP packet on port 443 (should NOT be redirected)
httpsPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 443)
originalHTTPS := make([]byte, len(httpsPacket))
copy(originalHTTPS, httpsPacket)
translated = manager.translateInboundPortDNAT(httpsPacket, parsePacket(t, httpsPacket))
require.False(t, translated, "Non-SSH TCP packet should NOT be translated")
require.Equal(t, originalHTTPS, httpsPacket, "Non-SSH TCP packet should remain unchanged")
// Test TCP packet on port 22 (SHOULD be redirected)
sshPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22)
translated = manager.translateInboundPortDNAT(sshPacket, parsePacket(t, sshPacket))
require.True(t, translated, "SSH TCP packet should be translated")
// Verify port was changed to 22022
d := parsePacket(t, sshPacket)
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "SSH port should be redirected to 22022")
}
// TestFlexiblePortRedirection tests the flexible port redirection functionality
func TestFlexiblePortRedirection(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Define peer and client IPs
peerIP := netip.MustParseAddr("10.0.0.50")
clientIP := netip.MustParseAddr("10.0.0.100")
// Add custom port redirection: TCP port 8080 -> 3000 for peer IP
err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 8080, 3000)
require.NoError(t, err)
// Verify port DNAT is enabled
require.True(t, manager.portDNATEnabled.Load(), "Port DNAT should be enabled")
require.Len(t, manager.portDNATMap.rules, 1, "Should have one port DNAT rule")
// Verify the rule configuration
rule := manager.portDNATMap.rules[0]
require.Equal(t, gopacket.LayerType(layers.LayerTypeTCP), rule.protocol)
require.Equal(t, uint16(8080), rule.sourcePort)
require.Equal(t, uint16(3000), rule.targetPort)
require.Equal(t, peerIP, rule.targetIP)
// Test inbound packet (client -> peer:8080, should redirect to peer:3000)
inboundPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 8080)
translated := manager.translateInboundPortDNAT(inboundPacket, parsePacket(t, inboundPacket))
require.True(t, translated, "Inbound packet should be translated")
// Verify destination port was changed from 8080 to 3000
d := parsePacket(t, inboundPacket)
require.Equal(t, uint16(3000), uint16(d.tcp.DstPort), "Destination port should be rewritten to 3000")
// Test outbound return packet (peer:3000 -> client, should rewrite source port to 8080)
outboundPacket := generateDNATTestPacket(t, peerIP, clientIP, layers.IPProtocolTCP, 3000, 54321)
reversed := manager.translateOutboundPortReverse(outboundPacket, parsePacket(t, outboundPacket))
require.True(t, reversed, "Outbound return packet should be reverse translated")
// Verify source port was changed from 3000 to 8080
d = parsePacket(t, outboundPacket)
require.Equal(t, uint16(8080), uint16(d.tcp.SrcPort), "Source port should be rewritten to 8080")
// Test removal of port redirection
err = manager.removePortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 8080, 3000)
require.NoError(t, err)
require.False(t, manager.portDNATEnabled.Load(), "Port DNAT should be disabled after removal")
require.Len(t, manager.portDNATMap.rules, 0, "Should have no port DNAT rules after removal")
}
// TestMultiplePortRedirections tests multiple port redirection rules
func TestMultiplePortRedirections(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Define peer and client IPs
peerIP := netip.MustParseAddr("172.16.0.50")
clientIP := netip.MustParseAddr("172.16.0.100")
// Add multiple port redirections for peer IP
err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 22, 22022) // SSH
require.NoError(t, err)
err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 80, 8080) // HTTP
require.NoError(t, err)
err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 443, 8443) // HTTPS
require.NoError(t, err)
// Verify all rules are present
require.True(t, manager.portDNATEnabled.Load(), "Port DNAT should be enabled")
require.Len(t, manager.portDNATMap.rules, 3, "Should have three port DNAT rules")
// Test SSH redirection (22 -> 22022)
sshPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22)
translated := manager.translateInboundPortDNAT(sshPacket, parsePacket(t, sshPacket))
require.True(t, translated, "SSH packet should be translated")
d := parsePacket(t, sshPacket)
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "SSH should redirect to 22022")
// Test HTTP redirection (80 -> 8080)
httpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 80)
translated = manager.translateInboundPortDNAT(httpPacket, parsePacket(t, httpPacket))
require.True(t, translated, "HTTP packet should be translated")
d = parsePacket(t, httpPacket)
require.Equal(t, uint16(8080), uint16(d.tcp.DstPort), "HTTP should redirect to 8080")
// Test HTTPS redirection (443 -> 8443)
httpsPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 443)
translated = manager.translateInboundPortDNAT(httpsPacket, parsePacket(t, httpsPacket))
require.True(t, translated, "HTTPS packet should be translated")
d = parsePacket(t, httpsPacket)
require.Equal(t, uint16(8443), uint16(d.tcp.DstPort), "HTTPS should redirect to 8443")
// Test removing one rule (HTTP)
err = manager.removePortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 80, 8080)
require.NoError(t, err)
require.Len(t, manager.portDNATMap.rules, 2, "Should have two rules after removing HTTP rule")
// Verify HTTP is no longer redirected
httpPacket2 := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 80)
originalHTTP := make([]byte, len(httpPacket2))
copy(originalHTTP, httpPacket2)
translated = manager.translateInboundPortDNAT(httpPacket2, parsePacket(t, httpPacket2))
require.False(t, translated, "HTTP packet should NOT be translated after rule removal")
require.Equal(t, originalHTTP, httpPacket2, "HTTP packet should remain unchanged")
// Verify SSH and HTTPS still work
sshPacket2 := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22)
translated = manager.translateInboundPortDNAT(sshPacket2, parsePacket(t, sshPacket2))
require.True(t, translated, "SSH should still be translated")
httpsPacket2 := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 443)
translated = manager.translateInboundPortDNAT(httpsPacket2, parsePacket(t, httpsPacket2))
require.True(t, translated, "HTTPS should still be translated")
}
// TestSSHPortRedirectionEndToEnd tests actual network delivery through sockets
func TestSSHPortRedirectionEndToEnd(t *testing.T) {
// Start a mock SSH server on port 22022 (NetBird SSH server)
mockSSHServer, err := net.Listen("tcp", "127.0.0.1:22022")
require.NoError(t, err, "Should be able to bind to NetBird SSH port")
defer func() {
require.NoError(t, mockSSHServer.Close())
}()
// Handle connections on the SSH server
serverReceivedData := make(chan string, 1)
go func() {
for {
conn, err := mockSSHServer.Accept()
if err != nil {
return // Server closed
}
go func(conn net.Conn) {
defer func() {
require.NoError(t, conn.Close())
}()
buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil && err != io.EOF {
t.Logf("Server read error: %v", err)
return
}
receivedData := string(buf[:n])
serverReceivedData <- receivedData
// Echo back a response
_, err = conn.Write([]byte("SSH-2.0-MockNetBirdSSH\r\n"))
if err != nil {
t.Logf("Server write error: %v", err)
}
}(conn)
}
}()
// Give server time to start
time.Sleep(100 * time.Millisecond)
// This test demonstrates what SHOULD happen after port redirection:
// 1. Client connects to 127.0.0.1:22 (standard SSH port)
// 2. Firewall redirects to 127.0.0.1:22022 (NetBird SSH server)
// 3. NetBird SSH server receives the connection
t.Run("DirectConnectionToNetBirdSSHPort", func(t *testing.T) {
// This simulates what should happen AFTER port redirection
// Connect directly to 22022 (where NetBird SSH server listens)
conn, err := net.DialTimeout("tcp", "127.0.0.1:22022", 5*time.Second)
require.NoError(t, err, "Should connect to NetBird SSH server")
defer func() {
require.NoError(t, conn.Close())
}()
// Send SSH client identification
testData := "SSH-2.0-TestClient\r\n"
_, err = conn.Write([]byte(testData))
require.NoError(t, err, "Should send data to SSH server")
// Verify server received the data
select {
case received := <-serverReceivedData:
require.Equal(t, testData, received, "Server should receive client data")
case <-time.After(2 * time.Second):
t.Fatal("Server did not receive data within timeout")
}
// Read server response
buf := make([]byte, 1024)
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, err := conn.Read(buf)
require.NoError(t, err, "Should read server response")
response := string(buf[:n])
require.Equal(t, "SSH-2.0-MockNetBirdSSH\r\n", response, "Should receive SSH server identification")
})
t.Run("PortRedirectionSimulation", func(t *testing.T) {
// This test simulates the port redirection process
// Note: This doesn't test the actual userspace packet interception,
// but demonstrates the expected behavior
t.Log("NOTE: This test demonstrates expected behavior after implementing")
t.Log("full userspace packet interception. Currently, we test packet")
t.Log("translation logic separately from actual network delivery.")
// In a real implementation with userspace packet interception:
// 1. Client would connect to 127.0.0.1:22
// 2. Userspace firewall would intercept packets
// 3. translateInboundPortDNAT would rewrite port 22 -> 22022
// 4. Packets would be delivered to 127.0.0.1:22022
// 5. NetBird SSH server would receive the connection
// For now, we verify that the packet translation logic works correctly
// (this is tested in other test functions) and that the target server
// is reachable (tested above)
clientIP := netip.MustParseAddr("127.0.0.1")
serverIP := netip.MustParseAddr("127.0.0.1")
// Create manager with SSH port redirection
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Add SSH port redirection for localhost (for testing)
err = manager.AddInboundDNAT(netip.MustParseAddr("127.0.0.1"), firewall.ProtocolTCP, 22, 22022)
require.NoError(t, err)
// Generate packet: client connecting to server:22
sshPacket := generateDNATTestPacket(t, clientIP, serverIP, layers.IPProtocolTCP, 54321, 22)
originalPacket := make([]byte, len(sshPacket))
copy(originalPacket, sshPacket)
// Apply port redirection
translated := manager.translateInboundPortDNAT(sshPacket, parsePacket(t, sshPacket))
require.True(t, translated, "SSH packet should be translated")
// Verify port was redirected from 22 to 22022
d := parsePacket(t, sshPacket)
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be redirected to NetBird SSH server")
require.NotEqual(t, originalPacket, sshPacket, "Packet should be modified")
t.Log("✓ Packet translation verified: port 22 redirected to 22022")
t.Log("✓ Target SSH server (port 22022) is reachable and responsive")
t.Log("→ Integration complete: SSH port redirection ready for userspace interception")
})
}
// TestFullSSHRedirectionWorkflow demonstrates the complete SSH redirection workflow
func TestFullSSHRedirectionWorkflow(t *testing.T) {
t.Log("=== SSH Port Redirection Workflow Test ===")
t.Log("This test demonstrates the complete SSH redirection process:")
t.Log("1. Client connects to peer:22 (standard SSH)")
t.Log("2. Userspace firewall intercepts and redirects to peer:22022")
t.Log("3. NetBird SSH server receives connection on port 22022")
t.Log("4. Return traffic is reverse-translated (22022 -> 22)")
// Setup test environment
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Define NetBird network and peer IPs
peerIP := netip.MustParseAddr("100.10.0.50")
clientIP := netip.MustParseAddr("100.10.0.100")
// Step 1: Configure SSH port redirection
err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022)
require.NoError(t, err)
t.Log("✓ SSH port redirection configured for NetBird network")
// Step 2: Simulate inbound SSH connection (client -> peer:22)
t.Log("→ Simulating: ssh user@100.10.0.50")
inboundPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22)
// Step 3: Apply inbound port redirection
translated := manager.translateInboundPortDNAT(inboundPacket, parsePacket(t, inboundPacket))
require.True(t, translated, "Inbound SSH packet should be redirected")
d := parsePacket(t, inboundPacket)
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Should redirect to NetBird SSH server port")
t.Log("✓ Inbound packet redirected: 100.10.0.50:22 → 100.10.0.50:22022")
// Step 4: Simulate outbound return traffic (peer:22022 -> client)
t.Log("→ Simulating return traffic from NetBird SSH server")
outboundPacket := generateDNATTestPacket(t, peerIP, clientIP, layers.IPProtocolTCP, 22022, 54321)
// Step 5: Apply outbound reverse translation
reversed := manager.translateOutboundPortReverse(outboundPacket, parsePacket(t, outboundPacket))
require.True(t, reversed, "Outbound return packet should be reverse translated")
d = parsePacket(t, outboundPacket)
require.Equal(t, uint16(22), uint16(d.tcp.SrcPort), "Should restore original SSH port")
t.Log("✓ Outbound packet reverse translated: 100.10.0.50:22022 → 100.10.0.50:22")
// Step 6: Verify client sees standard SSH connection
srcIPAfter := netip.AddrFrom4([4]byte{outboundPacket[12], outboundPacket[13], outboundPacket[14], outboundPacket[15]})
require.Equal(t, peerIP, srcIPAfter, "Client should see traffic from peer IP")
t.Log("✓ Client receives traffic from 100.10.0.50:22 (transparent redirection)")
t.Log("=== SSH Port Redirection Workflow Complete ===")
t.Log("Result: Standard SSH clients can connect to NetBird peers using:")
t.Log(" ssh user@100.10.0.50")
t.Log("Instead of:")
t.Log(" ssh user@100.10.0.50 -p 22022")
}