mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 17:26:40 +00:00
Complete overhaul
This commit is contained in:
@@ -260,6 +260,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
@@ -880,6 +880,54 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
|
||||
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
dnatRule := []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-p", strings.ToLower(string(protocol)),
|
||||
"--dport", strconv.Itoa(int(sourcePort)),
|
||||
"-d", localAddr.String(),
|
||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||
"-j", "DNAT",
|
||||
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
|
||||
}
|
||||
|
||||
ruleInfo := ruleInfo{
|
||||
table: tableNat,
|
||||
chain: chainRTRDR,
|
||||
rule: dnatRule,
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||
return fmt.Errorf("add inbound DNAT rule: %w", err)
|
||||
}
|
||||
r.rules[ruleID] = ruleInfo.rule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
|
||||
return fmt.Errorf("delete inbound DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyPort(flag string, port *firewall.Port) []string {
|
||||
if port == nil {
|
||||
return nil
|
||||
|
||||
@@ -151,14 +151,20 @@ type Manager interface {
|
||||
|
||||
DisableRouting() error
|
||||
|
||||
// AddDNATRule adds a DNAT rule
|
||||
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to the NetBird network.
|
||||
AddDNATRule(ForwardRule) (Rule, error)
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
// DeleteDNATRule deletes the outbound DNAT rule.
|
||||
DeleteDNATRule(Rule) error
|
||||
|
||||
// UpdateSet updates the set with the given prefixes
|
||||
UpdateSet(hash Set, prefixes []netip.Prefix) error
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
|
||||
AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
|
||||
@@ -376,6 +376,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return m.router.UpdateSet(set, prefixes)
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
|
||||
@@ -1350,6 +1350,103 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
|
||||
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
protoNum, err := protoToInt(protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 2,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 3,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 3,
|
||||
Data: binaryutil.BigEndian.PutUint16(sourcePort),
|
||||
},
|
||||
}
|
||||
|
||||
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: localAddr.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(targetPort),
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(nftables.TableFamilyIPv4),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: 2,
|
||||
RegProtoMax: 0,
|
||||
},
|
||||
)
|
||||
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingRdr],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleID),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("add inbound DNAT rule: %w", err)
|
||||
}
|
||||
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes inbound DNAT rule
|
||||
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||
|
||||
if rule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyNetwork generates nftables expressions for networks (CIDR) or sets
|
||||
func (r *router) applyNetwork(
|
||||
network firewall.Network,
|
||||
|
||||
@@ -29,6 +29,12 @@ import (
|
||||
|
||||
const layerTypeAll = 0
|
||||
|
||||
// serviceKey represents a protocol/port combination for netstack service registry
|
||||
type serviceKey struct {
|
||||
protocol gopacket.LayerType
|
||||
port uint16
|
||||
}
|
||||
|
||||
const (
|
||||
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
|
||||
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||
@@ -110,6 +116,15 @@ type Manager struct {
|
||||
dnatMappings map[netip.Addr]netip.Addr
|
||||
dnatMutex sync.RWMutex
|
||||
dnatBiMap *biDNATMap
|
||||
|
||||
// Port-specific DNAT for SSH redirection
|
||||
portDNATEnabled atomic.Bool
|
||||
portDNATMap *portDNATMap
|
||||
portDNATMutex sync.RWMutex
|
||||
portNATTracker *portNATTracker
|
||||
|
||||
netstackServices map[serviceKey]struct{}
|
||||
netstackServiceMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -196,6 +211,9 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||
portDNATMap: &portDNATMap{rules: make([]portDNATRule, 0)},
|
||||
portNATTracker: newPortNATTracker(),
|
||||
netstackServices: make(map[serviceKey]struct{}),
|
||||
}
|
||||
m.routingEnabled.Store(false)
|
||||
|
||||
@@ -333,18 +351,22 @@ func (m *Manager) initForwarder() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Init initializes the firewall manager with state manager.
|
||||
func (m *Manager) Init(*statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsServerRouteSupported returns whether server routes are supported.
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// IsStateful returns whether the firewall manager tracks connection state.
|
||||
func (m *Manager) IsStateful() bool {
|
||||
return m.stateful
|
||||
}
|
||||
|
||||
// AddNatRule adds a routing firewall rule for NAT translation.
|
||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddNatRule(pair)
|
||||
@@ -611,6 +633,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
|
||||
m.trackOutbound(d, srcIP, dstIP, size)
|
||||
m.translateOutboundDNAT(packetData, d)
|
||||
m.translateOutboundPortReverse(packetData, d)
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -738,6 +761,15 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if translated := m.translateInboundPortDNAT(packetData, d); translated {
|
||||
// Re-decode after port DNAT translation to update port information
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
m.logger.Error("Failed to re-decode packet after port DNAT: %v", err)
|
||||
return true
|
||||
}
|
||||
srcIP, dstIP = m.extractIPs(d)
|
||||
}
|
||||
|
||||
if translated := m.translateInboundReverse(packetData, d); translated {
|
||||
// Re-decode after translation to get original addresses
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
@@ -786,9 +818,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
||||
return true
|
||||
}
|
||||
|
||||
// If requested we pass local traffic to internal interfaces to the forwarder.
|
||||
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
|
||||
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
|
||||
if m.shouldForward(d, dstIP) {
|
||||
return m.handleForwardedLocalTraffic(packetData)
|
||||
}
|
||||
|
||||
@@ -1215,3 +1245,95 @@ func (m *Manager) DisableRouting() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port
|
||||
func (m *Manager) RegisterNetstackService(protocol nftypes.Protocol, port uint16) {
|
||||
m.netstackServiceMutex.Lock()
|
||||
defer m.netstackServiceMutex.Unlock()
|
||||
layerType := m.protocolToLayerType(protocol)
|
||||
key := serviceKey{protocol: layerType, port: port}
|
||||
m.netstackServices[key] = struct{}{}
|
||||
m.logger.Debug("RegisterNetstackService: registered %s:%d (layerType=%s)", protocol, port, layerType)
|
||||
m.logger.Debug("RegisterNetstackService: current registry size: %d", len(m.netstackServices))
|
||||
}
|
||||
|
||||
// UnregisterNetstackService removes a service from the netstack registry
|
||||
func (m *Manager) UnregisterNetstackService(protocol nftypes.Protocol, port uint16) {
|
||||
m.netstackServiceMutex.Lock()
|
||||
defer m.netstackServiceMutex.Unlock()
|
||||
layerType := m.protocolToLayerType(protocol)
|
||||
key := serviceKey{protocol: layerType, port: port}
|
||||
delete(m.netstackServices, key)
|
||||
m.logger.Debug("Unregistered netstack service on protocol %s port %d", protocol, port)
|
||||
}
|
||||
|
||||
// isNetstackService checks if a service is registered as listening on netstack for the given protocol and port
|
||||
func (m *Manager) isNetstackService(layerType gopacket.LayerType, port uint16) bool {
|
||||
m.netstackServiceMutex.RLock()
|
||||
defer m.netstackServiceMutex.RUnlock()
|
||||
key := serviceKey{protocol: layerType, port: port}
|
||||
_, exists := m.netstackServices[key]
|
||||
return exists
|
||||
}
|
||||
|
||||
// protocolToLayerType converts nftypes.Protocol to gopacket.LayerType for internal use
|
||||
func (m *Manager) protocolToLayerType(protocol nftypes.Protocol) gopacket.LayerType {
|
||||
switch protocol {
|
||||
case nftypes.TCP:
|
||||
return layers.LayerTypeTCP
|
||||
case nftypes.UDP:
|
||||
return layers.LayerTypeUDP
|
||||
case nftypes.ICMP:
|
||||
return layers.LayerTypeICMPv4
|
||||
default:
|
||||
return gopacket.LayerType(0) // Invalid/unknown
|
||||
}
|
||||
}
|
||||
|
||||
// shouldForward determines if a packet should be forwarded to the forwarder.
|
||||
// The forwarder handles routing packets to the native OS network stack.
|
||||
// Returns true if packet should go to the forwarder, false if it should go to netstack listeners or the native stack directly.
|
||||
func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool {
|
||||
// not enabled, never forward
|
||||
if !m.localForwarding {
|
||||
return false
|
||||
}
|
||||
|
||||
// netstack always needs to forward because it's lacking a native interface
|
||||
// exception for registered netstack services, those should go to netstack listeners
|
||||
if m.netstack {
|
||||
return !m.hasMatchingNetstackService(d)
|
||||
}
|
||||
|
||||
// traffic to our other local interfaces (not NetBird IP) - always forward
|
||||
if dstIP != m.wgIface.Address().IP {
|
||||
return true
|
||||
}
|
||||
|
||||
// traffic to our NetBird IP, not netstack mode - send to netstack listeners
|
||||
return false
|
||||
}
|
||||
|
||||
// hasMatchingNetstackService checks if there's a registered netstack service for this packet
|
||||
func (m *Manager) hasMatchingNetstackService(d *decoder) bool {
|
||||
if len(d.decoded) < 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
var dstPort uint16
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
dstPort = uint16(d.tcp.DstPort)
|
||||
case layers.LayerTypeUDP:
|
||||
dstPort = uint16(d.udp.DstPort)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
key := serviceKey{protocol: d.decoded[1], port: dstPort}
|
||||
m.netstackServiceMutex.RLock()
|
||||
_, exists := m.netstackServices[key]
|
||||
m.netstackServiceMutex.RUnlock()
|
||||
|
||||
return exists
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
)
|
||||
|
||||
@@ -896,3 +897,138 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
||||
require.Equal(t, tc.expected, isAllowed, tc.desc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldForward(t *testing.T) {
|
||||
// Set up test addresses
|
||||
wgIP := netip.MustParseAddr("100.10.0.1")
|
||||
otherIP := netip.MustParseAddr("100.10.0.2")
|
||||
|
||||
// Create test manager with mock interface
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
// Set the mock to return our test WG IP
|
||||
ifaceMock.AddressFunc = func() wgaddr.Address {
|
||||
return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)}
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Helper to create decoder with TCP packet
|
||||
createTCPDecoder := func(dstPort uint16) *decoder {
|
||||
ipv4 := &layers.IPv4{
|
||||
Version: 4,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
SrcIP: net.ParseIP("192.168.1.100"),
|
||||
DstIP: wgIP.AsSlice(),
|
||||
}
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: 54321,
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
}
|
||||
|
||||
err := tcp.SetNetworkLayerForChecksum(ipv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
err = gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
|
||||
err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded)
|
||||
require.NoError(t, err)
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
localForwarding bool
|
||||
netstack bool
|
||||
dstIP netip.Addr
|
||||
serviceRegistered bool
|
||||
servicePort uint16
|
||||
expected bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "no local forwarding",
|
||||
localForwarding: false,
|
||||
netstack: true,
|
||||
dstIP: wgIP,
|
||||
expected: false,
|
||||
description: "should never forward when local forwarding disabled",
|
||||
},
|
||||
{
|
||||
name: "traffic to other local interface",
|
||||
localForwarding: true,
|
||||
netstack: false,
|
||||
dstIP: otherIP,
|
||||
expected: true,
|
||||
description: "should forward traffic to our other local interfaces (not NetBird IP)",
|
||||
},
|
||||
{
|
||||
name: "traffic to NetBird IP, no netstack",
|
||||
localForwarding: true,
|
||||
netstack: false,
|
||||
dstIP: wgIP,
|
||||
expected: false,
|
||||
description: "should send to netstack listeners (final return false path)",
|
||||
},
|
||||
{
|
||||
name: "traffic to our IP, netstack mode, no service",
|
||||
localForwarding: true,
|
||||
netstack: true,
|
||||
dstIP: wgIP,
|
||||
expected: true,
|
||||
description: "should forward when in netstack mode with no matching service",
|
||||
},
|
||||
{
|
||||
name: "traffic to our IP, netstack mode, with service",
|
||||
localForwarding: true,
|
||||
netstack: true,
|
||||
dstIP: wgIP,
|
||||
serviceRegistered: true,
|
||||
servicePort: 22,
|
||||
expected: false,
|
||||
description: "should send to netstack listeners when service is registered",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Configure manager
|
||||
manager.localForwarding = tt.localForwarding
|
||||
manager.netstack = tt.netstack
|
||||
|
||||
// Register service if needed
|
||||
if tt.serviceRegistered {
|
||||
manager.RegisterNetstackService(nftypes.TCP, tt.servicePort)
|
||||
defer manager.UnregisterNetstackService(nftypes.TCP, tt.servicePort)
|
||||
}
|
||||
|
||||
// Create decoder for the test
|
||||
decoder := createTCPDecoder(tt.servicePort)
|
||||
if !tt.serviceRegistered {
|
||||
decoder = createTCPDecoder(8080) // Use non-registered port
|
||||
}
|
||||
|
||||
// Test the method
|
||||
result := manager.shouldForward(decoder, tt.dstIP)
|
||||
require.Equal(t, tt.expected, result, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -13,6 +16,12 @@ import (
|
||||
|
||||
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
|
||||
|
||||
const (
|
||||
invalidIPHeaderLengthMsg = "invalid IP header length"
|
||||
errRewriteTCPDestinationPort = "rewrite TCP destination port: %v"
|
||||
)
|
||||
|
||||
// ipv4Checksum calculates IPv4 header checksum using optimized parallel processing for performance.
|
||||
func ipv4Checksum(header []byte) uint16 {
|
||||
if len(header) < 20 {
|
||||
return 0
|
||||
@@ -20,13 +29,11 @@ func ipv4Checksum(header []byte) uint16 {
|
||||
|
||||
var sum1, sum2 uint32
|
||||
|
||||
// Parallel processing - unroll and compute two sums simultaneously
|
||||
sum1 += uint32(binary.BigEndian.Uint16(header[0:2]))
|
||||
sum2 += uint32(binary.BigEndian.Uint16(header[2:4]))
|
||||
sum1 += uint32(binary.BigEndian.Uint16(header[4:6]))
|
||||
sum2 += uint32(binary.BigEndian.Uint16(header[6:8]))
|
||||
sum1 += uint32(binary.BigEndian.Uint16(header[8:10]))
|
||||
// Skip checksum field at [10:12]
|
||||
sum2 += uint32(binary.BigEndian.Uint16(header[12:14]))
|
||||
sum1 += uint32(binary.BigEndian.Uint16(header[14:16]))
|
||||
sum2 += uint32(binary.BigEndian.Uint16(header[16:18]))
|
||||
@@ -34,7 +41,6 @@ func ipv4Checksum(header []byte) uint16 {
|
||||
|
||||
sum := sum1 + sum2
|
||||
|
||||
// Handle remaining bytes for headers > 20 bytes
|
||||
for i := 20; i < len(header)-1; i += 2 {
|
||||
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
|
||||
}
|
||||
@@ -43,7 +49,6 @@ func ipv4Checksum(header []byte) uint16 {
|
||||
sum += uint32(header[len(header)-1]) << 8
|
||||
}
|
||||
|
||||
// Optimized carry fold - single iteration handles most cases
|
||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||
if sum > 0xFFFF {
|
||||
sum++
|
||||
@@ -52,11 +57,11 @@ func ipv4Checksum(header []byte) uint16 {
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
// icmpChecksum calculates ICMP checksum using parallel accumulation for high-performance processing.
|
||||
func icmpChecksum(data []byte) uint16 {
|
||||
var sum1, sum2, sum3, sum4 uint32
|
||||
i := 0
|
||||
|
||||
// Process 16 bytes at once with 4 parallel accumulators
|
||||
for i <= len(data)-16 {
|
||||
sum1 += uint32(binary.BigEndian.Uint16(data[i : i+2]))
|
||||
sum2 += uint32(binary.BigEndian.Uint16(data[i+2 : i+4]))
|
||||
@@ -71,7 +76,6 @@ func icmpChecksum(data []byte) uint16 {
|
||||
|
||||
sum := sum1 + sum2 + sum3 + sum4
|
||||
|
||||
// Handle remaining bytes
|
||||
for i < len(data)-1 {
|
||||
sum += uint32(binary.BigEndian.Uint16(data[i : i+2]))
|
||||
i += 2
|
||||
@@ -89,11 +93,131 @@ func icmpChecksum(data []byte) uint16 {
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
// biDNATMap maintains bidirectional DNAT mappings for efficient forward and reverse lookups.
|
||||
type biDNATMap struct {
|
||||
forward map[netip.Addr]netip.Addr
|
||||
reverse map[netip.Addr]netip.Addr
|
||||
}
|
||||
|
||||
// portDNATRule represents a port-specific DNAT rule
|
||||
type portDNATRule struct {
|
||||
protocol gopacket.LayerType
|
||||
sourcePort uint16
|
||||
targetPort uint16
|
||||
targetIP netip.Addr
|
||||
}
|
||||
|
||||
// portDNATMap manages port-specific DNAT rules
|
||||
type portDNATMap struct {
|
||||
rules []portDNATRule
|
||||
}
|
||||
|
||||
// ConnKey represents a connection 4-tuple for NAT tracking.
|
||||
type ConnKey struct {
|
||||
SrcIP netip.Addr
|
||||
DstIP netip.Addr
|
||||
SrcPort uint16
|
||||
DstPort uint16
|
||||
}
|
||||
|
||||
// portNATConn tracks port NAT state for a specific connection.
|
||||
type portNATConn struct {
|
||||
rule portDNATRule
|
||||
originalPort uint16
|
||||
translatedAt time.Time
|
||||
}
|
||||
|
||||
// portNATTracker tracks connection-specific port NAT state
|
||||
type portNATTracker struct {
|
||||
connections map[ConnKey]*portNATConn
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// newPortNATTracker creates a new port NAT tracker for stateful connection tracking.
|
||||
func newPortNATTracker() *portNATTracker {
|
||||
return &portNATTracker{
|
||||
connections: make(map[ConnKey]*portNATConn),
|
||||
}
|
||||
}
|
||||
|
||||
// trackConnection tracks a connection that has port NAT applied using translated port as key.
|
||||
func (t *portNATTracker) trackConnection(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, rule portDNATRule) {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: rule.targetPort,
|
||||
}
|
||||
|
||||
t.connections[key] = &portNATConn{
|
||||
rule: rule,
|
||||
originalPort: dstPort,
|
||||
translatedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// getConnectionNAT returns NAT info for a connection if tracked, looking up by connection 4-tuple.
|
||||
func (t *portNATTracker) getConnectionNAT(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) (*portNATConn, bool) {
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
|
||||
conn, exists := t.connections[key]
|
||||
return conn, exists
|
||||
}
|
||||
|
||||
// removeConnection removes a tracked connection from the NAT tracking table.
|
||||
func (t *portNATTracker) removeConnection(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
|
||||
delete(t.connections, key)
|
||||
}
|
||||
|
||||
// shouldApplyNAT checks if NAT should be applied to a new connection to prevent bidirectional conflicts.
|
||||
func (t *portNATTracker) shouldApplyNAT(srcIP, dstIP netip.Addr, dstPort uint16) bool {
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
|
||||
for key, conn := range t.connections {
|
||||
if key.SrcIP == dstIP && key.DstIP == srcIP &&
|
||||
conn.rule.sourcePort == dstPort && conn.originalPort == dstPort {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// cleanupConnection removes a NAT connection based on original 4-tuple for connection cleanup.
|
||||
func (t *portNATTracker) cleanupConnection(srcIP, dstIP netip.Addr, srcPort uint16) {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
for key := range t.connections {
|
||||
if key.SrcIP == srcIP && key.DstIP == dstIP && key.SrcPort == srcPort {
|
||||
delete(t.connections, key)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newBiDNATMap creates a new bidirectional DNAT mapping structure for efficient forward/reverse lookups.
|
||||
func newBiDNATMap() *biDNATMap {
|
||||
return &biDNATMap{
|
||||
forward: make(map[netip.Addr]netip.Addr),
|
||||
@@ -101,11 +225,13 @@ func newBiDNATMap() *biDNATMap {
|
||||
}
|
||||
}
|
||||
|
||||
// set adds a bidirectional DNAT mapping between original and translated addresses for both directions.
|
||||
func (b *biDNATMap) set(original, translated netip.Addr) {
|
||||
b.forward[original] = translated
|
||||
b.reverse[translated] = original
|
||||
}
|
||||
|
||||
// delete removes a bidirectional DNAT mapping for the given original address.
|
||||
func (b *biDNATMap) delete(original netip.Addr) {
|
||||
if translated, exists := b.forward[original]; exists {
|
||||
delete(b.forward, original)
|
||||
@@ -113,19 +239,25 @@ func (b *biDNATMap) delete(original netip.Addr) {
|
||||
}
|
||||
}
|
||||
|
||||
// getTranslated returns the translated address for a given original address from forward mapping.
|
||||
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
|
||||
translated, exists := b.forward[original]
|
||||
return translated, exists
|
||||
}
|
||||
|
||||
// getOriginal returns the original address for a given translated address from reverse mapping.
|
||||
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
|
||||
original, exists := b.reverse[translated]
|
||||
return original, exists
|
||||
}
|
||||
|
||||
// AddInternalDNATMapping adds a 1:1 IP address mapping for internal DNAT translation.
|
||||
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
|
||||
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
|
||||
return fmt.Errorf("invalid IP addresses")
|
||||
if !originalAddr.IsValid() {
|
||||
return fmt.Errorf("invalid original IP address")
|
||||
}
|
||||
if !translatedAddr.IsValid() {
|
||||
return fmt.Errorf("invalid translated IP address")
|
||||
}
|
||||
|
||||
if m.localipmanager.IsLocalIP(translatedAddr) {
|
||||
@@ -135,7 +267,6 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
|
||||
m.dnatMutex.Lock()
|
||||
defer m.dnatMutex.Unlock()
|
||||
|
||||
// Initialize both maps together if either is nil
|
||||
if m.dnatMappings == nil || m.dnatBiMap == nil {
|
||||
m.dnatMappings = make(map[netip.Addr]netip.Addr)
|
||||
m.dnatBiMap = newBiDNATMap()
|
||||
@@ -151,7 +282,7 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
|
||||
// RemoveInternalDNATMapping removes a 1:1 IP address mapping.
|
||||
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
|
||||
m.dnatMutex.Lock()
|
||||
defer m.dnatMutex.Unlock()
|
||||
@@ -169,7 +300,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// getDNATTranslation returns the translated address if a mapping exists
|
||||
// getDNATTranslation returns the translated address if a mapping exists with fast-path optimization.
|
||||
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return addr, false
|
||||
@@ -181,7 +312,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
|
||||
return translated, exists
|
||||
}
|
||||
|
||||
// findReverseDNATMapping finds original address for return traffic
|
||||
// findReverseDNATMapping finds original address for return traffic using reverse mapping.
|
||||
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return translatedAddr, false
|
||||
@@ -193,7 +324,7 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr,
|
||||
return original, exists
|
||||
}
|
||||
|
||||
// translateOutboundDNAT applies DNAT translation to outbound packets
|
||||
// translateOutboundDNAT applies DNAT translation to outbound packets for 1:1 IP mapping.
|
||||
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return false
|
||||
@@ -211,7 +342,7 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
}
|
||||
|
||||
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
|
||||
m.logger.Error("Failed to rewrite packet destination: %v", err)
|
||||
m.logger.Error("rewrite packet destination: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -219,7 +350,7 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// translateInboundReverse applies reverse DNAT to inbound return traffic
|
||||
// translateInboundReverse applies reverse DNAT to inbound return traffic for 1:1 IP mapping.
|
||||
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
if !m.dnatEnabled.Load() {
|
||||
return false
|
||||
@@ -237,7 +368,7 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
}
|
||||
|
||||
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
|
||||
m.logger.Error("Failed to rewrite packet source: %v", err)
|
||||
m.logger.Error("rewrite packet source: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -245,7 +376,7 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// rewritePacketDestination replaces destination IP in the packet
|
||||
// rewritePacketDestination replaces destination IP in the packet and updates checksums.
|
||||
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||
return ErrIPv4Only
|
||||
@@ -259,7 +390,7 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return fmt.Errorf("invalid IP header length")
|
||||
return fmt.Errorf(invalidIPHeaderLengthMsg)
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||
@@ -280,7 +411,7 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP
|
||||
return nil
|
||||
}
|
||||
|
||||
// rewritePacketSource replaces the source IP address in the packet
|
||||
// rewritePacketSource replaces the source IP address in the packet and updates checksums.
|
||||
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||
return ErrIPv4Only
|
||||
@@ -294,7 +425,7 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return fmt.Errorf("invalid IP header length")
|
||||
return fmt.Errorf(invalidIPHeaderLengthMsg)
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||
@@ -315,6 +446,7 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateTCPChecksum updates TCP checksum after IP address change using incremental update per RFC 1624.
|
||||
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||
tcpStart := ipHeaderLen
|
||||
if len(packetData) < tcpStart+18 {
|
||||
@@ -327,6 +459,7 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
// updateUDPChecksum updates UDP checksum after IP address change using incremental update per RFC 1624.
|
||||
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||
udpStart := ipHeaderLen
|
||||
if len(packetData) < udpStart+8 {
|
||||
@@ -344,6 +477,7 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
// updateICMPChecksum recalculates ICMP checksum after packet modification using full recalculation.
|
||||
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
||||
icmpStart := ipHeaderLen
|
||||
if len(packetData) < icmpStart+8 {
|
||||
@@ -356,18 +490,16 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
||||
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
|
||||
}
|
||||
|
||||
// incrementalUpdate performs incremental checksum update per RFC 1624
|
||||
// incrementalUpdate performs incremental checksum update per RFC 1624 for performance.
|
||||
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||
sum := uint32(^oldChecksum)
|
||||
|
||||
// Fast path for IPv4 addresses (4 bytes) - most common case
|
||||
if len(oldBytes) == 4 && len(newBytes) == 4 {
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
|
||||
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
|
||||
} else {
|
||||
// Fallback for other lengths
|
||||
for i := 0; i < len(oldBytes)-1; i += 2 {
|
||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2]))
|
||||
}
|
||||
@@ -391,7 +523,7 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||
return ^uint16(sum)
|
||||
}
|
||||
|
||||
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
|
||||
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network.
|
||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil, errNatNotSupported
|
||||
@@ -399,10 +531,318 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
||||
return m.nativeFirewall.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
|
||||
// DeleteDNATRule deletes outbound DNAT rule.
|
||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return errNatNotSupported
|
||||
}
|
||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
// addPortRedirection adds port redirection rule for specified target IP, protocol and ports.
|
||||
func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
|
||||
m.portDNATMutex.Lock()
|
||||
defer m.portDNATMutex.Unlock()
|
||||
|
||||
rule := portDNATRule{
|
||||
protocol: protocol,
|
||||
sourcePort: sourcePort,
|
||||
targetPort: targetPort,
|
||||
targetIP: targetIP,
|
||||
}
|
||||
|
||||
m.portDNATMap.rules = append(m.portDNATMap.rules, rule)
|
||||
m.portDNATEnabled.Store(true)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services on specific ports.
|
||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
var layerType gopacket.LayerType
|
||||
if protocol == firewall.ProtocolTCP {
|
||||
layerType = layers.LayerTypeTCP
|
||||
} else if protocol == firewall.ProtocolUDP {
|
||||
layerType = layers.LayerTypeUDP
|
||||
} else {
|
||||
return fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
}
|
||||
|
||||
return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// removePortRedirection removes port redirection rule for specified target IP, protocol and ports.
|
||||
func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
|
||||
m.portDNATMutex.Lock()
|
||||
defer m.portDNATMutex.Unlock()
|
||||
|
||||
var filteredRules []portDNATRule
|
||||
for _, rule := range m.portDNATMap.rules {
|
||||
if !(rule.protocol == protocol && rule.sourcePort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0) {
|
||||
filteredRules = append(filteredRules, rule)
|
||||
}
|
||||
}
|
||||
m.portDNATMap.rules = filteredRules
|
||||
|
||||
if len(m.portDNATMap.rules) == 0 {
|
||||
m.portDNATEnabled.Store(false)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes inbound DNAT rule for specified local address and ports.
|
||||
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
|
||||
var layerType gopacket.LayerType
|
||||
if protocol == firewall.ProtocolTCP {
|
||||
layerType = layers.LayerTypeTCP
|
||||
} else if protocol == firewall.ProtocolUDP {
|
||||
layerType = layers.LayerTypeUDP
|
||||
} else {
|
||||
return fmt.Errorf("unsupported protocol: %s", protocol)
|
||||
}
|
||||
|
||||
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
|
||||
}
|
||||
|
||||
// translateInboundPortDNAT applies stateful port-specific DNAT translation to inbound packets.
|
||||
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder) bool {
|
||||
if !m.portDNATEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP {
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
srcPort := uint16(d.tcp.SrcPort)
|
||||
dstPort := uint16(d.tcp.DstPort)
|
||||
|
||||
if m.handleReturnTraffic(packetData, d, srcIP, dstIP, srcPort, dstPort) {
|
||||
return true
|
||||
}
|
||||
|
||||
return m.handleNewConnection(packetData, d, srcIP, dstIP, srcPort, dstPort)
|
||||
}
|
||||
|
||||
// handleReturnTraffic processes return traffic for existing NAT connections.
|
||||
func (m *Manager) handleReturnTraffic(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool {
|
||||
if m.isTranslatedPortTraffic(srcIP, srcPort) {
|
||||
return false
|
||||
}
|
||||
|
||||
if handled := m.handleExistingNATConnection(packetData, d, srcIP, dstIP, srcPort, dstPort); handled {
|
||||
return true
|
||||
}
|
||||
|
||||
return m.handleForwardTrafficInExistingConnections(packetData, d, srcIP, dstIP, srcPort, dstPort)
|
||||
}
|
||||
|
||||
// isTranslatedPortTraffic checks if traffic is from a translated port that should be handled by outbound reverse.
|
||||
func (m *Manager) isTranslatedPortTraffic(srcIP netip.Addr, srcPort uint16) bool {
|
||||
m.portDNATMutex.RLock()
|
||||
defer m.portDNATMutex.RUnlock()
|
||||
|
||||
for _, rule := range m.portDNATMap.rules {
|
||||
if rule.protocol == layers.LayerTypeTCP && rule.targetPort == srcPort &&
|
||||
rule.targetIP.Unmap().Compare(srcIP.Unmap()) == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// handleExistingNATConnection processes return traffic for existing NAT connections.
|
||||
func (m *Manager) handleExistingNATConnection(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool {
|
||||
if natConn, exists := m.portNATTracker.getConnectionNAT(dstIP, srcIP, dstPort, srcPort); exists {
|
||||
if err := m.rewriteTCPDestinationPort(packetData, d, natConn.originalPort); err != nil {
|
||||
m.logger.Error(errRewriteTCPDestinationPort, err)
|
||||
return false
|
||||
}
|
||||
m.logger.Trace("Inbound Port DNAT (return): %s:%d -> %s:%d", dstIP, srcPort, dstIP, natConn.originalPort)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// handleForwardTrafficInExistingConnections processes forward traffic in existing connections.
|
||||
func (m *Manager) handleForwardTrafficInExistingConnections(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool {
|
||||
m.portDNATMutex.RLock()
|
||||
defer m.portDNATMutex.RUnlock()
|
||||
|
||||
for _, rule := range m.portDNATMap.rules {
|
||||
if rule.protocol != layers.LayerTypeTCP || rule.sourcePort != dstPort {
|
||||
continue
|
||||
}
|
||||
if rule.targetIP.Unmap().Compare(dstIP.Unmap()) != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, exists := m.portNATTracker.getConnectionNAT(srcIP, dstIP, srcPort, rule.targetPort); !exists {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.rewriteTCPDestinationPort(packetData, d, rule.targetPort); err != nil {
|
||||
m.logger.Error(errRewriteTCPDestinationPort, err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// handleNewConnection processes new connections that match port DNAT rules.
|
||||
func (m *Manager) handleNewConnection(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool {
|
||||
m.portDNATMutex.RLock()
|
||||
defer m.portDNATMutex.RUnlock()
|
||||
|
||||
for _, rule := range m.portDNATMap.rules {
|
||||
if m.applyPortDNATRule(packetData, d, rule, srcIP, dstIP, srcPort, dstPort) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// applyPortDNATRule applies a specific port DNAT rule if conditions are met.
|
||||
func (m *Manager) applyPortDNATRule(packetData []byte, d *decoder, rule portDNATRule, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool {
|
||||
if rule.protocol != layers.LayerTypeTCP || rule.sourcePort != dstPort {
|
||||
return false
|
||||
}
|
||||
|
||||
if rule.targetIP.Unmap().Compare(dstIP.Unmap()) != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if !m.portNATTracker.shouldApplyNAT(srcIP, dstIP, dstPort) {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := m.rewriteTCPDestinationPort(packetData, d, rule.targetPort); err != nil {
|
||||
m.logger.Error(errRewriteTCPDestinationPort, err)
|
||||
return false
|
||||
}
|
||||
|
||||
m.portNATTracker.trackConnection(srcIP, dstIP, srcPort, dstPort, rule)
|
||||
m.logger.Trace("Inbound Port DNAT (new): %s:%d -> %s:%d (tracked: %s:%d -> %s:%d)", dstIP, rule.sourcePort, dstIP, rule.targetPort, srcIP, srcPort, dstIP, rule.targetPort)
|
||||
return true
|
||||
}
|
||||
|
||||
// rewriteTCPDestinationPort rewrites the destination port in a TCP packet and updates checksum.
|
||||
func (m *Manager) rewriteTCPDestinationPort(packetData []byte, d *decoder, newPort uint16) error {
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
return ErrIPv4Only
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP {
|
||||
return fmt.Errorf("not a TCP packet")
|
||||
}
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return fmt.Errorf(invalidIPHeaderLengthMsg)
|
||||
}
|
||||
|
||||
tcpStart := ipHeaderLen
|
||||
if len(packetData) < tcpStart+4 {
|
||||
return fmt.Errorf("packet too short for TCP header")
|
||||
}
|
||||
|
||||
oldPort := binary.BigEndian.Uint16(packetData[tcpStart+2 : tcpStart+4])
|
||||
|
||||
binary.BigEndian.PutUint16(packetData[tcpStart+2:tcpStart+4], newPort)
|
||||
|
||||
if len(packetData) >= tcpStart+18 {
|
||||
checksumOffset := tcpStart + 16
|
||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||
|
||||
var oldPortBytes, newPortBytes [2]byte
|
||||
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
|
||||
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
|
||||
|
||||
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rewriteTCPSourcePort rewrites the source port in a TCP packet and updates checksum.
|
||||
func (m *Manager) rewriteTCPSourcePort(packetData []byte, d *decoder, newPort uint16) error {
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
return ErrIPv4Only
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP {
|
||||
return fmt.Errorf("not a TCP packet")
|
||||
}
|
||||
|
||||
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||
return fmt.Errorf(invalidIPHeaderLengthMsg)
|
||||
}
|
||||
|
||||
tcpStart := ipHeaderLen
|
||||
if len(packetData) < tcpStart+4 {
|
||||
return fmt.Errorf("packet too short for TCP header")
|
||||
}
|
||||
|
||||
oldPort := binary.BigEndian.Uint16(packetData[tcpStart : tcpStart+2])
|
||||
|
||||
binary.BigEndian.PutUint16(packetData[tcpStart:tcpStart+2], newPort)
|
||||
|
||||
if len(packetData) >= tcpStart+18 {
|
||||
checksumOffset := tcpStart + 16
|
||||
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||
|
||||
var oldPortBytes, newPortBytes [2]byte
|
||||
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
|
||||
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
|
||||
|
||||
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
|
||||
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// translateOutboundPortReverse applies stateful reverse port DNAT to outbound return traffic for SSH redirection.
|
||||
func (m *Manager) translateOutboundPortReverse(packetData []byte, d *decoder) bool {
|
||||
if !m.portDNATEnabled.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP {
|
||||
return false
|
||||
}
|
||||
|
||||
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||
srcPort := uint16(d.tcp.SrcPort)
|
||||
dstPort := uint16(d.tcp.DstPort)
|
||||
|
||||
// For outbound reverse, we need to find the connection using the same key as when it was stored
|
||||
// Connection was stored as: srcIP, dstIP, srcPort, translatedPort
|
||||
// So for return traffic (srcIP=server, dstIP=client), we need: dstIP, srcIP, dstPort, srcPort
|
||||
if natConn, exists := m.portNATTracker.getConnectionNAT(dstIP, srcIP, dstPort, srcPort); exists {
|
||||
if err := m.rewriteTCPSourcePort(packetData, d, natConn.rule.sourcePort); err != nil {
|
||||
m.logger.Error("rewrite TCP source port: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
111
client/firewall/uspfilter/nat_stateful_test.go
Normal file
111
client/firewall/uspfilter/nat_stateful_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
// TestStatefulNATBidirectionalSSH tests that stateful NAT prevents interference
|
||||
// when two peers try to SSH to each other simultaneously
|
||||
func TestStatefulNATBidirectionalSSH(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define peer IPs
|
||||
peerA := netip.MustParseAddr("100.10.0.50")
|
||||
peerB := netip.MustParseAddr("100.10.0.51")
|
||||
|
||||
// Add SSH port redirection rule for peer B (the target)
|
||||
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Scenario: Peer A connects to Peer B on port 22 (should get NAT)
|
||||
// This simulates: ssh user@100.10.0.51
|
||||
packetAtoB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
|
||||
translatedAtoB := manager.translateInboundPortDNAT(packetAtoB, parsePacket(t, packetAtoB))
|
||||
require.True(t, translatedAtoB, "Peer A to Peer B should be translated (NAT applied)")
|
||||
|
||||
// Verify port was translated to 22022
|
||||
d := parsePacket(t, packetAtoB)
|
||||
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be rewritten to 22022")
|
||||
|
||||
// Verify NAT connection is tracked (with translated port as key)
|
||||
natConn, exists := manager.portNATTracker.getConnectionNAT(peerA, peerB, 54321, 22022)
|
||||
require.True(t, exists, "NAT connection should be tracked")
|
||||
require.Equal(t, uint16(22), natConn.originalPort, "Original port should be stored")
|
||||
|
||||
// Scenario: Peer B tries to connect to Peer A on port 22 (should NOT get NAT)
|
||||
// This simulates the reverse direction to prevent interference
|
||||
packetBtoA := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 54322, 22)
|
||||
translatedBtoA := manager.translateInboundPortDNAT(packetBtoA, parsePacket(t, packetBtoA))
|
||||
require.False(t, translatedBtoA, "Peer B to Peer A should NOT be translated (prevent interference)")
|
||||
|
||||
// Verify port was NOT translated
|
||||
d2 := parsePacket(t, packetBtoA)
|
||||
require.Equal(t, uint16(22), uint16(d2.tcp.DstPort), "Port should remain 22 (no translation)")
|
||||
|
||||
// Verify no reverse NAT connection is tracked
|
||||
_, reverseExists := manager.portNATTracker.getConnectionNAT(peerB, peerA, 54322, 22)
|
||||
require.False(t, reverseExists, "Reverse NAT connection should NOT be tracked")
|
||||
|
||||
// Scenario: Return traffic from Peer B (SSH server) to Peer A (should be reverse translated)
|
||||
returnPacket := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 22022, 54321)
|
||||
translatedReturn := manager.translateOutboundPortReverse(returnPacket, parsePacket(t, returnPacket))
|
||||
require.True(t, translatedReturn, "Return traffic should be reverse translated")
|
||||
|
||||
// Verify return traffic port was translated back to 22
|
||||
d3 := parsePacket(t, returnPacket)
|
||||
require.Equal(t, uint16(22), uint16(d3.tcp.SrcPort), "Return traffic source port should be 22")
|
||||
}
|
||||
|
||||
// TestStatefulNATConnectionCleanup tests connection cleanup functionality
|
||||
func TestStatefulNATConnectionCleanup(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define peer IPs
|
||||
peerA := netip.MustParseAddr("100.10.0.50")
|
||||
peerB := netip.MustParseAddr("100.10.0.51")
|
||||
|
||||
// Add SSH port redirection rules for both peers
|
||||
err = manager.addPortRedirection(peerA, layers.LayerTypeTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Establish connection with NAT
|
||||
packet := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
|
||||
translated := manager.translateInboundPortDNAT(packet, parsePacket(t, packet))
|
||||
require.True(t, translated, "Initial connection should be translated")
|
||||
|
||||
// Verify connection is tracked (using translated port as key)
|
||||
_, exists := manager.portNATTracker.getConnectionNAT(peerA, peerB, 54321, 22022)
|
||||
require.True(t, exists, "Connection should be tracked")
|
||||
|
||||
// Clean up connection
|
||||
manager.portNATTracker.cleanupConnection(peerA, peerB, 54321)
|
||||
|
||||
// Verify connection is no longer tracked (using translated port as key)
|
||||
_, stillExists := manager.portNATTracker.getConnectionNAT(peerA, peerB, 54321, 22022)
|
||||
require.False(t, stillExists, "Connection should be cleaned up")
|
||||
|
||||
// Verify new connection from opposite direction now works
|
||||
reversePacket := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 54322, 22)
|
||||
reverseTranslated := manager.translateInboundPortDNAT(reversePacket, parsePacket(t, reversePacket))
|
||||
require.True(t, reverseTranslated, "Reverse connection should now work after cleanup")
|
||||
}
|
||||
@@ -1,13 +1,17 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
@@ -143,3 +147,520 @@ func TestDNATMappingManagement(t *testing.T) {
|
||||
err = manager.RemoveInternalDNATMapping(originalIP)
|
||||
require.Error(t, err, "Should error when removing non-existent mapping")
|
||||
}
|
||||
|
||||
// TestSSHPortRedirection tests SSH port redirection functionality
|
||||
func TestSSHPortRedirection(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define NetBird network range
|
||||
peerIP := netip.MustParseAddr("100.10.0.50")
|
||||
clientIP := netip.MustParseAddr("100.10.0.100")
|
||||
|
||||
// Add SSH port redirection rule
|
||||
err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify port DNAT is enabled
|
||||
require.True(t, manager.portDNATEnabled.Load(), "Port DNAT should be enabled")
|
||||
require.Len(t, manager.portDNATMap.rules, 1, "Should have one port DNAT rule")
|
||||
|
||||
// Verify the rule configuration
|
||||
rule := manager.portDNATMap.rules[0]
|
||||
require.Equal(t, gopacket.LayerType(layers.LayerTypeTCP), rule.protocol)
|
||||
require.Equal(t, uint16(22), rule.sourcePort)
|
||||
require.Equal(t, uint16(22022), rule.targetPort)
|
||||
require.Equal(t, peerIP, rule.targetIP)
|
||||
|
||||
// Test inbound SSH packet (client -> peer:22, should redirect to peer:22022)
|
||||
inboundPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22)
|
||||
originalInbound := make([]byte, len(inboundPacket))
|
||||
copy(originalInbound, inboundPacket)
|
||||
|
||||
// Process inbound packet
|
||||
translated := manager.translateInboundPortDNAT(inboundPacket, parsePacket(t, inboundPacket))
|
||||
require.True(t, translated, "Inbound SSH packet should be translated")
|
||||
|
||||
// Verify destination port was changed from 22 to 22022
|
||||
d := parsePacket(t, inboundPacket)
|
||||
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Destination port should be rewritten to 22022")
|
||||
|
||||
// Verify destination IP remains unchanged
|
||||
dstIPAfter := netip.AddrFrom4([4]byte{inboundPacket[16], inboundPacket[17], inboundPacket[18], inboundPacket[19]})
|
||||
require.Equal(t, peerIP, dstIPAfter, "Destination IP should remain unchanged")
|
||||
|
||||
// Test outbound return packet (peer:22022 -> client, should rewrite source port to 22)
|
||||
outboundPacket := generateDNATTestPacket(t, peerIP, clientIP, layers.IPProtocolTCP, 22022, 54321)
|
||||
originalOutbound := make([]byte, len(outboundPacket))
|
||||
copy(originalOutbound, outboundPacket)
|
||||
|
||||
// Process outbound return packet
|
||||
reversed := manager.translateOutboundPortReverse(outboundPacket, parsePacket(t, outboundPacket))
|
||||
require.True(t, reversed, "Outbound return packet should be reverse translated")
|
||||
|
||||
// Verify source port was changed from 22022 to 22
|
||||
d = parsePacket(t, outboundPacket)
|
||||
require.Equal(t, uint16(22), uint16(d.tcp.SrcPort), "Source port should be rewritten to 22")
|
||||
|
||||
// Verify source IP remains unchanged
|
||||
srcIPAfter := netip.AddrFrom4([4]byte{outboundPacket[12], outboundPacket[13], outboundPacket[14], outboundPacket[15]})
|
||||
require.Equal(t, peerIP, srcIPAfter, "Source IP should remain unchanged")
|
||||
|
||||
// Test removal of SSH port redirection
|
||||
err = manager.RemoveInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
require.False(t, manager.portDNATEnabled.Load(), "Port DNAT should be disabled after removal")
|
||||
require.Len(t, manager.portDNATMap.rules, 0, "Should have no port DNAT rules after removal")
|
||||
}
|
||||
|
||||
// TestSSHPortRedirectionNetworkFiltering tests that SSH redirection only applies to specified networks
|
||||
func TestSSHPortRedirectionNetworkFiltering(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define NetBird network range
|
||||
peerInNetwork := netip.MustParseAddr("100.10.0.50")
|
||||
peerOutsideNetwork := netip.MustParseAddr("192.168.1.50")
|
||||
clientIP := netip.MustParseAddr("100.10.0.100")
|
||||
|
||||
// Add SSH port redirection rule for NetBird network only
|
||||
err = manager.AddInboundDNAT(peerInNetwork, firewall.ProtocolTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test SSH packet to peer within NetBird network (should be redirected)
|
||||
inNetworkPacket := generateDNATTestPacket(t, clientIP, peerInNetwork, layers.IPProtocolTCP, 54321, 22)
|
||||
translated := manager.translateInboundPortDNAT(inNetworkPacket, parsePacket(t, inNetworkPacket))
|
||||
require.True(t, translated, "SSH packet to NetBird peer should be translated")
|
||||
|
||||
// Verify port was changed
|
||||
d := parsePacket(t, inNetworkPacket)
|
||||
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be redirected for NetBird peer")
|
||||
|
||||
// Test SSH packet to peer outside NetBird network (should NOT be redirected)
|
||||
outOfNetworkPacket := generateDNATTestPacket(t, clientIP, peerOutsideNetwork, layers.IPProtocolTCP, 54321, 22)
|
||||
originalOutOfNetwork := make([]byte, len(outOfNetworkPacket))
|
||||
copy(originalOutOfNetwork, outOfNetworkPacket)
|
||||
|
||||
notTranslated := manager.translateInboundPortDNAT(outOfNetworkPacket, parsePacket(t, outOfNetworkPacket))
|
||||
require.False(t, notTranslated, "SSH packet to non-NetBird peer should NOT be translated")
|
||||
|
||||
// Verify packet was not modified
|
||||
require.Equal(t, originalOutOfNetwork, outOfNetworkPacket, "Packet to non-NetBird peer should remain unchanged")
|
||||
}
|
||||
|
||||
// TestSSHPortRedirectionNonTCPTraffic tests that only TCP traffic is affected
|
||||
func TestSSHPortRedirectionNonTCPTraffic(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define NetBird network range
|
||||
peerIP := netip.MustParseAddr("100.10.0.50")
|
||||
clientIP := netip.MustParseAddr("100.10.0.100")
|
||||
|
||||
// Add SSH port redirection rule
|
||||
err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test UDP packet on port 22 (should NOT be redirected)
|
||||
udpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolUDP, 54321, 22)
|
||||
originalUDP := make([]byte, len(udpPacket))
|
||||
copy(originalUDP, udpPacket)
|
||||
|
||||
translated := manager.translateInboundPortDNAT(udpPacket, parsePacket(t, udpPacket))
|
||||
require.False(t, translated, "UDP packet should NOT be translated by SSH port redirection")
|
||||
require.Equal(t, originalUDP, udpPacket, "UDP packet should remain unchanged")
|
||||
|
||||
// Test ICMP packet (should NOT be redirected)
|
||||
icmpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolICMPv4, 0, 0)
|
||||
originalICMP := make([]byte, len(icmpPacket))
|
||||
copy(originalICMP, icmpPacket)
|
||||
|
||||
translated = manager.translateInboundPortDNAT(icmpPacket, parsePacket(t, icmpPacket))
|
||||
require.False(t, translated, "ICMP packet should NOT be translated by SSH port redirection")
|
||||
require.Equal(t, originalICMP, icmpPacket, "ICMP packet should remain unchanged")
|
||||
}
|
||||
|
||||
// TestSSHPortRedirectionNonSSHPorts tests that only port 22 is redirected
|
||||
func TestSSHPortRedirectionNonSSHPorts(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define NetBird network range
|
||||
peerIP := netip.MustParseAddr("100.10.0.50")
|
||||
clientIP := netip.MustParseAddr("100.10.0.100")
|
||||
|
||||
// Add SSH port redirection rule
|
||||
err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test TCP packet on port 80 (should NOT be redirected)
|
||||
httpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 80)
|
||||
originalHTTP := make([]byte, len(httpPacket))
|
||||
copy(originalHTTP, httpPacket)
|
||||
|
||||
translated := manager.translateInboundPortDNAT(httpPacket, parsePacket(t, httpPacket))
|
||||
require.False(t, translated, "Non-SSH TCP packet should NOT be translated")
|
||||
require.Equal(t, originalHTTP, httpPacket, "Non-SSH TCP packet should remain unchanged")
|
||||
|
||||
// Test TCP packet on port 443 (should NOT be redirected)
|
||||
httpsPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 443)
|
||||
originalHTTPS := make([]byte, len(httpsPacket))
|
||||
copy(originalHTTPS, httpsPacket)
|
||||
|
||||
translated = manager.translateInboundPortDNAT(httpsPacket, parsePacket(t, httpsPacket))
|
||||
require.False(t, translated, "Non-SSH TCP packet should NOT be translated")
|
||||
require.Equal(t, originalHTTPS, httpsPacket, "Non-SSH TCP packet should remain unchanged")
|
||||
|
||||
// Test TCP packet on port 22 (SHOULD be redirected)
|
||||
sshPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22)
|
||||
translated = manager.translateInboundPortDNAT(sshPacket, parsePacket(t, sshPacket))
|
||||
require.True(t, translated, "SSH TCP packet should be translated")
|
||||
|
||||
// Verify port was changed to 22022
|
||||
d := parsePacket(t, sshPacket)
|
||||
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "SSH port should be redirected to 22022")
|
||||
}
|
||||
|
||||
// TestFlexiblePortRedirection tests the flexible port redirection functionality
|
||||
func TestFlexiblePortRedirection(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define peer and client IPs
|
||||
peerIP := netip.MustParseAddr("10.0.0.50")
|
||||
clientIP := netip.MustParseAddr("10.0.0.100")
|
||||
|
||||
// Add custom port redirection: TCP port 8080 -> 3000 for peer IP
|
||||
err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 8080, 3000)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify port DNAT is enabled
|
||||
require.True(t, manager.portDNATEnabled.Load(), "Port DNAT should be enabled")
|
||||
require.Len(t, manager.portDNATMap.rules, 1, "Should have one port DNAT rule")
|
||||
|
||||
// Verify the rule configuration
|
||||
rule := manager.portDNATMap.rules[0]
|
||||
require.Equal(t, gopacket.LayerType(layers.LayerTypeTCP), rule.protocol)
|
||||
require.Equal(t, uint16(8080), rule.sourcePort)
|
||||
require.Equal(t, uint16(3000), rule.targetPort)
|
||||
require.Equal(t, peerIP, rule.targetIP)
|
||||
|
||||
// Test inbound packet (client -> peer:8080, should redirect to peer:3000)
|
||||
inboundPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 8080)
|
||||
translated := manager.translateInboundPortDNAT(inboundPacket, parsePacket(t, inboundPacket))
|
||||
require.True(t, translated, "Inbound packet should be translated")
|
||||
|
||||
// Verify destination port was changed from 8080 to 3000
|
||||
d := parsePacket(t, inboundPacket)
|
||||
require.Equal(t, uint16(3000), uint16(d.tcp.DstPort), "Destination port should be rewritten to 3000")
|
||||
|
||||
// Test outbound return packet (peer:3000 -> client, should rewrite source port to 8080)
|
||||
outboundPacket := generateDNATTestPacket(t, peerIP, clientIP, layers.IPProtocolTCP, 3000, 54321)
|
||||
reversed := manager.translateOutboundPortReverse(outboundPacket, parsePacket(t, outboundPacket))
|
||||
require.True(t, reversed, "Outbound return packet should be reverse translated")
|
||||
|
||||
// Verify source port was changed from 3000 to 8080
|
||||
d = parsePacket(t, outboundPacket)
|
||||
require.Equal(t, uint16(8080), uint16(d.tcp.SrcPort), "Source port should be rewritten to 8080")
|
||||
|
||||
// Test removal of port redirection
|
||||
err = manager.removePortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 8080, 3000)
|
||||
require.NoError(t, err)
|
||||
require.False(t, manager.portDNATEnabled.Load(), "Port DNAT should be disabled after removal")
|
||||
require.Len(t, manager.portDNATMap.rules, 0, "Should have no port DNAT rules after removal")
|
||||
}
|
||||
|
||||
// TestMultiplePortRedirections tests multiple port redirection rules
|
||||
func TestMultiplePortRedirections(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define peer and client IPs
|
||||
peerIP := netip.MustParseAddr("172.16.0.50")
|
||||
clientIP := netip.MustParseAddr("172.16.0.100")
|
||||
|
||||
// Add multiple port redirections for peer IP
|
||||
err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 22, 22022) // SSH
|
||||
require.NoError(t, err)
|
||||
err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 80, 8080) // HTTP
|
||||
require.NoError(t, err)
|
||||
err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 443, 8443) // HTTPS
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all rules are present
|
||||
require.True(t, manager.portDNATEnabled.Load(), "Port DNAT should be enabled")
|
||||
require.Len(t, manager.portDNATMap.rules, 3, "Should have three port DNAT rules")
|
||||
|
||||
// Test SSH redirection (22 -> 22022)
|
||||
sshPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22)
|
||||
translated := manager.translateInboundPortDNAT(sshPacket, parsePacket(t, sshPacket))
|
||||
require.True(t, translated, "SSH packet should be translated")
|
||||
d := parsePacket(t, sshPacket)
|
||||
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "SSH should redirect to 22022")
|
||||
|
||||
// Test HTTP redirection (80 -> 8080)
|
||||
httpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 80)
|
||||
translated = manager.translateInboundPortDNAT(httpPacket, parsePacket(t, httpPacket))
|
||||
require.True(t, translated, "HTTP packet should be translated")
|
||||
d = parsePacket(t, httpPacket)
|
||||
require.Equal(t, uint16(8080), uint16(d.tcp.DstPort), "HTTP should redirect to 8080")
|
||||
|
||||
// Test HTTPS redirection (443 -> 8443)
|
||||
httpsPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 443)
|
||||
translated = manager.translateInboundPortDNAT(httpsPacket, parsePacket(t, httpsPacket))
|
||||
require.True(t, translated, "HTTPS packet should be translated")
|
||||
d = parsePacket(t, httpsPacket)
|
||||
require.Equal(t, uint16(8443), uint16(d.tcp.DstPort), "HTTPS should redirect to 8443")
|
||||
|
||||
// Test removing one rule (HTTP)
|
||||
err = manager.removePortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 80, 8080)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, manager.portDNATMap.rules, 2, "Should have two rules after removing HTTP rule")
|
||||
|
||||
// Verify HTTP is no longer redirected
|
||||
httpPacket2 := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 80)
|
||||
originalHTTP := make([]byte, len(httpPacket2))
|
||||
copy(originalHTTP, httpPacket2)
|
||||
translated = manager.translateInboundPortDNAT(httpPacket2, parsePacket(t, httpPacket2))
|
||||
require.False(t, translated, "HTTP packet should NOT be translated after rule removal")
|
||||
require.Equal(t, originalHTTP, httpPacket2, "HTTP packet should remain unchanged")
|
||||
|
||||
// Verify SSH and HTTPS still work
|
||||
sshPacket2 := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22)
|
||||
translated = manager.translateInboundPortDNAT(sshPacket2, parsePacket(t, sshPacket2))
|
||||
require.True(t, translated, "SSH should still be translated")
|
||||
|
||||
httpsPacket2 := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 443)
|
||||
translated = manager.translateInboundPortDNAT(httpsPacket2, parsePacket(t, httpsPacket2))
|
||||
require.True(t, translated, "HTTPS should still be translated")
|
||||
}
|
||||
|
||||
// TestSSHPortRedirectionEndToEnd tests actual network delivery through sockets
|
||||
func TestSSHPortRedirectionEndToEnd(t *testing.T) {
|
||||
// Start a mock SSH server on port 22022 (NetBird SSH server)
|
||||
mockSSHServer, err := net.Listen("tcp", "127.0.0.1:22022")
|
||||
require.NoError(t, err, "Should be able to bind to NetBird SSH port")
|
||||
defer func() {
|
||||
require.NoError(t, mockSSHServer.Close())
|
||||
}()
|
||||
|
||||
// Handle connections on the SSH server
|
||||
serverReceivedData := make(chan string, 1)
|
||||
go func() {
|
||||
for {
|
||||
conn, err := mockSSHServer.Accept()
|
||||
if err != nil {
|
||||
return // Server closed
|
||||
}
|
||||
go func(conn net.Conn) {
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil && err != io.EOF {
|
||||
t.Logf("Server read error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
receivedData := string(buf[:n])
|
||||
serverReceivedData <- receivedData
|
||||
|
||||
// Echo back a response
|
||||
_, err = conn.Write([]byte("SSH-2.0-MockNetBirdSSH\r\n"))
|
||||
if err != nil {
|
||||
t.Logf("Server write error: %v", err)
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
// Give server time to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// This test demonstrates what SHOULD happen after port redirection:
|
||||
// 1. Client connects to 127.0.0.1:22 (standard SSH port)
|
||||
// 2. Firewall redirects to 127.0.0.1:22022 (NetBird SSH server)
|
||||
// 3. NetBird SSH server receives the connection
|
||||
|
||||
t.Run("DirectConnectionToNetBirdSSHPort", func(t *testing.T) {
|
||||
// This simulates what should happen AFTER port redirection
|
||||
// Connect directly to 22022 (where NetBird SSH server listens)
|
||||
conn, err := net.DialTimeout("tcp", "127.0.0.1:22022", 5*time.Second)
|
||||
require.NoError(t, err, "Should connect to NetBird SSH server")
|
||||
defer func() {
|
||||
require.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
// Send SSH client identification
|
||||
testData := "SSH-2.0-TestClient\r\n"
|
||||
_, err = conn.Write([]byte(testData))
|
||||
require.NoError(t, err, "Should send data to SSH server")
|
||||
|
||||
// Verify server received the data
|
||||
select {
|
||||
case received := <-serverReceivedData:
|
||||
require.Equal(t, testData, received, "Server should receive client data")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Server did not receive data within timeout")
|
||||
}
|
||||
|
||||
// Read server response
|
||||
buf := make([]byte, 1024)
|
||||
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
n, err := conn.Read(buf)
|
||||
require.NoError(t, err, "Should read server response")
|
||||
|
||||
response := string(buf[:n])
|
||||
require.Equal(t, "SSH-2.0-MockNetBirdSSH\r\n", response, "Should receive SSH server identification")
|
||||
})
|
||||
|
||||
t.Run("PortRedirectionSimulation", func(t *testing.T) {
|
||||
// This test simulates the port redirection process
|
||||
// Note: This doesn't test the actual userspace packet interception,
|
||||
// but demonstrates the expected behavior
|
||||
|
||||
t.Log("NOTE: This test demonstrates expected behavior after implementing")
|
||||
t.Log("full userspace packet interception. Currently, we test packet")
|
||||
t.Log("translation logic separately from actual network delivery.")
|
||||
|
||||
// In a real implementation with userspace packet interception:
|
||||
// 1. Client would connect to 127.0.0.1:22
|
||||
// 2. Userspace firewall would intercept packets
|
||||
// 3. translateInboundPortDNAT would rewrite port 22 -> 22022
|
||||
// 4. Packets would be delivered to 127.0.0.1:22022
|
||||
// 5. NetBird SSH server would receive the connection
|
||||
|
||||
// For now, we verify that the packet translation logic works correctly
|
||||
// (this is tested in other test functions) and that the target server
|
||||
// is reachable (tested above)
|
||||
|
||||
clientIP := netip.MustParseAddr("127.0.0.1")
|
||||
serverIP := netip.MustParseAddr("127.0.0.1")
|
||||
|
||||
// Create manager with SSH port redirection
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Add SSH port redirection for localhost (for testing)
|
||||
err = manager.AddInboundDNAT(netip.MustParseAddr("127.0.0.1"), firewall.ProtocolTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate packet: client connecting to server:22
|
||||
sshPacket := generateDNATTestPacket(t, clientIP, serverIP, layers.IPProtocolTCP, 54321, 22)
|
||||
originalPacket := make([]byte, len(sshPacket))
|
||||
copy(originalPacket, sshPacket)
|
||||
|
||||
// Apply port redirection
|
||||
translated := manager.translateInboundPortDNAT(sshPacket, parsePacket(t, sshPacket))
|
||||
require.True(t, translated, "SSH packet should be translated")
|
||||
|
||||
// Verify port was redirected from 22 to 22022
|
||||
d := parsePacket(t, sshPacket)
|
||||
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be redirected to NetBird SSH server")
|
||||
require.NotEqual(t, originalPacket, sshPacket, "Packet should be modified")
|
||||
|
||||
t.Log("✓ Packet translation verified: port 22 redirected to 22022")
|
||||
t.Log("✓ Target SSH server (port 22022) is reachable and responsive")
|
||||
t.Log("→ Integration complete: SSH port redirection ready for userspace interception")
|
||||
})
|
||||
}
|
||||
|
||||
// TestFullSSHRedirectionWorkflow demonstrates the complete SSH redirection workflow
|
||||
func TestFullSSHRedirectionWorkflow(t *testing.T) {
|
||||
t.Log("=== SSH Port Redirection Workflow Test ===")
|
||||
t.Log("This test demonstrates the complete SSH redirection process:")
|
||||
t.Log("1. Client connects to peer:22 (standard SSH)")
|
||||
t.Log("2. Userspace firewall intercepts and redirects to peer:22022")
|
||||
t.Log("3. NetBird SSH server receives connection on port 22022")
|
||||
t.Log("4. Return traffic is reverse-translated (22022 -> 22)")
|
||||
|
||||
// Setup test environment
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define NetBird network and peer IPs
|
||||
peerIP := netip.MustParseAddr("100.10.0.50")
|
||||
clientIP := netip.MustParseAddr("100.10.0.100")
|
||||
|
||||
// Step 1: Configure SSH port redirection
|
||||
err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
t.Log("✓ SSH port redirection configured for NetBird network")
|
||||
|
||||
// Step 2: Simulate inbound SSH connection (client -> peer:22)
|
||||
t.Log("→ Simulating: ssh user@100.10.0.50")
|
||||
inboundPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22)
|
||||
|
||||
// Step 3: Apply inbound port redirection
|
||||
translated := manager.translateInboundPortDNAT(inboundPacket, parsePacket(t, inboundPacket))
|
||||
require.True(t, translated, "Inbound SSH packet should be redirected")
|
||||
|
||||
d := parsePacket(t, inboundPacket)
|
||||
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Should redirect to NetBird SSH server port")
|
||||
t.Log("✓ Inbound packet redirected: 100.10.0.50:22 → 100.10.0.50:22022")
|
||||
|
||||
// Step 4: Simulate outbound return traffic (peer:22022 -> client)
|
||||
t.Log("→ Simulating return traffic from NetBird SSH server")
|
||||
outboundPacket := generateDNATTestPacket(t, peerIP, clientIP, layers.IPProtocolTCP, 22022, 54321)
|
||||
|
||||
// Step 5: Apply outbound reverse translation
|
||||
reversed := manager.translateOutboundPortReverse(outboundPacket, parsePacket(t, outboundPacket))
|
||||
require.True(t, reversed, "Outbound return packet should be reverse translated")
|
||||
|
||||
d = parsePacket(t, outboundPacket)
|
||||
require.Equal(t, uint16(22), uint16(d.tcp.SrcPort), "Should restore original SSH port")
|
||||
t.Log("✓ Outbound packet reverse translated: 100.10.0.50:22022 → 100.10.0.50:22")
|
||||
|
||||
// Step 6: Verify client sees standard SSH connection
|
||||
srcIPAfter := netip.AddrFrom4([4]byte{outboundPacket[12], outboundPacket[13], outboundPacket[14], outboundPacket[15]})
|
||||
require.Equal(t, peerIP, srcIPAfter, "Client should see traffic from peer IP")
|
||||
t.Log("✓ Client receives traffic from 100.10.0.50:22 (transparent redirection)")
|
||||
|
||||
t.Log("=== SSH Port Redirection Workflow Complete ===")
|
||||
t.Log("Result: Standard SSH clients can connect to NetBird peers using:")
|
||||
t.Log(" ssh user@100.10.0.50")
|
||||
t.Log("Instead of:")
|
||||
t.Log(" ssh user@100.10.0.50 -p 22022")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user