diff --git a/client/Dockerfile b/client/Dockerfile index b2f627409..5cd459357 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -4,7 +4,7 @@ # sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . # sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest -FROM alpine:3.22.0 +FROM alpine:3.22.2 # iproute2: busybox doesn't display ip rules properly RUN apk add --no-cache \ bash \ diff --git a/client/cmd/debug.go b/client/cmd/debug.go index 18f3547ca..430012a17 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -168,7 +168,7 @@ func runForDuration(cmd *cobra.Command, args []string) error { client := proto.NewDaemonServiceClient(conn) - stat, err := client.Status(cmd.Context(), &proto.StatusRequest{}) + stat, err := client.Status(cmd.Context(), &proto.StatusRequest{ShouldRunProbes: true}) if err != nil { return fmt.Errorf("failed to get status: %v", status.Convert(err).Message()) } @@ -303,12 +303,18 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error { func getStatusOutput(cmd *cobra.Command, anon bool) string { var statusOutputString string - statusResp, err := getStatus(cmd.Context()) + statusResp, err := getStatus(cmd.Context(), true) if err != nil { cmd.PrintErrf("Failed to get status: %v\n", err) } else { + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + statusOutputString = nbstatus.ParseToFullDetailSummary( - nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""), + nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName), ) } return statusOutputString diff --git a/client/cmd/status.go b/client/cmd/status.go index 723f2367c..6e57ceb89 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -68,7 +68,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { ctx := internal.CtxInitState(cmd.Context()) - resp, err := getStatus(ctx) + resp, err := getStatus(ctx, false) if err != nil { return err } @@ -121,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { return nil } -func getStatus(ctx context.Context) (*proto.StatusResponse, error) { +func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) { conn, err := DialClientGRPCServer(ctx, daemonAddr) if err != nil { return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ @@ -130,7 +130,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) { } defer conn.Close() - resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true}) + resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes}) if err != nil { return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message()) } diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 32103b7ec..16b50211e 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -260,7 +260,7 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return m.router.UpdateSet(set, prefixes) } -// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -268,7 +268,7 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) } -// RemoveInboundDNAT removes inbound DNAT rule +// RemoveInboundDNAT removes an inbound DNAT rule. func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 343f5e05e..80aea7cf8 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -880,7 +880,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return nberrors.FormatErrorOrNil(merr) } -// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) @@ -913,7 +913,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol return nil } -// RemoveInboundDNAT removes inbound DNAT rule +// RemoveInboundDNAT removes an inbound DNAT rule. func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index aa016e1c2..aa90d3b9a 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -376,7 +376,7 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return m.router.UpdateSet(set, prefixes) } -// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -384,7 +384,7 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) } -// RemoveInboundDNAT removes inbound DNAT rule +// RemoveInboundDNAT removes an inbound DNAT rule. func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 0c091da96..648a6aedf 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -1350,7 +1350,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return nil } -// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) @@ -1426,7 +1426,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol return nil } -// RemoveInboundDNAT removes inbound DNAT rule +// RemoveInboundDNAT removes an inbound DNAT rule. func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { if err := r.refreshRulesMap(); err != nil { return fmt.Errorf(refreshRulesMapError, err) diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index bcf6d894b..7be0dd78f 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -22,6 +22,8 @@ type BaseConnTrack struct { PacketsRx atomic.Uint64 BytesTx atomic.Uint64 BytesRx atomic.Uint64 + + DNATOrigPort atomic.Uint32 } // these small methods will be inlined by the compiler diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index a2355e5c7..8d64412e0 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -157,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp return tracker } -func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) { +func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, uint16, bool) { key := ConnKey{ SrcIP: srcIP, DstIP: dstIP, @@ -171,28 +171,30 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui if exists { t.updateState(key, conn, flags, direction, size) - return key, true + return key, uint16(conn.DNATOrigPort.Load()), true } - return key, false + return key, 0, false } -// TrackOutbound records an outbound TCP connection -func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) { - if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists { - // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size) +// TrackOutbound records an outbound TCP connection and returns the original port if DNAT reversal is needed +func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) uint16 { + if _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); exists { + return origPort } + // if (inverted direction) conn is not tracked, track this direction + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size, 0) + return 0 } // TrackInbound processes an inbound TCP packet and updates connection state -func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) { - t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size) +func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int, dnatOrigPort uint16) { + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size, dnatOrigPort) } // track is the common implementation for tracking both inbound and outbound connections -func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) { - key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) +func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) { + key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) if exists || flags&TCPSyn == 0 { return } @@ -210,8 +212,13 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla conn.tombstone.Store(false) conn.state.Store(int32(TCPStateNew)) + conn.DNATOrigPort.Store(uint32(origPort)) - t.logger.Trace2("New %s TCP connection: %s", direction, key) + if origPort != 0 { + t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s TCP connection: %s", direction, key) + } t.updateState(key, conn, flags, direction, size) t.mutex.Lock() @@ -449,6 +456,21 @@ func (t *TCPTracker) cleanup() { } } +// GetConnection safely retrieves a connection state +func (t *TCPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*TCPConnTrack, bool) { + t.mutex.RLock() + defer t.mutex.RUnlock() + + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } + conn, exists := t.connections[key] + return conn, exists +} + // Close stops the cleanup routine and releases resources func (t *TCPTracker) Close() { t.tickerCancel() diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index d01a8db4f..bb440f70a 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -603,7 +603,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { serverPort := uint16(80) // 1. Client sends SYN (we receive it as inbound) - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0) key := ConnKey{ SrcIP: clientIP, @@ -623,12 +623,12 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100) // 3. Client sends ACK to complete handshake - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0) require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion") // 4. Test data transfer // Client sends data - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000, 0) // Server sends ACK for data tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100) @@ -637,7 +637,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500) // Client sends ACK for data - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0) // Verify state and counters require.Equal(t, TCPStateEstablished, conn.GetState()) diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index e7f49c46f..a3b6a418b 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -58,20 +58,23 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp return tracker } -// TrackOutbound records an outbound UDP connection -func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) { - if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists { - // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size) +// TrackOutbound records an outbound UDP connection and returns the original port if DNAT reversal is needed +func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) uint16 { + _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size) + if exists { + return origPort } + // if (inverted direction) conn is not tracked, track this direction + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size, 0) + return 0 } // TrackInbound records an inbound UDP connection -func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) { - t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size) +func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int, dnatOrigPort uint16) { + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size, dnatOrigPort) } -func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) { +func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, uint16, bool) { key := ConnKey{ SrcIP: srcIP, DstIP: dstIP, @@ -86,15 +89,15 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort if exists { conn.UpdateLastSeen() conn.UpdateCounters(direction, size) - return key, true + return key, uint16(conn.DNATOrigPort.Load()), true } - return key, false + return key, 0, false } // track is the common implementation for tracking both inbound and outbound connections -func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) { - key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) +func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) { + key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) if exists { return } @@ -109,6 +112,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d SourcePort: srcPort, DestPort: dstPort, } + conn.DNATOrigPort.Store(uint32(origPort)) conn.UpdateLastSeen() conn.UpdateCounters(direction, size) @@ -116,7 +120,11 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d t.connections[key] = conn t.mutex.Unlock() - t.logger.Trace2("New %s UDP connection: %s", direction, key) + if origPort != 0 { + t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s UDP connection: %s", direction, key) + } t.sendEvent(nftypes.TypeStart, conn, ruleID) } diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index e81042979..a480bbdbb 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -116,11 +116,9 @@ type Manager struct { dnatMutex sync.RWMutex dnatBiMap *biDNATMap - // Port-specific DNAT for SSH redirection portDNATEnabled atomic.Bool - portDNATMap *portDNATMap + portDNATRules []portDNATRule portDNATMutex sync.RWMutex - portNATTracker *portNATTracker netstackServices map[serviceKey]struct{} netstackServiceMutex sync.RWMutex @@ -137,6 +135,8 @@ type decoder struct { icmp6 layers.ICMPv6 decoded []gopacket.LayerType parser *gopacket.DecodingLayerParser + + dnatOrigPort uint16 } // Create userspace firewall manager constructor @@ -211,8 +211,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe netstack: netstack.IsEnabled(), localForwarding: enableLocalForwarding, dnatMappings: make(map[netip.Addr]netip.Addr), - portDNATMap: &portDNATMap{rules: make([]portDNATRule, 0)}, - portNATTracker: newPortNATTracker(), + portDNATRules: []portDNATRule{}, netstackServices: make(map[serviceKey]struct{}), } m.routingEnabled.Store(false) @@ -351,22 +350,18 @@ func (m *Manager) initForwarder() error { return nil } -// Init initializes the firewall manager with state manager. func (m *Manager) Init(*statemanager.Manager) error { return nil } -// IsServerRouteSupported returns whether server routes are supported. func (m *Manager) IsServerRouteSupported() bool { return true } -// IsStateful returns whether the firewall manager tracks connection state. func (m *Manager) IsStateful() bool { return m.stateful } -// AddNatRule adds a routing firewall rule for NAT translation. func (m *Manager) AddNatRule(pair firewall.RouterPair) error { if m.nativeRouter.Load() && m.nativeFirewall != nil { return m.nativeFirewall.AddNatRule(pair) @@ -652,9 +647,8 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool { return true } - m.trackOutbound(d, srcIP, dstIP, size) + m.trackOutbound(d, srcIP, dstIP, packetData, size) m.translateOutboundDNAT(packetData, d) - m.translateOutboundPortReverse(packetData, d) return false } @@ -697,14 +691,26 @@ func getTCPFlags(tcp *layers.TCP) uint8 { return flags } -func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) { +func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) { transport := d.decoded[1] switch transport { case layers.LayerTypeUDP: - m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) + origPort := m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) + if origPort == 0 { + break + } + if err := m.rewriteUDPPort(packetData, d, origPort, sourcePortOffset); err != nil { + m.logger.Error1("failed to rewrite UDP port: %v", err) + } case layers.LayerTypeTCP: flags := getTCPFlags(&d.tcp) - m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) + origPort := m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) + if origPort == 0 { + break + } + if err := m.rewriteTCPPort(packetData, d, origPort, sourcePortOffset); err != nil { + m.logger.Error1("failed to rewrite TCP port: %v", err) + } case layers.LayerTypeICMPv4: m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size) } @@ -714,13 +720,15 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt transport := d.decoded[1] switch transport { case layers.LayerTypeUDP: - m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size) + m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size, d.dnatOrigPort) case layers.LayerTypeTCP: flags := getTCPFlags(&d.tcp) - m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size) + m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort) case layers.LayerTypeICMPv4: m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size) } + + d.dnatOrigPort = 0 } // udpHooksDrop checks if any UDP hooks should drop the packet @@ -782,10 +790,11 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { return false } - if translated := m.translateInboundPortDNAT(packetData, d); translated { + // TODO: optimize port DNAT by caching matched rules in conntrack + if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated { // Re-decode after port DNAT translation to update port information if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { - m.logger.Error1("Failed to re-decode packet after port DNAT: %v", err) + m.logger.Error1("failed to re-decode packet after port DNAT: %v", err) return true } srcIP, dstIP = m.extractIPs(d) @@ -794,7 +803,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { if translated := m.translateInboundReverse(packetData, d); translated { // Re-decode after translation to get original addresses if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { - m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err) + m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err) return true } srcIP, dstIP = m.extractIPs(d) diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go index bf1c6feb5..13567872e 100644 --- a/client/firewall/uspfilter/nat.go +++ b/client/firewall/uspfilter/nat.go @@ -5,8 +5,7 @@ import ( "errors" "fmt" "net/netip" - "sync" - "time" + "slices" "github.com/google/gopacket" "github.com/google/gopacket/layers" @@ -21,10 +20,16 @@ var ( ) const ( - errRewriteTCPDestinationPort = "rewrite TCP destination port: %v" + // Port offsets in TCP/UDP headers + sourcePortOffset = 0 + destinationPortOffset = 2 + + // IP address offsets in IPv4 header + sourceIPOffset = 12 + destinationIPOffset = 16 ) -// ipv4Checksum calculates IPv4 header checksum using optimized parallel processing for performance. +// ipv4Checksum calculates IPv4 header checksum. func ipv4Checksum(header []byte) uint16 { if len(header) < 20 { return 0 @@ -64,7 +69,7 @@ func ipv4Checksum(header []byte) uint16 { return ^uint16(sum) } -// icmpChecksum calculates ICMP checksum using parallel accumulation for high-performance processing. +// icmpChecksum calculates ICMP checksum. func icmpChecksum(data []byte) uint16 { var sum1, sum2, sum3, sum4 uint32 i := 0 @@ -102,116 +107,21 @@ func icmpChecksum(data []byte) uint16 { return ^uint16(sum) } -// biDNATMap maintains bidirectional DNAT mappings for efficient forward and reverse lookups. +// biDNATMap maintains bidirectional DNAT mappings. type biDNATMap struct { forward map[netip.Addr]netip.Addr reverse map[netip.Addr]netip.Addr } -// portDNATRule represents a port-specific DNAT rule +// portDNATRule represents a port-specific DNAT rule. type portDNATRule struct { protocol gopacket.LayerType - sourcePort uint16 + origPort uint16 targetPort uint16 targetIP netip.Addr } -// portDNATMap manages port-specific DNAT rules -type portDNATMap struct { - rules []portDNATRule -} - -// ConnKey represents a connection 4-tuple for NAT tracking. -type ConnKey struct { - SrcIP netip.Addr - DstIP netip.Addr - SrcPort uint16 - DstPort uint16 -} - -// portNATConn tracks port NAT state for a specific connection. -type portNATConn struct { - rule portDNATRule - originalPort uint16 - translatedAt time.Time -} - -// portNATTracker tracks connection-specific port NAT state -type portNATTracker struct { - connections map[ConnKey]*portNATConn - mutex sync.RWMutex -} - -// newPortNATTracker creates a new port NAT tracker for stateful connection tracking. -func newPortNATTracker() *portNATTracker { - return &portNATTracker{ - connections: make(map[ConnKey]*portNATConn), - } -} - -// trackConnection tracks a connection that has port NAT applied using translated port as key. -func (t *portNATTracker) trackConnection(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, rule portDNATRule) { - t.mutex.Lock() - defer t.mutex.Unlock() - - key := ConnKey{ - SrcIP: srcIP, - DstIP: dstIP, - SrcPort: srcPort, - DstPort: rule.targetPort, - } - - t.connections[key] = &portNATConn{ - rule: rule, - originalPort: dstPort, - translatedAt: time.Now(), - } -} - -// getConnectionNAT returns NAT info for a connection if tracked, looking up by connection 4-tuple. -func (t *portNATTracker) getConnectionNAT(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) (*portNATConn, bool) { - t.mutex.RLock() - defer t.mutex.RUnlock() - - key := ConnKey{ - SrcIP: srcIP, - DstIP: dstIP, - SrcPort: srcPort, - DstPort: dstPort, - } - - conn, exists := t.connections[key] - return conn, exists -} - -// shouldApplyNAT checks if NAT should be applied to a new connection to prevent bidirectional conflicts. -func (t *portNATTracker) shouldApplyNAT(srcIP, dstIP netip.Addr, dstPort uint16) bool { - t.mutex.RLock() - defer t.mutex.RUnlock() - - for key, conn := range t.connections { - if key.SrcIP == dstIP && key.DstIP == srcIP && - conn.rule.sourcePort == dstPort && conn.originalPort == dstPort { - return false - } - } - return true -} - -// cleanupConnection removes a NAT connection based on original 4-tuple for connection cleanup. -func (t *portNATTracker) cleanupConnection(srcIP, dstIP netip.Addr, srcPort uint16) { - t.mutex.Lock() - defer t.mutex.Unlock() - - for key := range t.connections { - if key.SrcIP == srcIP && key.DstIP == dstIP && key.SrcPort == srcPort { - delete(t.connections, key) - return - } - } -} - -// newBiDNATMap creates a new bidirectional DNAT mapping structure for efficient forward/reverse lookups. +// newBiDNATMap creates a new bidirectional DNAT mapping structure. func newBiDNATMap() *biDNATMap { return &biDNATMap{ forward: make(map[netip.Addr]netip.Addr), @@ -219,7 +129,7 @@ func newBiDNATMap() *biDNATMap { } } -// set adds a bidirectional DNAT mapping between original and translated addresses for both directions. +// set adds a bidirectional DNAT mapping between original and translated addresses. func (b *biDNATMap) set(original, translated netip.Addr) { b.forward[original] = translated b.reverse[translated] = original @@ -233,13 +143,13 @@ func (b *biDNATMap) delete(original netip.Addr) { } } -// getTranslated returns the translated address for a given original address from forward mapping. +// getTranslated returns the translated address for a given original address. func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) { translated, exists := b.forward[original] return translated, exists } -// getOriginal returns the original address for a given translated address from reverse mapping. +// getOriginal returns the original address for a given translated address. func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) { original, exists := b.reverse[translated] return original, exists @@ -261,7 +171,6 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr m.dnatMutex.Lock() defer m.dnatMutex.Unlock() - // Initialize both maps together if either is nil if m.dnatMappings == nil || m.dnatBiMap == nil { m.dnatMappings = make(map[netip.Addr]netip.Addr) m.dnatBiMap = newBiDNATMap() @@ -295,7 +204,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error { return nil } -// getDNATTranslation returns the translated address if a mapping exists with fast-path optimization. +// getDNATTranslation returns the translated address if a mapping exists. func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { if !m.dnatEnabled.Load() { return addr, false @@ -307,7 +216,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { return translated, exists } -// findReverseDNATMapping finds original address for return traffic using reverse mapping. +// findReverseDNATMapping finds original address for return traffic. func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) { if !m.dnatEnabled.Load() { return translatedAddr, false @@ -319,16 +228,12 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, return original, exists } -// translateOutboundDNAT applies DNAT translation to outbound packets for 1:1 IP mapping. +// translateOutboundDNAT applies DNAT translation to outbound packets. func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { if !m.dnatEnabled.Load() { return false } - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return false - } - dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) translatedIP, exists := m.getDNATTranslation(dstIP) @@ -336,8 +241,8 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { return false } - if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil { - m.logger.Error1("rewrite packet destination: %v", err) + if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil { + m.logger.Error1("failed to rewrite packet destination: %v", err) return false } @@ -345,16 +250,12 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { return true } -// translateInboundReverse applies reverse DNAT to inbound return traffic for 1:1 IP mapping. +// translateInboundReverse applies reverse DNAT to inbound return traffic. func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { if !m.dnatEnabled.Load() { return false } - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return false - } - srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) originalIP, exists := m.findReverseDNATMapping(srcIP) @@ -362,8 +263,8 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return false } - if err := m.rewritePacketSource(packetData, d, originalIP); err != nil { - m.logger.Error1("rewrite packet source: %v", err) + if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil { + m.logger.Error1("failed to rewrite packet source: %v", err) return false } @@ -371,17 +272,17 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return true } -// rewritePacketDestination replaces destination IP in the packet and updates checksums. -func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error { - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { +// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums. +func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error { + if !newIP.Is4() { return ErrIPv4Only } - var oldDst [4]byte - copy(oldDst[:], packetData[16:20]) - newDst := newIP.As4() + var oldIP [4]byte + copy(oldIP[:], packetData[ipOffset:ipOffset+4]) + newIPBytes := newIP.As4() - copy(packetData[16:20], newDst[:]) + copy(packetData[ipOffset:ipOffset+4], newIPBytes[:]) ipHeaderLen := int(d.ip4.IHL) * 4 if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { @@ -395,9 +296,9 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP if len(d.decoded) > 1 { switch d.decoded[1] { case layers.LayerTypeTCP: - m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) + m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeUDP: - m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) + m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeICMPv4: m.updateICMPChecksum(packetData, ipHeaderLen) } @@ -406,42 +307,7 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP return nil } -// rewritePacketSource replaces the source IP address in the packet and updates checksums. -func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error { - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { - return ErrIPv4Only - } - - var oldSrc [4]byte - copy(oldSrc[:], packetData[12:16]) - newSrc := newIP.As4() - - copy(packetData[12:16], newSrc[:]) - - ipHeaderLen := int(d.ip4.IHL) * 4 - if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return errInvalidIPHeaderLength - } - - binary.BigEndian.PutUint16(packetData[10:12], 0) - ipChecksum := ipv4Checksum(packetData[:ipHeaderLen]) - binary.BigEndian.PutUint16(packetData[10:12], ipChecksum) - - if len(d.decoded) > 1 { - switch d.decoded[1] { - case layers.LayerTypeTCP: - m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:]) - case layers.LayerTypeUDP: - m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:]) - case layers.LayerTypeICMPv4: - m.updateICMPChecksum(packetData, ipHeaderLen) - } - } - - return nil -} - -// updateTCPChecksum updates TCP checksum after IP address change using incremental update per RFC 1624. +// updateTCPChecksum updates TCP checksum after IP address change per RFC 1624. func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { tcpStart := ipHeaderLen if len(packetData) < tcpStart+18 { @@ -454,7 +320,7 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) } -// updateUDPChecksum updates UDP checksum after IP address change using incremental update per RFC 1624. +// updateUDPChecksum updates UDP checksum after IP address change per RFC 1624. func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { udpStart := ipHeaderLen if len(packetData) < udpStart+8 { @@ -472,7 +338,7 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) } -// updateICMPChecksum recalculates ICMP checksum after packet modification using full recalculation. +// updateICMPChecksum recalculates ICMP checksum after packet modification. func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { icmpStart := ipHeaderLen if len(packetData) < icmpStart+8 { @@ -485,7 +351,7 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { binary.BigEndian.PutUint16(icmpData[2:4], checksum) } -// incrementalUpdate performs incremental checksum update per RFC 1624 for performance. +// incrementalUpdate performs incremental checksum update per RFC 1624. func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { sum := uint32(^oldChecksum) @@ -536,25 +402,25 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { return m.nativeFirewall.DeleteDNATRule(rule) } -// addPortRedirection adds port redirection rule for specified target IP, protocol and ports. +// addPortRedirection adds a port redirection rule. func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { m.portDNATMutex.Lock() defer m.portDNATMutex.Unlock() rule := portDNATRule{ protocol: protocol, - sourcePort: sourcePort, + origPort: sourcePort, targetPort: targetPort, targetIP: targetIP, } - m.portDNATMap.rules = append(m.portDNATMap.rules, rule) + m.portDNATRules = append(m.portDNATRules, rule) m.portDNATEnabled.Store(true) return nil } -// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services on specific ports. +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { var layerType gopacket.LayerType switch protocol { @@ -569,27 +435,23 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort) } -// removePortRedirection removes port redirection rule for specified target IP, protocol and ports. +// removePortRedirection removes a port redirection rule. func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { m.portDNATMutex.Lock() defer m.portDNATMutex.Unlock() - var filteredRules []portDNATRule - for _, rule := range m.portDNATMap.rules { - if !(rule.protocol == protocol && rule.sourcePort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0) { - filteredRules = append(filteredRules, rule) - } - } - m.portDNATMap.rules = filteredRules + m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool { + return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0 + }) - if len(m.portDNATMap.rules) == 0 { + if len(m.portDNATRules) == 0 { m.portDNATEnabled.Store(false) } return nil } -// RemoveInboundDNAT removes inbound DNAT rule for specified local address and ports. +// RemoveInboundDNAT removes an inbound DNAT rule. func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { var layerType gopacket.LayerType switch protocol { @@ -604,146 +466,55 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort) } -// translateInboundPortDNAT applies stateful port-specific DNAT translation to inbound packets. -func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder) bool { +// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets. +func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool { if !m.portDNATEnabled.Load() { return false } - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { + switch d.decoded[1] { + case layers.LayerTypeTCP: + dstPort := uint16(d.tcp.DstPort) + return m.applyPortRule(packetData, d, srcIP, dstIP, dstPort, layers.LayerTypeTCP, m.rewriteTCPPort) + case layers.LayerTypeUDP: + dstPort := uint16(d.udp.DstPort) + return m.applyPortRule(packetData, d, netip.Addr{}, dstIP, dstPort, layers.LayerTypeUDP, m.rewriteUDPPort) + default: return false } - - if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP { - return false - } - - srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) - dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) - srcPort := uint16(d.tcp.SrcPort) - dstPort := uint16(d.tcp.DstPort) - - if m.handleReturnTraffic(packetData, d, srcIP, dstIP, srcPort, dstPort) { - return true - } - - return m.handleNewConnection(packetData, d, srcIP, dstIP, srcPort, dstPort) } -// handleReturnTraffic processes return traffic for existing NAT connections. -func (m *Manager) handleReturnTraffic(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool { - if m.isTranslatedPortTraffic(srcIP, srcPort) { - return false - } +type portRewriteFunc func(packetData []byte, d *decoder, newPort uint16, portOffset int) error - if handled := m.handleExistingNATConnection(packetData, d, srcIP, dstIP, srcPort, dstPort); handled { - return true - } - - return m.handleForwardTrafficInExistingConnections(packetData, d, srcIP, dstIP, srcPort, dstPort) -} - -// isTranslatedPortTraffic checks if traffic is from a translated port that should be handled by outbound reverse. -func (m *Manager) isTranslatedPortTraffic(srcIP netip.Addr, srcPort uint16) bool { +func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, port uint16, protocol gopacket.LayerType, rewriteFn portRewriteFunc) bool { m.portDNATMutex.RLock() defer m.portDNATMutex.RUnlock() - for _, rule := range m.portDNATMap.rules { - if rule.protocol == layers.LayerTypeTCP && rule.targetPort == srcPort && - rule.targetIP.Unmap().Compare(srcIP.Unmap()) == 0 { - return true + for _, rule := range m.portDNATRules { + if rule.protocol != protocol || rule.targetIP.Compare(dstIP) != 0 { + continue } - } - return false -} -// handleExistingNATConnection processes return traffic for existing NAT connections. -func (m *Manager) handleExistingNATConnection(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool { - if natConn, exists := m.portNATTracker.getConnectionNAT(dstIP, srcIP, dstPort, srcPort); exists { - if err := m.rewriteTCPDestinationPort(packetData, d, natConn.originalPort); err != nil { - m.logger.Error1(errRewriteTCPDestinationPort, err) + if rule.targetPort == port && rule.targetIP.Compare(srcIP) == 0 { return false } - m.logger.Trace4("Inbound Port DNAT (return): %s:%d -> %s:%d", dstIP, srcPort, dstIP, natConn.originalPort) - return true - } - return false -} -// handleForwardTrafficInExistingConnections processes forward traffic in existing connections. -func (m *Manager) handleForwardTrafficInExistingConnections(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool { - m.portDNATMutex.RLock() - defer m.portDNATMutex.RUnlock() - - for _, rule := range m.portDNATMap.rules { - if rule.protocol != layers.LayerTypeTCP || rule.sourcePort != dstPort { - continue - } - if rule.targetIP.Unmap().Compare(dstIP.Unmap()) != 0 { + if rule.origPort != port { continue } - if _, exists := m.portNATTracker.getConnectionNAT(srcIP, dstIP, srcPort, rule.targetPort); !exists { - continue - } - - if err := m.rewriteTCPDestinationPort(packetData, d, rule.targetPort); err != nil { - m.logger.Error1(errRewriteTCPDestinationPort, err) + if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil { + m.logger.Error1("failed to rewrite port: %v", err) return false } + d.dnatOrigPort = rule.origPort return true } - return false } -// handleNewConnection processes new connections that match port DNAT rules. -func (m *Manager) handleNewConnection(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool { - m.portDNATMutex.RLock() - defer m.portDNATMutex.RUnlock() - - for _, rule := range m.portDNATMap.rules { - if m.applyPortDNATRule(packetData, d, rule, srcIP, dstIP, srcPort, dstPort) { - return true - } - } - return false -} - -// applyPortDNATRule applies a specific port DNAT rule if conditions are met. -func (m *Manager) applyPortDNATRule(packetData []byte, d *decoder, rule portDNATRule, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool { - if rule.protocol != layers.LayerTypeTCP || rule.sourcePort != dstPort { - return false - } - - if rule.targetIP.Unmap().Compare(dstIP.Unmap()) != 0 { - return false - } - - if !m.portNATTracker.shouldApplyNAT(srcIP, dstIP, dstPort) { - return false - } - - if err := m.rewriteTCPDestinationPort(packetData, d, rule.targetPort); err != nil { - m.logger.Error1(errRewriteTCPDestinationPort, err) - return false - } - - m.portNATTracker.trackConnection(srcIP, dstIP, srcPort, dstPort, rule) - m.logger.Trace8("Inbound Port DNAT (new): %s:%d -> %s:%d (tracked: %s:%d -> %s:%d)", dstIP, rule.sourcePort, dstIP, rule.targetPort, srcIP, srcPort, dstIP, rule.targetPort) - return true -} - -// rewriteTCPDestinationPort rewrites the destination port in a TCP packet and updates checksum. -func (m *Manager) rewriteTCPDestinationPort(packetData []byte, d *decoder, newPort uint16) error { - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return ErrIPv4Only - } - - if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP { - return fmt.Errorf("not a TCP packet") - } - +// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum. +func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { ipHeaderLen := int(d.ip4.IHL) * 4 if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { return errInvalidIPHeaderLength @@ -754,9 +525,9 @@ func (m *Manager) rewriteTCPDestinationPort(packetData []byte, d *decoder, newPo return fmt.Errorf("packet too short for TCP header") } - oldPort := binary.BigEndian.Uint16(packetData[tcpStart+2 : tcpStart+4]) - - binary.BigEndian.PutUint16(packetData[tcpStart+2:tcpStart+4], newPort) + portStart := tcpStart + portOffset + oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2]) + binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort) if len(packetData) >= tcpStart+18 { checksumOffset := tcpStart + 16 @@ -773,75 +544,34 @@ func (m *Manager) rewriteTCPDestinationPort(packetData []byte, d *decoder, newPo return nil } -// rewriteTCPSourcePort rewrites the source port in a TCP packet and updates checksum. -func (m *Manager) rewriteTCPSourcePort(packetData []byte, d *decoder, newPort uint16) error { - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return ErrIPv4Only - } - - if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP { - return fmt.Errorf("not a TCP packet") - } - +// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum. +func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { ipHeaderLen := int(d.ip4.IHL) * 4 if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { return errInvalidIPHeaderLength } - tcpStart := ipHeaderLen - if len(packetData) < tcpStart+4 { - return fmt.Errorf("packet too short for TCP header") + udpStart := ipHeaderLen + if len(packetData) < udpStart+8 { + return fmt.Errorf("packet too short for UDP header") } - oldPort := binary.BigEndian.Uint16(packetData[tcpStart : tcpStart+2]) + portStart := udpStart + portOffset + oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2]) + binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort) - binary.BigEndian.PutUint16(packetData[tcpStart:tcpStart+2], newPort) - - if len(packetData) >= tcpStart+18 { - checksumOffset := tcpStart + 16 + checksumOffset := udpStart + 6 + if len(packetData) >= udpStart+8 { oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2]) + if oldChecksum != 0 { + var oldPortBytes, newPortBytes [2]byte + binary.BigEndian.PutUint16(oldPortBytes[:], oldPort) + binary.BigEndian.PutUint16(newPortBytes[:], newPort) - var oldPortBytes, newPortBytes [2]byte - binary.BigEndian.PutUint16(oldPortBytes[:], oldPort) - binary.BigEndian.PutUint16(newPortBytes[:], newPort) - - newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:]) - binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) + newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:]) + binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) + } } return nil } - -// translateOutboundPortReverse applies stateful reverse port DNAT to outbound return traffic for SSH redirection. -func (m *Manager) translateOutboundPortReverse(packetData []byte, d *decoder) bool { - if !m.portDNATEnabled.Load() { - return false - } - - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return false - } - - if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP { - return false - } - - srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) - dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) - srcPort := uint16(d.tcp.SrcPort) - dstPort := uint16(d.tcp.DstPort) - - // For outbound reverse, we need to find the connection using the same key as when it was stored - // Connection was stored as: srcIP, dstIP, srcPort, translatedPort - // So for return traffic (srcIP=server, dstIP=client), we need: dstIP, srcIP, dstPort, srcPort - if natConn, exists := m.portNATTracker.getConnectionNAT(dstIP, srcIP, dstPort, srcPort); exists { - if err := m.rewriteTCPSourcePort(packetData, d, natConn.rule.sourcePort); err != nil { - m.logger.Error1("rewrite TCP source port: %v", err) - return false - } - - return true - } - - return false -} diff --git a/client/firewall/uspfilter/nat_bench_test.go b/client/firewall/uspfilter/nat_bench_test.go index 16dba682e..d726474cf 100644 --- a/client/firewall/uspfilter/nat_bench_test.go +++ b/client/firewall/uspfilter/nat_bench_test.go @@ -414,3 +414,127 @@ func BenchmarkChecksumOptimizations(b *testing.B) { } }) } + +// BenchmarkPortDNAT measures the performance of port DNAT operations +func BenchmarkPortDNAT(b *testing.B) { + scenarios := []struct { + name string + proto layers.IPProtocol + setupDNAT bool + useMatchPort bool + description string + }{ + { + name: "tcp_inbound_dnat_match", + proto: layers.IPProtocolTCP, + setupDNAT: true, + useMatchPort: true, + description: "TCP inbound port DNAT translation (22 → 22022)", + }, + { + name: "tcp_inbound_dnat_nomatch", + proto: layers.IPProtocolTCP, + setupDNAT: true, + useMatchPort: false, + description: "TCP inbound with DNAT configured but no port match", + }, + { + name: "tcp_inbound_no_dnat", + proto: layers.IPProtocolTCP, + setupDNAT: false, + useMatchPort: false, + description: "TCP inbound without DNAT (baseline)", + }, + { + name: "udp_inbound_dnat_match", + proto: layers.IPProtocolUDP, + setupDNAT: true, + useMatchPort: true, + description: "UDP inbound port DNAT translation (5353 → 22054)", + }, + { + name: "udp_inbound_dnat_nomatch", + proto: layers.IPProtocolUDP, + setupDNAT: true, + useMatchPort: false, + description: "UDP inbound with DNAT configured but no port match", + }, + { + name: "udp_inbound_no_dnat", + proto: layers.IPProtocolUDP, + setupDNAT: false, + useMatchPort: false, + description: "UDP inbound without DNAT (baseline)", + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + // Set logger to error level to reduce noise during benchmarking + manager.SetLogLevel(log.ErrorLevel) + defer func() { + // Restore to info level after benchmark + manager.SetLogLevel(log.InfoLevel) + }() + + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") + + var origPort, targetPort, testPort uint16 + if sc.proto == layers.IPProtocolTCP { + origPort, targetPort = 22, 22022 + } else { + origPort, targetPort = 5353, 22054 + } + + if sc.useMatchPort { + testPort = origPort + } else { + testPort = 443 // Different port + } + + // Setup port DNAT mapping if needed + if sc.setupDNAT { + err := manager.AddInboundDNAT(localAddr, protocolToFirewall(sc.proto), origPort, targetPort) + require.NoError(b, err) + } + + // Pre-establish inbound connection for outbound reverse test + if sc.setupDNAT && sc.useMatchPort { + inboundPacket := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, origPort) + manager.filterInbound(inboundPacket, 0) + } + + b.ResetTimer() + b.ReportAllocs() + + // Benchmark inbound DNAT translation + b.Run("inbound", func(b *testing.B) { + for i := 0; i < b.N; i++ { + // Create fresh packet each time + packet := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, testPort) + manager.filterInbound(packet, 0) + } + }) + + // Benchmark outbound reverse DNAT translation (only if DNAT is set up and port matches) + if sc.setupDNAT && sc.useMatchPort { + b.Run("outbound_reverse", func(b *testing.B) { + for i := 0; i < b.N; i++ { + // Create fresh return packet (from target port) + packet := generateDNATTestPacket(b, localAddr, clientIP, sc.proto, targetPort, 54321) + manager.filterOutbound(packet, 0) + } + }) + } + }) + } +} diff --git a/client/firewall/uspfilter/nat_test.go b/client/firewall/uspfilter/nat_test.go index 4c43077bc..2a285484c 100644 --- a/client/firewall/uspfilter/nat_test.go +++ b/client/firewall/uspfilter/nat_test.go @@ -1,11 +1,8 @@ package uspfilter import ( - "io" - "net" "net/netip" "testing" - "time" "github.com/google/gopacket" "github.com/google/gopacket/layers" @@ -148,8 +145,7 @@ func TestDNATMappingManagement(t *testing.T) { require.Error(t, err, "Should error when removing non-existent mapping") } -// TestSSHPortRedirection tests SSH port redirection functionality -func TestSSHPortRedirection(t *testing.T) { +func TestInboundPortDNAT(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, }, false, flowLogger) @@ -158,462 +154,48 @@ func TestSSHPortRedirection(t *testing.T) { require.NoError(t, manager.Close(nil)) }() - // Define NetBird network range - peerIP := netip.MustParseAddr("100.10.0.50") - clientIP := netip.MustParseAddr("100.10.0.100") + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") - // Add SSH port redirection rule - err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022) - require.NoError(t, err) + testCases := []struct { + name string + protocol layers.IPProtocol + sourcePort uint16 + targetPort uint16 + }{ + {"TCP SSH", layers.IPProtocolTCP, 22, 22022}, + {"UDP DNS", layers.IPProtocolUDP, 5353, 22054}, + } - // Verify port DNAT is enabled - require.True(t, manager.portDNATEnabled.Load(), "Port DNAT should be enabled") - require.Len(t, manager.portDNATMap.rules, 1, "Should have one port DNAT rule") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := manager.AddInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort) + require.NoError(t, err) - // Verify the rule configuration - rule := manager.portDNATMap.rules[0] - require.Equal(t, gopacket.LayerType(layers.LayerTypeTCP), rule.protocol) - require.Equal(t, uint16(22), rule.sourcePort) - require.Equal(t, uint16(22022), rule.targetPort) - require.Equal(t, peerIP, rule.targetIP) + inboundPacket := generateDNATTestPacket(t, clientIP, localAddr, tc.protocol, 54321, tc.sourcePort) + d := parsePacket(t, inboundPacket) - // Test inbound SSH packet (client -> peer:22, should redirect to peer:22022) - inboundPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22) - originalInbound := make([]byte, len(inboundPacket)) - copy(originalInbound, inboundPacket) + translated := manager.translateInboundPortDNAT(inboundPacket, d, clientIP, localAddr) + require.True(t, translated, "Inbound packet should be translated") - // Process inbound packet - translated := manager.translateInboundPortDNAT(inboundPacket, parsePacket(t, inboundPacket)) - require.True(t, translated, "Inbound SSH packet should be translated") - - // Verify destination port was changed from 22 to 22022 - d := parsePacket(t, inboundPacket) - require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Destination port should be rewritten to 22022") - - // Verify destination IP remains unchanged - dstIPAfter := netip.AddrFrom4([4]byte{inboundPacket[16], inboundPacket[17], inboundPacket[18], inboundPacket[19]}) - require.Equal(t, peerIP, dstIPAfter, "Destination IP should remain unchanged") - - // Test outbound return packet (peer:22022 -> client, should rewrite source port to 22) - outboundPacket := generateDNATTestPacket(t, peerIP, clientIP, layers.IPProtocolTCP, 22022, 54321) - originalOutbound := make([]byte, len(outboundPacket)) - copy(originalOutbound, outboundPacket) - - // Process outbound return packet - reversed := manager.translateOutboundPortReverse(outboundPacket, parsePacket(t, outboundPacket)) - require.True(t, reversed, "Outbound return packet should be reverse translated") - - // Verify source port was changed from 22022 to 22 - d = parsePacket(t, outboundPacket) - require.Equal(t, uint16(22), uint16(d.tcp.SrcPort), "Source port should be rewritten to 22") - - // Verify source IP remains unchanged - srcIPAfter := netip.AddrFrom4([4]byte{outboundPacket[12], outboundPacket[13], outboundPacket[14], outboundPacket[15]}) - require.Equal(t, peerIP, srcIPAfter, "Source IP should remain unchanged") - - // Test removal of SSH port redirection - err = manager.RemoveInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022) - require.NoError(t, err) - require.False(t, manager.portDNATEnabled.Load(), "Port DNAT should be disabled after removal") - require.Len(t, manager.portDNATMap.rules, 0, "Should have no port DNAT rules after removal") -} - -// TestSSHPortRedirectionNetworkFiltering tests that SSH redirection only applies to specified networks -func TestSSHPortRedirectionNetworkFiltering(t *testing.T) { - manager, err := Create(&IFaceMock{ - SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) - require.NoError(t, err) - defer func() { - require.NoError(t, manager.Close(nil)) - }() - - // Define NetBird network range - peerInNetwork := netip.MustParseAddr("100.10.0.50") - peerOutsideNetwork := netip.MustParseAddr("192.168.1.50") - clientIP := netip.MustParseAddr("100.10.0.100") - - // Add SSH port redirection rule for NetBird network only - err = manager.AddInboundDNAT(peerInNetwork, firewall.ProtocolTCP, 22, 22022) - require.NoError(t, err) - - // Test SSH packet to peer within NetBird network (should be redirected) - inNetworkPacket := generateDNATTestPacket(t, clientIP, peerInNetwork, layers.IPProtocolTCP, 54321, 22) - translated := manager.translateInboundPortDNAT(inNetworkPacket, parsePacket(t, inNetworkPacket)) - require.True(t, translated, "SSH packet to NetBird peer should be translated") - - // Verify port was changed - d := parsePacket(t, inNetworkPacket) - require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be redirected for NetBird peer") - - // Test SSH packet to peer outside NetBird network (should NOT be redirected) - outOfNetworkPacket := generateDNATTestPacket(t, clientIP, peerOutsideNetwork, layers.IPProtocolTCP, 54321, 22) - originalOutOfNetwork := make([]byte, len(outOfNetworkPacket)) - copy(originalOutOfNetwork, outOfNetworkPacket) - - notTranslated := manager.translateInboundPortDNAT(outOfNetworkPacket, parsePacket(t, outOfNetworkPacket)) - require.False(t, notTranslated, "SSH packet to non-NetBird peer should NOT be translated") - - // Verify packet was not modified - require.Equal(t, originalOutOfNetwork, outOfNetworkPacket, "Packet to non-NetBird peer should remain unchanged") -} - -// TestSSHPortRedirectionNonTCPTraffic tests that only TCP traffic is affected -func TestSSHPortRedirectionNonTCPTraffic(t *testing.T) { - manager, err := Create(&IFaceMock{ - SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) - require.NoError(t, err) - defer func() { - require.NoError(t, manager.Close(nil)) - }() - - // Define NetBird network range - peerIP := netip.MustParseAddr("100.10.0.50") - clientIP := netip.MustParseAddr("100.10.0.100") - - // Add SSH port redirection rule - err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022) - require.NoError(t, err) - - // Test UDP packet on port 22 (should NOT be redirected) - udpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolUDP, 54321, 22) - originalUDP := make([]byte, len(udpPacket)) - copy(originalUDP, udpPacket) - - translated := manager.translateInboundPortDNAT(udpPacket, parsePacket(t, udpPacket)) - require.False(t, translated, "UDP packet should NOT be translated by SSH port redirection") - require.Equal(t, originalUDP, udpPacket, "UDP packet should remain unchanged") - - // Test ICMP packet (should NOT be redirected) - icmpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolICMPv4, 0, 0) - originalICMP := make([]byte, len(icmpPacket)) - copy(originalICMP, icmpPacket) - - translated = manager.translateInboundPortDNAT(icmpPacket, parsePacket(t, icmpPacket)) - require.False(t, translated, "ICMP packet should NOT be translated by SSH port redirection") - require.Equal(t, originalICMP, icmpPacket, "ICMP packet should remain unchanged") -} - -// TestSSHPortRedirectionNonSSHPorts tests that only port 22 is redirected -func TestSSHPortRedirectionNonSSHPorts(t *testing.T) { - manager, err := Create(&IFaceMock{ - SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) - require.NoError(t, err) - defer func() { - require.NoError(t, manager.Close(nil)) - }() - - // Define NetBird network range - peerIP := netip.MustParseAddr("100.10.0.50") - clientIP := netip.MustParseAddr("100.10.0.100") - - // Add SSH port redirection rule - err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022) - require.NoError(t, err) - - // Test TCP packet on port 80 (should NOT be redirected) - httpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 80) - originalHTTP := make([]byte, len(httpPacket)) - copy(originalHTTP, httpPacket) - - translated := manager.translateInboundPortDNAT(httpPacket, parsePacket(t, httpPacket)) - require.False(t, translated, "Non-SSH TCP packet should NOT be translated") - require.Equal(t, originalHTTP, httpPacket, "Non-SSH TCP packet should remain unchanged") - - // Test TCP packet on port 443 (should NOT be redirected) - httpsPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 443) - originalHTTPS := make([]byte, len(httpsPacket)) - copy(originalHTTPS, httpsPacket) - - translated = manager.translateInboundPortDNAT(httpsPacket, parsePacket(t, httpsPacket)) - require.False(t, translated, "Non-SSH TCP packet should NOT be translated") - require.Equal(t, originalHTTPS, httpsPacket, "Non-SSH TCP packet should remain unchanged") - - // Test TCP packet on port 22 (SHOULD be redirected) - sshPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22) - translated = manager.translateInboundPortDNAT(sshPacket, parsePacket(t, sshPacket)) - require.True(t, translated, "SSH TCP packet should be translated") - - // Verify port was changed to 22022 - d := parsePacket(t, sshPacket) - require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "SSH port should be redirected to 22022") -} - -// TestFlexiblePortRedirection tests the flexible port redirection functionality -func TestFlexiblePortRedirection(t *testing.T) { - manager, err := Create(&IFaceMock{ - SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) - require.NoError(t, err) - defer func() { - require.NoError(t, manager.Close(nil)) - }() - - // Define peer and client IPs - peerIP := netip.MustParseAddr("10.0.0.50") - clientIP := netip.MustParseAddr("10.0.0.100") - - // Add custom port redirection: TCP port 8080 -> 3000 for peer IP - err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 8080, 3000) - require.NoError(t, err) - - // Verify port DNAT is enabled - require.True(t, manager.portDNATEnabled.Load(), "Port DNAT should be enabled") - require.Len(t, manager.portDNATMap.rules, 1, "Should have one port DNAT rule") - - // Verify the rule configuration - rule := manager.portDNATMap.rules[0] - require.Equal(t, gopacket.LayerType(layers.LayerTypeTCP), rule.protocol) - require.Equal(t, uint16(8080), rule.sourcePort) - require.Equal(t, uint16(3000), rule.targetPort) - require.Equal(t, peerIP, rule.targetIP) - - // Test inbound packet (client -> peer:8080, should redirect to peer:3000) - inboundPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 8080) - translated := manager.translateInboundPortDNAT(inboundPacket, parsePacket(t, inboundPacket)) - require.True(t, translated, "Inbound packet should be translated") - - // Verify destination port was changed from 8080 to 3000 - d := parsePacket(t, inboundPacket) - require.Equal(t, uint16(3000), uint16(d.tcp.DstPort), "Destination port should be rewritten to 3000") - - // Test outbound return packet (peer:3000 -> client, should rewrite source port to 8080) - outboundPacket := generateDNATTestPacket(t, peerIP, clientIP, layers.IPProtocolTCP, 3000, 54321) - reversed := manager.translateOutboundPortReverse(outboundPacket, parsePacket(t, outboundPacket)) - require.True(t, reversed, "Outbound return packet should be reverse translated") - - // Verify source port was changed from 3000 to 8080 - d = parsePacket(t, outboundPacket) - require.Equal(t, uint16(8080), uint16(d.tcp.SrcPort), "Source port should be rewritten to 8080") - - // Test removal of port redirection - err = manager.removePortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 8080, 3000) - require.NoError(t, err) - require.False(t, manager.portDNATEnabled.Load(), "Port DNAT should be disabled after removal") - require.Len(t, manager.portDNATMap.rules, 0, "Should have no port DNAT rules after removal") -} - -// TestMultiplePortRedirections tests multiple port redirection rules -func TestMultiplePortRedirections(t *testing.T) { - manager, err := Create(&IFaceMock{ - SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) - require.NoError(t, err) - defer func() { - require.NoError(t, manager.Close(nil)) - }() - - // Define peer and client IPs - peerIP := netip.MustParseAddr("172.16.0.50") - clientIP := netip.MustParseAddr("172.16.0.100") - - // Add multiple port redirections for peer IP - err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 22, 22022) // SSH - require.NoError(t, err) - err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 80, 8080) // HTTP - require.NoError(t, err) - err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 443, 8443) // HTTPS - require.NoError(t, err) - - // Verify all rules are present - require.True(t, manager.portDNATEnabled.Load(), "Port DNAT should be enabled") - require.Len(t, manager.portDNATMap.rules, 3, "Should have three port DNAT rules") - - // Test SSH redirection (22 -> 22022) - sshPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22) - translated := manager.translateInboundPortDNAT(sshPacket, parsePacket(t, sshPacket)) - require.True(t, translated, "SSH packet should be translated") - d := parsePacket(t, sshPacket) - require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "SSH should redirect to 22022") - - // Test HTTP redirection (80 -> 8080) - httpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 80) - translated = manager.translateInboundPortDNAT(httpPacket, parsePacket(t, httpPacket)) - require.True(t, translated, "HTTP packet should be translated") - d = parsePacket(t, httpPacket) - require.Equal(t, uint16(8080), uint16(d.tcp.DstPort), "HTTP should redirect to 8080") - - // Test HTTPS redirection (443 -> 8443) - httpsPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 443) - translated = manager.translateInboundPortDNAT(httpsPacket, parsePacket(t, httpsPacket)) - require.True(t, translated, "HTTPS packet should be translated") - d = parsePacket(t, httpsPacket) - require.Equal(t, uint16(8443), uint16(d.tcp.DstPort), "HTTPS should redirect to 8443") - - // Test removing one rule (HTTP) - err = manager.removePortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 80, 8080) - require.NoError(t, err) - require.Len(t, manager.portDNATMap.rules, 2, "Should have two rules after removing HTTP rule") - - // Verify HTTP is no longer redirected - httpPacket2 := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 80) - originalHTTP := make([]byte, len(httpPacket2)) - copy(originalHTTP, httpPacket2) - translated = manager.translateInboundPortDNAT(httpPacket2, parsePacket(t, httpPacket2)) - require.False(t, translated, "HTTP packet should NOT be translated after rule removal") - require.Equal(t, originalHTTP, httpPacket2, "HTTP packet should remain unchanged") - - // Verify SSH and HTTPS still work - sshPacket2 := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22) - translated = manager.translateInboundPortDNAT(sshPacket2, parsePacket(t, sshPacket2)) - require.True(t, translated, "SSH should still be translated") - - httpsPacket2 := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 443) - translated = manager.translateInboundPortDNAT(httpsPacket2, parsePacket(t, httpsPacket2)) - require.True(t, translated, "HTTPS should still be translated") -} - -// TestSSHPortRedirectionEndToEnd tests actual network delivery through sockets -func TestSSHPortRedirectionEndToEnd(t *testing.T) { - // Start a mock SSH server on port 22022 (NetBird SSH server) - mockSSHServer, err := net.Listen("tcp", "127.0.0.1:22022") - require.NoError(t, err, "Should be able to bind to NetBird SSH port") - defer func() { - require.NoError(t, mockSSHServer.Close()) - }() - - // Handle connections on the SSH server - serverReceivedData := make(chan string, 1) - go func() { - for { - conn, err := mockSSHServer.Accept() - if err != nil { - return // Server closed + d = parsePacket(t, inboundPacket) + var dstPort uint16 + switch tc.protocol { + case layers.IPProtocolTCP: + dstPort = uint16(d.tcp.DstPort) + case layers.IPProtocolUDP: + dstPort = uint16(d.udp.DstPort) } - go func(conn net.Conn) { - defer func() { - require.NoError(t, conn.Close()) - }() - buf := make([]byte, 1024) - n, err := conn.Read(buf) - if err != nil && err != io.EOF { - t.Logf("Server read error: %v", err) - return - } + require.Equal(t, tc.targetPort, dstPort, "Destination port should be rewritten to target port") - receivedData := string(buf[:n]) - serverReceivedData <- receivedData - - // Echo back a response - _, err = conn.Write([]byte("SSH-2.0-MockNetBirdSSH\r\n")) - if err != nil { - t.Logf("Server write error: %v", err) - } - }(conn) - } - }() - - // Give server time to start - time.Sleep(100 * time.Millisecond) - - // This test demonstrates what SHOULD happen after port redirection: - // 1. Client connects to 127.0.0.1:22 (standard SSH port) - // 2. Firewall redirects to 127.0.0.1:22022 (NetBird SSH server) - // 3. NetBird SSH server receives the connection - - t.Run("DirectConnectionToNetBirdSSHPort", func(t *testing.T) { - // This simulates what should happen AFTER port redirection - // Connect directly to 22022 (where NetBird SSH server listens) - conn, err := net.DialTimeout("tcp", "127.0.0.1:22022", 5*time.Second) - require.NoError(t, err, "Should connect to NetBird SSH server") - defer func() { - require.NoError(t, conn.Close()) - }() - - // Send SSH client identification - testData := "SSH-2.0-TestClient\r\n" - _, err = conn.Write([]byte(testData)) - require.NoError(t, err, "Should send data to SSH server") - - // Verify server received the data - select { - case received := <-serverReceivedData: - require.Equal(t, testData, received, "Server should receive client data") - case <-time.After(2 * time.Second): - t.Fatal("Server did not receive data within timeout") - } - - // Read server response - buf := make([]byte, 1024) - if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { - t.Logf("failed to set read deadline: %v", err) - } - n, err := conn.Read(buf) - require.NoError(t, err, "Should read server response") - - response := string(buf[:n]) - require.Equal(t, "SSH-2.0-MockNetBirdSSH\r\n", response, "Should receive SSH server identification") - }) - - t.Run("PortRedirectionSimulation", func(t *testing.T) { - // This test simulates the port redirection process - // Note: This doesn't test the actual userspace packet interception, - // but demonstrates the expected behavior - - t.Log("NOTE: This test demonstrates expected behavior after implementing") - t.Log("full userspace packet interception. Currently, we test packet") - t.Log("translation logic separately from actual network delivery.") - - // In a real implementation with userspace packet interception: - // 1. Client would connect to 127.0.0.1:22 - // 2. Userspace firewall would intercept packets - // 3. translateInboundPortDNAT would rewrite port 22 -> 22022 - // 4. Packets would be delivered to 127.0.0.1:22022 - // 5. NetBird SSH server would receive the connection - - // For now, we verify that the packet translation logic works correctly - // (this is tested in other test functions) and that the target server - // is reachable (tested above) - - clientIP := netip.MustParseAddr("127.0.0.1") - serverIP := netip.MustParseAddr("127.0.0.1") - - // Create manager with SSH port redirection - manager, err := Create(&IFaceMock{ - SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) - require.NoError(t, err) - defer func() { - require.NoError(t, manager.Close(nil)) - }() - - // Add SSH port redirection for localhost (for testing) - err = manager.AddInboundDNAT(netip.MustParseAddr("127.0.0.1"), firewall.ProtocolTCP, 22, 22022) - require.NoError(t, err) - - // Generate packet: client connecting to server:22 - sshPacket := generateDNATTestPacket(t, clientIP, serverIP, layers.IPProtocolTCP, 54321, 22) - originalPacket := make([]byte, len(sshPacket)) - copy(originalPacket, sshPacket) - - // Apply port redirection - translated := manager.translateInboundPortDNAT(sshPacket, parsePacket(t, sshPacket)) - require.True(t, translated, "SSH packet should be translated") - - // Verify port was redirected from 22 to 22022 - d := parsePacket(t, sshPacket) - require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be redirected to NetBird SSH server") - require.NotEqual(t, originalPacket, sshPacket, "Packet should be modified") - - t.Log("✓ Packet translation verified: port 22 redirected to 22022") - t.Log("✓ Target SSH server (port 22022) is reachable and responsive") - t.Log("→ Integration complete: SSH port redirection ready for userspace interception") - }) + err = manager.RemoveInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort) + require.NoError(t, err) + }) + } } -// TestFullSSHRedirectionWorkflow demonstrates the complete SSH redirection workflow -func TestFullSSHRedirectionWorkflow(t *testing.T) { - t.Log("=== SSH Port Redirection Workflow Test ===") - t.Log("This test demonstrates the complete SSH redirection process:") - t.Log("1. Client connects to peer:22 (standard SSH)") - t.Log("2. Userspace firewall intercepts and redirects to peer:22022") - t.Log("3. NetBird SSH server receives connection on port 22022") - t.Log("4. Return traffic is reverse-translated (22022 -> 22)") - - // Setup test environment +func TestInboundPortDNATNegative(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, }, false, flowLogger) @@ -622,47 +204,51 @@ func TestFullSSHRedirectionWorkflow(t *testing.T) { require.NoError(t, manager.Close(nil)) }() - // Define NetBird network and peer IPs - peerIP := netip.MustParseAddr("100.10.0.50") - clientIP := netip.MustParseAddr("100.10.0.100") + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") - // Step 1: Configure SSH port redirection - err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022) + err = manager.AddInboundDNAT(localAddr, firewall.ProtocolTCP, 22, 22022) require.NoError(t, err) - t.Log("✓ SSH port redirection configured for NetBird network") - // Step 2: Simulate inbound SSH connection (client -> peer:22) - t.Log("→ Simulating: ssh user@100.10.0.50") - inboundPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22) + testCases := []struct { + name string + protocol layers.IPProtocol + srcIP netip.Addr + dstIP netip.Addr + srcPort uint16 + dstPort uint16 + }{ + {"Wrong port", layers.IPProtocolTCP, clientIP, localAddr, 54321, 80}, + {"Wrong IP", layers.IPProtocolTCP, clientIP, netip.MustParseAddr("100.64.0.99"), 54321, 22}, + {"Wrong protocol", layers.IPProtocolUDP, clientIP, localAddr, 54321, 22}, + {"ICMP", layers.IPProtocolICMPv4, clientIP, localAddr, 0, 0}, + } - // Step 3: Apply inbound port redirection - translated := manager.translateInboundPortDNAT(inboundPacket, parsePacket(t, inboundPacket)) - require.True(t, translated, "Inbound SSH packet should be redirected") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + packet := generateDNATTestPacket(t, tc.srcIP, tc.dstIP, tc.protocol, tc.srcPort, tc.dstPort) + d := parsePacket(t, packet) - d := parsePacket(t, inboundPacket) - require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Should redirect to NetBird SSH server port") - t.Log("✓ Inbound packet redirected: 100.10.0.50:22 → 100.10.0.50:22022") + translated := manager.translateInboundPortDNAT(packet, d, tc.srcIP, tc.dstIP) + require.False(t, translated, "Packet should NOT be translated for %s", tc.name) - // Step 4: Simulate outbound return traffic (peer:22022 -> client) - t.Log("→ Simulating return traffic from NetBird SSH server") - outboundPacket := generateDNATTestPacket(t, peerIP, clientIP, layers.IPProtocolTCP, 22022, 54321) - - // Step 5: Apply outbound reverse translation - reversed := manager.translateOutboundPortReverse(outboundPacket, parsePacket(t, outboundPacket)) - require.True(t, reversed, "Outbound return packet should be reverse translated") - - d = parsePacket(t, outboundPacket) - require.Equal(t, uint16(22), uint16(d.tcp.SrcPort), "Should restore original SSH port") - t.Log("✓ Outbound packet reverse translated: 100.10.0.50:22022 → 100.10.0.50:22") - - // Step 6: Verify client sees standard SSH connection - srcIPAfter := netip.AddrFrom4([4]byte{outboundPacket[12], outboundPacket[13], outboundPacket[14], outboundPacket[15]}) - require.Equal(t, peerIP, srcIPAfter, "Client should see traffic from peer IP") - t.Log("✓ Client receives traffic from 100.10.0.50:22 (transparent redirection)") - - t.Log("=== SSH Port Redirection Workflow Complete ===") - t.Log("Result: Standard SSH clients can connect to NetBird peers using:") - t.Log(" ssh user@100.10.0.50") - t.Log("Instead of:") - t.Log(" ssh user@100.10.0.50 -p 22022") + d = parsePacket(t, packet) + if tc.protocol == layers.IPProtocolTCP { + require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged") + } else if tc.protocol == layers.IPProtocolUDP { + require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged") + } + }) + } +} + +func protocolToFirewall(proto layers.IPProtocol) firewall.Protocol { + switch proto { + case layers.IPProtocolTCP: + return firewall.ProtocolTCP + case layers.IPProtocolUDP: + return firewall.ProtocolUDP + default: + return firewall.ProtocolALL + } } diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go index c75c0249d..c46a6581d 100644 --- a/client/firewall/uspfilter/tracer.go +++ b/client/firewall/uspfilter/tracer.go @@ -16,25 +16,33 @@ type PacketStage int const ( StageReceived PacketStage = iota + StageInboundPortDNAT + StageInbound1to1NAT StageConntrack StagePeerACL StageRouting StageRouteACL StageForwarding StageCompleted + StageOutbound1to1NAT + StageOutboundPortReverse ) const msgProcessingCompleted = "Processing completed" func (s PacketStage) String() string { return map[PacketStage]string{ - StageReceived: "Received", - StageConntrack: "Connection Tracking", - StagePeerACL: "Peer ACL", - StageRouting: "Routing", - StageRouteACL: "Route ACL", - StageForwarding: "Forwarding", - StageCompleted: "Completed", + StageReceived: "Received", + StageInboundPortDNAT: "Inbound Port DNAT", + StageInbound1to1NAT: "Inbound 1:1 NAT", + StageConntrack: "Connection Tracking", + StagePeerACL: "Peer ACL", + StageRouting: "Routing", + StageRouteACL: "Route ACL", + StageForwarding: "Forwarding", + StageCompleted: "Completed", + StageOutbound1to1NAT: "Outbound 1:1 NAT", + StageOutboundPortReverse: "Outbound DNAT Reverse", }[s] } @@ -261,6 +269,10 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa } func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace { + if m.handleInboundDNAT(trace, packetData, d, &srcIP, &dstIP) { + return trace + } + if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) { return trace } @@ -400,7 +412,16 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str } func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { - // will create or update the connection state + d := m.decoders.Get().(*decoder) + defer m.decoders.Put(d) + + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageCompleted, "Packet dropped - decode error", false) + return trace + } + + m.handleOutboundDNAT(trace, packetData, d) + dropped := m.filterOutbound(packetData, 0) if dropped { trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) @@ -409,3 +430,199 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr } return trace } + +func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool { + portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d) + if portDNATApplied { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false) + return true + } + *srcIP, *dstIP = m.extractIPs(d) + trace.DestinationPort = m.getDestPort(d) + } + + nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d) + if nat1to1Applied { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false) + return true + } + *srcIP, *dstIP = m.extractIPs(d) + } + + return false +} + +func (m *Manager) traceInboundPortDNAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.portDNATEnabled.Load() { + trace.AddResult(StageInboundPortDNAT, "Port DNAT not enabled", true) + return false + } + + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { + trace.AddResult(StageInboundPortDNAT, "Not IPv4, skipping port DNAT", true) + return false + } + + if len(d.decoded) < 2 { + trace.AddResult(StageInboundPortDNAT, "No transport layer, skipping port DNAT", true) + return false + } + + protocol := d.decoded[1] + if protocol != layers.LayerTypeTCP && protocol != layers.LayerTypeUDP { + trace.AddResult(StageInboundPortDNAT, "Not TCP/UDP, skipping port DNAT", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + var originalPort uint16 + if protocol == layers.LayerTypeTCP { + originalPort = uint16(d.tcp.DstPort) + } else { + originalPort = uint16(d.udp.DstPort) + } + + translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP) + if translated { + ipHeaderLen := int((packetData[0] & 0x0F) * 4) + translatedPort := uint16(packetData[ipHeaderLen+2])<<8 | uint16(packetData[ipHeaderLen+3]) + + protoStr := "TCP" + if protocol == layers.LayerTypeUDP { + protoStr = "UDP" + } + msg := fmt.Sprintf("%s port DNAT applied: %s:%d -> %s:%d", protoStr, dstIP, originalPort, dstIP, translatedPort) + trace.AddResult(StageInboundPortDNAT, msg, true) + return true + } + + trace.AddResult(StageInboundPortDNAT, "No matching port DNAT rule", true) + return false +} + +func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.dnatEnabled.Load() { + trace.AddResult(StageInbound1to1NAT, "1:1 NAT not enabled", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + + translated := m.translateInboundReverse(packetData, d) + if translated { + m.dnatMutex.RLock() + translatedIP, exists := m.dnatBiMap.getOriginal(srcIP) + m.dnatMutex.RUnlock() + + if exists { + msg := fmt.Sprintf("1:1 NAT reverse applied: %s -> %s", srcIP, translatedIP) + trace.AddResult(StageInbound1to1NAT, msg, true) + return true + } + } + + trace.AddResult(StageInbound1to1NAT, "No matching 1:1 NAT rule", true) + return false +} + +func (m *Manager) handleOutboundDNAT(trace *PacketTrace, packetData []byte, d *decoder) { + m.traceOutbound1to1NAT(trace, packetData, d) + m.traceOutboundPortReverse(trace, packetData, d) +} + +func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.dnatEnabled.Load() { + trace.AddResult(StageOutbound1to1NAT, "1:1 NAT not enabled", true) + return false + } + + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + + translated := m.translateOutboundDNAT(packetData, d) + if translated { + m.dnatMutex.RLock() + translatedIP, exists := m.dnatMappings[dstIP] + m.dnatMutex.RUnlock() + + if exists { + msg := fmt.Sprintf("1:1 NAT applied: %s -> %s", dstIP, translatedIP) + trace.AddResult(StageOutbound1to1NAT, msg, true) + return true + } + } + + trace.AddResult(StageOutbound1to1NAT, "No matching 1:1 NAT rule", true) + return false +} + +func (m *Manager) traceOutboundPortReverse(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.portDNATEnabled.Load() { + trace.AddResult(StageOutboundPortReverse, "Port DNAT not enabled", true) + return false + } + + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { + trace.AddResult(StageOutboundPortReverse, "Not IPv4, skipping port reverse", true) + return false + } + + if len(d.decoded) < 2 { + trace.AddResult(StageOutboundPortReverse, "No transport layer, skipping port reverse", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + + var origPort uint16 + transport := d.decoded[1] + switch transport { + case layers.LayerTypeTCP: + srcPort := uint16(d.tcp.SrcPort) + dstPort := uint16(d.tcp.DstPort) + conn, exists := m.tcpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort) + if exists { + origPort = uint16(conn.DNATOrigPort.Load()) + } + if origPort != 0 { + msg := fmt.Sprintf("TCP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort) + trace.AddResult(StageOutboundPortReverse, msg, true) + return true + } + case layers.LayerTypeUDP: + srcPort := uint16(d.udp.SrcPort) + dstPort := uint16(d.udp.DstPort) + conn, exists := m.udpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort) + if exists { + origPort = uint16(conn.DNATOrigPort.Load()) + } + if origPort != 0 { + msg := fmt.Sprintf("UDP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort) + trace.AddResult(StageOutboundPortReverse, msg, true) + return true + } + default: + trace.AddResult(StageOutboundPortReverse, "Not TCP/UDP, skipping port reverse", true) + return false + } + + trace.AddResult(StageOutboundPortReverse, "No tracked connection for DNAT reverse", true) + return false +} + +func (m *Manager) getDestPort(d *decoder) uint16 { + if len(d.decoded) < 2 { + return 0 + } + switch d.decoded[1] { + case layers.LayerTypeTCP: + return uint16(d.tcp.DstPort) + case layers.LayerTypeUDP: + return uint16(d.udp.DstPort) + default: + return 0 + } +} diff --git a/client/firewall/uspfilter/tracer_test.go b/client/firewall/uspfilter/tracer_test.go index 46c115787..ee1bb8a23 100644 --- a/client/firewall/uspfilter/tracer_test.go +++ b/client/firewall/uspfilter/tracer_test.go @@ -104,6 +104,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -126,6 +128,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -153,6 +157,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -179,6 +185,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -204,6 +212,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -228,6 +238,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -246,6 +258,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -264,6 +278,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageCompleted, @@ -287,6 +303,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageCompleted, }, @@ -301,6 +319,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageOutbound1to1NAT, + StageOutboundPortReverse, StageCompleted, }, expectedAllow: true, @@ -319,6 +339,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -340,6 +362,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -362,6 +386,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -382,6 +408,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -406,6 +434,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageRouting, StagePeerACL, StageCompleted, diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index d39910cb4..c3d006a2b 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -47,7 +47,7 @@ nftables.txt: Anonymized nftables rules with packet counters, if --system-info f resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder. config.txt: Anonymized configuration information of the NetBird client. network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules. -state.json: Anonymized client state dump containing netbird states. +state.json: Anonymized client state dump containing netbird states for the active profile. mutex.prof: Mutex profiling information. goroutine.prof: Goroutine profiling information. block.prof: Block profiling information. @@ -576,6 +576,8 @@ func (g *BundleGenerator) addStateFile() error { return nil } + log.Debugf("Adding state file from: %s", path) + data, err := os.ReadFile(path) if err != nil { if errors.Is(err, fs.ErrNotExist) { diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index b06ba73ab..71badf0d4 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -13,6 +13,7 @@ import ( "strings" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -50,28 +51,21 @@ func (s *systemConfigurator) supportCustomPort() bool { } func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { - var err error - - if err := stateManager.UpdateState(&ShutdownState{}); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } - var ( searchDomains []string matchDomains []string ) - err = s.recordSystemDNSSettings(true) - if err != nil { + if err := s.recordSystemDNSSettings(true); err != nil { log.Errorf("unable to update record of System's DNS config: %s", err.Error()) } if config.RouteAll { searchDomains = append(searchDomains, "\"\"") - err = s.addLocalDNS() - if err != nil { - log.Infof("failed to enable split DNS") + if err := s.addLocalDNS(); err != nil { + log.Warnf("failed to add local DNS: %v", err) } + s.updateState(stateManager) } for _, dConf := range config.Domains { @@ -86,6 +80,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * } matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + var err error if len(matchDomains) != 0 { err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort) } else { @@ -95,6 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * if err != nil { return fmt.Errorf("add match domains: %w", err) } + s.updateState(stateManager) searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) if len(searchDomains) != 0 { @@ -106,6 +102,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * if err != nil { return fmt.Errorf("add search domains: %w", err) } + s.updateState(stateManager) if err := s.flushDNSCache(); err != nil { log.Errorf("failed to flush DNS cache: %v", err) @@ -114,6 +111,12 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * return nil } +func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) { + if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil { + log.Errorf("failed to update shutdown state: %s", err) + } +} + func (s *systemConfigurator) string() string { return "scutil" } @@ -167,18 +170,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { func (s *systemConfigurator) addLocalDNS() error { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { if err := s.recordSystemDNSSettings(true); err != nil { - log.Errorf("Unable to get system DNS configuration") return fmt.Errorf("recordSystemDNSSettings(): %w", err) } } localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) - if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 { - err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort) - if err != nil { - return fmt.Errorf("couldn't add local network DNS conf: %w", err) - } - } else { + if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { log.Info("Not enabling local DNS server") + return nil + } + + if err := s.addSearchDomains( + localKey, + strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, + ); err != nil { + return fmt.Errorf("add search domains: %w", err) } return nil diff --git a/client/internal/dns/host_darwin_test.go b/client/internal/dns/host_darwin_test.go new file mode 100644 index 000000000..c4efd17b0 --- /dev/null +++ b/client/internal/dns/host_darwin_test.go @@ -0,0 +1,111 @@ +//go:build !ios + +package dns + +import ( + "context" + "net/netip" + "os/exec" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) { + if testing.Short() { + t.Skip("skipping scutil integration test in short mode") + } + + tmpDir := t.TempDir() + stateFile := filepath.Join(tmpDir, "state.json") + + sm := statemanager.New(stateFile) + sm.RegisterState(&ShutdownState{}) + sm.Start() + defer func() { + require.NoError(t, sm.Stop(context.Background())) + }() + + configurator := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + } + + config := HostDNSConfig{ + ServerIP: netip.MustParseAddr("100.64.0.1"), + ServerPort: 53, + RouteAll: true, + Domains: []DomainConfig{ + {Domain: "example.com", MatchOnly: true}, + }, + } + + err := configurator.applyDNSConfig(config, sm) + require.NoError(t, err) + + require.NoError(t, sm.PersistState(context.Background())) + + searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) + matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) + + defer func() { + for _, key := range []string{searchKey, matchKey, localKey} { + _ = removeTestDNSKey(key) + } + }() + + for _, key := range []string{searchKey, matchKey, localKey} { + exists, err := checkDNSKeyExists(key) + require.NoError(t, err) + if exists { + t.Logf("Key %s exists before cleanup", key) + } + } + + sm2 := statemanager.New(stateFile) + sm2.RegisterState(&ShutdownState{}) + err = sm2.LoadState(&ShutdownState{}) + require.NoError(t, err) + + state := sm2.GetState(&ShutdownState{}) + if state == nil { + t.Skip("State not saved, skipping cleanup test") + } + + shutdownState, ok := state.(*ShutdownState) + require.True(t, ok) + + err = shutdownState.Cleanup() + require.NoError(t, err) + + for _, key := range []string{searchKey, matchKey, localKey} { + exists, err := checkDNSKeyExists(key) + require.NoError(t, err) + assert.False(t, exists, "Key %s should NOT exist after cleanup", key) + } +} + +func checkDNSKeyExists(key string) (bool, error) { + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader("show " + key + "\nquit\n") + output, err := cmd.CombinedOutput() + if err != nil { + if strings.Contains(string(output), "No such key") { + return false, nil + } + return false, err + } + return !strings.Contains(string(output), "No such key"), nil +} + +func removeTestDNSKey(key string) error { + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n") + _, err := cmd.CombinedOutput() + return err +} diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index a14a01f40..01b7edc48 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -17,6 +17,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/internal/winregistry" ) var ( @@ -178,13 +179,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) } - if err := stateManager.UpdateState(&ShutdownState{ - Guid: r.guid, - GPO: r.gpo, - NRPTEntryCount: r.nrptEntryCount, - }); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } + r.updateState(stateManager) var searchDomains, matchDomains []string for _, dConf := range config.Domains { @@ -197,6 +192,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, ".")) } + if err := r.removeDNSMatchPolicies(); err != nil { + log.Errorf("cleanup old dns match policies: %s", err) + } + if len(matchDomains) != 0 { count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP) if err != nil { @@ -204,19 +203,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager } r.nrptEntryCount = count } else { - if err := r.removeDNSMatchPolicies(); err != nil { - return fmt.Errorf("remove dns match policies: %w", err) - } r.nrptEntryCount = 0 } - if err := stateManager.UpdateState(&ShutdownState{ - Guid: r.guid, - GPO: r.gpo, - NRPTEntryCount: r.nrptEntryCount, - }); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } + r.updateState(stateManager) if err := r.updateSearchDomains(searchDomains); err != nil { return fmt.Errorf("update search domains: %w", err) @@ -227,6 +217,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager return nil } +func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) { + if err := stateManager.UpdateState(&ShutdownState{ + Guid: r.guid, + GPO: r.gpo, + NRPTEntryCount: r.nrptEntryCount, + }); err != nil { + log.Errorf("failed to update shutdown state: %s", err) + } +} + func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error { if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil { return fmt.Errorf("adding dns setup for all failed: %w", err) @@ -273,9 +273,9 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s return fmt.Errorf("remove existing dns policy: %w", err) } - regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE) + regKey, _, err := winregistry.CreateVolatileKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE) if err != nil { - return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err) + return fmt.Errorf("create volatile registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err) } defer closer(regKey) diff --git a/client/internal/dns/host_windows_test.go b/client/internal/dns/host_windows_test.go new file mode 100644 index 000000000..19496bf5a --- /dev/null +++ b/client/internal/dns/host_windows_test.go @@ -0,0 +1,102 @@ +package dns + +import ( + "fmt" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/windows/registry" +) + +// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up +// when the number of match domains decreases between configuration changes. +func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) { + if testing.Short() { + t.Skip("skipping registry integration test in short mode") + } + + defer cleanupRegistryKeys(t) + cleanupRegistryKeys(t) + + testIP := netip.MustParseAddr("100.64.0.1") + + // Create a test interface registry key so updateSearchDomains doesn't fail + testGUID := "{12345678-1234-1234-1234-123456789ABC}" + interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID + testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE) + require.NoError(t, err, "Should create test interface registry key") + testKey.Close() + defer func() { + _ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath) + }() + + cfg := ®istryConfigurator{ + guid: testGUID, + gpo: false, + } + + config5 := HostDNSConfig{ + ServerIP: testIP, + Domains: []DomainConfig{ + {Domain: "domain1.com", MatchOnly: true}, + {Domain: "domain2.com", MatchOnly: true}, + {Domain: "domain3.com", MatchOnly: true}, + {Domain: "domain4.com", MatchOnly: true}, + {Domain: "domain5.com", MatchOnly: true}, + }, + } + + err = cfg.applyDNSConfig(config5, nil) + require.NoError(t, err) + + // Verify all 5 entries exist + for i := 0; i < 5; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.True(t, exists, "Entry %d should exist after first config", i) + } + + config2 := HostDNSConfig{ + ServerIP: testIP, + Domains: []DomainConfig{ + {Domain: "domain1.com", MatchOnly: true}, + {Domain: "domain2.com", MatchOnly: true}, + }, + } + + err = cfg.applyDNSConfig(config2, nil) + require.NoError(t, err) + + // Verify first 2 entries exist + for i := 0; i < 2; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.True(t, exists, "Entry %d should exist after second config", i) + } + + // Verify entries 2-4 are cleaned up + for i := 2; i < 5; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i) + } +} + +func registryKeyExists(path string) (bool, error) { + k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE) + if err != nil { + if err == registry.ErrNotExist { + return false, nil + } + return false, err + } + k.Close() + return true, nil +} + +func cleanupRegistryKeys(*testing.T) { + cfg := ®istryConfigurator{nrptEntryCount: 10} + _ = cfg.removeDNSMatchPolicies() +} diff --git a/client/internal/dns/unclean_shutdown_darwin.go b/client/internal/dns/unclean_shutdown_darwin.go index 9bbdd2b56..f51b5cf8d 100644 --- a/client/internal/dns/unclean_shutdown_darwin.go +++ b/client/internal/dns/unclean_shutdown_darwin.go @@ -7,6 +7,7 @@ import ( ) type ShutdownState struct { + CreatedKeys []string } func (s *ShutdownState) Name() string { @@ -19,6 +20,10 @@ func (s *ShutdownState) Cleanup() error { return fmt.Errorf("create host manager: %w", err) } + for _, key := range s.CreatedKeys { + manager.createdKeys[key] = struct{}{} + } + if err := manager.restoreUncleanShutdownDNS(); err != nil { return fmt.Errorf("restore unclean shutdown dns: %w", err) } diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index a3a4ba40f..a1c0dff98 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "net" - "sync" + "net/netip" + "os" + "strconv" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" @@ -12,18 +14,14 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/peer" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" ) -var ( - // ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also - listenPort uint16 = 5353 - listenPortMu sync.RWMutex -) - const ( - dnsTTL = 60 //seconds + dnsTTL = 60 + envServerPort = "NB_DNS_FORWARDER_PORT" ) // ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list. @@ -36,28 +34,30 @@ type ForwarderEntry struct { type Manager struct { firewall firewall.Manager statusRecorder *peer.Status + localAddr netip.Addr + serverPort uint16 fwRules []firewall.Rule tcpRules []firewall.Rule dnsForwarder *DNSForwarder } -func ListenPort() uint16 { - listenPortMu.RLock() - defer listenPortMu.RUnlock() - return listenPort -} +func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr netip.Addr) *Manager { + serverPort := nbdns.ForwarderServerPort + if envPort := os.Getenv(envServerPort); envPort != "" { + if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 { + serverPort = uint16(port) + log.Infof("using custom DNS forwarder port from %s: %d", envServerPort, serverPort) + } else { + log.Warnf("invalid %s value %q, using default %d", envServerPort, envPort, nbdns.ForwarderServerPort) + } + } -func SetListenPort(port uint16) { - listenPortMu.Lock() - listenPort = port - listenPortMu.Unlock() -} - -func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager { return &Manager{ firewall: fw, statusRecorder: statusRecorder, + localAddr: localAddr, + serverPort: serverPort, } } @@ -71,7 +71,21 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder) + if m.localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + log.Warnf("failed to add DNS UDP DNAT rule: %v", err) + } else { + log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort) + } + + if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + log.Warnf("failed to add DNS TCP DNAT rule: %v", err) + } else { + log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort) + } + } + + m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", m.serverPort), dnsTTL, m.firewall, m.statusRecorder) go func() { if err := m.dnsForwarder.Listen(fwdEntries); err != nil { // todo handle close error if it is exists @@ -96,6 +110,17 @@ func (m *Manager) Stop(ctx context.Context) error { } var mErr *multierror.Error + + if m.localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err)) + } + + if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err)) + } + } + if err := m.dropDNSFirewall(); err != nil { mErr = multierror.Append(mErr, err) } @@ -111,7 +136,7 @@ func (m *Manager) Stop(ctx context.Context) error { func (m *Manager) allowDNSFirewall() error { dport := &firewall.Port{ IsRange: false, - Values: []uint16{ListenPort()}, + Values: []uint16{m.serverPort}, } if m.firewall == nil { diff --git a/client/internal/engine.go b/client/internal/engine.go index ed47c0cee..90af070ae 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -205,8 +205,7 @@ type Engine struct { wgIfaceMonitor *WGIfaceMonitor wgIfaceMonitorWg sync.WaitGroup - // dns forwarder port - dnsFwdPort uint16 + probeStunTurn *relay.StunTurnProbe } // Peer is an instance of the Connection Peer @@ -248,7 +247,7 @@ func NewEngine( statusRecorder: statusRecorder, checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), - dnsFwdPort: dnsfwd.ListenPort(), + probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), } sm := profilemanager.NewServiceManager("") @@ -1038,7 +1037,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes) - e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort)) + e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries) // Ingress forward rules forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules()) @@ -1613,7 +1612,7 @@ func (e *Engine) getRosenpassAddr() string { // RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services // and updates the status recorder with the latest states. -func (e *Engine) RunHealthProbes() bool { +func (e *Engine) RunHealthProbes(waitForResult bool) bool { e.syncMsgMux.Lock() signalHealthy := e.signal.IsHealthy() @@ -1645,8 +1644,12 @@ func (e *Engine) RunHealthProbes() bool { } e.syncMsgMux.Unlock() - - results := e.probeICE(stuns, turns) + var results []relay.ProbeResult + if waitForResult { + results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns) + } else { + results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns) + } e.statusRecorder.UpdateRelayStates(results) relayHealthy := true @@ -1663,13 +1666,6 @@ func (e *Engine) RunHealthProbes() bool { return allHealthy } -func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult { - return append( - relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns), - relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)..., - ) -} - // restartEngine restarts the engine by cancelling the client context func (e *Engine) restartEngine() { e.syncMsgMux.Lock() @@ -1789,16 +1785,11 @@ func (e *Engine) GetWgAddr() netip.Addr { func (e *Engine) updateDNSForwarder( enabled bool, fwdEntries []*dnsfwd.ForwarderEntry, - forwarderPort uint16, ) { if e.config.DisableServerRoutes { return } - if forwarderPort > 0 { - dnsfwd.SetListenPort(forwarderPort) - } - if !enabled { if e.dnsForwardMgr == nil { return @@ -1810,20 +1801,17 @@ func (e *Engine) updateDNSForwarder( } if len(fwdEntries) > 0 { - switch { - case e.dnsForwardMgr == nil: - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) + if e.dnsForwardMgr == nil { + localAddr := e.wgInterface.Address().IP + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, localAddr) + if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil } - log.Infof("started domain router service with %d entries", len(fwdEntries)) - case e.dnsFwdPort != forwarderPort: - log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) - e.restartDnsFwd(fwdEntries, forwarderPort) - e.dnsFwdPort = forwarderPort - default: + log.Infof("started domain router service with %d entries", len(fwdEntries)) + } else { e.dnsForwardMgr.UpdateDomains(fwdEntries) } } else if e.dnsForwardMgr != nil { @@ -1833,20 +1821,6 @@ func (e *Engine) updateDNSForwarder( } e.dnsForwardMgr = nil } - -} - -func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) { - log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) - // stop and start the forwarder to apply the new port - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) - if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { - log.Errorf("failed to start DNS forward: %v", err) - e.dnsForwardMgr = nil - } } func (e *Engine) GetNet() (*netstack.Net, error) { diff --git a/client/internal/netflow/logger/logger.go b/client/internal/netflow/logger/logger.go index 899faf108..a033a2a7c 100644 --- a/client/internal/netflow/logger/logger.go +++ b/client/internal/netflow/logger/logger.go @@ -10,10 +10,10 @@ import ( "github.com/google/uuid" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/netflow/store" "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/dns" ) type rcvChan chan *types.EventFields @@ -138,7 +138,8 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) { func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool { // check dns collection - if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) { + if !l.dnsCollection.Load() && event.Protocol == types.UDP && + (event.DestPort == 53 || event.DestPort == dns.ForwarderClientPort || event.DestPort == dns.ForwarderServerPort) { return false } diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index fa208716f..693ea1f31 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -2,6 +2,8 @@ package relay import ( "context" + "crypto/sha256" + "errors" "fmt" "net" "sync" @@ -15,6 +17,15 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) +const ( + DefaultCacheTTL = 20 * time.Second + probeTimeout = 6 * time.Second +) + +var ( + ErrCheckInProgress = errors.New("probe check is already in progress") +) + // ProbeResult holds the info about the result of a relay probe request type ProbeResult struct { URI string @@ -22,8 +33,164 @@ type ProbeResult struct { Addr string } +type StunTurnProbe struct { + cacheResults []ProbeResult + cacheTimestamp time.Time + cacheKey string + cacheTTL time.Duration + probeInProgress bool + probeDone chan struct{} + mu sync.Mutex +} + +func NewStunTurnProbe(cacheTTL time.Duration) *StunTurnProbe { + return &StunTurnProbe{ + cacheTTL: cacheTTL, + } +} + +func (p *StunTurnProbe) ProbeAllWaitResult(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + cacheKey := generateCacheKey(stuns, turns) + + p.mu.Lock() + if p.probeInProgress { + doneChan := p.probeDone + p.mu.Unlock() + + select { + case <-ctx.Done(): + log.Debugf("Context cancelled while waiting for probe results") + return createErrorResults(stuns, turns) + case <-doneChan: + return p.getCachedResults(cacheKey, stuns, turns) + } + } + + p.probeInProgress = true + probeDone := make(chan struct{}) + p.probeDone = probeDone + p.mu.Unlock() + + p.doProbe(ctx, stuns, turns, cacheKey) + close(probeDone) + + return p.getCachedResults(cacheKey, stuns, turns) +} + +// ProbeAll probes all given servers asynchronously and returns the results +func (p *StunTurnProbe) ProbeAll(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + cacheKey := generateCacheKey(stuns, turns) + + p.mu.Lock() + + if results := p.checkCache(cacheKey); results != nil { + p.mu.Unlock() + return results + } + + if p.probeInProgress { + p.mu.Unlock() + return createErrorResults(stuns, turns) + } + + p.probeInProgress = true + probeDone := make(chan struct{}) + p.probeDone = probeDone + log.Infof("started new probe for STUN, TURN servers") + go func() { + p.doProbe(ctx, stuns, turns, cacheKey) + close(probeDone) + }() + + p.mu.Unlock() + + timer := time.NewTimer(1300 * time.Millisecond) + defer timer.Stop() + + select { + case <-ctx.Done(): + log.Debugf("Context cancelled while waiting for probe results") + return createErrorResults(stuns, turns) + case <-probeDone: + // when the probe is return fast, return the results right away + return p.getCachedResults(cacheKey, stuns, turns) + case <-timer.C: + // if the probe takes longer than 1.3s, return error results to avoid blocking + return createErrorResults(stuns, turns) + } +} + +func (p *StunTurnProbe) checkCache(cacheKey string) []ProbeResult { + if p.cacheKey == cacheKey && len(p.cacheResults) > 0 { + age := time.Since(p.cacheTimestamp) + if age < p.cacheTTL { + results := append([]ProbeResult(nil), p.cacheResults...) + log.Debugf("returning cached probe results (age: %v)", age) + return results + } + } + return nil +} + +func (p *StunTurnProbe) getCachedResults(cacheKey string, stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + p.mu.Lock() + defer p.mu.Unlock() + + if p.cacheKey == cacheKey && len(p.cacheResults) > 0 { + return append([]ProbeResult(nil), p.cacheResults...) + } + return createErrorResults(stuns, turns) +} + +func (p *StunTurnProbe) doProbe(ctx context.Context, stuns []*stun.URI, turns []*stun.URI, cacheKey string) { + defer func() { + p.mu.Lock() + p.probeInProgress = false + p.mu.Unlock() + }() + results := make([]ProbeResult, len(stuns)+len(turns)) + + var wg sync.WaitGroup + for i, uri := range stuns { + wg.Add(1) + go func(idx int, stunURI *stun.URI) { + defer wg.Done() + + probeCtx, cancel := context.WithTimeout(ctx, probeTimeout) + defer cancel() + + results[idx].URI = stunURI.String() + results[idx].Addr, results[idx].Err = p.probeSTUN(probeCtx, stunURI) + }(i, uri) + } + + stunOffset := len(stuns) + for i, uri := range turns { + wg.Add(1) + go func(idx int, turnURI *stun.URI) { + defer wg.Done() + + probeCtx, cancel := context.WithTimeout(ctx, probeTimeout) + defer cancel() + + results[idx].URI = turnURI.String() + results[idx].Addr, results[idx].Err = p.probeTURN(probeCtx, turnURI) + }(stunOffset+i, uri) + } + + wg.Wait() + + p.mu.Lock() + p.cacheResults = results + p.cacheTimestamp = time.Now() + p.cacheKey = cacheKey + p.mu.Unlock() + + log.Debug("Stored new probe results in cache") +} + // ProbeSTUN tries binding to the given STUN uri and acquiring an address -func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { +func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { defer func() { if probeErr != nil { log.Debugf("stun probe error from %s: %s", uri, probeErr) @@ -83,7 +250,7 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) } // ProbeTURN tries allocating a session from the given TURN URI -func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { +func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { defer func() { if probeErr != nil { log.Debugf("turn probe error from %s: %s", uri, probeErr) @@ -160,28 +327,28 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) return relayConn.LocalAddr().String(), nil } -// ProbeAll probes all given servers asynchronously and returns the results -func ProbeAll( - ctx context.Context, - fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error), - relays []*stun.URI, -) []ProbeResult { - results := make([]ProbeResult, len(relays)) +func createErrorResults(stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + total := len(stuns) + len(turns) + results := make([]ProbeResult, total) - var wg sync.WaitGroup - for i, uri := range relays { - ctx, cancel := context.WithTimeout(ctx, 6*time.Second) - defer cancel() - - wg.Add(1) - go func(res *ProbeResult, stunURI *stun.URI) { - defer wg.Done() - res.URI = stunURI.String() - res.Addr, res.Err = fn(ctx, stunURI) - }(&results[i], uri) + allURIs := append(append([]*stun.URI{}, stuns...), turns...) + for i, uri := range allURIs { + results[i] = ProbeResult{ + URI: uri.String(), + Err: ErrCheckInProgress, + } } - wg.Wait() - return results } + +func generateCacheKey(stuns []*stun.URI, turns []*stun.URI) string { + h := sha256.New() + for _, uri := range stuns { + h.Write([]byte(uri.String())) + } + for _, uri := range turns { + h.Write([]byte(uri.String())) + } + return fmt.Sprintf("%x", h.Sum(nil)) +} diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 47c2ffcda..a8e697626 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -18,9 +18,9 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" nbdns "github.com/netbirdio/netbird/client/internal/dns" - "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" + pkgdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/fakeip" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" @@ -257,7 +257,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { r.MsgHdr.AuthenticatedData = true } - upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort()) + upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), pkgdns.ForwarderClientPort) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) defer cancel() diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 04513bbe4..d590dba0d 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -106,7 +106,7 @@ type DefaultManager struct { func NewManager(config ManagerConfig) *DefaultManager { mCTX, cancel := context.WithCancel(config.Context) notifier := notifier.NewNotifier() - sysOps := systemops.NewSysOps(config.WGInterface, notifier) + sysOps := systemops.New(config.WGInterface, notifier) if runtime.GOOS == "windows" && config.WGInterface != nil { nbnet.SetVPNInterfaceName(config.WGInterface.Name()) diff --git a/client/internal/routemanager/systemops/flush_nonbsd.go b/client/internal/routemanager/systemops/flush_nonbsd.go new file mode 100644 index 000000000..f1c45d6cf --- /dev/null +++ b/client/internal/routemanager/systemops/flush_nonbsd.go @@ -0,0 +1,8 @@ +//go:build !((darwin && !ios) || dragonfly || freebsd || netbsd || openbsd) + +package systemops + +// FlushMarkedRoutes is a no-op on non-BSD platforms. +func (r *SysOps) FlushMarkedRoutes() error { + return nil +} diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go index 8e158711e..e0d045b07 100644 --- a/client/internal/routemanager/systemops/state.go +++ b/client/internal/routemanager/systemops/state.go @@ -13,11 +13,11 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - sysops := NewSysOps(nil, nil) - sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) - sysops.refCounter.LoadData((*ExclusionCounter)(s)) + sysOps := New(nil, nil) + sysOps.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysOps.removeFromRouteTable) + sysOps.refCounter.LoadData((*ExclusionCounter)(s)) - return sysops.refCounter.Flush() + return sysOps.refCounter.Flush() } func (s *ShutdownState) MarshalJSON() ([]byte, error) { diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index 8da138117..c0ca21d22 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -83,7 +83,7 @@ type SysOps struct { localSubnetsCacheTime time.Time } -func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { +func New(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { return &SysOps{ wgInterface: wgInterface, notifier: notifier, diff --git a/client/internal/routemanager/systemops/systemops_bsd_test.go b/client/internal/routemanager/systemops/systemops_bsd_test.go index 0d892c162..ec4fc406e 100644 --- a/client/internal/routemanager/systemops/systemops_bsd_test.go +++ b/client/internal/routemanager/systemops/systemops_bsd_test.go @@ -42,7 +42,7 @@ func TestConcurrentRoutes(t *testing.T) { _, intf = setupDummyInterface(t) nexthop = Nexthop{netip.Addr{}, intf} - r := NewSysOps(nil, nil) + r := New(nil, nil) var wg sync.WaitGroup for i := 0; i < 1024; i++ { @@ -146,7 +146,7 @@ func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR strin nexthop := Nexthop{netip.Addr{}, netIntf} - r := NewSysOps(nil, nil) + r := New(nil, nil) err = r.addToRouteTable(prefix, nexthop) require.NoError(t, err, "Failed to add route to table") diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 32ea38a7a..d9b109beb 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) @@ -342,7 +342,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) @@ -486,7 +486,7 @@ func setupTestEnv(t *testing.T) { assert.NoError(t, wgInterface.Close()) }) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err, "setupRouting should not return err") diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index d43c2d5bf..7089178fb 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -7,19 +7,39 @@ import ( "fmt" "net" "net/netip" + "os" "strconv" "syscall" "time" "unsafe" "github.com/cenkalti/backoff/v4" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "golang.org/x/net/route" "golang.org/x/sys/unix" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" ) +const ( + envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG" +) + +var routeProtoFlag int + +func init() { + switch os.Getenv(envRouteProtoFlag) { + case "2": + routeProtoFlag = unix.RTF_PROTO2 + case "3": + routeProtoFlag = unix.RTF_PROTO3 + default: + routeProtoFlag = unix.RTF_PROTO1 + } +} + func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { return r.setupRefCounter(initAddresses, stateManager) } @@ -28,6 +48,62 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRout return r.cleanupRefCounter(stateManager) } +// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag. +func (r *SysOps) FlushMarkedRoutes() error { + rib, err := retryFetchRIB() + if err != nil { + return fmt.Errorf("fetch routing table: %w", err) + } + + msgs, err := route.ParseRIB(route.RIBTypeRoute, rib) + if err != nil { + return fmt.Errorf("parse routing table: %w", err) + } + + var merr *multierror.Error + flushedCount := 0 + + for _, msg := range msgs { + rtMsg, ok := msg.(*route.RouteMessage) + if !ok { + continue + } + + if rtMsg.Flags&routeProtoFlag == 0 { + continue + } + + routeInfo, err := MsgToRoute(rtMsg) + if err != nil { + log.Debugf("Skipping route flush: %v", err) + continue + } + + if !routeInfo.Dst.IsValid() || !routeInfo.Dst.IsSingleIP() { + continue + } + + nexthop := Nexthop{ + IP: routeInfo.Gw, + Intf: routeInfo.Interface, + } + + if err := r.removeFromRouteTable(routeInfo.Dst, nexthop); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", routeInfo.Dst, err)) + continue + } + + flushedCount++ + log.Debugf("Flushed marked route: %s", routeInfo.Dst) + } + + if flushedCount > 0 { + log.Infof("Flushed %d residual NetBird routes from previous session", flushedCount) + } + + return nberrors.FormatErrorOrNil(merr) +} + func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { return r.routeSocket(unix.RTM_ADD, prefix, nexthop) } @@ -105,7 +181,7 @@ func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func( func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) { msg = &route.RouteMessage{ Type: action, - Flags: unix.RTF_UP, + Flags: unix.RTF_UP | routeProtoFlag, Version: unix.RTM_VERSION, Seq: r.getSeq(), } diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 29f962ad2..2c9e46290 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -295,7 +295,7 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage, data, err := os.ReadFile(m.filePath) if err != nil { if errors.Is(err, fs.ErrNotExist) { - log.Debug("state file does not exist") + log.Debugf("state file %s does not exist", m.filePath) return nil, nil // nolint:nilnil } return nil, fmt.Errorf("read state file: %w", err) diff --git a/client/internal/winregistry/volatile_windows.go b/client/internal/winregistry/volatile_windows.go new file mode 100644 index 000000000..a8e350fe7 --- /dev/null +++ b/client/internal/winregistry/volatile_windows.go @@ -0,0 +1,59 @@ +package winregistry + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows/registry" +) + +var ( + advapi = syscall.NewLazyDLL("advapi32.dll") + regCreateKeyExW = advapi.NewProc("RegCreateKeyExW") +) + +const ( + // Registry key options + regOptionNonVolatile = 0x0 // Key is preserved when system is rebooted + regOptionVolatile = 0x1 // Key is not preserved when system is rebooted + + // Registry disposition values + regCreatedNewKey = 0x1 + regOpenedExistingKey = 0x2 +) + +// CreateVolatileKey creates a volatile registry key named path under open key root. +// CreateVolatileKey returns the new key and a boolean flag that reports whether the key already existed. +// The access parameter specifies the access rights for the key to be created. +// +// Volatile keys are stored in memory and are automatically deleted when the system is shut down. +// This provides automatic cleanup without requiring manual registry maintenance. +func CreateVolatileKey(root registry.Key, path string, access uint32) (registry.Key, bool, error) { + pathPtr, err := syscall.UTF16PtrFromString(path) + if err != nil { + return 0, false, err + } + + var ( + handle syscall.Handle + disposition uint32 + ) + + ret, _, _ := regCreateKeyExW.Call( + uintptr(root), + uintptr(unsafe.Pointer(pathPtr)), + 0, // reserved + 0, // class + uintptr(regOptionVolatile), // options - volatile key + uintptr(access), // desired access + 0, // security attributes + uintptr(unsafe.Pointer(&handle)), + uintptr(unsafe.Pointer(&disposition)), + ) + + if ret != 0 { + return 0, false, syscall.Errno(ret) + } + + return registry.Key(handle), disposition == regOpenedExistingKey, nil +} diff --git a/client/server/server.go b/client/server/server.go index b3fdb2f1f..76f44cd13 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -1070,10 +1070,7 @@ func (s *Server) Status( s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) if msg.GetFullPeerStatus { - if msg.ShouldRunProbes { - s.runProbes() - } - + s.runProbes(msg.ShouldRunProbes) fullStatus := s.statusRecorder.GetFullStatus() pbFullStatus := toProtoFullStatus(fullStatus) pbFullStatus.Events = s.statusRecorder.GetEventHistory() @@ -1266,7 +1263,7 @@ func isUnixRunningDesktop() bool { return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != "" } -func (s *Server) runProbes() { +func (s *Server) runProbes(waitForProbeResult bool) { if s.connectClient == nil { return } @@ -1277,7 +1274,7 @@ func (s *Server) runProbes() { } if time.Since(s.lastProbe) > probeThreshold { - if engine.RunHealthProbes() { + if engine.RunHealthProbes(waitForProbeResult) { s.lastProbe = time.Now() } } diff --git a/client/server/setconfig_test.go b/client/server/setconfig_test.go index 339d02bac..98fdc7c5b 100644 --- a/client/server/setconfig_test.go +++ b/client/server/setconfig_test.go @@ -167,35 +167,35 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) { } expectedFields := map[string]bool{ - "ManagementUrl": true, - "AdminURL": true, - "RosenpassEnabled": true, - "RosenpassPermissive": true, - "ServerSSHAllowed": true, - "InterfaceName": true, - "WireguardPort": true, - "OptionalPreSharedKey": true, - "DisableAutoConnect": true, - "NetworkMonitor": true, - "DisableClientRoutes": true, - "DisableServerRoutes": true, - "DisableDns": true, - "DisableFirewall": true, - "BlockLanAccess": true, - "DisableNotifications": true, - "LazyConnectionEnabled": true, - "BlockInbound": true, - "NatExternalIPs": true, - "CustomDNSAddress": true, - "ExtraIFaceBlacklist": true, - "DnsLabels": true, - "DnsRouteInterval": true, - "Mtu": true, - "EnableSSHRoot": true, - "EnableSSHSFTP": true, - "EnableSSHLocalPortForward": true, - "EnableSSHRemotePortForward": true, - "DisableSSHAuth": true, + "ManagementUrl": true, + "AdminURL": true, + "RosenpassEnabled": true, + "RosenpassPermissive": true, + "ServerSSHAllowed": true, + "InterfaceName": true, + "WireguardPort": true, + "OptionalPreSharedKey": true, + "DisableAutoConnect": true, + "NetworkMonitor": true, + "DisableClientRoutes": true, + "DisableServerRoutes": true, + "DisableDns": true, + "DisableFirewall": true, + "BlockLanAccess": true, + "DisableNotifications": true, + "LazyConnectionEnabled": true, + "BlockInbound": true, + "NatExternalIPs": true, + "CustomDNSAddress": true, + "ExtraIFaceBlacklist": true, + "DnsLabels": true, + "DnsRouteInterval": true, + "Mtu": true, + "EnableSSHRoot": true, + "EnableSSHSFTP": true, + "EnableSSHLocalPortForward": true, + "EnableSSHRemotePortForward": true, + "DisableSSHAuth": true, } val := reflect.ValueOf(req).Elem() @@ -226,34 +226,34 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) { // Map of CLI flag names to their corresponding SetConfigRequest field names. // This map must be updated when adding new config-related CLI flags. flagToField := map[string]string{ - "management-url": "ManagementUrl", - "admin-url": "AdminURL", - "enable-rosenpass": "RosenpassEnabled", - "rosenpass-permissive": "RosenpassPermissive", - "allow-server-ssh": "ServerSSHAllowed", - "interface-name": "InterfaceName", - "wireguard-port": "WireguardPort", - "preshared-key": "OptionalPreSharedKey", - "disable-auto-connect": "DisableAutoConnect", - "network-monitor": "NetworkMonitor", - "disable-client-routes": "DisableClientRoutes", - "disable-server-routes": "DisableServerRoutes", - "disable-dns": "DisableDns", - "disable-firewall": "DisableFirewall", - "block-lan-access": "BlockLanAccess", - "block-inbound": "BlockInbound", - "enable-lazy-connection": "LazyConnectionEnabled", - "external-ip-map": "NatExternalIPs", - "dns-resolver-address": "CustomDNSAddress", - "extra-iface-blacklist": "ExtraIFaceBlacklist", - "extra-dns-labels": "DnsLabels", - "dns-router-interval": "DnsRouteInterval", - "mtu": "Mtu", - "enable-ssh-root": "EnableSSHRoot", - "enable-ssh-sftp": "EnableSSHSFTP", - "enable-ssh-local-port-forwarding": "EnableSSHLocalPortForward", - "enable-ssh-remote-port-forwarding": "EnableSSHRemotePortForward", - "disable-ssh-auth": "DisableSSHAuth", + "management-url": "ManagementUrl", + "admin-url": "AdminURL", + "enable-rosenpass": "RosenpassEnabled", + "rosenpass-permissive": "RosenpassPermissive", + "allow-server-ssh": "ServerSSHAllowed", + "interface-name": "InterfaceName", + "wireguard-port": "WireguardPort", + "preshared-key": "OptionalPreSharedKey", + "disable-auto-connect": "DisableAutoConnect", + "network-monitor": "NetworkMonitor", + "disable-client-routes": "DisableClientRoutes", + "disable-server-routes": "DisableServerRoutes", + "disable-dns": "DisableDns", + "disable-firewall": "DisableFirewall", + "block-lan-access": "BlockLanAccess", + "block-inbound": "BlockInbound", + "enable-lazy-connection": "LazyConnectionEnabled", + "external-ip-map": "NatExternalIPs", + "dns-resolver-address": "CustomDNSAddress", + "extra-iface-blacklist": "ExtraIFaceBlacklist", + "extra-dns-labels": "DnsLabels", + "dns-router-interval": "DnsRouteInterval", + "mtu": "Mtu", + "enable-ssh-root": "EnableSSHRoot", + "enable-ssh-sftp": "EnableSSHSFTP", + "enable-ssh-local-port-forwarding": "EnableSSHLocalPortForward", + "enable-ssh-remote-port-forwarding": "EnableSSHRemotePortForward", + "disable-ssh-auth": "DisableSSHAuth", } // SetConfigRequest fields that don't have CLI flags (settable only via UI or other means). diff --git a/client/server/state.go b/client/server/state.go index 107f55154..1cf85cd37 100644 --- a/client/server/state.go +++ b/client/server/state.go @@ -10,7 +10,9 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/client/proto" ) @@ -135,5 +137,12 @@ func restoreResidualState(ctx context.Context, statePath string) error { merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) } + // clean up any remaining routes independently of the state file + if !nbnet.AdvancedRouting() { + if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err)) + } + } + return nberrors.FormatErrorOrNil(merr) } diff --git a/client/status/status.go b/client/status/status.go index 5e4fcd8dc..8a0b7bae0 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal/peer" + probeRelay "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/version" @@ -340,10 +341,16 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, for _, relay := range overview.Relays.Details { available := "Available" reason := "" + if !relay.Available { - available = "Unavailable" - reason = fmt.Sprintf(", reason: %s", relay.Error) + if relay.Error == probeRelay.ErrCheckInProgress.Error() { + available = "Checking..." + } else { + available = "Unavailable" + reason = fmt.Sprintf(", reason: %s", relay.Error) + } } + relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason) } } else { diff --git a/client/ui/debug.go b/client/ui/debug.go index 76afc7753..bf9839dda 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -18,6 +18,7 @@ import ( "github.com/skratchdot/open-golang/open" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/proto" nbstatus "github.com/netbirdio/netbird/client/status" uptypes "github.com/netbirdio/netbird/upload-server/types" @@ -426,6 +427,12 @@ func (s *serviceClient) collectDebugData( return "", err } + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) if err != nil { log.Warnf("Failed to get post-up status: %v", err) @@ -433,7 +440,7 @@ func (s *serviceClient) collectDebugData( var postUpStatusOutput string if postUpStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName) postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) @@ -450,7 +457,7 @@ func (s *serviceClient) collectDebugData( var preDownStatusOutput string if preDownStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName) preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", @@ -574,6 +581,12 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa return nil, fmt.Errorf("get client: %v", err) } + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) if err != nil { log.Warnf("failed to get status for debug bundle: %v", err) @@ -581,7 +594,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa var statusOutput string if statusResp != nil { - overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName) statusOutput = nbstatus.ParseToFullDetailSummary(overview) } diff --git a/dns/dns.go b/dns/dns.go index f889a32ec..40586f24d 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -19,6 +19,10 @@ const ( RootZone = "." // DefaultClass is the class supported by the system DefaultClass = "IN" + // ForwarderClientPort is the port clients connect to. DNAT rewrites packets from ForwarderClientPort to ForwarderServerPort. + ForwarderClientPort uint16 = 5353 + // ForwarderServerPort is the port the DNS forwarder actually listens on. Packets to ForwarderClientPort are DNATed here. + ForwarderServerPort uint16 = 22054 ) const invalidHostLabel = "[^a-zA-Z0-9-]+" diff --git a/go.mod b/go.mod index 1a3e500e8..17851b44a 100644 --- a/go.mod +++ b/go.mod @@ -63,7 +63,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f + github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 1bc3c24ca..21a225d62 100644 --- a/go.sum +++ b/go.sum @@ -508,8 +508,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 h1:moJbL1uuaWR35yUgHZ6suijjqqW8/qGCuPPBXu5MeWQ= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48/go.mod h1:ifKa2jGPsOzZhJFo72v2AE5nMP3GYvlhoZ9JV6lHlJ8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh index e3fcbfdde..92252d0b3 100755 --- a/infrastructure_files/configure.sh +++ b/infrastructure_files/configure.sh @@ -185,12 +185,15 @@ if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME" echo "" - export NETBIRD_SIGNAL_PROTOCOL="https" unset NETBIRD_LETSENCRYPT_DOMAIN unset NETBIRD_MGMT_API_CERT_FILE unset NETBIRD_MGMT_API_CERT_KEY_FILE fi +if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then + export NETBIRD_SIGNAL_PROTOCOL="https" +fi + # Check if management identity provider is set if [ -n "$NETBIRD_MGMT_IDP" ]; then EXTRA_CONFIG={} diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index b24e853b4..2bc49d3e5 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -40,13 +40,21 @@ services: signal: <<: *default image: netbirdio/signal:$NETBIRD_SIGNAL_TAG + depends_on: + - dashboard volumes: - $SIGNAL_VOLUMENAME:/var/lib/netbird + - $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro ports: - $NETBIRD_SIGNAL_PORT:80 # # port and command for Let's Encrypt validation # - 443:443 # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] + command: [ + "--cert-file", "$NETBIRD_MGMT_API_CERT_FILE", + "--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE", + "--log-file", "console" + ] # Relay relay: diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index bc326cd7e..09c5225ad 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -682,17 +682,6 @@ renderManagementJson() { "URI": "stun:$NETBIRD_DOMAIN:3478" } ], - "TURNConfig": { - "Turns": [ - { - "Proto": "udp", - "URI": "turn:$NETBIRD_DOMAIN:3478", - "Username": "$TURN_USER", - "Password": "$TURN_PASSWORD" - } - ], - "TimeBasedCredentials": false - }, "Relay": { "Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"], "CredentialsTTL": "24h", diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index ab9893f27..1b61c081d 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -35,7 +35,13 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation { func (s *BaseServer) PermissionsManager() permissions.Manager { return Create(s, func() permissions.Manager { - return integrations.InitPermissionsManager(s.Store()) + manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter()) + + s.AfterInit(func(s *BaseServer) { + manager.SetAccountManager(s.AccountManager()) + }) + + return manager }) } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index a1ed9498b..fe9fb25c6 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -109,7 +109,7 @@ type Manager interface { GetIdpManager() idp.Manager UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) - GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) + GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error diff --git a/management/server/dns.go b/management/server/dns.go index 534f43ec6..e5166ce47 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -21,8 +21,8 @@ import ( ) const ( - dnsForwarderPort = 22054 - oldForwarderPort = 5353 + dnsForwarderPort = nbdns.ForwarderServerPort + oldForwarderPort = nbdns.ForwarderClientPort ) const dnsForwarderPortMinVersion = "v0.59.0" @@ -196,7 +196,7 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID // If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0. func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { if len(peers) == 0 { - return oldForwarderPort + return int64(oldForwarderPort) } reqVer := semver.Canonical(requiredVersion) @@ -211,17 +211,17 @@ func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { peerVersion := semver.Canonical("v" + peer.Meta.WtVersion) if peerVersion == "" { // If any peer doesn't have version info, return 0 - return oldForwarderPort + return int64(oldForwarderPort) } // Compare versions if semver.Compare(peerVersion, reqVer) < 0 { - return oldForwarderPort + return int64(oldForwarderPort) } } // All peers have the required version or newer - return dnsForwarderPort + return int64(dnsForwarderPort) } // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 7e65b8f92..6b2db1162 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -394,7 +394,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - toProtocolDNSConfig(testData, cache, dnsForwarderPort) + toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort)) } }) @@ -402,7 +402,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { cache := &DNSConfigCache{} - toProtocolDNSConfig(testData, cache, dnsForwarderPort) + toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort)) } }) } @@ -455,13 +455,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { } // First run with config1 - result1 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) + result1 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort)) // Second run with config2 - result2 := toProtocolDNSConfig(config2, &cache, dnsForwarderPort) + result2 := toProtocolDNSConfig(config2, &cache, int64(dnsForwarderPort)) // Third run with config1 again - result3 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) + result3 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort)) // Verify that result1 and result3 are identical if !reflect.DeepEqual(result1, result3) { @@ -486,7 +486,7 @@ func TestComputeForwarderPort(t *testing.T) { // Test with empty peers list peers := []*nbpeer.Peer{} result := computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result) } @@ -504,7 +504,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result) } @@ -522,7 +522,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != dnsForwarderPort { + if result != int64(dnsForwarderPort) { t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result) } @@ -540,7 +540,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result) } @@ -553,7 +553,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result) } @@ -565,7 +565,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result == oldForwarderPort { + if result == int64(oldForwarderPort) { t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result) } @@ -578,7 +578,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result) } } diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 4b33495de..df89c616c 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -78,7 +78,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) - validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to list approved peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) @@ -86,7 +86,9 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, } _, valid := validPeers[peer.ID] - util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid)) + reason := invalidPeers[peer.ID] + + util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid, reason)) } func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { @@ -147,16 +149,17 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0) - validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { - log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(ctx).Errorf("failed to get validated peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) return } _, valid := validPeers[peer.ID] + reason := invalidPeers[peer.ID] - util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid)) + util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason)) } func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { @@ -240,22 +243,25 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0)) } - validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { - log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } - h.setApprovalRequiredFlag(respBody, validPeersMap) + h.setApprovalRequiredFlag(respBody, validPeersMap, invalidPeersMap) util.WriteJSONObject(r.Context(), w, respBody) } -func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { +func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersMap map[string]struct{}, invalidPeersMap map[string]string) { for _, peer := range respBody { - _, ok := approvedPeersMap[peer.Id] + _, ok := validPeersMap[peer.Id] if !ok { peer.ApprovalRequired = true + + reason := invalidPeersMap[peer.Id] + peer.DisapprovalReason = &reason } } } @@ -304,7 +310,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { } } - validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) @@ -430,13 +436,13 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer { +func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool, reason string) *api.Peer { osVersion := peer.Meta.OSVersion if osVersion == "" { osVersion = peer.Meta.Core } - return &api.Peer{ + apiPeer := &api.Peer{ CreatedAt: peer.CreatedAt, Id: peer.ID, Name: peer.Name, @@ -465,6 +471,12 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD InactivityExpirationEnabled: peer.InactivityExpirationEnabled, Ephemeral: peer.Ephemeral, } + + if !approved { + apiPeer.DisapprovalReason = &reason + } + + return apiPeer } func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch { diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 387de43e5..804b4a73f 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -7,9 +7,10 @@ import ( "time" "github.com/golang-jwt/jwt/v5" - "github.com/netbirdio/management-integrations/integrations" "github.com/stretchr/testify/assert" + "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 251c04273..e9a1c8701 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -88,7 +88,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID return true, nil } -func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { +func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { var err error var groups []*types.Group var peers []*nbpeer.Peer @@ -96,20 +96,30 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, err + return nil, nil, err } peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") if err != nil { - return nil, err + return nil, nil, err } settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, err + return nil, nil, err } - return am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra) + validPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra) + if err != nil { + return nil, nil, err + } + + invalidPeers, err := am.integratedPeerValidator.GetInvalidPeers(ctx, accountID, settings.Extra) + if err != nil { + return nil, nil, err + } + + return validPeers, invalidPeers, nil } type MockIntegratedValidator struct { @@ -136,6 +146,10 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID return validatedPeers, nil } +func (a MockIntegratedValidator) GetInvalidPeers(_ context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) { + return make(map[string]string), nil +} + func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer { return peer } diff --git a/management/server/integrations/integrated_validator/interface.go b/management/server/integrations/integrated_validator/interface.go index be05c2527..26c338cb6 100644 --- a/management/server/integrations/integrated_validator/interface.go +++ b/management/server/integrations/integrated_validator/interface.go @@ -15,6 +15,7 @@ type IntegratedValidator interface { PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) + GetInvalidPeers(ctx context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error SetPeerInvalidationListener(fn func(accountID string, peerIDs []string)) Stop(ctx context.Context) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index d160e7269..e87043f26 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -189,17 +189,17 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st panic("implement me") } -func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { +func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { account, err := am.GetAccountFunc(ctx, accountID) if err != nil { - return nil, err + return nil, nil, err } approvedPeers := make(map[string]struct{}) for id := range account.Peers { approvedPeers[id] = struct{}{} } - return approvedPeers, nil + return approvedPeers, nil, nil } // GetGroup mock implementation of GetGroup from server.AccountManager interface diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 5a134eb32..b8c03a4bd 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1161,7 +1161,7 @@ func TestToSyncResponse(t *testing.T) { } dnsCache := &DNSConfigCache{} accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true} - response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, dnsForwarderPort) + response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort)) assert.NotNil(t, response) // assert peer config diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go index 891fa59bb..e6bdd2025 100644 --- a/management/server/permissions/manager.go +++ b/management/server/permissions/manager.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -22,6 +23,7 @@ type Manager interface { ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error) + SetAccountManager(accountManager account.Manager) } type managerImpl struct { @@ -121,3 +123,7 @@ func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserR return permissions, nil } + +func (m *managerImpl) SetAccountManager(accountManager account.Manager) { + // no-op +} diff --git a/management/server/permissions/manager_mock.go b/management/server/permissions/manager_mock.go index fa115d628..ec9f263f9 100644 --- a/management/server/permissions/manager_mock.go +++ b/management/server/permissions/manager_mock.go @@ -9,6 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + account "github.com/netbirdio/netbird/management/server/account" modules "github.com/netbirdio/netbird/management/server/permissions/modules" operations "github.com/netbirdio/netbird/management/server/permissions/operations" roles "github.com/netbirdio/netbird/management/server/permissions/roles" @@ -53,6 +54,18 @@ func (mr *MockManagerMockRecorder) GetPermissionsByRole(ctx, role interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermissionsByRole", reflect.TypeOf((*MockManager)(nil).GetPermissionsByRole), ctx, role) } +// SetAccountManager mocks base method. +func (m *MockManager) SetAccountManager(accountManager account.Manager) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAccountManager", accountManager) +} + +// SetAccountManager indicates an expected call of SetAccountManager. +func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager) +} + // ValidateAccountAccess mocks base method. func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error { m.ctrl.T.Helper() diff --git a/management/server/types/account.go b/management/server/types/account.go index f830023c7..50bdc6ab3 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -301,7 +301,7 @@ func (a *Account) GetPeerNetworkMap( if dnsManagementStatus { var zones []nbdns.CustomZone if peersCustomZone.Domain != "" { - records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect) + records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers) zones = append(zones, nbdns.CustomZone{ Domain: peersCustomZone.Domain, Records: records, @@ -1682,7 +1682,7 @@ func peerSupportsPortRanges(peerVer string) bool { } // filterZoneRecordsForPeers filters DNS records to only include peers to connect. -func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord { +func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord { filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records)) peerIPs := make(map[string]struct{}) @@ -1693,6 +1693,10 @@ func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, p peerIPs[peerToConnect.IP.String()] = struct{}{} } + for _, expiredPeer := range expiredPeers { + peerIPs[expiredPeer.IP.String()] = struct{}{} + } + for _, record := range customZone.Records { if _, exists := peerIPs[record.RData]; exists { filteredRecords = append(filteredRecords, record) diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index cd221b590..32538933a 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -845,6 +845,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { peer *nbpeer.Peer customZone nbdns.CustomZone peersToConnect []*nbpeer.Peer + expiredPeers []*nbpeer.Peer expectedRecords []nbdns.SimpleRecord }{ { @@ -857,6 +858,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { }, }, peersToConnect: []*nbpeer.Peer{}, + expiredPeers: []*nbpeer.Peer{}, peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: []nbdns.SimpleRecord{ {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, @@ -890,7 +892,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { } return peers }(), - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expiredPeers: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: func() []nbdns.SimpleRecord { var records []nbdns.SimpleRecord for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { @@ -924,7 +927,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { {ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}}, {ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}}, }, - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expiredPeers: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: []nbdns.SimpleRecord{ {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, @@ -934,11 +938,35 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, }, }, + { + name: "expired peers are included in DNS entries", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + peersToConnect: []*nbpeer.Peer{ + {ID: "peer1", IP: net.ParseIP("10.0.0.1")}, + }, + expiredPeers: []*nbpeer.Peer{ + {ID: "expired-peer", IP: net.ParseIP("10.0.0.99")}, + }, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect) + result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect, tt.expiredPeers) assert.Equal(t, len(tt.expectedRecords), len(result)) assert.ElementsMatch(t, tt.expectedRecords, result) }) diff --git a/release_files/install.sh b/release_files/install.sh index 5d5349ec4..6a2c5f458 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -29,6 +29,8 @@ if [ -z ${NETBIRD_RELEASE+x} ]; then NETBIRD_RELEASE=latest fi +TAG_NAME="" + get_release() { local RELEASE=$1 if [ "$RELEASE" = "latest" ]; then @@ -38,17 +40,19 @@ get_release() { local TAG="tags/${RELEASE}" local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" fi + OUTPUT="" if [ -n "$GITHUB_TOKEN" ]; then - curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \ - | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' + OUTPUT=$(curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}") else - curl -s "${URL}" \ - | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' + OUTPUT=$(curl -s "${URL}") fi + TAG_NAME=$(echo ${OUTPUT} | grep -Eo '\"tag_name\":\s*\"v([0-9]+\.){2}[0-9]+"' | tail -n 1) + echo "${TAG_NAME}" | grep -oE 'v[0-9]+\.[0-9]+\.[0-9]+' } download_release_binary() { VERSION=$(get_release "$NETBIRD_RELEASE") + echo "Using the following tag name for binary installation: ${TAG_NAME}" BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download" BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz" diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 93578b1ae..4a5454002 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -463,6 +463,9 @@ components: description: (Cloud only) Indicates whether peer needs approval type: boolean example: true + disapproval_reason: + description: (Cloud only) Reason why the peer requires approval + type: string country_code: $ref: '#/components/schemas/CountryCode' city_name: diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 3dbb32ef6..9611d26d6 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1037,6 +1037,9 @@ type Peer struct { // CreatedAt Peer creation date (UTC) CreatedAt time.Time `json:"created_at"` + // DisapprovalReason (Cloud only) Reason why the peer requires approval + DisapprovalReason *string `json:"disapproval_reason,omitempty"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` @@ -1124,6 +1127,9 @@ type PeerBatch struct { // CreatedAt Peer creation date (UTC) CreatedAt time.Time `json:"created_at"` + // DisapprovalReason (Cloud only) Reason why the peer requires approval + DisapprovalReason *string `json:"disapproval_reason,omitempty"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index 11107e5de..16737cf58 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -428,7 +428,7 @@ message DNSConfig { bool ServiceEnable = 1; repeated NameServerGroup NameServerGroups = 2; repeated CustomZone CustomZones = 3; - int64 ForwarderPort = 4; + int64 ForwarderPort = 4 [deprecated = true]; } // CustomZone represents a dns.CustomZone diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 96873dee7..bf8f8e327 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -94,7 +94,7 @@ var ( startPprof() - opts, certManager, err := getTLSConfigurations() + opts, certManager, tlsConfig, err := getTLSConfigurations() if err != nil { return err } @@ -132,7 +132,7 @@ var ( // Start the main server - always serve HTTP with WebSocket proxy support // If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager - if certManager == nil { + if tlsConfig == nil { // Without TLS, serve plain HTTP httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort)) if err != nil { @@ -140,9 +140,10 @@ var ( } log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String()) serveHTTP(httpListener, grpcRootHandler) - } else if signalPort != 443 { - // With TLS but not on port 443, serve HTTPS - httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), certManager.TLSConfig()) + } else if certManager == nil || signalPort != 443 { + // Serve HTTPS if not already handled by startServerWithCertManager + // (custom certificates or Let's Encrypt with custom port) + httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), tlsConfig) if err != nil { return err } @@ -202,7 +203,7 @@ func startPprof() { }() } -func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { +func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, *tls.Config, error) { var ( err error certManager *autocert.Manager @@ -211,33 +212,33 @@ func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" { log.Infof("running without TLS") - return nil, nil, nil + return nil, nil, nil, nil } if signalLetsencryptDomain != "" { certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain) if err != nil { - return nil, certManager, err + return nil, certManager, nil, err } tlsConfig = certManager.TLSConfig() log.Infof("setting up TLS with LetsEncrypt.") } else { if signalCertFile == "" || signalCertKey == "" { log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt") - return nil, certManager, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") + return nil, certManager, nil, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") } tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey) if err != nil { log.Errorf("cannot load TLS credentials: %v", err) - return nil, certManager, err + return nil, certManager, nil, err } log.Infof("setting up TLS with custom certificates.") } transportCredentials := credentials.NewTLS(tlsConfig) - return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, err + return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, tlsConfig, err } func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) {