diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 81f7a9125..16b50211e 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -260,6 +260,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return m.router.UpdateSet(set, prefixes) } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + func getConntrackEstablished() []string { return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} } diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 081991235..80aea7cf8 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -880,6 +880,54 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return nberrors.FormatErrorOrNil(merr) } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if _, exists := r.rules[ruleID]; exists { + return nil + } + + dnatRule := []string{ + "-i", r.wgIface.Name(), + "-p", strings.ToLower(string(protocol)), + "--dport", strconv.Itoa(int(sourcePort)), + "-d", localAddr.String(), + "-m", "addrtype", "--dst-type", "LOCAL", + "-j", "DNAT", + "--to-destination", ":" + strconv.Itoa(int(targetPort)), + } + + ruleInfo := ruleInfo{ + table: tableNat, + chain: chainRTRDR, + rule: dnatRule, + } + + if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil { + return fmt.Errorf("add inbound DNAT rule: %w", err) + } + r.rules[ruleID] = ruleInfo.rule + + r.updateState() + return nil +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if dnatRule, exists := r.rules[ruleID]; exists { + if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil { + return fmt.Errorf("delete inbound DNAT rule: %w", err) + } + delete(r.rules, ruleID) + } + + r.updateState() + return nil +} + func applyPort(flag string, port *firewall.Port) []string { if port == nil { return nil diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 3b3164823..7ee33118b 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -151,14 +151,20 @@ type Manager interface { DisableRouting() error - // AddDNATRule adds a DNAT rule + // AddDNATRule adds outbound DNAT rule for forwarding external traffic to the NetBird network. AddDNATRule(ForwardRule) (Rule, error) - // DeleteDNATRule deletes a DNAT rule + // DeleteDNATRule deletes the outbound DNAT rule. DeleteDNATRule(Rule) error // UpdateSet updates the set with the given prefixes UpdateSet(hash Set, prefixes []netip.Prefix) error + + // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services + AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + + // RemoveInboundDNAT removes inbound DNAT rule + RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error } func GenKey(format string, pair RouterPair) string { diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 560f224f5..aa90d3b9a 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -376,6 +376,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return m.router.UpdateSet(set, prefixes) } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + func (m *Manager) createWorkTable() (*nftables.Table, error) { tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) if err != nil { diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index e918d0524..648a6aedf 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -1350,6 +1350,103 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return nil } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if _, exists := r.rules[ruleID]; exists { + return nil + } + + protoNum, err := protoToInt(protocol) + if err != nil { + return fmt.Errorf("convert protocol to number: %w", err) + } + + exprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(r.wgIface.Name()), + }, + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 2, + Data: []byte{protoNum}, + }, + &expr.Payload{ + DestRegister: 3, + Base: expr.PayloadBaseTransportHeader, + Offset: 2, + Len: 2, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 3, + Data: binaryutil.BigEndian.PutUint16(sourcePort), + }, + } + + exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) + + exprs = append(exprs, + &expr.Immediate{ + Register: 1, + Data: localAddr.AsSlice(), + }, + &expr.Immediate{ + Register: 2, + Data: binaryutil.BigEndian.PutUint16(targetPort), + }, + &expr.NAT{ + Type: expr.NATTypeDestNAT, + Family: uint32(nftables.TableFamilyIPv4), + RegAddrMin: 1, + RegProtoMin: 2, + RegProtoMax: 0, + }, + ) + + dnatRule := &nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingRdr], + Exprs: exprs, + UserData: []byte(ruleID), + } + r.conn.AddRule(dnatRule) + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("add inbound DNAT rule: %w", err) + } + + r.rules[ruleID] = dnatRule + + return nil +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if rule, exists := r.rules[ruleID]; exists { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err) + } + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("flush delete inbound DNAT rule: %w", err) + } + delete(r.rules, ruleID) + } + + return nil +} + // applyNetwork generates nftables expressions for networks (CIDR) or sets func (r *router) applyNetwork( network firewall.Network, diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index bcf6d894b..7be0dd78f 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -22,6 +22,8 @@ type BaseConnTrack struct { PacketsRx atomic.Uint64 BytesTx atomic.Uint64 BytesRx atomic.Uint64 + + DNATOrigPort atomic.Uint32 } // these small methods will be inlined by the compiler diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index a2355e5c7..8d64412e0 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -157,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp return tracker } -func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) { +func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, uint16, bool) { key := ConnKey{ SrcIP: srcIP, DstIP: dstIP, @@ -171,28 +171,30 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui if exists { t.updateState(key, conn, flags, direction, size) - return key, true + return key, uint16(conn.DNATOrigPort.Load()), true } - return key, false + return key, 0, false } -// TrackOutbound records an outbound TCP connection -func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) { - if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists { - // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size) +// TrackOutbound records an outbound TCP connection and returns the original port if DNAT reversal is needed +func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) uint16 { + if _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); exists { + return origPort } + // if (inverted direction) conn is not tracked, track this direction + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size, 0) + return 0 } // TrackInbound processes an inbound TCP packet and updates connection state -func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) { - t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size) +func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int, dnatOrigPort uint16) { + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size, dnatOrigPort) } // track is the common implementation for tracking both inbound and outbound connections -func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) { - key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) +func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) { + key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) if exists || flags&TCPSyn == 0 { return } @@ -210,8 +212,13 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla conn.tombstone.Store(false) conn.state.Store(int32(TCPStateNew)) + conn.DNATOrigPort.Store(uint32(origPort)) - t.logger.Trace2("New %s TCP connection: %s", direction, key) + if origPort != 0 { + t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s TCP connection: %s", direction, key) + } t.updateState(key, conn, flags, direction, size) t.mutex.Lock() @@ -449,6 +456,21 @@ func (t *TCPTracker) cleanup() { } } +// GetConnection safely retrieves a connection state +func (t *TCPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*TCPConnTrack, bool) { + t.mutex.RLock() + defer t.mutex.RUnlock() + + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } + conn, exists := t.connections[key] + return conn, exists +} + // Close stops the cleanup routine and releases resources func (t *TCPTracker) Close() { t.tickerCancel() diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index d01a8db4f..bb440f70a 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -603,7 +603,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { serverPort := uint16(80) // 1. Client sends SYN (we receive it as inbound) - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0) key := ConnKey{ SrcIP: clientIP, @@ -623,12 +623,12 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100) // 3. Client sends ACK to complete handshake - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0) require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion") // 4. Test data transfer // Client sends data - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000, 0) // Server sends ACK for data tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100) @@ -637,7 +637,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500) // Client sends ACK for data - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0) // Verify state and counters require.Equal(t, TCPStateEstablished, conn.GetState()) diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index e7f49c46f..a3b6a418b 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -58,20 +58,23 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp return tracker } -// TrackOutbound records an outbound UDP connection -func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) { - if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists { - // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size) +// TrackOutbound records an outbound UDP connection and returns the original port if DNAT reversal is needed +func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) uint16 { + _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size) + if exists { + return origPort } + // if (inverted direction) conn is not tracked, track this direction + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size, 0) + return 0 } // TrackInbound records an inbound UDP connection -func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) { - t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size) +func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int, dnatOrigPort uint16) { + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size, dnatOrigPort) } -func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) { +func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, uint16, bool) { key := ConnKey{ SrcIP: srcIP, DstIP: dstIP, @@ -86,15 +89,15 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort if exists { conn.UpdateLastSeen() conn.UpdateCounters(direction, size) - return key, true + return key, uint16(conn.DNATOrigPort.Load()), true } - return key, false + return key, 0, false } // track is the common implementation for tracking both inbound and outbound connections -func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) { - key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) +func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) { + key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) if exists { return } @@ -109,6 +112,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d SourcePort: srcPort, DestPort: dstPort, } + conn.DNATOrigPort.Store(uint32(origPort)) conn.UpdateLastSeen() conn.UpdateCounters(direction, size) @@ -116,7 +120,11 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d t.connections[key] = conn t.mutex.Unlock() - t.logger.Trace2("New %s UDP connection: %s", direction, key) + if origPort != 0 { + t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s UDP connection: %s", direction, key) + } t.sendEvent(nftypes.TypeStart, conn, ruleID) } diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 7eef49e31..fbc39b740 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -109,6 +109,10 @@ type Manager struct { dnatMappings map[netip.Addr]netip.Addr dnatMutex sync.RWMutex dnatBiMap *biDNATMap + + portDNATEnabled atomic.Bool + portDNATRules []portDNATRule + portDNATMutex sync.RWMutex } // decoder for packages @@ -122,6 +126,8 @@ type decoder struct { icmp6 layers.ICMPv6 decoded []gopacket.LayerType parser *gopacket.DecodingLayerParser + + dnatOrigPort uint16 } // Create userspace firewall manager constructor @@ -196,6 +202,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe netstack: netstack.IsEnabled(), localForwarding: enableLocalForwarding, dnatMappings: make(map[netip.Addr]netip.Addr), + portDNATRules: []portDNATRule{}, } m.routingEnabled.Store(false) @@ -630,7 +637,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool { return true } - m.trackOutbound(d, srcIP, dstIP, size) + m.trackOutbound(d, srcIP, dstIP, packetData, size) m.translateOutboundDNAT(packetData, d) return false @@ -674,14 +681,26 @@ func getTCPFlags(tcp *layers.TCP) uint8 { return flags } -func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) { +func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) { transport := d.decoded[1] switch transport { case layers.LayerTypeUDP: - m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) + origPort := m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) + if origPort == 0 { + break + } + if err := m.rewriteUDPPort(packetData, d, origPort, sourcePortOffset); err != nil { + m.logger.Error1("failed to rewrite UDP port: %v", err) + } case layers.LayerTypeTCP: flags := getTCPFlags(&d.tcp) - m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) + origPort := m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) + if origPort == 0 { + break + } + if err := m.rewriteTCPPort(packetData, d, origPort, sourcePortOffset); err != nil { + m.logger.Error1("failed to rewrite TCP port: %v", err) + } case layers.LayerTypeICMPv4: m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size) } @@ -691,13 +710,15 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt transport := d.decoded[1] switch transport { case layers.LayerTypeUDP: - m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size) + m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size, d.dnatOrigPort) case layers.LayerTypeTCP: flags := getTCPFlags(&d.tcp) - m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size) + m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort) case layers.LayerTypeICMPv4: m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size) } + + d.dnatOrigPort = 0 } // udpHooksDrop checks if any UDP hooks should drop the packet @@ -759,10 +780,20 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { return false } + // TODO: optimize port DNAT by caching matched rules in conntrack + if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated { + // Re-decode after port DNAT translation to update port information + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + m.logger.Error1("failed to re-decode packet after port DNAT: %v", err) + return true + } + srcIP, dstIP = m.extractIPs(d) + } + if translated := m.translateInboundReverse(packetData, d); translated { // Re-decode after translation to get original addresses if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { - m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err) + m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err) return true } srcIP, dstIP = m.extractIPs(d) diff --git a/client/firewall/uspfilter/log/log.go b/client/firewall/uspfilter/log/log.go index 5614e2ec3..139f702f2 100644 --- a/client/firewall/uspfilter/log/log.go +++ b/client/firewall/uspfilter/log/log.go @@ -50,6 +50,8 @@ type logMessage struct { arg4 any arg5 any arg6 any + arg7 any + arg8 any } // Logger is a high-performance, non-blocking logger @@ -94,7 +96,6 @@ func (l *Logger) SetLevel(level Level) { log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) } - func (l *Logger) Error(format string) { if l.level.Load() >= uint32(LevelError) { select { @@ -185,6 +186,15 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) { } } +func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) { + if l.level.Load() >= uint32(LevelDebug) { + select { + case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: + default: + } + } +} + func (l *Logger) Trace1(format string, arg1 any) { if l.level.Load() >= uint32(LevelTrace) { select { @@ -239,6 +249,16 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) { } } +// Trace8 logs a trace message with 8 arguments (8 placeholder in format string) +func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) { + if l.level.Load() >= uint32(LevelTrace) { + select { + case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}: + default: + } + } +} + func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { *buf = (*buf)[:0] *buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00") @@ -260,6 +280,12 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { argCount++ if msg.arg6 != nil { argCount++ + if msg.arg7 != nil { + argCount++ + if msg.arg8 != nil { + argCount++ + } + } } } } @@ -283,6 +309,10 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5) case 6: formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6) + case 7: + formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7) + case 8: + formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7, msg.arg8) } *buf = append(*buf, formatted...) @@ -390,4 +420,4 @@ func (l *Logger) Stop(ctx context.Context) error { case <-done: return nil } -} \ No newline at end of file +} diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go index 27b752531..13567872e 100644 --- a/client/firewall/uspfilter/nat.go +++ b/client/firewall/uspfilter/nat.go @@ -5,7 +5,9 @@ import ( "errors" "fmt" "net/netip" + "slices" + "github.com/google/gopacket" "github.com/google/gopacket/layers" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -13,6 +15,21 @@ import ( var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT") +var ( + errInvalidIPHeaderLength = errors.New("invalid IP header length") +) + +const ( + // Port offsets in TCP/UDP headers + sourcePortOffset = 0 + destinationPortOffset = 2 + + // IP address offsets in IPv4 header + sourceIPOffset = 12 + destinationIPOffset = 16 +) + +// ipv4Checksum calculates IPv4 header checksum. func ipv4Checksum(header []byte) uint16 { if len(header) < 20 { return 0 @@ -52,6 +69,7 @@ func ipv4Checksum(header []byte) uint16 { return ^uint16(sum) } +// icmpChecksum calculates ICMP checksum. func icmpChecksum(data []byte) uint16 { var sum1, sum2, sum3, sum4 uint32 i := 0 @@ -89,11 +107,21 @@ func icmpChecksum(data []byte) uint16 { return ^uint16(sum) } +// biDNATMap maintains bidirectional DNAT mappings. type biDNATMap struct { forward map[netip.Addr]netip.Addr reverse map[netip.Addr]netip.Addr } +// portDNATRule represents a port-specific DNAT rule. +type portDNATRule struct { + protocol gopacket.LayerType + origPort uint16 + targetPort uint16 + targetIP netip.Addr +} + +// newBiDNATMap creates a new bidirectional DNAT mapping structure. func newBiDNATMap() *biDNATMap { return &biDNATMap{ forward: make(map[netip.Addr]netip.Addr), @@ -101,11 +129,13 @@ func newBiDNATMap() *biDNATMap { } } +// set adds a bidirectional DNAT mapping between original and translated addresses. func (b *biDNATMap) set(original, translated netip.Addr) { b.forward[original] = translated b.reverse[translated] = original } +// delete removes a bidirectional DNAT mapping for the given original address. func (b *biDNATMap) delete(original netip.Addr) { if translated, exists := b.forward[original]; exists { delete(b.forward, original) @@ -113,19 +143,25 @@ func (b *biDNATMap) delete(original netip.Addr) { } } +// getTranslated returns the translated address for a given original address. func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) { translated, exists := b.forward[original] return translated, exists } +// getOriginal returns the original address for a given translated address. func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) { original, exists := b.reverse[translated] return original, exists } +// AddInternalDNATMapping adds a 1:1 IP address mapping for internal DNAT translation. func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error { - if !originalAddr.IsValid() || !translatedAddr.IsValid() { - return fmt.Errorf("invalid IP addresses") + if !originalAddr.IsValid() { + return fmt.Errorf("invalid original IP address") + } + if !translatedAddr.IsValid() { + return fmt.Errorf("invalid translated IP address") } if m.localipmanager.IsLocalIP(translatedAddr) { @@ -135,7 +171,6 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr m.dnatMutex.Lock() defer m.dnatMutex.Unlock() - // Initialize both maps together if either is nil if m.dnatMappings == nil || m.dnatBiMap == nil { m.dnatMappings = make(map[netip.Addr]netip.Addr) m.dnatBiMap = newBiDNATMap() @@ -151,7 +186,7 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr return nil } -// RemoveInternalDNATMapping removes a 1:1 IP address mapping +// RemoveInternalDNATMapping removes a 1:1 IP address mapping. func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error { m.dnatMutex.Lock() defer m.dnatMutex.Unlock() @@ -169,7 +204,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error { return nil } -// getDNATTranslation returns the translated address if a mapping exists +// getDNATTranslation returns the translated address if a mapping exists. func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { if !m.dnatEnabled.Load() { return addr, false @@ -181,7 +216,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { return translated, exists } -// findReverseDNATMapping finds original address for return traffic +// findReverseDNATMapping finds original address for return traffic. func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) { if !m.dnatEnabled.Load() { return translatedAddr, false @@ -193,16 +228,12 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, return original, exists } -// translateOutboundDNAT applies DNAT translation to outbound packets +// translateOutboundDNAT applies DNAT translation to outbound packets. func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { if !m.dnatEnabled.Load() { return false } - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return false - } - dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) translatedIP, exists := m.getDNATTranslation(dstIP) @@ -210,8 +241,8 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { return false } - if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil { - m.logger.Error1("Failed to rewrite packet destination: %v", err) + if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil { + m.logger.Error1("failed to rewrite packet destination: %v", err) return false } @@ -219,16 +250,12 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { return true } -// translateInboundReverse applies reverse DNAT to inbound return traffic +// translateInboundReverse applies reverse DNAT to inbound return traffic. func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { if !m.dnatEnabled.Load() { return false } - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return false - } - srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) originalIP, exists := m.findReverseDNATMapping(srcIP) @@ -236,8 +263,8 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return false } - if err := m.rewritePacketSource(packetData, d, originalIP); err != nil { - m.logger.Error1("Failed to rewrite packet source: %v", err) + if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil { + m.logger.Error1("failed to rewrite packet source: %v", err) return false } @@ -245,21 +272,21 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return true } -// rewritePacketDestination replaces destination IP in the packet -func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error { - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { +// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums. +func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error { + if !newIP.Is4() { return ErrIPv4Only } - var oldDst [4]byte - copy(oldDst[:], packetData[16:20]) - newDst := newIP.As4() + var oldIP [4]byte + copy(oldIP[:], packetData[ipOffset:ipOffset+4]) + newIPBytes := newIP.As4() - copy(packetData[16:20], newDst[:]) + copy(packetData[ipOffset:ipOffset+4], newIPBytes[:]) ipHeaderLen := int(d.ip4.IHL) * 4 if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return fmt.Errorf("invalid IP header length") + return errInvalidIPHeaderLength } binary.BigEndian.PutUint16(packetData[10:12], 0) @@ -269,44 +296,9 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP if len(d.decoded) > 1 { switch d.decoded[1] { case layers.LayerTypeTCP: - m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) + m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeUDP: - m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) - case layers.LayerTypeICMPv4: - m.updateICMPChecksum(packetData, ipHeaderLen) - } - } - - return nil -} - -// rewritePacketSource replaces the source IP address in the packet -func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error { - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { - return ErrIPv4Only - } - - var oldSrc [4]byte - copy(oldSrc[:], packetData[12:16]) - newSrc := newIP.As4() - - copy(packetData[12:16], newSrc[:]) - - ipHeaderLen := int(d.ip4.IHL) * 4 - if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return fmt.Errorf("invalid IP header length") - } - - binary.BigEndian.PutUint16(packetData[10:12], 0) - ipChecksum := ipv4Checksum(packetData[:ipHeaderLen]) - binary.BigEndian.PutUint16(packetData[10:12], ipChecksum) - - if len(d.decoded) > 1 { - switch d.decoded[1] { - case layers.LayerTypeTCP: - m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:]) - case layers.LayerTypeUDP: - m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:]) + m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeICMPv4: m.updateICMPChecksum(packetData, ipHeaderLen) } @@ -315,6 +307,7 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip return nil } +// updateTCPChecksum updates TCP checksum after IP address change per RFC 1624. func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { tcpStart := ipHeaderLen if len(packetData) < tcpStart+18 { @@ -327,6 +320,7 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) } +// updateUDPChecksum updates UDP checksum after IP address change per RFC 1624. func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { udpStart := ipHeaderLen if len(packetData) < udpStart+8 { @@ -344,6 +338,7 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) } +// updateICMPChecksum recalculates ICMP checksum after packet modification. func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { icmpStart := ipHeaderLen if len(packetData) < icmpStart+8 { @@ -356,7 +351,7 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { binary.BigEndian.PutUint16(icmpData[2:4], checksum) } -// incrementalUpdate performs incremental checksum update per RFC 1624 +// incrementalUpdate performs incremental checksum update per RFC 1624. func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { sum := uint32(^oldChecksum) @@ -391,7 +386,7 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { return ^uint16(sum) } -// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding) +// AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network. func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { if m.nativeFirewall == nil { return nil, errNatNotSupported @@ -399,10 +394,184 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) return m.nativeFirewall.AddDNATRule(rule) } -// DeleteDNATRule deletes a DNAT rule (delegates to native firewall) +// DeleteDNATRule deletes outbound DNAT rule. func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { if m.nativeFirewall == nil { return errNatNotSupported } return m.nativeFirewall.DeleteDNATRule(rule) } + +// addPortRedirection adds a port redirection rule. +func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { + m.portDNATMutex.Lock() + defer m.portDNATMutex.Unlock() + + rule := portDNATRule{ + protocol: protocol, + origPort: sourcePort, + targetPort: targetPort, + targetIP: targetIP, + } + + m.portDNATRules = append(m.portDNATRules, rule) + m.portDNATEnabled.Store(true) + + return nil +} + +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + var layerType gopacket.LayerType + switch protocol { + case firewall.ProtocolTCP: + layerType = layers.LayerTypeTCP + case firewall.ProtocolUDP: + layerType = layers.LayerTypeUDP + default: + return fmt.Errorf("unsupported protocol: %s", protocol) + } + + return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort) +} + +// removePortRedirection removes a port redirection rule. +func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { + m.portDNATMutex.Lock() + defer m.portDNATMutex.Unlock() + + m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool { + return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0 + }) + + if len(m.portDNATRules) == 0 { + m.portDNATEnabled.Store(false) + } + + return nil +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + var layerType gopacket.LayerType + switch protocol { + case firewall.ProtocolTCP: + layerType = layers.LayerTypeTCP + case firewall.ProtocolUDP: + layerType = layers.LayerTypeUDP + default: + return fmt.Errorf("unsupported protocol: %s", protocol) + } + + return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort) +} + +// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets. +func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool { + if !m.portDNATEnabled.Load() { + return false + } + + switch d.decoded[1] { + case layers.LayerTypeTCP: + dstPort := uint16(d.tcp.DstPort) + return m.applyPortRule(packetData, d, srcIP, dstIP, dstPort, layers.LayerTypeTCP, m.rewriteTCPPort) + case layers.LayerTypeUDP: + dstPort := uint16(d.udp.DstPort) + return m.applyPortRule(packetData, d, netip.Addr{}, dstIP, dstPort, layers.LayerTypeUDP, m.rewriteUDPPort) + default: + return false + } +} + +type portRewriteFunc func(packetData []byte, d *decoder, newPort uint16, portOffset int) error + +func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, port uint16, protocol gopacket.LayerType, rewriteFn portRewriteFunc) bool { + m.portDNATMutex.RLock() + defer m.portDNATMutex.RUnlock() + + for _, rule := range m.portDNATRules { + if rule.protocol != protocol || rule.targetIP.Compare(dstIP) != 0 { + continue + } + + if rule.targetPort == port && rule.targetIP.Compare(srcIP) == 0 { + return false + } + + if rule.origPort != port { + continue + } + + if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil { + m.logger.Error1("failed to rewrite port: %v", err) + return false + } + d.dnatOrigPort = rule.origPort + return true + } + return false +} + +// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum. +func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { + ipHeaderLen := int(d.ip4.IHL) * 4 + if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { + return errInvalidIPHeaderLength + } + + tcpStart := ipHeaderLen + if len(packetData) < tcpStart+4 { + return fmt.Errorf("packet too short for TCP header") + } + + portStart := tcpStart + portOffset + oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2]) + binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort) + + if len(packetData) >= tcpStart+18 { + checksumOffset := tcpStart + 16 + oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2]) + + var oldPortBytes, newPortBytes [2]byte + binary.BigEndian.PutUint16(oldPortBytes[:], oldPort) + binary.BigEndian.PutUint16(newPortBytes[:], newPort) + + newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:]) + binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) + } + + return nil +} + +// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum. +func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { + ipHeaderLen := int(d.ip4.IHL) * 4 + if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { + return errInvalidIPHeaderLength + } + + udpStart := ipHeaderLen + if len(packetData) < udpStart+8 { + return fmt.Errorf("packet too short for UDP header") + } + + portStart := udpStart + portOffset + oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2]) + binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort) + + checksumOffset := udpStart + 6 + if len(packetData) >= udpStart+8 { + oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2]) + if oldChecksum != 0 { + var oldPortBytes, newPortBytes [2]byte + binary.BigEndian.PutUint16(oldPortBytes[:], oldPort) + binary.BigEndian.PutUint16(newPortBytes[:], newPort) + + newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:]) + binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) + } + } + + return nil +} diff --git a/client/firewall/uspfilter/nat_bench_test.go b/client/firewall/uspfilter/nat_bench_test.go index 16dba682e..d726474cf 100644 --- a/client/firewall/uspfilter/nat_bench_test.go +++ b/client/firewall/uspfilter/nat_bench_test.go @@ -414,3 +414,127 @@ func BenchmarkChecksumOptimizations(b *testing.B) { } }) } + +// BenchmarkPortDNAT measures the performance of port DNAT operations +func BenchmarkPortDNAT(b *testing.B) { + scenarios := []struct { + name string + proto layers.IPProtocol + setupDNAT bool + useMatchPort bool + description string + }{ + { + name: "tcp_inbound_dnat_match", + proto: layers.IPProtocolTCP, + setupDNAT: true, + useMatchPort: true, + description: "TCP inbound port DNAT translation (22 → 22022)", + }, + { + name: "tcp_inbound_dnat_nomatch", + proto: layers.IPProtocolTCP, + setupDNAT: true, + useMatchPort: false, + description: "TCP inbound with DNAT configured but no port match", + }, + { + name: "tcp_inbound_no_dnat", + proto: layers.IPProtocolTCP, + setupDNAT: false, + useMatchPort: false, + description: "TCP inbound without DNAT (baseline)", + }, + { + name: "udp_inbound_dnat_match", + proto: layers.IPProtocolUDP, + setupDNAT: true, + useMatchPort: true, + description: "UDP inbound port DNAT translation (5353 → 22054)", + }, + { + name: "udp_inbound_dnat_nomatch", + proto: layers.IPProtocolUDP, + setupDNAT: true, + useMatchPort: false, + description: "UDP inbound with DNAT configured but no port match", + }, + { + name: "udp_inbound_no_dnat", + proto: layers.IPProtocolUDP, + setupDNAT: false, + useMatchPort: false, + description: "UDP inbound without DNAT (baseline)", + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + // Set logger to error level to reduce noise during benchmarking + manager.SetLogLevel(log.ErrorLevel) + defer func() { + // Restore to info level after benchmark + manager.SetLogLevel(log.InfoLevel) + }() + + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") + + var origPort, targetPort, testPort uint16 + if sc.proto == layers.IPProtocolTCP { + origPort, targetPort = 22, 22022 + } else { + origPort, targetPort = 5353, 22054 + } + + if sc.useMatchPort { + testPort = origPort + } else { + testPort = 443 // Different port + } + + // Setup port DNAT mapping if needed + if sc.setupDNAT { + err := manager.AddInboundDNAT(localAddr, protocolToFirewall(sc.proto), origPort, targetPort) + require.NoError(b, err) + } + + // Pre-establish inbound connection for outbound reverse test + if sc.setupDNAT && sc.useMatchPort { + inboundPacket := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, origPort) + manager.filterInbound(inboundPacket, 0) + } + + b.ResetTimer() + b.ReportAllocs() + + // Benchmark inbound DNAT translation + b.Run("inbound", func(b *testing.B) { + for i := 0; i < b.N; i++ { + // Create fresh packet each time + packet := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, testPort) + manager.filterInbound(packet, 0) + } + }) + + // Benchmark outbound reverse DNAT translation (only if DNAT is set up and port matches) + if sc.setupDNAT && sc.useMatchPort { + b.Run("outbound_reverse", func(b *testing.B) { + for i := 0; i < b.N; i++ { + // Create fresh return packet (from target port) + packet := generateDNATTestPacket(b, localAddr, clientIP, sc.proto, targetPort, 54321) + manager.filterOutbound(packet, 0) + } + }) + } + }) + } +} diff --git a/client/firewall/uspfilter/nat_test.go b/client/firewall/uspfilter/nat_test.go index 710abd445..2a285484c 100644 --- a/client/firewall/uspfilter/nat_test.go +++ b/client/firewall/uspfilter/nat_test.go @@ -8,6 +8,7 @@ import ( "github.com/google/gopacket/layers" "github.com/stretchr/testify/require" + firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/device" ) @@ -143,3 +144,111 @@ func TestDNATMappingManagement(t *testing.T) { err = manager.RemoveInternalDNATMapping(originalIP) require.Error(t, err, "Should error when removing non-existent mapping") } + +func TestInboundPortDNAT(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") + + testCases := []struct { + name string + protocol layers.IPProtocol + sourcePort uint16 + targetPort uint16 + }{ + {"TCP SSH", layers.IPProtocolTCP, 22, 22022}, + {"UDP DNS", layers.IPProtocolUDP, 5353, 22054}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := manager.AddInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort) + require.NoError(t, err) + + inboundPacket := generateDNATTestPacket(t, clientIP, localAddr, tc.protocol, 54321, tc.sourcePort) + d := parsePacket(t, inboundPacket) + + translated := manager.translateInboundPortDNAT(inboundPacket, d, clientIP, localAddr) + require.True(t, translated, "Inbound packet should be translated") + + d = parsePacket(t, inboundPacket) + var dstPort uint16 + switch tc.protocol { + case layers.IPProtocolTCP: + dstPort = uint16(d.tcp.DstPort) + case layers.IPProtocolUDP: + dstPort = uint16(d.udp.DstPort) + } + + require.Equal(t, tc.targetPort, dstPort, "Destination port should be rewritten to target port") + + err = manager.RemoveInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort) + require.NoError(t, err) + }) + } +} + +func TestInboundPortDNATNegative(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") + + err = manager.AddInboundDNAT(localAddr, firewall.ProtocolTCP, 22, 22022) + require.NoError(t, err) + + testCases := []struct { + name string + protocol layers.IPProtocol + srcIP netip.Addr + dstIP netip.Addr + srcPort uint16 + dstPort uint16 + }{ + {"Wrong port", layers.IPProtocolTCP, clientIP, localAddr, 54321, 80}, + {"Wrong IP", layers.IPProtocolTCP, clientIP, netip.MustParseAddr("100.64.0.99"), 54321, 22}, + {"Wrong protocol", layers.IPProtocolUDP, clientIP, localAddr, 54321, 22}, + {"ICMP", layers.IPProtocolICMPv4, clientIP, localAddr, 0, 0}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + packet := generateDNATTestPacket(t, tc.srcIP, tc.dstIP, tc.protocol, tc.srcPort, tc.dstPort) + d := parsePacket(t, packet) + + translated := manager.translateInboundPortDNAT(packet, d, tc.srcIP, tc.dstIP) + require.False(t, translated, "Packet should NOT be translated for %s", tc.name) + + d = parsePacket(t, packet) + if tc.protocol == layers.IPProtocolTCP { + require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged") + } else if tc.protocol == layers.IPProtocolUDP { + require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged") + } + }) + } +} + +func protocolToFirewall(proto layers.IPProtocol) firewall.Protocol { + switch proto { + case layers.IPProtocolTCP: + return firewall.ProtocolTCP + case layers.IPProtocolUDP: + return firewall.ProtocolUDP + default: + return firewall.ProtocolALL + } +} diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go index c75c0249d..c46a6581d 100644 --- a/client/firewall/uspfilter/tracer.go +++ b/client/firewall/uspfilter/tracer.go @@ -16,25 +16,33 @@ type PacketStage int const ( StageReceived PacketStage = iota + StageInboundPortDNAT + StageInbound1to1NAT StageConntrack StagePeerACL StageRouting StageRouteACL StageForwarding StageCompleted + StageOutbound1to1NAT + StageOutboundPortReverse ) const msgProcessingCompleted = "Processing completed" func (s PacketStage) String() string { return map[PacketStage]string{ - StageReceived: "Received", - StageConntrack: "Connection Tracking", - StagePeerACL: "Peer ACL", - StageRouting: "Routing", - StageRouteACL: "Route ACL", - StageForwarding: "Forwarding", - StageCompleted: "Completed", + StageReceived: "Received", + StageInboundPortDNAT: "Inbound Port DNAT", + StageInbound1to1NAT: "Inbound 1:1 NAT", + StageConntrack: "Connection Tracking", + StagePeerACL: "Peer ACL", + StageRouting: "Routing", + StageRouteACL: "Route ACL", + StageForwarding: "Forwarding", + StageCompleted: "Completed", + StageOutbound1to1NAT: "Outbound 1:1 NAT", + StageOutboundPortReverse: "Outbound DNAT Reverse", }[s] } @@ -261,6 +269,10 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa } func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace { + if m.handleInboundDNAT(trace, packetData, d, &srcIP, &dstIP) { + return trace + } + if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) { return trace } @@ -400,7 +412,16 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str } func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { - // will create or update the connection state + d := m.decoders.Get().(*decoder) + defer m.decoders.Put(d) + + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageCompleted, "Packet dropped - decode error", false) + return trace + } + + m.handleOutboundDNAT(trace, packetData, d) + dropped := m.filterOutbound(packetData, 0) if dropped { trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) @@ -409,3 +430,199 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr } return trace } + +func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool { + portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d) + if portDNATApplied { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false) + return true + } + *srcIP, *dstIP = m.extractIPs(d) + trace.DestinationPort = m.getDestPort(d) + } + + nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d) + if nat1to1Applied { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false) + return true + } + *srcIP, *dstIP = m.extractIPs(d) + } + + return false +} + +func (m *Manager) traceInboundPortDNAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.portDNATEnabled.Load() { + trace.AddResult(StageInboundPortDNAT, "Port DNAT not enabled", true) + return false + } + + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { + trace.AddResult(StageInboundPortDNAT, "Not IPv4, skipping port DNAT", true) + return false + } + + if len(d.decoded) < 2 { + trace.AddResult(StageInboundPortDNAT, "No transport layer, skipping port DNAT", true) + return false + } + + protocol := d.decoded[1] + if protocol != layers.LayerTypeTCP && protocol != layers.LayerTypeUDP { + trace.AddResult(StageInboundPortDNAT, "Not TCP/UDP, skipping port DNAT", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + var originalPort uint16 + if protocol == layers.LayerTypeTCP { + originalPort = uint16(d.tcp.DstPort) + } else { + originalPort = uint16(d.udp.DstPort) + } + + translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP) + if translated { + ipHeaderLen := int((packetData[0] & 0x0F) * 4) + translatedPort := uint16(packetData[ipHeaderLen+2])<<8 | uint16(packetData[ipHeaderLen+3]) + + protoStr := "TCP" + if protocol == layers.LayerTypeUDP { + protoStr = "UDP" + } + msg := fmt.Sprintf("%s port DNAT applied: %s:%d -> %s:%d", protoStr, dstIP, originalPort, dstIP, translatedPort) + trace.AddResult(StageInboundPortDNAT, msg, true) + return true + } + + trace.AddResult(StageInboundPortDNAT, "No matching port DNAT rule", true) + return false +} + +func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.dnatEnabled.Load() { + trace.AddResult(StageInbound1to1NAT, "1:1 NAT not enabled", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + + translated := m.translateInboundReverse(packetData, d) + if translated { + m.dnatMutex.RLock() + translatedIP, exists := m.dnatBiMap.getOriginal(srcIP) + m.dnatMutex.RUnlock() + + if exists { + msg := fmt.Sprintf("1:1 NAT reverse applied: %s -> %s", srcIP, translatedIP) + trace.AddResult(StageInbound1to1NAT, msg, true) + return true + } + } + + trace.AddResult(StageInbound1to1NAT, "No matching 1:1 NAT rule", true) + return false +} + +func (m *Manager) handleOutboundDNAT(trace *PacketTrace, packetData []byte, d *decoder) { + m.traceOutbound1to1NAT(trace, packetData, d) + m.traceOutboundPortReverse(trace, packetData, d) +} + +func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.dnatEnabled.Load() { + trace.AddResult(StageOutbound1to1NAT, "1:1 NAT not enabled", true) + return false + } + + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + + translated := m.translateOutboundDNAT(packetData, d) + if translated { + m.dnatMutex.RLock() + translatedIP, exists := m.dnatMappings[dstIP] + m.dnatMutex.RUnlock() + + if exists { + msg := fmt.Sprintf("1:1 NAT applied: %s -> %s", dstIP, translatedIP) + trace.AddResult(StageOutbound1to1NAT, msg, true) + return true + } + } + + trace.AddResult(StageOutbound1to1NAT, "No matching 1:1 NAT rule", true) + return false +} + +func (m *Manager) traceOutboundPortReverse(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.portDNATEnabled.Load() { + trace.AddResult(StageOutboundPortReverse, "Port DNAT not enabled", true) + return false + } + + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { + trace.AddResult(StageOutboundPortReverse, "Not IPv4, skipping port reverse", true) + return false + } + + if len(d.decoded) < 2 { + trace.AddResult(StageOutboundPortReverse, "No transport layer, skipping port reverse", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + + var origPort uint16 + transport := d.decoded[1] + switch transport { + case layers.LayerTypeTCP: + srcPort := uint16(d.tcp.SrcPort) + dstPort := uint16(d.tcp.DstPort) + conn, exists := m.tcpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort) + if exists { + origPort = uint16(conn.DNATOrigPort.Load()) + } + if origPort != 0 { + msg := fmt.Sprintf("TCP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort) + trace.AddResult(StageOutboundPortReverse, msg, true) + return true + } + case layers.LayerTypeUDP: + srcPort := uint16(d.udp.SrcPort) + dstPort := uint16(d.udp.DstPort) + conn, exists := m.udpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort) + if exists { + origPort = uint16(conn.DNATOrigPort.Load()) + } + if origPort != 0 { + msg := fmt.Sprintf("UDP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort) + trace.AddResult(StageOutboundPortReverse, msg, true) + return true + } + default: + trace.AddResult(StageOutboundPortReverse, "Not TCP/UDP, skipping port reverse", true) + return false + } + + trace.AddResult(StageOutboundPortReverse, "No tracked connection for DNAT reverse", true) + return false +} + +func (m *Manager) getDestPort(d *decoder) uint16 { + if len(d.decoded) < 2 { + return 0 + } + switch d.decoded[1] { + case layers.LayerTypeTCP: + return uint16(d.tcp.DstPort) + case layers.LayerTypeUDP: + return uint16(d.udp.DstPort) + default: + return 0 + } +} diff --git a/client/firewall/uspfilter/tracer_test.go b/client/firewall/uspfilter/tracer_test.go index 46c115787..ee1bb8a23 100644 --- a/client/firewall/uspfilter/tracer_test.go +++ b/client/firewall/uspfilter/tracer_test.go @@ -104,6 +104,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -126,6 +128,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -153,6 +157,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -179,6 +185,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -204,6 +212,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -228,6 +238,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -246,6 +258,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -264,6 +278,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageCompleted, @@ -287,6 +303,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageCompleted, }, @@ -301,6 +319,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageOutbound1to1NAT, + StageOutboundPortReverse, StageCompleted, }, expectedAllow: true, @@ -319,6 +339,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -340,6 +362,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -362,6 +386,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -382,6 +408,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -406,6 +434,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageRouting, StagePeerACL, StageCompleted, diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index a3a4ba40f..a1c0dff98 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "net" - "sync" + "net/netip" + "os" + "strconv" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" @@ -12,18 +14,14 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/peer" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" ) -var ( - // ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also - listenPort uint16 = 5353 - listenPortMu sync.RWMutex -) - const ( - dnsTTL = 60 //seconds + dnsTTL = 60 + envServerPort = "NB_DNS_FORWARDER_PORT" ) // ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list. @@ -36,28 +34,30 @@ type ForwarderEntry struct { type Manager struct { firewall firewall.Manager statusRecorder *peer.Status + localAddr netip.Addr + serverPort uint16 fwRules []firewall.Rule tcpRules []firewall.Rule dnsForwarder *DNSForwarder } -func ListenPort() uint16 { - listenPortMu.RLock() - defer listenPortMu.RUnlock() - return listenPort -} +func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr netip.Addr) *Manager { + serverPort := nbdns.ForwarderServerPort + if envPort := os.Getenv(envServerPort); envPort != "" { + if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 { + serverPort = uint16(port) + log.Infof("using custom DNS forwarder port from %s: %d", envServerPort, serverPort) + } else { + log.Warnf("invalid %s value %q, using default %d", envServerPort, envPort, nbdns.ForwarderServerPort) + } + } -func SetListenPort(port uint16) { - listenPortMu.Lock() - listenPort = port - listenPortMu.Unlock() -} - -func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager { return &Manager{ firewall: fw, statusRecorder: statusRecorder, + localAddr: localAddr, + serverPort: serverPort, } } @@ -71,7 +71,21 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder) + if m.localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + log.Warnf("failed to add DNS UDP DNAT rule: %v", err) + } else { + log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort) + } + + if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + log.Warnf("failed to add DNS TCP DNAT rule: %v", err) + } else { + log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort) + } + } + + m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", m.serverPort), dnsTTL, m.firewall, m.statusRecorder) go func() { if err := m.dnsForwarder.Listen(fwdEntries); err != nil { // todo handle close error if it is exists @@ -96,6 +110,17 @@ func (m *Manager) Stop(ctx context.Context) error { } var mErr *multierror.Error + + if m.localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err)) + } + + if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err)) + } + } + if err := m.dropDNSFirewall(); err != nil { mErr = multierror.Append(mErr, err) } @@ -111,7 +136,7 @@ func (m *Manager) Stop(ctx context.Context) error { func (m *Manager) allowDNSFirewall() error { dport := &firewall.Port{ IsRange: false, - Values: []uint16{ListenPort()}, + Values: []uint16{m.serverPort}, } if m.firewall == nil { diff --git a/client/internal/engine.go b/client/internal/engine.go index bebf04f6c..8f75c0646 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -202,9 +202,6 @@ type Engine struct { // WireGuard interface monitor wgIfaceMonitor *WGIfaceMonitor wgIfaceMonitorWg sync.WaitGroup - - // dns forwarder port - dnsFwdPort uint16 } // Peer is an instance of the Connection Peer @@ -247,7 +244,6 @@ func NewEngine( statusRecorder: statusRecorder, checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), - dnsFwdPort: dnsfwd.ListenPort(), } sm := profilemanager.NewServiceManager("") @@ -1084,7 +1080,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes) - e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort)) + e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries) // Ingress forward rules forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules()) @@ -1843,16 +1839,11 @@ func (e *Engine) GetWgAddr() netip.Addr { func (e *Engine) updateDNSForwarder( enabled bool, fwdEntries []*dnsfwd.ForwarderEntry, - forwarderPort uint16, ) { if e.config.DisableServerRoutes { return } - if forwarderPort > 0 { - dnsfwd.SetListenPort(forwarderPort) - } - if !enabled { if e.dnsForwardMgr == nil { return @@ -1864,20 +1855,17 @@ func (e *Engine) updateDNSForwarder( } if len(fwdEntries) > 0 { - switch { - case e.dnsForwardMgr == nil: - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) + if e.dnsForwardMgr == nil { + localAddr := e.wgInterface.Address().IP + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, localAddr) + if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil } - log.Infof("started domain router service with %d entries", len(fwdEntries)) - case e.dnsFwdPort != forwarderPort: - log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) - e.restartDnsFwd(fwdEntries, forwarderPort) - e.dnsFwdPort = forwarderPort - default: + log.Infof("started domain router service with %d entries", len(fwdEntries)) + } else { e.dnsForwardMgr.UpdateDomains(fwdEntries) } } else if e.dnsForwardMgr != nil { @@ -1887,20 +1875,6 @@ func (e *Engine) updateDNSForwarder( } e.dnsForwardMgr = nil } - -} - -func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) { - log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) - // stop and start the forwarder to apply the new port - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) - if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { - log.Errorf("failed to start DNS forward: %v", err) - e.dnsForwardMgr = nil - } } func (e *Engine) GetNet() (*netstack.Net, error) { diff --git a/client/internal/netflow/logger/logger.go b/client/internal/netflow/logger/logger.go index 899faf108..a033a2a7c 100644 --- a/client/internal/netflow/logger/logger.go +++ b/client/internal/netflow/logger/logger.go @@ -10,10 +10,10 @@ import ( "github.com/google/uuid" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/netflow/store" "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/dns" ) type rcvChan chan *types.EventFields @@ -138,7 +138,8 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) { func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool { // check dns collection - if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) { + if !l.dnsCollection.Load() && event.Protocol == types.UDP && + (event.DestPort == 53 || event.DestPort == dns.ForwarderClientPort || event.DestPort == dns.ForwarderServerPort) { return false } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 47c2ffcda..a8e697626 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -18,9 +18,9 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" nbdns "github.com/netbirdio/netbird/client/internal/dns" - "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" + pkgdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/fakeip" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" @@ -257,7 +257,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { r.MsgHdr.AuthenticatedData = true } - upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort()) + upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), pkgdns.ForwarderClientPort) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) defer cancel() diff --git a/dns/dns.go b/dns/dns.go index f889a32ec..40586f24d 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -19,6 +19,10 @@ const ( RootZone = "." // DefaultClass is the class supported by the system DefaultClass = "IN" + // ForwarderClientPort is the port clients connect to. DNAT rewrites packets from ForwarderClientPort to ForwarderServerPort. + ForwarderClientPort uint16 = 5353 + // ForwarderServerPort is the port the DNS forwarder actually listens on. Packets to ForwarderClientPort are DNATed here. + ForwarderServerPort uint16 = 22054 ) const invalidHostLabel = "[^a-zA-Z0-9-]+" diff --git a/management/server/dns.go b/management/server/dns.go index 534f43ec6..e5166ce47 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -21,8 +21,8 @@ import ( ) const ( - dnsForwarderPort = 22054 - oldForwarderPort = 5353 + dnsForwarderPort = nbdns.ForwarderServerPort + oldForwarderPort = nbdns.ForwarderClientPort ) const dnsForwarderPortMinVersion = "v0.59.0" @@ -196,7 +196,7 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID // If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0. func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { if len(peers) == 0 { - return oldForwarderPort + return int64(oldForwarderPort) } reqVer := semver.Canonical(requiredVersion) @@ -211,17 +211,17 @@ func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { peerVersion := semver.Canonical("v" + peer.Meta.WtVersion) if peerVersion == "" { // If any peer doesn't have version info, return 0 - return oldForwarderPort + return int64(oldForwarderPort) } // Compare versions if semver.Compare(peerVersion, reqVer) < 0 { - return oldForwarderPort + return int64(oldForwarderPort) } } // All peers have the required version or newer - return dnsForwarderPort + return int64(dnsForwarderPort) } // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 83caf74ef..96f73a390 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -394,7 +394,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - toProtocolDNSConfig(testData, cache, dnsForwarderPort) + toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort)) } }) @@ -402,7 +402,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { cache := &DNSConfigCache{} - toProtocolDNSConfig(testData, cache, dnsForwarderPort) + toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort)) } }) } @@ -455,13 +455,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { } // First run with config1 - result1 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) + result1 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort)) // Second run with config2 - result2 := toProtocolDNSConfig(config2, &cache, dnsForwarderPort) + result2 := toProtocolDNSConfig(config2, &cache, int64(dnsForwarderPort)) // Third run with config1 again - result3 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) + result3 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort)) // Verify that result1 and result3 are identical if !reflect.DeepEqual(result1, result3) { @@ -486,7 +486,7 @@ func TestComputeForwarderPort(t *testing.T) { // Test with empty peers list peers := []*nbpeer.Peer{} result := computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result) } @@ -504,7 +504,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result) } @@ -522,7 +522,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != dnsForwarderPort { + if result != int64(dnsForwarderPort) { t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result) } @@ -540,7 +540,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result) } @@ -553,7 +553,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result) } @@ -565,7 +565,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result == oldForwarderPort { + if result == int64(oldForwarderPort) { t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result) } @@ -578,7 +578,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result) } } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index fd795b926..3b2ab87fc 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1161,7 +1161,7 @@ func TestToSyncResponse(t *testing.T) { } dnsCache := &DNSConfigCache{} accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true} - response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, dnsForwarderPort) + response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort)) assert.NotNil(t, response) // assert peer config diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index ad82d37d9..3982ea2af 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -410,7 +410,7 @@ message DNSConfig { bool ServiceEnable = 1; repeated NameServerGroup NameServerGroups = 2; repeated CustomZone CustomZones = 3; - int64 ForwarderPort = 4; + int64 ForwarderPort = 4 [deprecated = true]; } // CustomZone represents a dns.CustomZone