diff --git a/client/Dockerfile b/client/Dockerfile index b2f627409..5cd459357 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -4,7 +4,7 @@ # sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . # sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest -FROM alpine:3.22.0 +FROM alpine:3.22.2 # iproute2: busybox doesn't display ip rules properly RUN apk add --no-cache \ bash \ diff --git a/client/cmd/debug.go b/client/cmd/debug.go index 18f3547ca..430012a17 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -168,7 +168,7 @@ func runForDuration(cmd *cobra.Command, args []string) error { client := proto.NewDaemonServiceClient(conn) - stat, err := client.Status(cmd.Context(), &proto.StatusRequest{}) + stat, err := client.Status(cmd.Context(), &proto.StatusRequest{ShouldRunProbes: true}) if err != nil { return fmt.Errorf("failed to get status: %v", status.Convert(err).Message()) } @@ -303,12 +303,18 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error { func getStatusOutput(cmd *cobra.Command, anon bool) string { var statusOutputString string - statusResp, err := getStatus(cmd.Context()) + statusResp, err := getStatus(cmd.Context(), true) if err != nil { cmd.PrintErrf("Failed to get status: %v\n", err) } else { + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + statusOutputString = nbstatus.ParseToFullDetailSummary( - nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""), + nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName), ) } return statusOutputString diff --git a/client/cmd/login.go b/client/cmd/login.go index 3ac211805..40b55f858 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "os/exec" "os/user" "runtime" "strings" @@ -356,13 +357,21 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro cmd.Println("") if !noBrowser { - if err := open.Run(verificationURIComplete); err != nil { + if err := openBrowser(verificationURIComplete); err != nil { cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" + "https://docs.netbird.io/how-to/register-machines-using-setup-keys") } } } +// openBrowser opens the URL in a browser, respecting the BROWSER environment variable. +func openBrowser(url string) error { + if browser := os.Getenv("BROWSER"); browser != "" { + return exec.Command(browser, url).Start() + } + return open.Run(url) +} + // isUnixRunningDesktop checks if a Linux OS is running desktop environment func isUnixRunningDesktop() bool { if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { diff --git a/client/cmd/status.go b/client/cmd/status.go index 723f2367c..6e57ceb89 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -68,7 +68,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { ctx := internal.CtxInitState(cmd.Context()) - resp, err := getStatus(ctx) + resp, err := getStatus(ctx, false) if err != nil { return err } @@ -121,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { return nil } -func getStatus(ctx context.Context) (*proto.StatusResponse, error) { +func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) { conn, err := DialClientGRPCServer(ctx, daemonAddr) if err != nil { return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ @@ -130,7 +130,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) { } defer conn.Close() - resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true}) + resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes}) if err != nil { return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message()) } diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index ed8a7403b..d78372c9e 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -400,7 +400,6 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi return "" } - // Include action in the ipset name to prevent squashing rules with different actions actionSuffix := "" if action == firewall.ActionDrop { actionSuffix = "-drop" diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 81f7a9125..16b50211e 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -260,6 +260,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return m.router.UpdateSet(set, prefixes) } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + func getConntrackEstablished() []string { return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} } diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 081991235..80aea7cf8 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -880,6 +880,54 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return nberrors.FormatErrorOrNil(merr) } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if _, exists := r.rules[ruleID]; exists { + return nil + } + + dnatRule := []string{ + "-i", r.wgIface.Name(), + "-p", strings.ToLower(string(protocol)), + "--dport", strconv.Itoa(int(sourcePort)), + "-d", localAddr.String(), + "-m", "addrtype", "--dst-type", "LOCAL", + "-j", "DNAT", + "--to-destination", ":" + strconv.Itoa(int(targetPort)), + } + + ruleInfo := ruleInfo{ + table: tableNat, + chain: chainRTRDR, + rule: dnatRule, + } + + if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil { + return fmt.Errorf("add inbound DNAT rule: %w", err) + } + r.rules[ruleID] = ruleInfo.rule + + r.updateState() + return nil +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if dnatRule, exists := r.rules[ruleID]; exists { + if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil { + return fmt.Errorf("delete inbound DNAT rule: %w", err) + } + delete(r.rules, ruleID) + } + + r.updateState() + return nil +} + func applyPort(flag string, port *firewall.Port) []string { if port == nil { return nil diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 3b3164823..7ee33118b 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -151,14 +151,20 @@ type Manager interface { DisableRouting() error - // AddDNATRule adds a DNAT rule + // AddDNATRule adds outbound DNAT rule for forwarding external traffic to the NetBird network. AddDNATRule(ForwardRule) (Rule, error) - // DeleteDNATRule deletes a DNAT rule + // DeleteDNATRule deletes the outbound DNAT rule. DeleteDNATRule(Rule) error // UpdateSet updates the set with the given prefixes UpdateSet(hash Set, prefixes []netip.Prefix) error + + // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services + AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + + // RemoveInboundDNAT removes inbound DNAT rule + RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error } func GenKey(format string, pair RouterPair) string { diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 560f224f5..aa90d3b9a 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -376,6 +376,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return m.router.UpdateSet(set, prefixes) } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + func (m *Manager) createWorkTable() (*nftables.Table, error) { tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) if err != nil { diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index e918d0524..648a6aedf 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -1350,6 +1350,103 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return nil } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if _, exists := r.rules[ruleID]; exists { + return nil + } + + protoNum, err := protoToInt(protocol) + if err != nil { + return fmt.Errorf("convert protocol to number: %w", err) + } + + exprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(r.wgIface.Name()), + }, + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 2, + Data: []byte{protoNum}, + }, + &expr.Payload{ + DestRegister: 3, + Base: expr.PayloadBaseTransportHeader, + Offset: 2, + Len: 2, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 3, + Data: binaryutil.BigEndian.PutUint16(sourcePort), + }, + } + + exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) + + exprs = append(exprs, + &expr.Immediate{ + Register: 1, + Data: localAddr.AsSlice(), + }, + &expr.Immediate{ + Register: 2, + Data: binaryutil.BigEndian.PutUint16(targetPort), + }, + &expr.NAT{ + Type: expr.NATTypeDestNAT, + Family: uint32(nftables.TableFamilyIPv4), + RegAddrMin: 1, + RegProtoMin: 2, + RegProtoMax: 0, + }, + ) + + dnatRule := &nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingRdr], + Exprs: exprs, + UserData: []byte(ruleID), + } + r.conn.AddRule(dnatRule) + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("add inbound DNAT rule: %w", err) + } + + r.rules[ruleID] = dnatRule + + return nil +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if rule, exists := r.rules[ruleID]; exists { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err) + } + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("flush delete inbound DNAT rule: %w", err) + } + delete(r.rules, ruleID) + } + + return nil +} + // applyNetwork generates nftables expressions for networks (CIDR) or sets func (r *router) applyNetwork( network firewall.Network, diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index bcf6d894b..7be0dd78f 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -22,6 +22,8 @@ type BaseConnTrack struct { PacketsRx atomic.Uint64 BytesTx atomic.Uint64 BytesRx atomic.Uint64 + + DNATOrigPort atomic.Uint32 } // these small methods will be inlined by the compiler diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index a2355e5c7..8d64412e0 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -157,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp return tracker } -func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) { +func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, uint16, bool) { key := ConnKey{ SrcIP: srcIP, DstIP: dstIP, @@ -171,28 +171,30 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui if exists { t.updateState(key, conn, flags, direction, size) - return key, true + return key, uint16(conn.DNATOrigPort.Load()), true } - return key, false + return key, 0, false } -// TrackOutbound records an outbound TCP connection -func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) { - if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists { - // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size) +// TrackOutbound records an outbound TCP connection and returns the original port if DNAT reversal is needed +func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) uint16 { + if _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); exists { + return origPort } + // if (inverted direction) conn is not tracked, track this direction + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size, 0) + return 0 } // TrackInbound processes an inbound TCP packet and updates connection state -func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) { - t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size) +func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int, dnatOrigPort uint16) { + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size, dnatOrigPort) } // track is the common implementation for tracking both inbound and outbound connections -func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) { - key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) +func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) { + key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) if exists || flags&TCPSyn == 0 { return } @@ -210,8 +212,13 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla conn.tombstone.Store(false) conn.state.Store(int32(TCPStateNew)) + conn.DNATOrigPort.Store(uint32(origPort)) - t.logger.Trace2("New %s TCP connection: %s", direction, key) + if origPort != 0 { + t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s TCP connection: %s", direction, key) + } t.updateState(key, conn, flags, direction, size) t.mutex.Lock() @@ -449,6 +456,21 @@ func (t *TCPTracker) cleanup() { } } +// GetConnection safely retrieves a connection state +func (t *TCPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*TCPConnTrack, bool) { + t.mutex.RLock() + defer t.mutex.RUnlock() + + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } + conn, exists := t.connections[key] + return conn, exists +} + // Close stops the cleanup routine and releases resources func (t *TCPTracker) Close() { t.tickerCancel() diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index d01a8db4f..bb440f70a 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -603,7 +603,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { serverPort := uint16(80) // 1. Client sends SYN (we receive it as inbound) - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0) key := ConnKey{ SrcIP: clientIP, @@ -623,12 +623,12 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100) // 3. Client sends ACK to complete handshake - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0) require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion") // 4. Test data transfer // Client sends data - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000, 0) // Server sends ACK for data tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100) @@ -637,7 +637,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500) // Client sends ACK for data - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0) // Verify state and counters require.Equal(t, TCPStateEstablished, conn.GetState()) diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index e7f49c46f..a3b6a418b 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -58,20 +58,23 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp return tracker } -// TrackOutbound records an outbound UDP connection -func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) { - if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists { - // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size) +// TrackOutbound records an outbound UDP connection and returns the original port if DNAT reversal is needed +func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) uint16 { + _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size) + if exists { + return origPort } + // if (inverted direction) conn is not tracked, track this direction + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size, 0) + return 0 } // TrackInbound records an inbound UDP connection -func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) { - t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size) +func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int, dnatOrigPort uint16) { + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size, dnatOrigPort) } -func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) { +func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, uint16, bool) { key := ConnKey{ SrcIP: srcIP, DstIP: dstIP, @@ -86,15 +89,15 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort if exists { conn.UpdateLastSeen() conn.UpdateCounters(direction, size) - return key, true + return key, uint16(conn.DNATOrigPort.Load()), true } - return key, false + return key, 0, false } // track is the common implementation for tracking both inbound and outbound connections -func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) { - key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) +func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) { + key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) if exists { return } @@ -109,6 +112,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d SourcePort: srcPort, DestPort: dstPort, } + conn.DNATOrigPort.Store(uint32(origPort)) conn.UpdateLastSeen() conn.UpdateCounters(direction, size) @@ -116,7 +120,11 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d t.connections[key] = conn t.mutex.Unlock() - t.logger.Trace2("New %s UDP connection: %s", direction, key) + if origPort != 0 { + t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s UDP connection: %s", direction, key) + } t.sendEvent(nftypes.TypeStart, conn, ruleID) } diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 7eef49e31..fbc39b740 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -109,6 +109,10 @@ type Manager struct { dnatMappings map[netip.Addr]netip.Addr dnatMutex sync.RWMutex dnatBiMap *biDNATMap + + portDNATEnabled atomic.Bool + portDNATRules []portDNATRule + portDNATMutex sync.RWMutex } // decoder for packages @@ -122,6 +126,8 @@ type decoder struct { icmp6 layers.ICMPv6 decoded []gopacket.LayerType parser *gopacket.DecodingLayerParser + + dnatOrigPort uint16 } // Create userspace firewall manager constructor @@ -196,6 +202,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe netstack: netstack.IsEnabled(), localForwarding: enableLocalForwarding, dnatMappings: make(map[netip.Addr]netip.Addr), + portDNATRules: []portDNATRule{}, } m.routingEnabled.Store(false) @@ -630,7 +637,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool { return true } - m.trackOutbound(d, srcIP, dstIP, size) + m.trackOutbound(d, srcIP, dstIP, packetData, size) m.translateOutboundDNAT(packetData, d) return false @@ -674,14 +681,26 @@ func getTCPFlags(tcp *layers.TCP) uint8 { return flags } -func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) { +func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) { transport := d.decoded[1] switch transport { case layers.LayerTypeUDP: - m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) + origPort := m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) + if origPort == 0 { + break + } + if err := m.rewriteUDPPort(packetData, d, origPort, sourcePortOffset); err != nil { + m.logger.Error1("failed to rewrite UDP port: %v", err) + } case layers.LayerTypeTCP: flags := getTCPFlags(&d.tcp) - m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) + origPort := m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) + if origPort == 0 { + break + } + if err := m.rewriteTCPPort(packetData, d, origPort, sourcePortOffset); err != nil { + m.logger.Error1("failed to rewrite TCP port: %v", err) + } case layers.LayerTypeICMPv4: m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size) } @@ -691,13 +710,15 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt transport := d.decoded[1] switch transport { case layers.LayerTypeUDP: - m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size) + m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size, d.dnatOrigPort) case layers.LayerTypeTCP: flags := getTCPFlags(&d.tcp) - m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size) + m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort) case layers.LayerTypeICMPv4: m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size) } + + d.dnatOrigPort = 0 } // udpHooksDrop checks if any UDP hooks should drop the packet @@ -759,10 +780,20 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { return false } + // TODO: optimize port DNAT by caching matched rules in conntrack + if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated { + // Re-decode after port DNAT translation to update port information + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + m.logger.Error1("failed to re-decode packet after port DNAT: %v", err) + return true + } + srcIP, dstIP = m.extractIPs(d) + } + if translated := m.translateInboundReverse(packetData, d); translated { // Re-decode after translation to get original addresses if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { - m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err) + m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err) return true } srcIP, dstIP = m.extractIPs(d) diff --git a/client/firewall/uspfilter/log/log.go b/client/firewall/uspfilter/log/log.go index 5614e2ec3..139f702f2 100644 --- a/client/firewall/uspfilter/log/log.go +++ b/client/firewall/uspfilter/log/log.go @@ -50,6 +50,8 @@ type logMessage struct { arg4 any arg5 any arg6 any + arg7 any + arg8 any } // Logger is a high-performance, non-blocking logger @@ -94,7 +96,6 @@ func (l *Logger) SetLevel(level Level) { log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) } - func (l *Logger) Error(format string) { if l.level.Load() >= uint32(LevelError) { select { @@ -185,6 +186,15 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) { } } +func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) { + if l.level.Load() >= uint32(LevelDebug) { + select { + case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: + default: + } + } +} + func (l *Logger) Trace1(format string, arg1 any) { if l.level.Load() >= uint32(LevelTrace) { select { @@ -239,6 +249,16 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) { } } +// Trace8 logs a trace message with 8 arguments (8 placeholder in format string) +func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) { + if l.level.Load() >= uint32(LevelTrace) { + select { + case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}: + default: + } + } +} + func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { *buf = (*buf)[:0] *buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00") @@ -260,6 +280,12 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { argCount++ if msg.arg6 != nil { argCount++ + if msg.arg7 != nil { + argCount++ + if msg.arg8 != nil { + argCount++ + } + } } } } @@ -283,6 +309,10 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5) case 6: formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6) + case 7: + formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7) + case 8: + formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7, msg.arg8) } *buf = append(*buf, formatted...) @@ -390,4 +420,4 @@ func (l *Logger) Stop(ctx context.Context) error { case <-done: return nil } -} \ No newline at end of file +} diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go index 27b752531..13567872e 100644 --- a/client/firewall/uspfilter/nat.go +++ b/client/firewall/uspfilter/nat.go @@ -5,7 +5,9 @@ import ( "errors" "fmt" "net/netip" + "slices" + "github.com/google/gopacket" "github.com/google/gopacket/layers" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -13,6 +15,21 @@ import ( var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT") +var ( + errInvalidIPHeaderLength = errors.New("invalid IP header length") +) + +const ( + // Port offsets in TCP/UDP headers + sourcePortOffset = 0 + destinationPortOffset = 2 + + // IP address offsets in IPv4 header + sourceIPOffset = 12 + destinationIPOffset = 16 +) + +// ipv4Checksum calculates IPv4 header checksum. func ipv4Checksum(header []byte) uint16 { if len(header) < 20 { return 0 @@ -52,6 +69,7 @@ func ipv4Checksum(header []byte) uint16 { return ^uint16(sum) } +// icmpChecksum calculates ICMP checksum. func icmpChecksum(data []byte) uint16 { var sum1, sum2, sum3, sum4 uint32 i := 0 @@ -89,11 +107,21 @@ func icmpChecksum(data []byte) uint16 { return ^uint16(sum) } +// biDNATMap maintains bidirectional DNAT mappings. type biDNATMap struct { forward map[netip.Addr]netip.Addr reverse map[netip.Addr]netip.Addr } +// portDNATRule represents a port-specific DNAT rule. +type portDNATRule struct { + protocol gopacket.LayerType + origPort uint16 + targetPort uint16 + targetIP netip.Addr +} + +// newBiDNATMap creates a new bidirectional DNAT mapping structure. func newBiDNATMap() *biDNATMap { return &biDNATMap{ forward: make(map[netip.Addr]netip.Addr), @@ -101,11 +129,13 @@ func newBiDNATMap() *biDNATMap { } } +// set adds a bidirectional DNAT mapping between original and translated addresses. func (b *biDNATMap) set(original, translated netip.Addr) { b.forward[original] = translated b.reverse[translated] = original } +// delete removes a bidirectional DNAT mapping for the given original address. func (b *biDNATMap) delete(original netip.Addr) { if translated, exists := b.forward[original]; exists { delete(b.forward, original) @@ -113,19 +143,25 @@ func (b *biDNATMap) delete(original netip.Addr) { } } +// getTranslated returns the translated address for a given original address. func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) { translated, exists := b.forward[original] return translated, exists } +// getOriginal returns the original address for a given translated address. func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) { original, exists := b.reverse[translated] return original, exists } +// AddInternalDNATMapping adds a 1:1 IP address mapping for internal DNAT translation. func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error { - if !originalAddr.IsValid() || !translatedAddr.IsValid() { - return fmt.Errorf("invalid IP addresses") + if !originalAddr.IsValid() { + return fmt.Errorf("invalid original IP address") + } + if !translatedAddr.IsValid() { + return fmt.Errorf("invalid translated IP address") } if m.localipmanager.IsLocalIP(translatedAddr) { @@ -135,7 +171,6 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr m.dnatMutex.Lock() defer m.dnatMutex.Unlock() - // Initialize both maps together if either is nil if m.dnatMappings == nil || m.dnatBiMap == nil { m.dnatMappings = make(map[netip.Addr]netip.Addr) m.dnatBiMap = newBiDNATMap() @@ -151,7 +186,7 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr return nil } -// RemoveInternalDNATMapping removes a 1:1 IP address mapping +// RemoveInternalDNATMapping removes a 1:1 IP address mapping. func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error { m.dnatMutex.Lock() defer m.dnatMutex.Unlock() @@ -169,7 +204,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error { return nil } -// getDNATTranslation returns the translated address if a mapping exists +// getDNATTranslation returns the translated address if a mapping exists. func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { if !m.dnatEnabled.Load() { return addr, false @@ -181,7 +216,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { return translated, exists } -// findReverseDNATMapping finds original address for return traffic +// findReverseDNATMapping finds original address for return traffic. func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) { if !m.dnatEnabled.Load() { return translatedAddr, false @@ -193,16 +228,12 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, return original, exists } -// translateOutboundDNAT applies DNAT translation to outbound packets +// translateOutboundDNAT applies DNAT translation to outbound packets. func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { if !m.dnatEnabled.Load() { return false } - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return false - } - dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) translatedIP, exists := m.getDNATTranslation(dstIP) @@ -210,8 +241,8 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { return false } - if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil { - m.logger.Error1("Failed to rewrite packet destination: %v", err) + if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil { + m.logger.Error1("failed to rewrite packet destination: %v", err) return false } @@ -219,16 +250,12 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { return true } -// translateInboundReverse applies reverse DNAT to inbound return traffic +// translateInboundReverse applies reverse DNAT to inbound return traffic. func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { if !m.dnatEnabled.Load() { return false } - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return false - } - srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) originalIP, exists := m.findReverseDNATMapping(srcIP) @@ -236,8 +263,8 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return false } - if err := m.rewritePacketSource(packetData, d, originalIP); err != nil { - m.logger.Error1("Failed to rewrite packet source: %v", err) + if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil { + m.logger.Error1("failed to rewrite packet source: %v", err) return false } @@ -245,21 +272,21 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return true } -// rewritePacketDestination replaces destination IP in the packet -func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error { - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { +// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums. +func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error { + if !newIP.Is4() { return ErrIPv4Only } - var oldDst [4]byte - copy(oldDst[:], packetData[16:20]) - newDst := newIP.As4() + var oldIP [4]byte + copy(oldIP[:], packetData[ipOffset:ipOffset+4]) + newIPBytes := newIP.As4() - copy(packetData[16:20], newDst[:]) + copy(packetData[ipOffset:ipOffset+4], newIPBytes[:]) ipHeaderLen := int(d.ip4.IHL) * 4 if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return fmt.Errorf("invalid IP header length") + return errInvalidIPHeaderLength } binary.BigEndian.PutUint16(packetData[10:12], 0) @@ -269,44 +296,9 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP if len(d.decoded) > 1 { switch d.decoded[1] { case layers.LayerTypeTCP: - m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) + m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeUDP: - m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) - case layers.LayerTypeICMPv4: - m.updateICMPChecksum(packetData, ipHeaderLen) - } - } - - return nil -} - -// rewritePacketSource replaces the source IP address in the packet -func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error { - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { - return ErrIPv4Only - } - - var oldSrc [4]byte - copy(oldSrc[:], packetData[12:16]) - newSrc := newIP.As4() - - copy(packetData[12:16], newSrc[:]) - - ipHeaderLen := int(d.ip4.IHL) * 4 - if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return fmt.Errorf("invalid IP header length") - } - - binary.BigEndian.PutUint16(packetData[10:12], 0) - ipChecksum := ipv4Checksum(packetData[:ipHeaderLen]) - binary.BigEndian.PutUint16(packetData[10:12], ipChecksum) - - if len(d.decoded) > 1 { - switch d.decoded[1] { - case layers.LayerTypeTCP: - m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:]) - case layers.LayerTypeUDP: - m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:]) + m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeICMPv4: m.updateICMPChecksum(packetData, ipHeaderLen) } @@ -315,6 +307,7 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip return nil } +// updateTCPChecksum updates TCP checksum after IP address change per RFC 1624. func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { tcpStart := ipHeaderLen if len(packetData) < tcpStart+18 { @@ -327,6 +320,7 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) } +// updateUDPChecksum updates UDP checksum after IP address change per RFC 1624. func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { udpStart := ipHeaderLen if len(packetData) < udpStart+8 { @@ -344,6 +338,7 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) } +// updateICMPChecksum recalculates ICMP checksum after packet modification. func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { icmpStart := ipHeaderLen if len(packetData) < icmpStart+8 { @@ -356,7 +351,7 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { binary.BigEndian.PutUint16(icmpData[2:4], checksum) } -// incrementalUpdate performs incremental checksum update per RFC 1624 +// incrementalUpdate performs incremental checksum update per RFC 1624. func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { sum := uint32(^oldChecksum) @@ -391,7 +386,7 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { return ^uint16(sum) } -// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding) +// AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network. func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { if m.nativeFirewall == nil { return nil, errNatNotSupported @@ -399,10 +394,184 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) return m.nativeFirewall.AddDNATRule(rule) } -// DeleteDNATRule deletes a DNAT rule (delegates to native firewall) +// DeleteDNATRule deletes outbound DNAT rule. func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { if m.nativeFirewall == nil { return errNatNotSupported } return m.nativeFirewall.DeleteDNATRule(rule) } + +// addPortRedirection adds a port redirection rule. +func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { + m.portDNATMutex.Lock() + defer m.portDNATMutex.Unlock() + + rule := portDNATRule{ + protocol: protocol, + origPort: sourcePort, + targetPort: targetPort, + targetIP: targetIP, + } + + m.portDNATRules = append(m.portDNATRules, rule) + m.portDNATEnabled.Store(true) + + return nil +} + +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + var layerType gopacket.LayerType + switch protocol { + case firewall.ProtocolTCP: + layerType = layers.LayerTypeTCP + case firewall.ProtocolUDP: + layerType = layers.LayerTypeUDP + default: + return fmt.Errorf("unsupported protocol: %s", protocol) + } + + return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort) +} + +// removePortRedirection removes a port redirection rule. +func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { + m.portDNATMutex.Lock() + defer m.portDNATMutex.Unlock() + + m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool { + return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0 + }) + + if len(m.portDNATRules) == 0 { + m.portDNATEnabled.Store(false) + } + + return nil +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + var layerType gopacket.LayerType + switch protocol { + case firewall.ProtocolTCP: + layerType = layers.LayerTypeTCP + case firewall.ProtocolUDP: + layerType = layers.LayerTypeUDP + default: + return fmt.Errorf("unsupported protocol: %s", protocol) + } + + return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort) +} + +// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets. +func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool { + if !m.portDNATEnabled.Load() { + return false + } + + switch d.decoded[1] { + case layers.LayerTypeTCP: + dstPort := uint16(d.tcp.DstPort) + return m.applyPortRule(packetData, d, srcIP, dstIP, dstPort, layers.LayerTypeTCP, m.rewriteTCPPort) + case layers.LayerTypeUDP: + dstPort := uint16(d.udp.DstPort) + return m.applyPortRule(packetData, d, netip.Addr{}, dstIP, dstPort, layers.LayerTypeUDP, m.rewriteUDPPort) + default: + return false + } +} + +type portRewriteFunc func(packetData []byte, d *decoder, newPort uint16, portOffset int) error + +func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, port uint16, protocol gopacket.LayerType, rewriteFn portRewriteFunc) bool { + m.portDNATMutex.RLock() + defer m.portDNATMutex.RUnlock() + + for _, rule := range m.portDNATRules { + if rule.protocol != protocol || rule.targetIP.Compare(dstIP) != 0 { + continue + } + + if rule.targetPort == port && rule.targetIP.Compare(srcIP) == 0 { + return false + } + + if rule.origPort != port { + continue + } + + if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil { + m.logger.Error1("failed to rewrite port: %v", err) + return false + } + d.dnatOrigPort = rule.origPort + return true + } + return false +} + +// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum. +func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { + ipHeaderLen := int(d.ip4.IHL) * 4 + if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { + return errInvalidIPHeaderLength + } + + tcpStart := ipHeaderLen + if len(packetData) < tcpStart+4 { + return fmt.Errorf("packet too short for TCP header") + } + + portStart := tcpStart + portOffset + oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2]) + binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort) + + if len(packetData) >= tcpStart+18 { + checksumOffset := tcpStart + 16 + oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2]) + + var oldPortBytes, newPortBytes [2]byte + binary.BigEndian.PutUint16(oldPortBytes[:], oldPort) + binary.BigEndian.PutUint16(newPortBytes[:], newPort) + + newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:]) + binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) + } + + return nil +} + +// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum. +func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { + ipHeaderLen := int(d.ip4.IHL) * 4 + if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { + return errInvalidIPHeaderLength + } + + udpStart := ipHeaderLen + if len(packetData) < udpStart+8 { + return fmt.Errorf("packet too short for UDP header") + } + + portStart := udpStart + portOffset + oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2]) + binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort) + + checksumOffset := udpStart + 6 + if len(packetData) >= udpStart+8 { + oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2]) + if oldChecksum != 0 { + var oldPortBytes, newPortBytes [2]byte + binary.BigEndian.PutUint16(oldPortBytes[:], oldPort) + binary.BigEndian.PutUint16(newPortBytes[:], newPort) + + newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:]) + binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) + } + } + + return nil +} diff --git a/client/firewall/uspfilter/nat_bench_test.go b/client/firewall/uspfilter/nat_bench_test.go index 16dba682e..d726474cf 100644 --- a/client/firewall/uspfilter/nat_bench_test.go +++ b/client/firewall/uspfilter/nat_bench_test.go @@ -414,3 +414,127 @@ func BenchmarkChecksumOptimizations(b *testing.B) { } }) } + +// BenchmarkPortDNAT measures the performance of port DNAT operations +func BenchmarkPortDNAT(b *testing.B) { + scenarios := []struct { + name string + proto layers.IPProtocol + setupDNAT bool + useMatchPort bool + description string + }{ + { + name: "tcp_inbound_dnat_match", + proto: layers.IPProtocolTCP, + setupDNAT: true, + useMatchPort: true, + description: "TCP inbound port DNAT translation (22 → 22022)", + }, + { + name: "tcp_inbound_dnat_nomatch", + proto: layers.IPProtocolTCP, + setupDNAT: true, + useMatchPort: false, + description: "TCP inbound with DNAT configured but no port match", + }, + { + name: "tcp_inbound_no_dnat", + proto: layers.IPProtocolTCP, + setupDNAT: false, + useMatchPort: false, + description: "TCP inbound without DNAT (baseline)", + }, + { + name: "udp_inbound_dnat_match", + proto: layers.IPProtocolUDP, + setupDNAT: true, + useMatchPort: true, + description: "UDP inbound port DNAT translation (5353 → 22054)", + }, + { + name: "udp_inbound_dnat_nomatch", + proto: layers.IPProtocolUDP, + setupDNAT: true, + useMatchPort: false, + description: "UDP inbound with DNAT configured but no port match", + }, + { + name: "udp_inbound_no_dnat", + proto: layers.IPProtocolUDP, + setupDNAT: false, + useMatchPort: false, + description: "UDP inbound without DNAT (baseline)", + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + // Set logger to error level to reduce noise during benchmarking + manager.SetLogLevel(log.ErrorLevel) + defer func() { + // Restore to info level after benchmark + manager.SetLogLevel(log.InfoLevel) + }() + + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") + + var origPort, targetPort, testPort uint16 + if sc.proto == layers.IPProtocolTCP { + origPort, targetPort = 22, 22022 + } else { + origPort, targetPort = 5353, 22054 + } + + if sc.useMatchPort { + testPort = origPort + } else { + testPort = 443 // Different port + } + + // Setup port DNAT mapping if needed + if sc.setupDNAT { + err := manager.AddInboundDNAT(localAddr, protocolToFirewall(sc.proto), origPort, targetPort) + require.NoError(b, err) + } + + // Pre-establish inbound connection for outbound reverse test + if sc.setupDNAT && sc.useMatchPort { + inboundPacket := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, origPort) + manager.filterInbound(inboundPacket, 0) + } + + b.ResetTimer() + b.ReportAllocs() + + // Benchmark inbound DNAT translation + b.Run("inbound", func(b *testing.B) { + for i := 0; i < b.N; i++ { + // Create fresh packet each time + packet := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, testPort) + manager.filterInbound(packet, 0) + } + }) + + // Benchmark outbound reverse DNAT translation (only if DNAT is set up and port matches) + if sc.setupDNAT && sc.useMatchPort { + b.Run("outbound_reverse", func(b *testing.B) { + for i := 0; i < b.N; i++ { + // Create fresh return packet (from target port) + packet := generateDNATTestPacket(b, localAddr, clientIP, sc.proto, targetPort, 54321) + manager.filterOutbound(packet, 0) + } + }) + } + }) + } +} diff --git a/client/firewall/uspfilter/nat_test.go b/client/firewall/uspfilter/nat_test.go index 710abd445..2a285484c 100644 --- a/client/firewall/uspfilter/nat_test.go +++ b/client/firewall/uspfilter/nat_test.go @@ -8,6 +8,7 @@ import ( "github.com/google/gopacket/layers" "github.com/stretchr/testify/require" + firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/device" ) @@ -143,3 +144,111 @@ func TestDNATMappingManagement(t *testing.T) { err = manager.RemoveInternalDNATMapping(originalIP) require.Error(t, err, "Should error when removing non-existent mapping") } + +func TestInboundPortDNAT(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") + + testCases := []struct { + name string + protocol layers.IPProtocol + sourcePort uint16 + targetPort uint16 + }{ + {"TCP SSH", layers.IPProtocolTCP, 22, 22022}, + {"UDP DNS", layers.IPProtocolUDP, 5353, 22054}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := manager.AddInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort) + require.NoError(t, err) + + inboundPacket := generateDNATTestPacket(t, clientIP, localAddr, tc.protocol, 54321, tc.sourcePort) + d := parsePacket(t, inboundPacket) + + translated := manager.translateInboundPortDNAT(inboundPacket, d, clientIP, localAddr) + require.True(t, translated, "Inbound packet should be translated") + + d = parsePacket(t, inboundPacket) + var dstPort uint16 + switch tc.protocol { + case layers.IPProtocolTCP: + dstPort = uint16(d.tcp.DstPort) + case layers.IPProtocolUDP: + dstPort = uint16(d.udp.DstPort) + } + + require.Equal(t, tc.targetPort, dstPort, "Destination port should be rewritten to target port") + + err = manager.RemoveInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort) + require.NoError(t, err) + }) + } +} + +func TestInboundPortDNATNegative(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") + + err = manager.AddInboundDNAT(localAddr, firewall.ProtocolTCP, 22, 22022) + require.NoError(t, err) + + testCases := []struct { + name string + protocol layers.IPProtocol + srcIP netip.Addr + dstIP netip.Addr + srcPort uint16 + dstPort uint16 + }{ + {"Wrong port", layers.IPProtocolTCP, clientIP, localAddr, 54321, 80}, + {"Wrong IP", layers.IPProtocolTCP, clientIP, netip.MustParseAddr("100.64.0.99"), 54321, 22}, + {"Wrong protocol", layers.IPProtocolUDP, clientIP, localAddr, 54321, 22}, + {"ICMP", layers.IPProtocolICMPv4, clientIP, localAddr, 0, 0}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + packet := generateDNATTestPacket(t, tc.srcIP, tc.dstIP, tc.protocol, tc.srcPort, tc.dstPort) + d := parsePacket(t, packet) + + translated := manager.translateInboundPortDNAT(packet, d, tc.srcIP, tc.dstIP) + require.False(t, translated, "Packet should NOT be translated for %s", tc.name) + + d = parsePacket(t, packet) + if tc.protocol == layers.IPProtocolTCP { + require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged") + } else if tc.protocol == layers.IPProtocolUDP { + require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged") + } + }) + } +} + +func protocolToFirewall(proto layers.IPProtocol) firewall.Protocol { + switch proto { + case layers.IPProtocolTCP: + return firewall.ProtocolTCP + case layers.IPProtocolUDP: + return firewall.ProtocolUDP + default: + return firewall.ProtocolALL + } +} diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go index c75c0249d..c46a6581d 100644 --- a/client/firewall/uspfilter/tracer.go +++ b/client/firewall/uspfilter/tracer.go @@ -16,25 +16,33 @@ type PacketStage int const ( StageReceived PacketStage = iota + StageInboundPortDNAT + StageInbound1to1NAT StageConntrack StagePeerACL StageRouting StageRouteACL StageForwarding StageCompleted + StageOutbound1to1NAT + StageOutboundPortReverse ) const msgProcessingCompleted = "Processing completed" func (s PacketStage) String() string { return map[PacketStage]string{ - StageReceived: "Received", - StageConntrack: "Connection Tracking", - StagePeerACL: "Peer ACL", - StageRouting: "Routing", - StageRouteACL: "Route ACL", - StageForwarding: "Forwarding", - StageCompleted: "Completed", + StageReceived: "Received", + StageInboundPortDNAT: "Inbound Port DNAT", + StageInbound1to1NAT: "Inbound 1:1 NAT", + StageConntrack: "Connection Tracking", + StagePeerACL: "Peer ACL", + StageRouting: "Routing", + StageRouteACL: "Route ACL", + StageForwarding: "Forwarding", + StageCompleted: "Completed", + StageOutbound1to1NAT: "Outbound 1:1 NAT", + StageOutboundPortReverse: "Outbound DNAT Reverse", }[s] } @@ -261,6 +269,10 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa } func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace { + if m.handleInboundDNAT(trace, packetData, d, &srcIP, &dstIP) { + return trace + } + if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) { return trace } @@ -400,7 +412,16 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str } func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { - // will create or update the connection state + d := m.decoders.Get().(*decoder) + defer m.decoders.Put(d) + + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageCompleted, "Packet dropped - decode error", false) + return trace + } + + m.handleOutboundDNAT(trace, packetData, d) + dropped := m.filterOutbound(packetData, 0) if dropped { trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) @@ -409,3 +430,199 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr } return trace } + +func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool { + portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d) + if portDNATApplied { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false) + return true + } + *srcIP, *dstIP = m.extractIPs(d) + trace.DestinationPort = m.getDestPort(d) + } + + nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d) + if nat1to1Applied { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false) + return true + } + *srcIP, *dstIP = m.extractIPs(d) + } + + return false +} + +func (m *Manager) traceInboundPortDNAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.portDNATEnabled.Load() { + trace.AddResult(StageInboundPortDNAT, "Port DNAT not enabled", true) + return false + } + + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { + trace.AddResult(StageInboundPortDNAT, "Not IPv4, skipping port DNAT", true) + return false + } + + if len(d.decoded) < 2 { + trace.AddResult(StageInboundPortDNAT, "No transport layer, skipping port DNAT", true) + return false + } + + protocol := d.decoded[1] + if protocol != layers.LayerTypeTCP && protocol != layers.LayerTypeUDP { + trace.AddResult(StageInboundPortDNAT, "Not TCP/UDP, skipping port DNAT", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + var originalPort uint16 + if protocol == layers.LayerTypeTCP { + originalPort = uint16(d.tcp.DstPort) + } else { + originalPort = uint16(d.udp.DstPort) + } + + translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP) + if translated { + ipHeaderLen := int((packetData[0] & 0x0F) * 4) + translatedPort := uint16(packetData[ipHeaderLen+2])<<8 | uint16(packetData[ipHeaderLen+3]) + + protoStr := "TCP" + if protocol == layers.LayerTypeUDP { + protoStr = "UDP" + } + msg := fmt.Sprintf("%s port DNAT applied: %s:%d -> %s:%d", protoStr, dstIP, originalPort, dstIP, translatedPort) + trace.AddResult(StageInboundPortDNAT, msg, true) + return true + } + + trace.AddResult(StageInboundPortDNAT, "No matching port DNAT rule", true) + return false +} + +func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.dnatEnabled.Load() { + trace.AddResult(StageInbound1to1NAT, "1:1 NAT not enabled", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + + translated := m.translateInboundReverse(packetData, d) + if translated { + m.dnatMutex.RLock() + translatedIP, exists := m.dnatBiMap.getOriginal(srcIP) + m.dnatMutex.RUnlock() + + if exists { + msg := fmt.Sprintf("1:1 NAT reverse applied: %s -> %s", srcIP, translatedIP) + trace.AddResult(StageInbound1to1NAT, msg, true) + return true + } + } + + trace.AddResult(StageInbound1to1NAT, "No matching 1:1 NAT rule", true) + return false +} + +func (m *Manager) handleOutboundDNAT(trace *PacketTrace, packetData []byte, d *decoder) { + m.traceOutbound1to1NAT(trace, packetData, d) + m.traceOutboundPortReverse(trace, packetData, d) +} + +func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.dnatEnabled.Load() { + trace.AddResult(StageOutbound1to1NAT, "1:1 NAT not enabled", true) + return false + } + + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + + translated := m.translateOutboundDNAT(packetData, d) + if translated { + m.dnatMutex.RLock() + translatedIP, exists := m.dnatMappings[dstIP] + m.dnatMutex.RUnlock() + + if exists { + msg := fmt.Sprintf("1:1 NAT applied: %s -> %s", dstIP, translatedIP) + trace.AddResult(StageOutbound1to1NAT, msg, true) + return true + } + } + + trace.AddResult(StageOutbound1to1NAT, "No matching 1:1 NAT rule", true) + return false +} + +func (m *Manager) traceOutboundPortReverse(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.portDNATEnabled.Load() { + trace.AddResult(StageOutboundPortReverse, "Port DNAT not enabled", true) + return false + } + + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { + trace.AddResult(StageOutboundPortReverse, "Not IPv4, skipping port reverse", true) + return false + } + + if len(d.decoded) < 2 { + trace.AddResult(StageOutboundPortReverse, "No transport layer, skipping port reverse", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + + var origPort uint16 + transport := d.decoded[1] + switch transport { + case layers.LayerTypeTCP: + srcPort := uint16(d.tcp.SrcPort) + dstPort := uint16(d.tcp.DstPort) + conn, exists := m.tcpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort) + if exists { + origPort = uint16(conn.DNATOrigPort.Load()) + } + if origPort != 0 { + msg := fmt.Sprintf("TCP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort) + trace.AddResult(StageOutboundPortReverse, msg, true) + return true + } + case layers.LayerTypeUDP: + srcPort := uint16(d.udp.SrcPort) + dstPort := uint16(d.udp.DstPort) + conn, exists := m.udpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort) + if exists { + origPort = uint16(conn.DNATOrigPort.Load()) + } + if origPort != 0 { + msg := fmt.Sprintf("UDP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort) + trace.AddResult(StageOutboundPortReverse, msg, true) + return true + } + default: + trace.AddResult(StageOutboundPortReverse, "Not TCP/UDP, skipping port reverse", true) + return false + } + + trace.AddResult(StageOutboundPortReverse, "No tracked connection for DNAT reverse", true) + return false +} + +func (m *Manager) getDestPort(d *decoder) uint16 { + if len(d.decoded) < 2 { + return 0 + } + switch d.decoded[1] { + case layers.LayerTypeTCP: + return uint16(d.tcp.DstPort) + case layers.LayerTypeUDP: + return uint16(d.udp.DstPort) + default: + return 0 + } +} diff --git a/client/firewall/uspfilter/tracer_test.go b/client/firewall/uspfilter/tracer_test.go index 46c115787..ee1bb8a23 100644 --- a/client/firewall/uspfilter/tracer_test.go +++ b/client/firewall/uspfilter/tracer_test.go @@ -104,6 +104,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -126,6 +128,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -153,6 +157,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -179,6 +185,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -204,6 +212,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -228,6 +238,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -246,6 +258,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -264,6 +278,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageCompleted, @@ -287,6 +303,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageCompleted, }, @@ -301,6 +319,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageOutbound1to1NAT, + StageOutboundPortReverse, StageCompleted, }, expectedAllow: true, @@ -319,6 +339,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -340,6 +362,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -362,6 +386,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -382,6 +408,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -406,6 +434,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageRouting, StagePeerACL, StageCompleted, diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 6aff53b92..7763f2417 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -4,12 +4,15 @@ import ( "context" "crypto/tls" "crypto/x509" + "errors" + "fmt" "runtime" "time" "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" @@ -17,6 +20,9 @@ import ( "github.com/netbirdio/netbird/util/embeddedroots" ) +// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready +var ErrConnectionShutdown = errors.New("connection shutdown before ready") + // Backoff returns a backoff configuration for gRPC calls func Backoff(ctx context.Context) backoff.BackOff { b := backoff.NewExponentialBackOff() @@ -25,6 +31,26 @@ func Backoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(b, ctx) } +// waitForConnectionReady blocks until the connection becomes ready or fails. +// Returns an error if the connection times out, is cancelled, or enters shutdown state. +func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error { + conn.Connect() + + state := conn.GetState() + for state != connectivity.Ready && state != connectivity.Shutdown { + if !conn.WaitForStateChange(ctx, state) { + return fmt.Errorf("wait state change from %s: %w", state, ctx.Err()) + } + state = conn.GetState() + } + + if state == connectivity.Shutdown { + return ErrConnectionShutdown + } + + return nil +} + // CreateConnection creates a gRPC client connection with the appropriate transport options. // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { @@ -42,22 +68,24 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone })) } - connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - - conn, err := grpc.DialContext( - connCtx, + conn, err := grpc.NewClient( addr, transportOption, WithCustomDialer(tlsEnabled, component), - grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, Timeout: 10 * time.Second, }), ) if err != nil { - log.Printf("DialContext error: %v", err) + return nil, fmt.Errorf("new client: %w", err) + } + + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + if err := waitForConnectionReady(ctx, conn); err != nil { + _ = conn.Close() return nil, err } diff --git a/client/grpc/dialer_generic.go b/client/grpc/dialer_generic.go index 96f347c64..479575996 100644 --- a/client/grpc/dialer_generic.go +++ b/client/grpc/dialer_generic.go @@ -18,7 +18,7 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) -func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { +func WithCustomDialer(_ bool, _ string) grpc.DialOption { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { if runtime.GOOS == "linux" { currentUser, err := user.Current() @@ -36,7 +36,6 @@ func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) if err != nil { - log.Errorf("Failed to dial: %s", err) return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err) } return conn, nil diff --git a/client/iface/configurer/kernel_unix.go b/client/iface/configurer/kernel_unix.go index 84afc38f5..96b286175 100644 --- a/client/iface/configurer/kernel_unix.go +++ b/client/iface/configurer/kernel_unix.go @@ -73,6 +73,44 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, return nil } +func (c *KernelConfigurer) RemoveEndpointAddress(peerKey string) error { + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + + // Get the existing peer to preserve its allowed IPs + existingPeer, err := c.getPeer(c.deviceName, peerKey) + if err != nil { + return fmt.Errorf("get peer: %w", err) + } + + removePeerCfg := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + Remove: true, + } + + if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{removePeerCfg}}); err != nil { + return fmt.Errorf(`error removing peer %s from interface %s: %w`, peerKey, c.deviceName, err) + } + + //Re-add the peer without the endpoint but same AllowedIPs + reAddPeerCfg := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + AllowedIPs: existingPeer.AllowedIPs, + ReplaceAllowedIPs: true, + } + + if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{reAddPeerCfg}}); err != nil { + return fmt.Errorf( + `error re-adding peer %s to interface %s with allowed IPs %v: %w`, + peerKey, c.deviceName, existingPeer.AllowedIPs, err, + ) + } + + return nil +} + func (c *KernelConfigurer) RemovePeer(peerKey string) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index f744e0127..bc875b73c 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -106,6 +106,67 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, return nil } +func (c *WGUSPConfigurer) RemoveEndpointAddress(peerKey string) error { + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return fmt.Errorf("parse peer key: %w", err) + } + + ipcStr, err := c.device.IpcGet() + if err != nil { + return fmt.Errorf("get IPC config: %w", err) + } + + // Parse current status to get allowed IPs for the peer + stats, err := parseStatus(c.deviceName, ipcStr) + if err != nil { + return fmt.Errorf("parse IPC config: %w", err) + } + + var allowedIPs []net.IPNet + found := false + for _, peer := range stats.Peers { + if peer.PublicKey == peerKey { + allowedIPs = peer.AllowedIPs + found = true + break + } + } + if !found { + return fmt.Errorf("peer %s not found", peerKey) + } + + // remove the peer from the WireGuard configuration + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + Remove: true, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil { + return fmt.Errorf("failed to remove peer: %s", ipcErr) + } + + // Build the peer config + peer = wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + ReplaceAllowedIPs: true, + AllowedIPs: allowedIPs, + } + + config = wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + + if err := c.device.IpcSet(toWgUserspaceString(config)); err != nil { + return fmt.Errorf("remove endpoint address: %w", err) + } + + return nil +} + func (c *WGUSPConfigurer) RemovePeer(peerKey string) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { diff --git a/client/iface/device.go b/client/iface/device.go index 921f0ea98..c0c829825 100644 --- a/client/iface/device.go +++ b/client/iface/device.go @@ -23,4 +23,5 @@ type WGTunDevice interface { FilteredDevice() *device.FilteredDevice Device() *wgdevice.Device GetNet() *netstack.Net + GetICEBind() device.EndpointManager } diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index a731684cc..48346fc0f 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -150,6 +150,11 @@ func (t *WGTunDevice) GetNet() *netstack.Net { return nil } +// GetICEBind returns the ICEBind instance +func (t *WGTunDevice) GetICEBind() EndpointManager { + return t.iceBind +} + func routesToString(routes []string) string { return strings.Join(routes, ";") } diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index 390efe088..acd5f6f11 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -154,3 +154,8 @@ func (t *TunDevice) assignAddr() error { func (t *TunDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *TunDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index 96e4c8bcf..f96edf992 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -144,3 +144,8 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice { func (t *TunDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *TunDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index cdac43a53..2a836f846 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -179,3 +179,8 @@ func (t *TunKernelDevice) assignAddr() error { func (t *TunKernelDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns nil for kernel mode devices +func (t *TunKernelDevice) GetICEBind() EndpointManager { + return nil +} diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index e37321b68..40d8fdac8 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -21,6 +21,7 @@ type Bind interface { conn.Bind GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) ActivityRecorder() *bind.ActivityRecorder + EndpointManager } type TunNetstackDevice struct { @@ -155,3 +156,8 @@ func (t *TunNetstackDevice) Device() *device.Device { func (t *TunNetstackDevice) GetNet() *netstack.Net { return t.net } + +// GetICEBind returns the bind instance +func (t *TunNetstackDevice) GetICEBind() EndpointManager { + return t.bind +} diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 4cdd70a32..24654fc03 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -146,3 +146,8 @@ func (t *USPDevice) assignAddr() error { func (t *USPDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *USPDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index f1023bc0a..96350df8a 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -185,3 +185,8 @@ func (t *TunDevice) assignAddr() error { func (t *TunDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *TunDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/endpoint_manager.go b/client/iface/device/endpoint_manager.go new file mode 100644 index 000000000..b53888baa --- /dev/null +++ b/client/iface/device/endpoint_manager.go @@ -0,0 +1,13 @@ +package device + +import ( + "net" + "net/netip" +) + +// EndpointManager manages fake IP to connection mappings for userspace bind implementations. +// Implemented by bind.ICEBind and bind.RelayBindJS. +type EndpointManager interface { + SetEndpoint(fakeIP netip.Addr, conn net.Conn) + RemoveEndpoint(fakeIP netip.Addr) +} diff --git a/client/iface/device/interface.go b/client/iface/device/interface.go index 1f40b0d46..db53d9c3a 100644 --- a/client/iface/device/interface.go +++ b/client/iface/device/interface.go @@ -21,4 +21,5 @@ type WGConfigurer interface { GetStats() (map[string]configurer.WGStats, error) FullStats() (*configurer.Stats, error) LastActivities() map[string]monotime.Time + RemoveEndpointAddress(peerKey string) error } diff --git a/client/iface/device_android.go b/client/iface/device_android.go index 4649b8b97..cdfcea48d 100644 --- a/client/iface/device_android.go +++ b/client/iface/device_android.go @@ -21,4 +21,5 @@ type WGTunDevice interface { FilteredDevice() *device.FilteredDevice Device() *wgdevice.Device GetNet() *netstack.Net + GetICEBind() device.EndpointManager } diff --git a/client/iface/iface.go b/client/iface/iface.go index 609572561..07235a995 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -80,6 +80,17 @@ func (w *WGIface) GetProxy() wgproxy.Proxy { return w.wgProxyFactory.GetProxy() } +// GetBind returns the EndpointManager userspace bind mode. +func (w *WGIface) GetBind() device.EndpointManager { + w.mu.Lock() + defer w.mu.Unlock() + + if w.tun == nil { + return nil + } + return w.tun.GetICEBind() +} + // IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind func (w *WGIface) IsUserspaceBind() bool { return w.userspaceBind @@ -148,6 +159,17 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) } +func (w *WGIface) RemoveEndpointAddress(peerKey string) error { + w.mu.Lock() + defer w.mu.Unlock() + if w.configurer == nil { + return ErrIfaceNotFound + } + + log.Debugf("Removing endpoint address: %s", peerKey) + return w.configurer.RemoveEndpointAddress(peerKey) +} + // RemovePeer removes a Wireguard Peer from the interface iface func (w *WGIface) RemovePeer(peerKey string) error { w.mu.Lock() diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index 5ca950297..965decc73 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -29,11 +29,6 @@ type Manager interface { ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) } -type protoMatch struct { - ips map[string]int - policyID []byte -} - // DefaultManager uses firewall manager to handle type DefaultManager struct { firewall firewall.Manager @@ -86,21 +81,14 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout } func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { - rules, squashedProtocols := d.squashAcceptRules(networkMap) + rules := networkMap.FirewallRules enableSSH := networkMap.PeerConfig != nil && networkMap.PeerConfig.SshConfig != nil && networkMap.PeerConfig.SshConfig.SshEnabled - if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { - enableSSH = enableSSH && !ok - } - if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok { - enableSSH = enableSSH && !ok - } - // if TCP protocol rules not squashed and SSH enabled - // we add default firewall rule which accepts connection to any peer - // in the network by SSH (TCP 22 port). + // If SSH enabled, add default firewall rule which accepts connection to any peer + // in the network by SSH (TCP port defined by ssh.DefaultSSHPort). if enableSSH { rules = append(rules, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", @@ -368,145 +356,6 @@ func (d *DefaultManager) getPeerRuleID( return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr)))) } -// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type -// to all peers in the network map to one rule which just accepts that type of the traffic. -// -// NOTE: It will not squash two rules for same protocol if one covers all peers in the network, -// but other has port definitions or has drop policy. -func (d *DefaultManager) squashAcceptRules( - networkMap *mgmProto.NetworkMap, -) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) { - totalIPs := 0 - for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) { - for range p.AllowedIps { - totalIPs++ - } - } - - in := map[mgmProto.RuleProtocol]*protoMatch{} - out := map[mgmProto.RuleProtocol]*protoMatch{} - - // trace which type of protocols was squashed - squashedRules := []*mgmProto.FirewallRule{} - squashedProtocols := map[mgmProto.RuleProtocol]struct{}{} - - // this function we use to do calculation, can we squash the rules by protocol or not. - // We summ amount of Peers IP for given protocol we found in original rules list. - // But we zeroed the IP's for protocol if: - // 1. Any of the rule has DROP action type. - // 2. Any of rule contains Port. - // - // We zeroed this to notify squash function that this protocol can't be squashed. - addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) { - hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP || - r.Port != "" || !portInfoEmpty(r.PortInfo) - - if hasPortRestrictions { - // Don't squash rules with port restrictions - protocols[r.Protocol] = &protoMatch{ips: map[string]int{}} - return - } - - if _, ok := protocols[r.Protocol]; !ok { - protocols[r.Protocol] = &protoMatch{ - ips: map[string]int{}, - // store the first encountered PolicyID for this protocol - policyID: r.PolicyID, - } - } - - // special case, when we receive this all network IP address - // it means that rules for that protocol was already optimized on the - // management side - if r.PeerIP == "0.0.0.0" { - squashedRules = append(squashedRules, r) - squashedProtocols[r.Protocol] = struct{}{} - return - } - - ipset := protocols[r.Protocol].ips - - if _, ok := ipset[r.PeerIP]; ok { - return - } - ipset[r.PeerIP] = i - } - - for i, r := range networkMap.FirewallRules { - // calculate squash for different directions - if r.Direction == mgmProto.RuleDirection_IN { - addRuleToCalculationMap(i, r, in) - } else { - addRuleToCalculationMap(i, r, out) - } - } - - // order of squashing by protocol is important - // only for their first element ALL, it must be done first - protocolOrders := []mgmProto.RuleProtocol{ - mgmProto.RuleProtocol_ALL, - mgmProto.RuleProtocol_ICMP, - mgmProto.RuleProtocol_TCP, - mgmProto.RuleProtocol_UDP, - } - - squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) { - for _, protocol := range protocolOrders { - match, ok := matches[protocol] - if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 { - // don't squash if : - // 1. Rules not cover all peers in the network - // 2. Rules cover only one peer in the network. - continue - } - - // add special rule 0.0.0.0 which allows all IP's in our firewall implementations - squashedRules = append(squashedRules, &mgmProto.FirewallRule{ - PeerIP: "0.0.0.0", - Direction: direction, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: protocol, - PolicyID: match.policyID, - }) - squashedProtocols[protocol] = struct{}{} - - if protocol == mgmProto.RuleProtocol_ALL { - // if we have ALL traffic type squashed rule - // it allows all other type of traffic, so we can stop processing - break - } - } - } - - squash(in, mgmProto.RuleDirection_IN) - squash(out, mgmProto.RuleDirection_OUT) - - // if all protocol was squashed everything is allow and we can ignore all other rules - if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { - return squashedRules, squashedProtocols - } - - if len(squashedRules) == 0 { - return networkMap.FirewallRules, squashedProtocols - } - - var rules []*mgmProto.FirewallRule - // filter out rules which was squashed from final list - // if we also have other not squashed rules. - for i, r := range networkMap.FirewallRules { - if _, ok := squashedProtocols[r.Protocol]; ok { - if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i { - continue - } else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i { - continue - } - } - rules = append(rules, r) - } - - return append(rules, squashedRules...), squashedProtocols -} - // getRuleGroupingSelector takes all rule properties except IP address to build selector func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string { return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo) diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 664476ef4..daf4979ce 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -188,492 +188,6 @@ func TestDefaultManagerStateless(t *testing.T) { }) } -func TestDefaultManagerSquashRules(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - {AllowedIps: []string{"10.93.0.4"}}, - }, - FirewallRules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - }, - } - - manager := &DefaultManager{} - rules, _ := manager.squashAcceptRules(networkMap) - assert.Equal(t, 2, len(rules)) - - r := rules[0] - assert.Equal(t, "0.0.0.0", r.PeerIP) - assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction) - assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol) - assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action) - - r = rules[1] - assert.Equal(t, "0.0.0.0", r.PeerIP) - assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction) - assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol) - assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action) -} - -func TestDefaultManagerSquashRulesNoAffect(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - {AllowedIps: []string{"10.93.0.4"}}, - }, - FirewallRules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - }, - } - - manager := &DefaultManager{} - rules, _ := manager.squashAcceptRules(networkMap) - assert.Equal(t, len(networkMap.FirewallRules), len(rules)) -} - -func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) { - tests := []struct { - name string - rules []*mgmProto.FirewallRule - expectedCount int - description string - }{ - { - name: "should not squash rules with port ranges", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - }, - expectedCount: 4, - description: "Rules with port ranges should not be squashed even if they cover all peers", - }, - { - name: "should not squash rules with specific ports", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - }, - expectedCount: 4, - description: "Rules with specific ports should not be squashed even if they cover all peers", - }, - { - name: "should not squash rules with legacy port field", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - }, - expectedCount: 4, - description: "Rules with legacy port field should not be squashed", - }, - { - name: "should not squash rules with DROP action", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - }, - expectedCount: 4, - description: "Rules with DROP action should not be squashed", - }, - { - name: "should squash rules without port restrictions", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - }, - expectedCount: 1, - description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule", - }, - { - name: "mixed rules should not squash protocol with port restrictions", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - }, - expectedCount: 4, - description: "TCP should not be squashed because one rule has port restrictions", - }, - { - name: "should squash UDP but not TCP when TCP has port restrictions", - rules: []*mgmProto.FirewallRule{ - // TCP rules with port restrictions - should NOT be squashed - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - // UDP rules without port restrictions - SHOULD be squashed - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - }, - expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0) - description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - {AllowedIps: []string{"10.93.0.4"}}, - }, - FirewallRules: tt.rules, - } - - manager := &DefaultManager{} - rules, _ := manager.squashAcceptRules(networkMap) - - assert.Equal(t, tt.expectedCount, len(rules), tt.description) - - // For squashed rules, verify we get the expected 0.0.0.0 rule - if tt.expectedCount == 1 { - assert.Equal(t, "0.0.0.0", rules[0].PeerIP) - assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction) - assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action) - } - }) - } -} - func TestPortInfoEmpty(t *testing.T) { tests := []struct { name string diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index ec920c5f3..442f54e71 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -47,7 +47,7 @@ nftables.txt: Anonymized nftables rules with packet counters, if --system-info f resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder. config.txt: Anonymized configuration information of the NetBird client. network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules. -state.json: Anonymized client state dump containing netbird states. +state.json: Anonymized client state dump containing netbird states for the active profile. mutex.prof: Mutex profiling information. goroutine.prof: Goroutine profiling information. block.prof: Block profiling information. @@ -564,6 +564,8 @@ func (g *BundleGenerator) addStateFile() error { return nil } + log.Debugf("Adding state file from: %s", path) + data, err := os.ReadFile(path) if err != nil { if errors.Is(err, fs.ErrNotExist) { diff --git a/client/internal/debug/wgshow.go b/client/internal/debug/wgshow.go index e4b4c2368..8233ca510 100644 --- a/client/internal/debug/wgshow.go +++ b/client/internal/debug/wgshow.go @@ -14,6 +14,9 @@ type WGIface interface { } func (g *BundleGenerator) addWgShow() error { + if g.statusRecorder == nil { + return fmt.Errorf("no status recorder available for wg show") + } result, err := g.statusRecorder.PeersStatus() if err != nil { return err diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index b06ba73ab..71badf0d4 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -13,6 +13,7 @@ import ( "strings" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -50,28 +51,21 @@ func (s *systemConfigurator) supportCustomPort() bool { } func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { - var err error - - if err := stateManager.UpdateState(&ShutdownState{}); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } - var ( searchDomains []string matchDomains []string ) - err = s.recordSystemDNSSettings(true) - if err != nil { + if err := s.recordSystemDNSSettings(true); err != nil { log.Errorf("unable to update record of System's DNS config: %s", err.Error()) } if config.RouteAll { searchDomains = append(searchDomains, "\"\"") - err = s.addLocalDNS() - if err != nil { - log.Infof("failed to enable split DNS") + if err := s.addLocalDNS(); err != nil { + log.Warnf("failed to add local DNS: %v", err) } + s.updateState(stateManager) } for _, dConf := range config.Domains { @@ -86,6 +80,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * } matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + var err error if len(matchDomains) != 0 { err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort) } else { @@ -95,6 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * if err != nil { return fmt.Errorf("add match domains: %w", err) } + s.updateState(stateManager) searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) if len(searchDomains) != 0 { @@ -106,6 +102,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * if err != nil { return fmt.Errorf("add search domains: %w", err) } + s.updateState(stateManager) if err := s.flushDNSCache(); err != nil { log.Errorf("failed to flush DNS cache: %v", err) @@ -114,6 +111,12 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * return nil } +func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) { + if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil { + log.Errorf("failed to update shutdown state: %s", err) + } +} + func (s *systemConfigurator) string() string { return "scutil" } @@ -167,18 +170,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { func (s *systemConfigurator) addLocalDNS() error { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { if err := s.recordSystemDNSSettings(true); err != nil { - log.Errorf("Unable to get system DNS configuration") return fmt.Errorf("recordSystemDNSSettings(): %w", err) } } localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) - if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 { - err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort) - if err != nil { - return fmt.Errorf("couldn't add local network DNS conf: %w", err) - } - } else { + if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { log.Info("Not enabling local DNS server") + return nil + } + + if err := s.addSearchDomains( + localKey, + strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, + ); err != nil { + return fmt.Errorf("add search domains: %w", err) } return nil diff --git a/client/internal/dns/host_darwin_test.go b/client/internal/dns/host_darwin_test.go new file mode 100644 index 000000000..c4efd17b0 --- /dev/null +++ b/client/internal/dns/host_darwin_test.go @@ -0,0 +1,111 @@ +//go:build !ios + +package dns + +import ( + "context" + "net/netip" + "os/exec" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) { + if testing.Short() { + t.Skip("skipping scutil integration test in short mode") + } + + tmpDir := t.TempDir() + stateFile := filepath.Join(tmpDir, "state.json") + + sm := statemanager.New(stateFile) + sm.RegisterState(&ShutdownState{}) + sm.Start() + defer func() { + require.NoError(t, sm.Stop(context.Background())) + }() + + configurator := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + } + + config := HostDNSConfig{ + ServerIP: netip.MustParseAddr("100.64.0.1"), + ServerPort: 53, + RouteAll: true, + Domains: []DomainConfig{ + {Domain: "example.com", MatchOnly: true}, + }, + } + + err := configurator.applyDNSConfig(config, sm) + require.NoError(t, err) + + require.NoError(t, sm.PersistState(context.Background())) + + searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) + matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) + + defer func() { + for _, key := range []string{searchKey, matchKey, localKey} { + _ = removeTestDNSKey(key) + } + }() + + for _, key := range []string{searchKey, matchKey, localKey} { + exists, err := checkDNSKeyExists(key) + require.NoError(t, err) + if exists { + t.Logf("Key %s exists before cleanup", key) + } + } + + sm2 := statemanager.New(stateFile) + sm2.RegisterState(&ShutdownState{}) + err = sm2.LoadState(&ShutdownState{}) + require.NoError(t, err) + + state := sm2.GetState(&ShutdownState{}) + if state == nil { + t.Skip("State not saved, skipping cleanup test") + } + + shutdownState, ok := state.(*ShutdownState) + require.True(t, ok) + + err = shutdownState.Cleanup() + require.NoError(t, err) + + for _, key := range []string{searchKey, matchKey, localKey} { + exists, err := checkDNSKeyExists(key) + require.NoError(t, err) + assert.False(t, exists, "Key %s should NOT exist after cleanup", key) + } +} + +func checkDNSKeyExists(key string) (bool, error) { + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader("show " + key + "\nquit\n") + output, err := cmd.CombinedOutput() + if err != nil { + if strings.Contains(string(output), "No such key") { + return false, nil + } + return false, err + } + return !strings.Contains(string(output), "No such key"), nil +} + +func removeTestDNSKey(key string) error { + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n") + _, err := cmd.CombinedOutput() + return err +} diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index a14a01f40..01b7edc48 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -17,6 +17,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/internal/winregistry" ) var ( @@ -178,13 +179,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) } - if err := stateManager.UpdateState(&ShutdownState{ - Guid: r.guid, - GPO: r.gpo, - NRPTEntryCount: r.nrptEntryCount, - }); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } + r.updateState(stateManager) var searchDomains, matchDomains []string for _, dConf := range config.Domains { @@ -197,6 +192,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, ".")) } + if err := r.removeDNSMatchPolicies(); err != nil { + log.Errorf("cleanup old dns match policies: %s", err) + } + if len(matchDomains) != 0 { count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP) if err != nil { @@ -204,19 +203,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager } r.nrptEntryCount = count } else { - if err := r.removeDNSMatchPolicies(); err != nil { - return fmt.Errorf("remove dns match policies: %w", err) - } r.nrptEntryCount = 0 } - if err := stateManager.UpdateState(&ShutdownState{ - Guid: r.guid, - GPO: r.gpo, - NRPTEntryCount: r.nrptEntryCount, - }); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } + r.updateState(stateManager) if err := r.updateSearchDomains(searchDomains); err != nil { return fmt.Errorf("update search domains: %w", err) @@ -227,6 +217,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager return nil } +func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) { + if err := stateManager.UpdateState(&ShutdownState{ + Guid: r.guid, + GPO: r.gpo, + NRPTEntryCount: r.nrptEntryCount, + }); err != nil { + log.Errorf("failed to update shutdown state: %s", err) + } +} + func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error { if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil { return fmt.Errorf("adding dns setup for all failed: %w", err) @@ -273,9 +273,9 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s return fmt.Errorf("remove existing dns policy: %w", err) } - regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE) + regKey, _, err := winregistry.CreateVolatileKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE) if err != nil { - return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err) + return fmt.Errorf("create volatile registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err) } defer closer(regKey) diff --git a/client/internal/dns/host_windows_test.go b/client/internal/dns/host_windows_test.go new file mode 100644 index 000000000..19496bf5a --- /dev/null +++ b/client/internal/dns/host_windows_test.go @@ -0,0 +1,102 @@ +package dns + +import ( + "fmt" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/windows/registry" +) + +// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up +// when the number of match domains decreases between configuration changes. +func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) { + if testing.Short() { + t.Skip("skipping registry integration test in short mode") + } + + defer cleanupRegistryKeys(t) + cleanupRegistryKeys(t) + + testIP := netip.MustParseAddr("100.64.0.1") + + // Create a test interface registry key so updateSearchDomains doesn't fail + testGUID := "{12345678-1234-1234-1234-123456789ABC}" + interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID + testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE) + require.NoError(t, err, "Should create test interface registry key") + testKey.Close() + defer func() { + _ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath) + }() + + cfg := ®istryConfigurator{ + guid: testGUID, + gpo: false, + } + + config5 := HostDNSConfig{ + ServerIP: testIP, + Domains: []DomainConfig{ + {Domain: "domain1.com", MatchOnly: true}, + {Domain: "domain2.com", MatchOnly: true}, + {Domain: "domain3.com", MatchOnly: true}, + {Domain: "domain4.com", MatchOnly: true}, + {Domain: "domain5.com", MatchOnly: true}, + }, + } + + err = cfg.applyDNSConfig(config5, nil) + require.NoError(t, err) + + // Verify all 5 entries exist + for i := 0; i < 5; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.True(t, exists, "Entry %d should exist after first config", i) + } + + config2 := HostDNSConfig{ + ServerIP: testIP, + Domains: []DomainConfig{ + {Domain: "domain1.com", MatchOnly: true}, + {Domain: "domain2.com", MatchOnly: true}, + }, + } + + err = cfg.applyDNSConfig(config2, nil) + require.NoError(t, err) + + // Verify first 2 entries exist + for i := 0; i < 2; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.True(t, exists, "Entry %d should exist after second config", i) + } + + // Verify entries 2-4 are cleaned up + for i := 2; i < 5; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i) + } +} + +func registryKeyExists(path string) (bool, error) { + k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE) + if err != nil { + if err == registry.ErrNotExist { + return false, nil + } + return false, err + } + k.Close() + return true, nil +} + +func cleanupRegistryKeys(*testing.T) { + cfg := ®istryConfigurator{nrptEntryCount: 10} + _ = cfg.removeDNSMatchPolicies() +} diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index 0e8a53a63..d9854c033 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -31,6 +31,7 @@ const ( systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute" systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains" systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC" + systemdDbusSetDNSOverTLSMethodSuffix = systemdDbusLinkInterface + ".SetDNSOverTLS" systemdDbusResolvConfModeForeign = "foreign" dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject" @@ -102,6 +103,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana log.Warnf("failed to set DNSSEC to 'no': %v", err) } + // We don't support DNSOverTLS. On some machines this is default on so we explicitly set it to off + if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethodSuffix, dnsSecDisabled); err != nil { + log.Warnf("failed to set DNSOverTLS to 'no': %v", err) + } + var ( searchDomains []string matchDomains []string diff --git a/client/internal/dns/unclean_shutdown_darwin.go b/client/internal/dns/unclean_shutdown_darwin.go index 9bbdd2b56..f51b5cf8d 100644 --- a/client/internal/dns/unclean_shutdown_darwin.go +++ b/client/internal/dns/unclean_shutdown_darwin.go @@ -7,6 +7,7 @@ import ( ) type ShutdownState struct { + CreatedKeys []string } func (s *ShutdownState) Name() string { @@ -19,6 +20,10 @@ func (s *ShutdownState) Cleanup() error { return fmt.Errorf("create host manager: %w", err) } + for _, key := range s.CreatedKeys { + manager.createdKeys[key] = struct{}{} + } + if err := manager.restoreUncleanShutdownDNS(); err != nil { return fmt.Errorf("restore unclean shutdown dns: %w", err) } diff --git a/client/internal/dnsfwd/cache.go b/client/internal/dnsfwd/cache.go new file mode 100644 index 000000000..43fe2d020 --- /dev/null +++ b/client/internal/dnsfwd/cache.go @@ -0,0 +1,78 @@ +package dnsfwd + +import ( + "net/netip" + "slices" + "strings" + "sync" + + "github.com/miekg/dns" +) + +type cache struct { + mu sync.RWMutex + records map[string]*cacheEntry +} + +type cacheEntry struct { + ip4Addrs []netip.Addr + ip6Addrs []netip.Addr +} + +func newCache() *cache { + return &cache{ + records: make(map[string]*cacheEntry), + } +} + +func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + entry, exists := c.records[normalizeDomain(domain)] + if !exists { + return nil, false + } + + switch reqType { + case dns.TypeA: + return slices.Clone(entry.ip4Addrs), true + case dns.TypeAAAA: + return slices.Clone(entry.ip6Addrs), true + default: + return nil, false + } + +} + +func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) { + c.mu.Lock() + defer c.mu.Unlock() + norm := normalizeDomain(domain) + entry, exists := c.records[norm] + if !exists { + entry = &cacheEntry{} + c.records[norm] = entry + } + + switch reqType { + case dns.TypeA: + entry.ip4Addrs = slices.Clone(addrs) + case dns.TypeAAAA: + entry.ip6Addrs = slices.Clone(addrs) + } +} + +// unset removes cached entries for the given domain and request type. +func (c *cache) unset(domain string) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.records, normalizeDomain(domain)) +} + +// normalizeDomain converts an input domain into a canonical form used as cache key: +// lowercase and fully-qualified (with trailing dot). +func normalizeDomain(domain string) string { + // dns.Fqdn ensures trailing dot; ToLower for consistent casing + return dns.Fqdn(strings.ToLower(domain)) +} diff --git a/client/internal/dnsfwd/cache_test.go b/client/internal/dnsfwd/cache_test.go new file mode 100644 index 000000000..c23f0f31d --- /dev/null +++ b/client/internal/dnsfwd/cache_test.go @@ -0,0 +1,86 @@ +package dnsfwd + +import ( + "net/netip" + "testing" +) + +func mustAddr(t *testing.T, s string) netip.Addr { + t.Helper() + a, err := netip.ParseAddr(s) + if err != nil { + t.Fatalf("parse addr %s: %v", s, err) + } + return a +} + +func TestCacheNormalization(t *testing.T) { + c := newCache() + + // Mixed case, without trailing dot + domainInput := "ExAmPlE.CoM" + ipv4 := []netip.Addr{mustAddr(t, "1.2.3.4")} + c.set(domainInput, 1 /* dns.TypeA */, ipv4) + + // Lookup with lower, with trailing dot + if got, ok := c.get("example.com.", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" { + t.Fatalf("expected cached IPv4 result via normalized key, got=%v ok=%v", got, ok) + } + + // Lookup with different casing again + if got, ok := c.get("EXAMPLE.COM", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" { + t.Fatalf("expected cached IPv4 result via different casing, got=%v ok=%v", got, ok) + } +} + +func TestCacheSeparateTypes(t *testing.T) { + c := newCache() + + domain := "test.local" + ipv4 := []netip.Addr{mustAddr(t, "10.0.0.1")} + ipv6 := []netip.Addr{mustAddr(t, "2001:db8::1")} + + c.set(domain, 1 /* A */, ipv4) + c.set(domain, 28 /* AAAA */, ipv6) + + got4, ok4 := c.get(domain, 1) + if !ok4 || len(got4) != 1 || got4[0] != ipv4[0] { + t.Fatalf("expected A record from cache, got=%v ok=%v", got4, ok4) + } + + got6, ok6 := c.get(domain, 28) + if !ok6 || len(got6) != 1 || got6[0] != ipv6[0] { + t.Fatalf("expected AAAA record from cache, got=%v ok=%v", got6, ok6) + } +} + +func TestCacheCloneOnGetAndSet(t *testing.T) { + c := newCache() + domain := "clone.test" + + src := []netip.Addr{mustAddr(t, "8.8.8.8")} + c.set(domain, 1, src) + + // Mutate source slice; cache should be unaffected + src[0] = mustAddr(t, "9.9.9.9") + + got, ok := c.get(domain, 1) + if !ok || len(got) != 1 || got[0].String() != "8.8.8.8" { + t.Fatalf("expected cached value to be independent of source slice, got=%v ok=%v", got, ok) + } + + // Mutate returned slice; internal cache should remain unchanged + got[0] = mustAddr(t, "4.4.4.4") + got2, ok2 := c.get(domain, 1) + if !ok2 || len(got2) != 1 || got2[0].String() != "8.8.8.8" { + t.Fatalf("expected returned slice to be a clone, got=%v ok=%v", got2, ok2) + } +} + +func TestCacheMiss(t *testing.T) { + c := newCache() + if got, ok := c.get("missing.example", 1); ok || got != nil { + t.Fatalf("expected cache miss, got=%v ok=%v", got, ok) + } +} + diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index d912919a1..7a262fa4c 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -46,6 +46,7 @@ type DNSForwarder struct { fwdEntries []*ForwarderEntry firewall firewaller resolver resolver + cache *cache } func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder { @@ -56,6 +57,7 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat firewall: firewall, statusRecorder: statusRecorder, resolver: net.DefaultResolver, + cache: newCache(), } } @@ -103,10 +105,39 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { f.mutex.Lock() defer f.mutex.Unlock() + // remove cache entries for domains that no longer appear + f.removeStaleCacheEntries(f.fwdEntries, entries) + f.fwdEntries = entries log.Debugf("Updated DNS forwarder with %d domains", len(entries)) } +// removeStaleCacheEntries unsets cache items for domains that were present +// in the old list but not present in the new list. +func (f *DNSForwarder) removeStaleCacheEntries(oldEntries, newEntries []*ForwarderEntry) { + if f.cache == nil { + return + } + + newSet := make(map[string]struct{}, len(newEntries)) + for _, e := range newEntries { + if e == nil { + continue + } + newSet[e.Domain.PunycodeString()] = struct{}{} + } + + for _, e := range oldEntries { + if e == nil { + continue + } + pattern := e.Domain.PunycodeString() + if _, ok := newSet[pattern]; !ok { + f.cache.unset(pattern) + } + } +} + func (f *DNSForwarder) Close(ctx context.Context) error { var result *multierror.Error @@ -171,6 +202,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns f.updateInternalState(ips, mostSpecificResId, matchingEntries) f.addIPsToResponse(resp, domain, ips) + f.cache.set(domain, question.Qtype, ips) return resp } @@ -282,29 +314,69 @@ func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns resp.Rcode = dns.RcodeSuccess } -// handleDNSError processes DNS lookup errors and sends an appropriate error response -func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) { +// handleDNSError processes DNS lookup errors and sends an appropriate error response. +func (f *DNSForwarder) handleDNSError( + ctx context.Context, + w dns.ResponseWriter, + question dns.Question, + resp *dns.Msg, + domain string, + err error, +) { + // Default to SERVFAIL; override below when appropriate. + resp.Rcode = dns.RcodeServerFailure + + qType := question.Qtype + qTypeName := dns.TypeToString[qType] + + // Prefer typed DNS errors; fall back to generic logging otherwise. var dnsErr *net.DNSError - - switch { - case errors.As(err, &dnsErr): - resp.Rcode = dns.RcodeServerFailure - if dnsErr.IsNotFound { - f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype) + if !errors.As(err, &dnsErr) { + log.Warnf(errResolveFailed, domain, err) + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write failure DNS response: %v", writeErr) } + return + } - if dnsErr.Server != "" { - log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err) - } else { - log.Warnf(errResolveFailed, domain, err) + // NotFound: set NXDOMAIN / appropriate code via helper. + if dnsErr.IsNotFound { + f.setResponseCodeForNotFound(ctx, resp, domain, qType) + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write failure DNS response: %v", writeErr) } - default: - resp.Rcode = dns.RcodeServerFailure + f.cache.set(domain, question.Qtype, nil) + return + } + + // Upstream failed but we might have a cached answer—serve it if present. + if ips, ok := f.cache.get(domain, qType); ok { + if len(ips) > 0 { + log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName) + f.addIPsToResponse(resp, domain, ips) + resp.Rcode = dns.RcodeSuccess + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write cached DNS response: %v", writeErr) + } + } else { // send NXDOMAIN / appropriate code if cache is empty + f.setResponseCodeForNotFound(ctx, resp, domain, qType) + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write failure DNS response: %v", writeErr) + } + } + return + } + + // No cache. Log with or without the server field for more context. + if dnsErr.Server != "" { + log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err) + } else { log.Warnf(errResolveFailed, domain, err) } - if err := w.WriteMsg(resp); err != nil { - log.Errorf("failed to write failure DNS response: %v", err) + // Write final failure response. + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write failure DNS response: %v", writeErr) } } diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index 57085e19a..c1c95a2c1 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -648,6 +648,95 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) { assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size") } +// Ensures that when the first query succeeds and populates the cache, +// a subsequent upstream failure still returns a successful response from cache. +func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { + mockResolver := &MockResolver{} + forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder.resolver = mockResolver + + d, err := domain.FromString("example.com") + require.NoError(t, err) + entries := []*ForwarderEntry{{Domain: d, ResID: "res-cache"}} + forwarder.UpdateDomains(entries) + + ip := netip.MustParseAddr("1.2.3.4") + + // First call resolves successfully and populates cache + mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")). + Return([]netip.Addr{ip}, nil).Once() + + // Second call fails upstream; forwarder should serve from cache + mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")). + Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once() + + // First query: populate cache + q1 := &dns.Msg{} + q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) + w1 := &test.MockResponseWriter{} + resp1 := forwarder.handleDNSQuery(w1, q1) + require.NotNil(t, resp1) + require.Equal(t, dns.RcodeSuccess, resp1.Rcode) + require.Len(t, resp1.Answer, 1) + + // Second query: serve from cache after upstream failure + q2 := &dns.Msg{} + q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) + var writtenResp *dns.Msg + w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} + _ = forwarder.handleDNSQuery(w2, q2) + + require.NotNil(t, writtenResp, "expected response to be written") + require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) + require.Len(t, writtenResp.Answer, 1) + + mockResolver.AssertExpectations(t) +} + +// Verifies that cache normalization works across casing and trailing dot variations. +func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { + mockResolver := &MockResolver{} + forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder.resolver = mockResolver + + d, err := domain.FromString("ExAmPlE.CoM") + require.NoError(t, err) + entries := []*ForwarderEntry{{Domain: d, ResID: "res-norm"}} + forwarder.UpdateDomains(entries) + + ip := netip.MustParseAddr("9.8.7.6") + + // Initial resolution with mixed case to populate cache + mixedQuery := "ExAmPlE.CoM" + mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(strings.ToLower(mixedQuery))). + Return([]netip.Addr{ip}, nil).Once() + + q1 := &dns.Msg{} + q1.SetQuestion(mixedQuery+".", dns.TypeA) + w1 := &test.MockResponseWriter{} + resp1 := forwarder.handleDNSQuery(w1, q1) + require.NotNil(t, resp1) + require.Equal(t, dns.RcodeSuccess, resp1.Rcode) + require.Len(t, resp1.Answer, 1) + + // Subsequent query without dot and upper case should hit cache even if upstream fails + // Forwarder lowercases and uses the question name as-is (no trailing dot here) + mockResolver.On("LookupNetIP", mock.Anything, "ip4", strings.ToLower("EXAMPLE.COM")). + Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once() + + q2 := &dns.Msg{} + q2.SetQuestion("EXAMPLE.COM", dns.TypeA) + var writtenResp *dns.Msg + w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} + _ = forwarder.handleDNSQuery(w2, q2) + + require.NotNil(t, writtenResp) + require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) + require.Len(t, writtenResp.Answer, 1) + + mockResolver.AssertExpectations(t) +} + func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { // Test complex overlapping pattern scenarios mockFirewall := &MockFirewall{} diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index 5c7a3fbdd..a1c0dff98 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -4,7 +4,9 @@ import ( "context" "fmt" "net" - "sync" + "net/netip" + "os" + "strconv" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" @@ -12,18 +14,14 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/peer" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" ) -var ( - // ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also - listenPort uint16 = 5353 - listenPortMu sync.RWMutex -) - const ( - dnsTTL = 60 //seconds + dnsTTL = 60 + envServerPort = "NB_DNS_FORWARDER_PORT" ) // ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list. @@ -36,24 +34,30 @@ type ForwarderEntry struct { type Manager struct { firewall firewall.Manager statusRecorder *peer.Status + localAddr netip.Addr + serverPort uint16 fwRules []firewall.Rule tcpRules []firewall.Rule dnsForwarder *DNSForwarder - port uint16 } -func ListenPort() uint16 { - listenPortMu.RLock() - defer listenPortMu.RUnlock() - return listenPort -} +func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr netip.Addr) *Manager { + serverPort := nbdns.ForwarderServerPort + if envPort := os.Getenv(envServerPort); envPort != "" { + if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 { + serverPort = uint16(port) + log.Infof("using custom DNS forwarder port from %s: %d", envServerPort, serverPort) + } else { + log.Warnf("invalid %s value %q, using default %d", envServerPort, envPort, nbdns.ForwarderServerPort) + } + } -func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager { return &Manager{ firewall: fw, statusRecorder: statusRecorder, - port: port, + localAddr: localAddr, + serverPort: serverPort, } } @@ -67,13 +71,21 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } - if m.port > 0 { - listenPortMu.Lock() - listenPort = m.port - listenPortMu.Unlock() + if m.localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + log.Warnf("failed to add DNS UDP DNAT rule: %v", err) + } else { + log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort) + } + + if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + log.Warnf("failed to add DNS TCP DNAT rule: %v", err) + } else { + log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort) + } } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder) + m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", m.serverPort), dnsTTL, m.firewall, m.statusRecorder) go func() { if err := m.dnsForwarder.Listen(fwdEntries); err != nil { // todo handle close error if it is exists @@ -98,6 +110,17 @@ func (m *Manager) Stop(ctx context.Context) error { } var mErr *multierror.Error + + if m.localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err)) + } + + if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err)) + } + } + if err := m.dropDNSFirewall(); err != nil { mErr = multierror.Append(mErr, err) } @@ -113,7 +136,7 @@ func (m *Manager) Stop(ctx context.Context) error { func (m *Manager) allowDNSFirewall() error { dport := &firewall.Port{ IsRange: false, - Values: []uint16{ListenPort()}, + Values: []uint16{m.serverPort}, } if m.firewall == nil { diff --git a/client/internal/engine.go b/client/internal/engine.go index 646e059d4..19d37eee1 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -203,8 +203,7 @@ type Engine struct { wgIfaceMonitor *WGIfaceMonitor wgIfaceMonitorWg sync.WaitGroup - // dns forwarder port - dnsFwdPort uint16 + probeStunTurn *relay.StunTurnProbe } // Peer is an instance of the Connection Peer @@ -247,7 +246,7 @@ func NewEngine( statusRecorder: statusRecorder, checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), - dnsFwdPort: dnsfwd.ListenPort(), + probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), } sm := profilemanager.NewServiceManager("") @@ -1060,10 +1059,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { protoDNSConfig = &mgmProto.DNSConfig{} } - if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil { + dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network) + + if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil { log.Errorf("failed to update dns server, err: %v", err) } + e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort) + // apply routes first, route related actions might depend on routing being enabled routes := toRoutes(networkMap.GetRoutes()) serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes) @@ -1084,7 +1087,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes) - e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort)) + e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries) // Ingress forward rules forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules()) @@ -1208,10 +1211,16 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE } func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config { + forwarderPort := uint16(protoDNSConfig.GetForwarderPort()) + if forwarderPort == 0 { + forwarderPort = nbdns.ForwarderClientPort + } + dnsUpdate := nbdns.Config{ ServiceEnable: protoDNSConfig.GetServiceEnable(), CustomZones: make([]nbdns.CustomZone, 0), NameServerGroups: make([]*nbdns.NameServerGroup, 0), + ForwarderPort: forwarderPort, } for _, zone := range protoDNSConfig.GetCustomZones() { @@ -1667,7 +1676,7 @@ func (e *Engine) getRosenpassAddr() string { // RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services // and updates the status recorder with the latest states. -func (e *Engine) RunHealthProbes() bool { +func (e *Engine) RunHealthProbes(waitForResult bool) bool { e.syncMsgMux.Lock() signalHealthy := e.signal.IsHealthy() @@ -1699,8 +1708,12 @@ func (e *Engine) RunHealthProbes() bool { } e.syncMsgMux.Unlock() - - results := e.probeICE(stuns, turns) + var results []relay.ProbeResult + if waitForResult { + results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns) + } else { + results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns) + } e.statusRecorder.UpdateRelayStates(results) relayHealthy := true @@ -1717,13 +1730,6 @@ func (e *Engine) RunHealthProbes() bool { return allHealthy } -func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult { - return append( - relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns), - relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)..., - ) -} - // restartEngine restarts the engine by cancelling the client context func (e *Engine) restartEngine() { e.syncMsgMux.Lock() @@ -1843,7 +1849,6 @@ func (e *Engine) GetWgAddr() netip.Addr { func (e *Engine) updateDNSForwarder( enabled bool, fwdEntries []*dnsfwd.ForwarderEntry, - forwarderPort uint16, ) { if e.config.DisableServerRoutes { return @@ -1860,20 +1865,17 @@ func (e *Engine) updateDNSForwarder( } if len(fwdEntries) > 0 { - switch { - case e.dnsForwardMgr == nil: - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort) + if e.dnsForwardMgr == nil { + localAddr := e.wgInterface.Address().IP + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, localAddr) + if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil } - log.Infof("started domain router service with %d entries", len(fwdEntries)) - case e.dnsFwdPort != forwarderPort: - log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) - e.restartDnsFwd(fwdEntries, forwarderPort) - e.dnsFwdPort = forwarderPort - default: + log.Infof("started domain router service with %d entries", len(fwdEntries)) + } else { e.dnsForwardMgr.UpdateDomains(fwdEntries) } } else if e.dnsForwardMgr != nil { @@ -1883,20 +1885,6 @@ func (e *Engine) updateDNSForwarder( } e.dnsForwardMgr = nil } - -} - -func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) { - log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) - // stop and start the forwarder to apply the new port - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort) - if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { - log.Errorf("failed to start DNS forward: %v", err) - e.dnsForwardMgr = nil - } } func (e *Engine) GetNet() (*netstack.Net, error) { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 344104405..2f1098100 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -105,6 +105,10 @@ type MockWGIface struct { LastActivitiesFunc func() map[string]monotime.Time } +func (m *MockWGIface) RemoveEndpointAddress(_ string) error { + return nil +} + func (m *MockWGIface) FullStats() (*configurer.Stats, error) { return nil, fmt.Errorf("not implemented") } diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index 690fdb7cc..98fe01912 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -28,6 +28,7 @@ type wgIfaceBase interface { UpdateAddr(newAddr string) error GetProxy() wgproxy.Proxy UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + RemoveEndpointAddress(key string) error RemovePeer(peerKey string) error AddAllowedIP(peerKey string, allowedIP netip.Prefix) error RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error diff --git a/client/internal/lazyconn/activity/lazy_conn.go b/client/internal/lazyconn/activity/lazy_conn.go new file mode 100644 index 000000000..2564a9905 --- /dev/null +++ b/client/internal/lazyconn/activity/lazy_conn.go @@ -0,0 +1,82 @@ +package activity + +import ( + "context" + "io" + "net" + "time" +) + +// lazyConn detects activity when WireGuard attempts to send packets. +// It does not deliver packets, only signals that activity occurred. +type lazyConn struct { + activityCh chan struct{} + ctx context.Context + cancel context.CancelFunc +} + +// newLazyConn creates a new lazyConn for activity detection. +func newLazyConn() *lazyConn { + ctx, cancel := context.WithCancel(context.Background()) + return &lazyConn{ + activityCh: make(chan struct{}, 1), + ctx: ctx, + cancel: cancel, + } +} + +// Read blocks until the connection is closed. +func (c *lazyConn) Read(_ []byte) (n int, err error) { + <-c.ctx.Done() + return 0, io.EOF +} + +// Write signals activity detection when ICEBind routes packets to this endpoint. +func (c *lazyConn) Write(b []byte) (n int, err error) { + if c.ctx.Err() != nil { + return 0, io.EOF + } + + select { + case c.activityCh <- struct{}{}: + default: + } + + return len(b), nil +} + +// ActivityChan returns the channel that signals when activity is detected. +func (c *lazyConn) ActivityChan() <-chan struct{} { + return c.activityCh +} + +// Close closes the connection. +func (c *lazyConn) Close() error { + c.cancel() + return nil +} + +// LocalAddr returns the local address. +func (c *lazyConn) LocalAddr() net.Addr { + return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort} +} + +// RemoteAddr returns the remote address. +func (c *lazyConn) RemoteAddr() net.Addr { + return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort} +} + +// SetDeadline sets the read and write deadlines. +func (c *lazyConn) SetDeadline(_ time.Time) error { + return nil +} + +// SetReadDeadline sets the deadline for future Read calls. +func (c *lazyConn) SetReadDeadline(_ time.Time) error { + return nil +} + +// SetWriteDeadline sets the deadline for future Write calls. +func (c *lazyConn) SetWriteDeadline(_ time.Time) error { + return nil +} diff --git a/client/internal/lazyconn/activity/listener_bind.go b/client/internal/lazyconn/activity/listener_bind.go new file mode 100644 index 000000000..792d04215 --- /dev/null +++ b/client/internal/lazyconn/activity/listener_bind.go @@ -0,0 +1,127 @@ +package activity + +import ( + "fmt" + "net" + "net/netip" + "sync" + + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/internal/lazyconn" +) + +type bindProvider interface { + GetBind() device.EndpointManager +} + +const ( + // lazyBindPort is an obscure port used for lazy peer endpoints to avoid confusion with real peers. + // The actual routing is done via fakeIP in ICEBind, not by this port. + lazyBindPort = 17473 +) + +// BindListener uses lazyConn with bind implementations for direct data passing in userspace bind mode. +type BindListener struct { + wgIface WgInterface + peerCfg lazyconn.PeerConfig + done sync.WaitGroup + + lazyConn *lazyConn + bind device.EndpointManager + fakeIP netip.Addr +} + +// NewBindListener creates a listener that passes data directly through bind using LazyConn. +// It automatically derives a unique fake IP from the peer's NetBird IP in the 127.2.x.x range. +func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyconn.PeerConfig) (*BindListener, error) { + fakeIP, err := deriveFakeIP(wgIface, cfg.AllowedIPs) + if err != nil { + return nil, fmt.Errorf("derive fake IP: %w", err) + } + + d := &BindListener{ + wgIface: wgIface, + peerCfg: cfg, + bind: bind, + fakeIP: fakeIP, + } + + if err := d.setupLazyConn(); err != nil { + return nil, fmt.Errorf("setup lazy connection: %v", err) + } + + d.done.Add(1) + return d, nil +} + +// deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP. +// Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y). +// It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface. +func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) { + if len(allowedIPs) == 0 { + return netip.Addr{}, fmt.Errorf("no allowed IPs for peer") + } + + ourNetwork := wgIface.Address().Network + + var peerIP netip.Addr + for _, allowedIP := range allowedIPs { + ip := allowedIP.Addr() + if !ip.Is4() { + continue + } + if ourNetwork.Contains(ip) { + peerIP = ip + break + } + } + + if !peerIP.IsValid() { + return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs") + } + + octets := peerIP.As4() + fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}) + return fakeIP, nil +} + +func (d *BindListener) setupLazyConn() error { + d.lazyConn = newLazyConn() + d.bind.SetEndpoint(d.fakeIP, d.lazyConn) + + endpoint := &net.UDPAddr{ + IP: d.fakeIP.AsSlice(), + Port: lazyBindPort, + } + return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, endpoint, nil) +} + +// ReadPackets blocks until activity is detected on the LazyConn or the listener is closed. +func (d *BindListener) ReadPackets() { + select { + case <-d.lazyConn.ActivityChan(): + d.peerCfg.Log.Infof("activity detected via LazyConn") + case <-d.lazyConn.ctx.Done(): + d.peerCfg.Log.Infof("exit from activity listener") + } + + d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey) + if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil { + d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err) + } + + _ = d.lazyConn.Close() + d.bind.RemoveEndpoint(d.fakeIP) + d.done.Done() +} + +// Close stops the listener and cleans up resources. +func (d *BindListener) Close() { + d.peerCfg.Log.Infof("closing activity listener (LazyConn)") + + if err := d.lazyConn.Close(); err != nil { + d.peerCfg.Log.Errorf("failed to close LazyConn: %s", err) + } + + d.done.Wait() +} diff --git a/client/internal/lazyconn/activity/listener_bind_test.go b/client/internal/lazyconn/activity/listener_bind_test.go new file mode 100644 index 000000000..f86dd3877 --- /dev/null +++ b/client/internal/lazyconn/activity/listener_bind_test.go @@ -0,0 +1,291 @@ +package activity + +import ( + "net" + "net/netip" + "runtime" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/internal/lazyconn" + peerid "github.com/netbirdio/netbird/client/internal/peer/id" +) + +func isBindListenerPlatform() bool { + return runtime.GOOS == "windows" || runtime.GOOS == "js" +} + +// mockEndpointManager implements device.EndpointManager for testing +type mockEndpointManager struct { + endpoints map[netip.Addr]net.Conn +} + +func newMockEndpointManager() *mockEndpointManager { + return &mockEndpointManager{ + endpoints: make(map[netip.Addr]net.Conn), + } +} + +func (m *mockEndpointManager) SetEndpoint(fakeIP netip.Addr, conn net.Conn) { + m.endpoints[fakeIP] = conn +} + +func (m *mockEndpointManager) RemoveEndpoint(fakeIP netip.Addr) { + delete(m.endpoints, fakeIP) +} + +func (m *mockEndpointManager) GetEndpoint(fakeIP netip.Addr) net.Conn { + return m.endpoints[fakeIP] +} + +// MockWGIfaceBind mocks WgInterface with bind support +type MockWGIfaceBind struct { + endpointMgr *mockEndpointManager +} + +func (m *MockWGIfaceBind) RemovePeer(string) error { + return nil +} + +func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error { + return nil +} + +func (m *MockWGIfaceBind) IsUserspaceBind() bool { + return true +} + +func (m *MockWGIfaceBind) Address() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + } +} + +func (m *MockWGIfaceBind) GetBind() device.EndpointManager { + return m.endpointMgr +} + +func TestBindListener_Creation(t *testing.T) { + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg) + require.NoError(t, err) + + expectedFakeIP := netip.MustParseAddr("127.2.0.2") + conn := mockEndpointMgr.GetEndpoint(expectedFakeIP) + require.NotNil(t, conn, "Endpoint should be registered in mock endpoint manager") + + _, ok := conn.(*lazyConn) + assert.True(t, ok, "Registered endpoint should be a lazyConn") + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } +} + +func TestBindListener_ActivityDetection(t *testing.T) { + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg) + require.NoError(t, err) + + activityDetected := make(chan struct{}) + go func() { + listener.ReadPackets() + close(activityDetected) + }() + + fakeIP := listener.fakeIP + conn := mockEndpointMgr.GetEndpoint(fakeIP) + require.NotNil(t, conn, "Endpoint should be registered") + + _, err = conn.Write([]byte{0x01, 0x02, 0x03}) + require.NoError(t, err) + + select { + case <-activityDetected: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity detection") + } + + assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity detection") +} + +func TestBindListener_Close(t *testing.T) { + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg) + require.NoError(t, err) + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + fakeIP := listener.fakeIP + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } + + assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after Close") +} + +func TestManager_BindMode(t *testing.T) { + if !isBindListenerPlatform() { + t.Skip("BindListener only used on Windows/JS platforms") + } + + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + mgr := NewManager(mockIface) + defer mgr.Close() + + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + err := mgr.MonitorPeerActivity(cfg) + require.NoError(t, err) + + listener, exists := mgr.GetPeerListener(cfg.PeerConnID) + require.True(t, exists, "Peer listener should be found") + + bindListener, ok := listener.(*BindListener) + require.True(t, ok, "Listener should be BindListener, got %T", listener) + + fakeIP := bindListener.fakeIP + conn := mockEndpointMgr.GetEndpoint(fakeIP) + require.NotNil(t, conn, "Endpoint should be registered") + + _, err = conn.Write([]byte{0x01, 0x02, 0x03}) + require.NoError(t, err) + + select { + case peerConnID := <-mgr.OnActivityChan: + assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match") + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity notification") + } + + assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity") +} + +func TestManager_BindMode_MultiplePeers(t *testing.T) { + if !isBindListenerPlatform() { + t.Skip("BindListener only used on Windows/JS platforms") + } + + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer1 := &MocPeer{PeerID: "testPeer1"} + peer2 := &MocPeer{PeerID: "testPeer2"} + mgr := NewManager(mockIface) + defer mgr.Close() + + cfg1 := lazyconn.PeerConfig{ + PublicKey: peer1.PeerID, + PeerConnID: peer1.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + cfg2 := lazyconn.PeerConfig{ + PublicKey: peer2.PeerID, + PeerConnID: peer2.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.3/32")}, + Log: log.WithField("peer", "testPeer2"), + } + + err := mgr.MonitorPeerActivity(cfg1) + require.NoError(t, err) + + err = mgr.MonitorPeerActivity(cfg2) + require.NoError(t, err) + + listener1, exists := mgr.GetPeerListener(cfg1.PeerConnID) + require.True(t, exists, "Peer1 listener should be found") + bindListener1 := listener1.(*BindListener) + + listener2, exists := mgr.GetPeerListener(cfg2.PeerConnID) + require.True(t, exists, "Peer2 listener should be found") + bindListener2 := listener2.(*BindListener) + + conn1 := mockEndpointMgr.GetEndpoint(bindListener1.fakeIP) + require.NotNil(t, conn1, "Peer1 endpoint should be registered") + _, err = conn1.Write([]byte{0x01}) + require.NoError(t, err) + + conn2 := mockEndpointMgr.GetEndpoint(bindListener2.fakeIP) + require.NotNil(t, conn2, "Peer2 endpoint should be registered") + _, err = conn2.Write([]byte{0x02}) + require.NoError(t, err) + + receivedPeers := make(map[peerid.ConnID]bool) + for i := 0; i < 2; i++ { + select { + case peerConnID := <-mgr.OnActivityChan: + receivedPeers[peerConnID] = true + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity notifications") + } + } + + assert.True(t, receivedPeers[cfg1.PeerConnID], "Peer1 activity should be received") + assert.True(t, receivedPeers[cfg2.PeerConnID], "Peer2 activity should be received") +} diff --git a/client/internal/lazyconn/activity/listener_test.go b/client/internal/lazyconn/activity/listener_test.go deleted file mode 100644 index 98d7838d2..000000000 --- a/client/internal/lazyconn/activity/listener_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package activity - -import ( - "testing" - "time" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/lazyconn" -) - -func TestNewListener(t *testing.T) { - peer := &MocPeer{ - PeerID: "examplePublicKey1", - } - - cfg := lazyconn.PeerConfig{ - PublicKey: peer.PeerID, - PeerConnID: peer.ConnID(), - Log: log.WithField("peer", "examplePublicKey1"), - } - - l, err := NewListener(MocWGIface{}, cfg) - if err != nil { - t.Fatalf("failed to create listener: %v", err) - } - - chanClosed := make(chan struct{}) - go func() { - defer close(chanClosed) - l.ReadPackets() - }() - - time.Sleep(1 * time.Second) - l.Close() - - select { - case <-chanClosed: - case <-time.After(time.Second): - } -} diff --git a/client/internal/lazyconn/activity/listener.go b/client/internal/lazyconn/activity/listener_udp.go similarity index 64% rename from client/internal/lazyconn/activity/listener.go rename to client/internal/lazyconn/activity/listener_udp.go index 817ff00c3..e0b09be6c 100644 --- a/client/internal/lazyconn/activity/listener.go +++ b/client/internal/lazyconn/activity/listener_udp.go @@ -11,26 +11,27 @@ import ( "github.com/netbirdio/netbird/client/internal/lazyconn" ) -// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking -type Listener struct { +// UDPListener uses UDP sockets for activity detection in kernel mode. +type UDPListener struct { wgIface WgInterface peerCfg lazyconn.PeerConfig conn *net.UDPConn endpoint *net.UDPAddr done sync.Mutex - isClosed atomic.Bool // use to avoid error log when closing the listener + isClosed atomic.Bool } -func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) { - d := &Listener{ +// NewUDPListener creates a listener that detects activity via UDP socket reads. +func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener, error) { + d := &UDPListener{ wgIface: wgIface, peerCfg: cfg, } conn, err := d.newConn() if err != nil { - return nil, fmt.Errorf("failed to creating activity listener: %v", err) + return nil, fmt.Errorf("create UDP connection: %v", err) } d.conn = conn d.endpoint = conn.LocalAddr().(*net.UDPAddr) @@ -38,12 +39,14 @@ func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error if err := d.createEndpoint(); err != nil { return nil, err } + d.done.Lock() - cfg.Log.Infof("created activity listener: %s", conn.LocalAddr().(*net.UDPAddr).String()) + cfg.Log.Infof("created activity listener: %s", d.conn.LocalAddr().(*net.UDPAddr).String()) return d, nil } -func (d *Listener) ReadPackets() { +// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed. +func (d *UDPListener) ReadPackets() { for { n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1)) if err != nil { @@ -64,15 +67,17 @@ func (d *Listener) ReadPackets() { } d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String()) - if err := d.removeEndpoint(); err != nil { + if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil { d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err) } - _ = d.conn.Close() // do not care err because some cases it will return "use of closed network connection" + // Ignore close error as it may return "use of closed network connection" if already closed. + _ = d.conn.Close() d.done.Unlock() } -func (d *Listener) Close() { +// Close stops the listener and cleans up resources. +func (d *UDPListener) Close() { d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String()) d.isClosed.Store(true) @@ -82,16 +87,12 @@ func (d *Listener) Close() { d.done.Lock() } -func (d *Listener) removeEndpoint() error { - return d.wgIface.RemovePeer(d.peerCfg.PublicKey) -} - -func (d *Listener) createEndpoint() error { +func (d *UDPListener) createEndpoint() error { d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String()) return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil) } -func (d *Listener) newConn() (*net.UDPConn, error) { +func (d *UDPListener) newConn() (*net.UDPConn, error) { addr := &net.UDPAddr{ Port: 0, IP: listenIP, diff --git a/client/internal/lazyconn/activity/listener_udp_test.go b/client/internal/lazyconn/activity/listener_udp_test.go new file mode 100644 index 000000000..d2adb9bf4 --- /dev/null +++ b/client/internal/lazyconn/activity/listener_udp_test.go @@ -0,0 +1,110 @@ +package activity + +import ( + "net" + "net/netip" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/lazyconn" +) + +func TestUDPListener_Creation(t *testing.T) { + mockIface := &MocWGIface{} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewUDPListener(mockIface, cfg) + require.NoError(t, err) + require.NotNil(t, listener.conn) + require.NotNil(t, listener.endpoint) + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } +} + +func TestUDPListener_ActivityDetection(t *testing.T) { + mockIface := &MocWGIface{} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewUDPListener(mockIface, cfg) + require.NoError(t, err) + + activityDetected := make(chan struct{}) + go func() { + listener.ReadPackets() + close(activityDetected) + }() + + conn, err := net.Dial("udp", listener.conn.LocalAddr().String()) + require.NoError(t, err) + defer conn.Close() + + _, err = conn.Write([]byte{0x01, 0x02, 0x03}) + require.NoError(t, err) + + select { + case <-activityDetected: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity detection") + } +} + +func TestUDPListener_Close(t *testing.T) { + mockIface := &MocWGIface{} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewUDPListener(mockIface, cfg) + require.NoError(t, err) + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } + + assert.True(t, listener.isClosed.Load(), "Listener should be marked as closed") +} diff --git a/client/internal/lazyconn/activity/manager.go b/client/internal/lazyconn/activity/manager.go index 915fb9cb8..db283ec9a 100644 --- a/client/internal/lazyconn/activity/manager.go +++ b/client/internal/lazyconn/activity/manager.go @@ -1,21 +1,32 @@ package activity import ( + "errors" "net" "net/netip" + "runtime" "sync" "time" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/lazyconn" peerid "github.com/netbirdio/netbird/client/internal/peer/id" ) +// listener defines the contract for activity detection listeners. +type listener interface { + ReadPackets() + Close() +} + type WgInterface interface { RemovePeer(peerKey string) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + IsUserspaceBind() bool + Address() wgaddr.Address } type Manager struct { @@ -23,7 +34,7 @@ type Manager struct { wgIface WgInterface - peers map[peerid.ConnID]*Listener + peers map[peerid.ConnID]listener done chan struct{} mu sync.Mutex @@ -33,7 +44,7 @@ func NewManager(wgIface WgInterface) *Manager { m := &Manager{ OnActivityChan: make(chan peerid.ConnID, 1), wgIface: wgIface, - peers: make(map[peerid.ConnID]*Listener), + peers: make(map[peerid.ConnID]listener), done: make(chan struct{}), } return m @@ -48,16 +59,38 @@ func (m *Manager) MonitorPeerActivity(peerCfg lazyconn.PeerConfig) error { return nil } - listener, err := NewListener(m.wgIface, peerCfg) + listener, err := m.createListener(peerCfg) if err != nil { return err } - m.peers[peerCfg.PeerConnID] = listener + m.peers[peerCfg.PeerConnID] = listener go m.waitForTraffic(listener, peerCfg.PeerConnID) return nil } +func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error) { + if !m.wgIface.IsUserspaceBind() { + return NewUDPListener(m.wgIface, peerCfg) + } + + // BindListener is only used on Windows and JS platforms: + // - JS: Cannot listen to UDP sockets + // - Windows: IP_UNICAST_IF socket option forces packets out the interface the default + // gateway points to, preventing them from reaching the loopback interface. + // BindListener bypasses this by passing data directly through the bind. + if runtime.GOOS != "windows" && runtime.GOOS != "js" { + return NewUDPListener(m.wgIface, peerCfg) + } + + provider, ok := m.wgIface.(bindProvider) + if !ok { + return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider") + } + + return NewBindListener(m.wgIface, provider.GetBind(), peerCfg) +} + func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) { m.mu.Lock() defer m.mu.Unlock() @@ -82,8 +115,8 @@ func (m *Manager) Close() { } } -func (m *Manager) waitForTraffic(listener *Listener, peerConnID peerid.ConnID) { - listener.ReadPackets() +func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) { + l.ReadPackets() m.mu.Lock() if _, ok := m.peers[peerConnID]; !ok { diff --git a/client/internal/lazyconn/activity/manager_test.go b/client/internal/lazyconn/activity/manager_test.go index ae6c31da4..0768d9219 100644 --- a/client/internal/lazyconn/activity/manager_test.go +++ b/client/internal/lazyconn/activity/manager_test.go @@ -9,6 +9,7 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/lazyconn" peerid "github.com/netbirdio/netbird/client/internal/peer/id" ) @@ -30,16 +31,26 @@ func (m MocWGIface) RemovePeer(string) error { func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error { return nil - } -// Add this method to the Manager struct -func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (*Listener, bool) { +func (m MocWGIface) IsUserspaceBind() bool { + return false +} + +func (m MocWGIface) Address() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + } +} + +// GetPeerListener is a test helper to access listeners +func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) { m.mu.Lock() defer m.mu.Unlock() - listener, exists := m.peers[peerConnID] - return listener, exists + l, exists := m.peers[peerConnID] + return l, exists } func TestManager_MonitorPeerActivity(t *testing.T) { @@ -65,7 +76,12 @@ func TestManager_MonitorPeerActivity(t *testing.T) { t.Fatalf("peer listener not found") } - if err := trigger(listener.conn.LocalAddr().String()); err != nil { + // Get the UDP listener's address for triggering + udpListener, ok := listener.(*UDPListener) + if !ok { + t.Fatalf("expected UDPListener") + } + if err := trigger(udpListener.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } @@ -97,7 +113,9 @@ func TestManager_RemovePeerActivity(t *testing.T) { t.Fatalf("failed to monitor peer activity: %v", err) } - addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String() + listener, _ := mgr.GetPeerListener(peerCfg1.PeerConnID) + udpListener, _ := listener.(*UDPListener) + addr := udpListener.conn.LocalAddr().String() mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID) @@ -147,7 +165,8 @@ func TestManager_MultiPeerActivity(t *testing.T) { t.Fatalf("peer listener for peer1 not found") } - if err := trigger(listener.conn.LocalAddr().String()); err != nil { + udpListener1, _ := listener.(*UDPListener) + if err := trigger(udpListener1.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } @@ -156,7 +175,8 @@ func TestManager_MultiPeerActivity(t *testing.T) { t.Fatalf("peer listener for peer2 not found") } - if err := trigger(listener.conn.LocalAddr().String()); err != nil { + udpListener2, _ := listener.(*UDPListener) + if err := trigger(udpListener2.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } diff --git a/client/internal/lazyconn/wgiface.go b/client/internal/lazyconn/wgiface.go index 0351904f7..0626c1815 100644 --- a/client/internal/lazyconn/wgiface.go +++ b/client/internal/lazyconn/wgiface.go @@ -7,6 +7,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/monotime" ) @@ -14,5 +15,6 @@ type WGIface interface { RemovePeer(peerKey string) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error IsUserspaceBind() bool + Address() wgaddr.Address LastActivities() map[string]monotime.Time } diff --git a/client/internal/netflow/logger/logger.go b/client/internal/netflow/logger/logger.go index 899faf108..a033a2a7c 100644 --- a/client/internal/netflow/logger/logger.go +++ b/client/internal/netflow/logger/logger.go @@ -10,10 +10,10 @@ import ( "github.com/google/uuid" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/netflow/store" "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/dns" ) type rcvChan chan *types.EventFields @@ -138,7 +138,8 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) { func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool { // check dns collection - if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) { + if !l.dnsCollection.Load() && event.Protocol == types.UDP && + (event.DestPort == 53 || event.DestPort == dns.ForwarderClientPort || event.DestPort == dns.ForwarderServerPort) { return false } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 8db9e58f4..68afe986a 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -171,9 +171,9 @@ func (conn *Conn) Open(engineCtx context.Context) error { conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay) - conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer) + conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer) if !isForceRelayed() { - conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) + conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer) } conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher) @@ -430,6 +430,9 @@ func (conn *Conn) onICEStateDisconnected() { } else { conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String()) conn.currentConnPriority = conntype.None + if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil { + conn.Log.Errorf("failed to remove wg endpoint: %v", err) + } } changed := conn.statusICE.Get() != worker.StatusDisconnected @@ -523,6 +526,9 @@ func (conn *Conn) onRelayDisconnected() { if conn.currentConnPriority == conntype.Relay { conn.Log.Debugf("clean up WireGuard config") conn.currentConnPriority = conntype.None + if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil { + conn.Log.Errorf("failed to remove wg endpoint: %v", err) + } } if conn.wgProxyRelay != nil { diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index c839ab147..6b47f95eb 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -79,10 +79,10 @@ func TestConn_OnRemoteOffer(t *testing.T) { return } - onNewOffeChan := make(chan struct{}) + onNewOfferChan := make(chan struct{}) - conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) { - onNewOffeChan <- struct{}{} + conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) { + onNewOfferChan <- struct{}{} }) conn.OnRemoteOffer(OfferAnswer{ @@ -98,7 +98,7 @@ func TestConn_OnRemoteOffer(t *testing.T) { defer cancel() select { - case <-onNewOffeChan: + case <-onNewOfferChan: // success case <-ctx.Done(): t.Error("expected to receive a new offer notification, but timed out") @@ -118,10 +118,10 @@ func TestConn_OnRemoteAnswer(t *testing.T) { return } - onNewOffeChan := make(chan struct{}) + onNewOfferChan := make(chan struct{}) - conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) { - onNewOffeChan <- struct{}{} + conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) { + onNewOfferChan <- struct{}{} }) conn.OnRemoteAnswer(OfferAnswer{ @@ -136,7 +136,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) { defer cancel() select { - case <-onNewOffeChan: + case <-onNewOfferChan: // success case <-ctx.Done(): t.Error("expected to receive a new offer notification, but timed out") diff --git a/client/internal/peer/guard/env.go b/client/internal/peer/guard/env.go new file mode 100644 index 000000000..1ea2d21be --- /dev/null +++ b/client/internal/peer/guard/env.go @@ -0,0 +1,20 @@ +package guard + +import ( + "os" + "strconv" + "time" +) + +const ( + envICEMonitorPeriod = "NB_ICE_MONITOR_PERIOD" +) + +func GetICEMonitorPeriod() time.Duration { + if envVal := os.Getenv(envICEMonitorPeriod); envVal != "" { + if seconds, err := strconv.Atoi(envVal); err == nil && seconds > 0 { + return time.Duration(seconds) * time.Second + } + } + return defaultCandidatesMonitorPeriod +} diff --git a/client/internal/peer/guard/ice_monitor.go b/client/internal/peer/guard/ice_monitor.go index 09cf9ae63..0f22ee7b0 100644 --- a/client/internal/peer/guard/ice_monitor.go +++ b/client/internal/peer/guard/ice_monitor.go @@ -16,8 +16,8 @@ import ( ) const ( - candidatesMonitorPeriod = 5 * time.Minute - candidateGatheringTimeout = 5 * time.Second + defaultCandidatesMonitorPeriod = 5 * time.Minute + candidateGatheringTimeout = 5 * time.Second ) type ICEMonitor struct { @@ -25,16 +25,19 @@ type ICEMonitor struct { iFaceDiscover stdnet.ExternalIFaceDiscover iceConfig icemaker.Config + tickerPeriod time.Duration currentCandidatesAddress []string candidatesMu sync.Mutex } -func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor { +func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config, period time.Duration) *ICEMonitor { + log.Debugf("prepare ICE monitor with period: %s", period) cm := &ICEMonitor{ ReconnectCh: make(chan struct{}, 1), iFaceDiscover: iFaceDiscover, iceConfig: config, + tickerPeriod: period, } return cm } @@ -46,7 +49,12 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) { return } - ticker := time.NewTicker(candidatesMonitorPeriod) + // Initial check to populate the candidates for later comparison + if _, err := cm.handleCandidateTick(ctx, ufrag, pwd); err != nil { + log.Warnf("Failed to check initial ICE candidates: %v", err) + } + + ticker := time.NewTicker(cm.tickerPeriod) defer ticker.Stop() for { diff --git a/client/internal/peer/guard/sr_watcher.go b/client/internal/peer/guard/sr_watcher.go index 90e45426f..686430752 100644 --- a/client/internal/peer/guard/sr_watcher.go +++ b/client/internal/peer/guard/sr_watcher.go @@ -51,7 +51,7 @@ func (w *SRWatcher) Start() { ctx, cancel := context.WithCancel(context.Background()) w.cancelIceMonitor = cancel - iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig) + iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod()) go iceMonitor.Start(ctx, w.onICEChanged) w.signalClient.SetOnReconnectedListener(w.onReconnected) w.relayManager.SetOnReconnectedListener(w.onReconnected) diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go index 42eaea683..aff26f847 100644 --- a/client/internal/peer/handshaker.go +++ b/client/internal/peer/handshaker.go @@ -44,13 +44,19 @@ type OfferAnswer struct { } type Handshaker struct { - mu sync.Mutex - log *log.Entry - config ConnConfig - signaler *Signaler - ice *WorkerICE - relay *WorkerRelay - onNewOfferListeners []*OfferListener + mu sync.Mutex + log *log.Entry + config ConnConfig + signaler *Signaler + ice *WorkerICE + relay *WorkerRelay + // relayListener is not blocking because the listener is using a goroutine to process the messages + // and it will only keep the latest message if multiple offers are received in a short time + // this is to avoid blocking the handshaker if the listener is doing some heavy processing + // and also to avoid processing old offers if multiple offers are received in a short time + // the listener will always process the latest offer + relayListener *AsyncOfferListener + iceListener func(remoteOfferAnswer *OfferAnswer) // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection remoteOffersCh chan OfferAnswer @@ -70,28 +76,39 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W } } -func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) { - l := NewOfferListener(offer) - h.onNewOfferListeners = append(h.onNewOfferListeners, l) +func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) { + h.relayListener = NewAsyncOfferListener(offer) +} + +func (h *Handshaker) AddICEListener(offer func(remoteOfferAnswer *OfferAnswer)) { + h.iceListener = offer } func (h *Handshaker) Listen(ctx context.Context) { for { select { case remoteOfferAnswer := <-h.remoteOffersCh: - // received confirmation from the remote peer -> ready to proceed + h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) + if h.relayListener != nil { + h.relayListener.Notify(&remoteOfferAnswer) + } + + if h.iceListener != nil { + h.iceListener(&remoteOfferAnswer) + } + if err := h.sendAnswer(); err != nil { h.log.Errorf("failed to send remote offer confirmation: %s", err) continue } - for _, listener := range h.onNewOfferListeners { - listener.Notify(&remoteOfferAnswer) - } - h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) case remoteOfferAnswer := <-h.remoteAnswerCh: h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) - for _, listener := range h.onNewOfferListeners { - listener.Notify(&remoteOfferAnswer) + if h.relayListener != nil { + h.relayListener.Notify(&remoteOfferAnswer) + } + + if h.iceListener != nil { + h.iceListener(&remoteOfferAnswer) } case <-ctx.Done(): h.log.Infof("stop listening for remote offers and answers") diff --git a/client/internal/peer/handshaker_listener.go b/client/internal/peer/handshaker_listener.go index e2d3f3f38..772e2777f 100644 --- a/client/internal/peer/handshaker_listener.go +++ b/client/internal/peer/handshaker_listener.go @@ -13,20 +13,20 @@ func (oa *OfferAnswer) SessionIDString() string { return oa.SessionID.String() } -type OfferListener struct { +type AsyncOfferListener struct { fn callbackFunc running bool latest *OfferAnswer mu sync.Mutex } -func NewOfferListener(fn callbackFunc) *OfferListener { - return &OfferListener{ +func NewAsyncOfferListener(fn callbackFunc) *AsyncOfferListener { + return &AsyncOfferListener{ fn: fn, } } -func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) { +func (o *AsyncOfferListener) Notify(remoteOfferAnswer *OfferAnswer) { o.mu.Lock() defer o.mu.Unlock() diff --git a/client/internal/peer/handshaker_listener_test.go b/client/internal/peer/handshaker_listener_test.go index 8363741a5..1a7156d10 100644 --- a/client/internal/peer/handshaker_listener_test.go +++ b/client/internal/peer/handshaker_listener_test.go @@ -14,7 +14,7 @@ func Test_newOfferListener(t *testing.T) { runChan <- struct{}{} } - hl := NewOfferListener(longRunningFn) + hl := NewAsyncOfferListener(longRunningFn) hl.Notify(dummyOfferAnswer) hl.Notify(dummyOfferAnswer) diff --git a/client/internal/peer/iface.go b/client/internal/peer/iface.go index 0bcc7a68e..678396e61 100644 --- a/client/internal/peer/iface.go +++ b/client/internal/peer/iface.go @@ -18,4 +18,5 @@ type WGIface interface { GetStats() (map[string]configurer.WGStats, error) GetProxy() wgproxy.Proxy Address() wgaddr.Address + RemoveEndpointAddress(key string) error } diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index eb886a4d3..3675f0157 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -92,23 +92,16 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn * func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString()) w.muxAgent.Lock() + defer w.muxAgent.Unlock() - if w.agentConnecting { - w.log.Debugf("agent connection is in progress, skipping the offer") - w.muxAgent.Unlock() - return - } - - if w.agent != nil { + if w.agent != nil || w.agentConnecting { // backward compatibility with old clients that do not send session ID if remoteOfferAnswer.SessionID == nil { w.log.Debugf("agent already exists, skipping the offer") - w.muxAgent.Unlock() return } if w.remoteSessionID == *remoteOfferAnswer.SessionID { w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString()) - w.muxAgent.Unlock() return } w.log.Debugf("agent already exists, recreate the connection") @@ -116,6 +109,12 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { if err := w.agent.Close(); err != nil { w.log.Warnf("failed to close ICE agent: %s", err) } + + sessionID, err := NewICESessionID() + if err != nil { + w.log.Errorf("failed to create new session ID: %s", err) + } + w.sessionID = sessionID w.agent = nil } @@ -126,18 +125,23 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { preferredCandidateTypes = icemaker.CandidateTypes() } - w.log.Debugf("recreate ICE agent") + if remoteOfferAnswer.SessionID != nil { + w.log.Debugf("recreate ICE agent: %s / %s", w.sessionID, *remoteOfferAnswer.SessionID) + } dialerCtx, dialerCancel := context.WithCancel(w.ctx) agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes) if err != nil { w.log.Errorf("failed to recreate ICE Agent: %s", err) - w.muxAgent.Unlock() return } w.agent = agent w.agentDialerCancel = dialerCancel w.agentConnecting = true - w.muxAgent.Unlock() + if remoteOfferAnswer.SessionID != nil { + w.remoteSessionID = *remoteOfferAnswer.SessionID + } else { + w.remoteSessionID = "" + } go w.connect(dialerCtx, agent, remoteOfferAnswer) } @@ -293,9 +297,6 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent w.muxAgent.Lock() w.agentConnecting = false w.lastSuccess = time.Now() - if remoteOfferAnswer.SessionID != nil { - w.remoteSessionID = *remoteOfferAnswer.SessionID - } w.muxAgent.Unlock() // todo: the potential problem is a race between the onConnectionStateChange @@ -309,16 +310,17 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C } w.muxAgent.Lock() - // todo review does it make sense to generate new session ID all the time when w.agent==agent - sessionID, err := NewICESessionID() - if err != nil { - w.log.Errorf("failed to create new session ID: %s", err) - } - w.sessionID = sessionID if w.agent == agent { + // consider to remove from here and move to the OnNewOffer + sessionID, err := NewICESessionID() + if err != nil { + w.log.Errorf("failed to create new session ID: %s", err) + } + w.sessionID = sessionID w.agent = nil w.agentConnecting = false + w.remoteSessionID = "" } w.muxAgent.Unlock() } @@ -395,11 +397,12 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia // ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to // notify the conn.onICEStateDisconnected changes to update the current used priority + w.closeAgent(agent, dialerCancel) + if w.lastKnownState == ice.ConnectionStateConnected { w.lastKnownState = ice.ConnectionStateDisconnected w.conn.onICEStateDisconnected() } - w.closeAgent(agent, dialerCancel) default: return } diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index 4e6b422f6..f03822089 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -195,6 +195,7 @@ func createNewConfig(input ConfigInput) (*Config, error) { config := &Config{ // defaults to false only for new (post 0.26) configurations ServerSSHAllowed: util.False(), + WgPort: iface.DefaultWgPort, } if _, err := config.apply(input); err != nil { diff --git a/client/internal/profilemanager/config_test.go b/client/internal/profilemanager/config_test.go index 45e37bf0e..90bde7707 100644 --- a/client/internal/profilemanager/config_test.go +++ b/client/internal/profilemanager/config_test.go @@ -5,11 +5,14 @@ import ( "errors" "os" "path/filepath" + "runtime" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/util" ) @@ -141,6 +144,95 @@ func TestHiddenPreSharedKey(t *testing.T) { } } +func TestNewProfileDefaults(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + config, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: configPath, + }) + require.NoError(t, err, "should create new config") + + assert.Equal(t, DefaultManagementURL, config.ManagementURL.String(), "ManagementURL should have default") + assert.Equal(t, DefaultAdminURL, config.AdminURL.String(), "AdminURL should have default") + assert.NotEmpty(t, config.PrivateKey, "PrivateKey should be generated") + assert.NotEmpty(t, config.SSHKey, "SSHKey should be generated") + assert.Equal(t, iface.WgInterfaceDefault, config.WgIface, "WgIface should have default") + assert.Equal(t, iface.DefaultWgPort, config.WgPort, "WgPort should default to 51820") + assert.Equal(t, uint16(iface.DefaultMTU), config.MTU, "MTU should have default") + assert.Equal(t, dynamic.DefaultInterval, config.DNSRouteInterval, "DNSRouteInterval should have default") + assert.NotNil(t, config.ServerSSHAllowed, "ServerSSHAllowed should be set") + assert.NotNil(t, config.DisableNotifications, "DisableNotifications should be set") + assert.NotEmpty(t, config.IFaceBlackList, "IFaceBlackList should have defaults") + + if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { + assert.NotNil(t, config.NetworkMonitor, "NetworkMonitor should be set on Windows/macOS") + assert.True(t, *config.NetworkMonitor, "NetworkMonitor should be enabled by default on Windows/macOS") + } +} + +func TestWireguardPortZeroExplicit(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + // Create a new profile with explicit port 0 (random port) + explicitZero := 0 + config, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: configPath, + WireguardPort: &explicitZero, + }) + require.NoError(t, err, "should create config with explicit port 0") + + assert.Equal(t, 0, config.WgPort, "WgPort should be 0 when explicitly set by user") + + // Verify it persists + readConfig, err := GetConfig(configPath) + require.NoError(t, err) + assert.Equal(t, 0, readConfig.WgPort, "WgPort should remain 0 after reading from file") +} + +func TestWireguardPortDefaultVsExplicit(t *testing.T) { + tests := []struct { + name string + wireguardPort *int + expectedPort int + description string + }{ + { + name: "no port specified uses default", + wireguardPort: nil, + expectedPort: iface.DefaultWgPort, + description: "When user doesn't specify port, default to 51820", + }, + { + name: "explicit zero for random port", + wireguardPort: func() *int { v := 0; return &v }(), + expectedPort: 0, + description: "When user explicitly sets 0, use 0 for random port", + }, + { + name: "explicit custom port", + wireguardPort: func() *int { v := 52000; return &v }(), + expectedPort: 52000, + description: "When user sets custom port, use that port", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + config, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: configPath, + WireguardPort: tt.wireguardPort, + }) + require.NoError(t, err, tt.description) + assert.Equal(t, tt.expectedPort, config.WgPort, tt.description) + }) + } +} + func TestUpdateOldManagementURL(t *testing.T) { tests := []struct { name string diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index fa208716f..693ea1f31 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -2,6 +2,8 @@ package relay import ( "context" + "crypto/sha256" + "errors" "fmt" "net" "sync" @@ -15,6 +17,15 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) +const ( + DefaultCacheTTL = 20 * time.Second + probeTimeout = 6 * time.Second +) + +var ( + ErrCheckInProgress = errors.New("probe check is already in progress") +) + // ProbeResult holds the info about the result of a relay probe request type ProbeResult struct { URI string @@ -22,8 +33,164 @@ type ProbeResult struct { Addr string } +type StunTurnProbe struct { + cacheResults []ProbeResult + cacheTimestamp time.Time + cacheKey string + cacheTTL time.Duration + probeInProgress bool + probeDone chan struct{} + mu sync.Mutex +} + +func NewStunTurnProbe(cacheTTL time.Duration) *StunTurnProbe { + return &StunTurnProbe{ + cacheTTL: cacheTTL, + } +} + +func (p *StunTurnProbe) ProbeAllWaitResult(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + cacheKey := generateCacheKey(stuns, turns) + + p.mu.Lock() + if p.probeInProgress { + doneChan := p.probeDone + p.mu.Unlock() + + select { + case <-ctx.Done(): + log.Debugf("Context cancelled while waiting for probe results") + return createErrorResults(stuns, turns) + case <-doneChan: + return p.getCachedResults(cacheKey, stuns, turns) + } + } + + p.probeInProgress = true + probeDone := make(chan struct{}) + p.probeDone = probeDone + p.mu.Unlock() + + p.doProbe(ctx, stuns, turns, cacheKey) + close(probeDone) + + return p.getCachedResults(cacheKey, stuns, turns) +} + +// ProbeAll probes all given servers asynchronously and returns the results +func (p *StunTurnProbe) ProbeAll(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + cacheKey := generateCacheKey(stuns, turns) + + p.mu.Lock() + + if results := p.checkCache(cacheKey); results != nil { + p.mu.Unlock() + return results + } + + if p.probeInProgress { + p.mu.Unlock() + return createErrorResults(stuns, turns) + } + + p.probeInProgress = true + probeDone := make(chan struct{}) + p.probeDone = probeDone + log.Infof("started new probe for STUN, TURN servers") + go func() { + p.doProbe(ctx, stuns, turns, cacheKey) + close(probeDone) + }() + + p.mu.Unlock() + + timer := time.NewTimer(1300 * time.Millisecond) + defer timer.Stop() + + select { + case <-ctx.Done(): + log.Debugf("Context cancelled while waiting for probe results") + return createErrorResults(stuns, turns) + case <-probeDone: + // when the probe is return fast, return the results right away + return p.getCachedResults(cacheKey, stuns, turns) + case <-timer.C: + // if the probe takes longer than 1.3s, return error results to avoid blocking + return createErrorResults(stuns, turns) + } +} + +func (p *StunTurnProbe) checkCache(cacheKey string) []ProbeResult { + if p.cacheKey == cacheKey && len(p.cacheResults) > 0 { + age := time.Since(p.cacheTimestamp) + if age < p.cacheTTL { + results := append([]ProbeResult(nil), p.cacheResults...) + log.Debugf("returning cached probe results (age: %v)", age) + return results + } + } + return nil +} + +func (p *StunTurnProbe) getCachedResults(cacheKey string, stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + p.mu.Lock() + defer p.mu.Unlock() + + if p.cacheKey == cacheKey && len(p.cacheResults) > 0 { + return append([]ProbeResult(nil), p.cacheResults...) + } + return createErrorResults(stuns, turns) +} + +func (p *StunTurnProbe) doProbe(ctx context.Context, stuns []*stun.URI, turns []*stun.URI, cacheKey string) { + defer func() { + p.mu.Lock() + p.probeInProgress = false + p.mu.Unlock() + }() + results := make([]ProbeResult, len(stuns)+len(turns)) + + var wg sync.WaitGroup + for i, uri := range stuns { + wg.Add(1) + go func(idx int, stunURI *stun.URI) { + defer wg.Done() + + probeCtx, cancel := context.WithTimeout(ctx, probeTimeout) + defer cancel() + + results[idx].URI = stunURI.String() + results[idx].Addr, results[idx].Err = p.probeSTUN(probeCtx, stunURI) + }(i, uri) + } + + stunOffset := len(stuns) + for i, uri := range turns { + wg.Add(1) + go func(idx int, turnURI *stun.URI) { + defer wg.Done() + + probeCtx, cancel := context.WithTimeout(ctx, probeTimeout) + defer cancel() + + results[idx].URI = turnURI.String() + results[idx].Addr, results[idx].Err = p.probeTURN(probeCtx, turnURI) + }(stunOffset+i, uri) + } + + wg.Wait() + + p.mu.Lock() + p.cacheResults = results + p.cacheTimestamp = time.Now() + p.cacheKey = cacheKey + p.mu.Unlock() + + log.Debug("Stored new probe results in cache") +} + // ProbeSTUN tries binding to the given STUN uri and acquiring an address -func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { +func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { defer func() { if probeErr != nil { log.Debugf("stun probe error from %s: %s", uri, probeErr) @@ -83,7 +250,7 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) } // ProbeTURN tries allocating a session from the given TURN URI -func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { +func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { defer func() { if probeErr != nil { log.Debugf("turn probe error from %s: %s", uri, probeErr) @@ -160,28 +327,28 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) return relayConn.LocalAddr().String(), nil } -// ProbeAll probes all given servers asynchronously and returns the results -func ProbeAll( - ctx context.Context, - fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error), - relays []*stun.URI, -) []ProbeResult { - results := make([]ProbeResult, len(relays)) +func createErrorResults(stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + total := len(stuns) + len(turns) + results := make([]ProbeResult, total) - var wg sync.WaitGroup - for i, uri := range relays { - ctx, cancel := context.WithTimeout(ctx, 6*time.Second) - defer cancel() - - wg.Add(1) - go func(res *ProbeResult, stunURI *stun.URI) { - defer wg.Done() - res.URI = stunURI.String() - res.Addr, res.Err = fn(ctx, stunURI) - }(&results[i], uri) + allURIs := append(append([]*stun.URI{}, stuns...), turns...) + for i, uri := range allURIs { + results[i] = ProbeResult{ + URI: uri.String(), + Err: ErrCheckInProgress, + } } - wg.Wait() - return results } + +func generateCacheKey(stuns []*stun.URI, turns []*stun.URI) string { + h := sha256.New() + for _, uri := range stuns { + h.Write([]byte(uri.String())) + } + for _, uri := range turns { + h.Write([]byte(uri.String())) + } + return fmt.Sprintf("%x", h.Sum(nil)) +} diff --git a/client/internal/routemanager/common/params.go b/client/internal/routemanager/common/params.go index def18411f..8b5407850 100644 --- a/client/internal/routemanager/common/params.go +++ b/client/internal/routemanager/common/params.go @@ -1,6 +1,7 @@ package common import ( + "sync/atomic" "time" "github.com/netbirdio/netbird/client/firewall/manager" @@ -25,4 +26,5 @@ type HandlerParams struct { UseNewDNSRoute bool Firewall manager.Manager FakeIPManager *fakeip.Manager + ForwarderPort *atomic.Uint32 } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 47c2ffcda..348338dac 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -8,6 +8,7 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "time" "github.com/hashicorp/go-multierror" @@ -18,7 +19,6 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" nbdns "github.com/netbirdio/netbird/client/internal/dns" - "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/common" @@ -55,6 +55,7 @@ type DnsInterceptor struct { peerStore *peerstore.Store firewall firewall.Manager fakeIPManager *fakeip.Manager + forwarderPort *atomic.Uint32 } func New(params common.HandlerParams) *DnsInterceptor { @@ -69,6 +70,7 @@ func New(params common.HandlerParams) *DnsInterceptor { firewall: params.Firewall, fakeIPManager: params.FakeIPManager, interceptedDomains: make(domainMap), + forwarderPort: params.ForwarderPort, } } @@ -257,7 +259,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { r.MsgHdr.AuthenticatedData = true } - upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort()) + upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), uint16(d.forwarderPort.Load())) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) defer cancel() diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 04513bbe4..37974cd17 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -10,6 +10,7 @@ import ( "runtime" "slices" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -23,6 +24,7 @@ import ( "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/client" @@ -54,6 +56,7 @@ type Manager interface { SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string SetFirewall(firewall.Manager) error + SetDNSForwarderPort(port uint16) Stop(stateManager *statemanager.Manager) } @@ -101,12 +104,13 @@ type DefaultManager struct { disableServerRoutes bool activeRoutes map[route.HAUniqueID]client.RouteHandler fakeIPManager *fakeip.Manager + dnsForwarderPort atomic.Uint32 } func NewManager(config ManagerConfig) *DefaultManager { mCTX, cancel := context.WithCancel(config.Context) notifier := notifier.NewNotifier() - sysOps := systemops.NewSysOps(config.WGInterface, notifier) + sysOps := systemops.New(config.WGInterface, notifier) if runtime.GOOS == "windows" && config.WGInterface != nil { nbnet.SetVPNInterfaceName(config.WGInterface.Name()) @@ -130,6 +134,7 @@ func NewManager(config ManagerConfig) *DefaultManager { disableServerRoutes: config.DisableServerRoutes, activeRoutes: make(map[route.HAUniqueID]client.RouteHandler), } + dm.dnsForwarderPort.Store(uint32(nbdns.ForwarderClientPort)) useNoop := netstack.IsEnabled() || config.DisableClientRoutes dm.setupRefCounters(useNoop) @@ -270,6 +275,11 @@ func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error { return nil } +// SetDNSForwarderPort sets the DNS forwarder port for route handlers +func (m *DefaultManager) SetDNSForwarderPort(port uint16) { + m.dnsForwarderPort.Store(uint32(port)) +} + // Stop stops the manager watchers and clean firewall rules func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { m.stop() @@ -345,6 +355,7 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error { UseNewDNSRoute: m.useNewDNSRoute, Firewall: m.firewall, FakeIPManager: m.fakeIPManager, + ForwarderPort: &m.dnsForwarderPort, } handler := client.HandlerFromRoute(params) if err := handler.AddRoute(m.ctx); err != nil { diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index be633c3fa..6b06144b2 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -90,6 +90,10 @@ func (m *MockManager) SetFirewall(firewall.Manager) error { panic("implement me") } +// SetDNSForwarderPort mock implementation of SetDNSForwarderPort from Manager interface +func (m *MockManager) SetDNSForwarderPort(port uint16) { +} + // Stop mock implementation of Stop from Manager interface func (m *MockManager) Stop(stateManager *statemanager.Manager) { if m.StopFunc != nil { diff --git a/client/internal/routemanager/systemops/flush_nonbsd.go b/client/internal/routemanager/systemops/flush_nonbsd.go new file mode 100644 index 000000000..f1c45d6cf --- /dev/null +++ b/client/internal/routemanager/systemops/flush_nonbsd.go @@ -0,0 +1,8 @@ +//go:build !((darwin && !ios) || dragonfly || freebsd || netbsd || openbsd) + +package systemops + +// FlushMarkedRoutes is a no-op on non-BSD platforms. +func (r *SysOps) FlushMarkedRoutes() error { + return nil +} diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go index 8e158711e..e0d045b07 100644 --- a/client/internal/routemanager/systemops/state.go +++ b/client/internal/routemanager/systemops/state.go @@ -13,11 +13,11 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - sysops := NewSysOps(nil, nil) - sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) - sysops.refCounter.LoadData((*ExclusionCounter)(s)) + sysOps := New(nil, nil) + sysOps.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysOps.removeFromRouteTable) + sysOps.refCounter.LoadData((*ExclusionCounter)(s)) - return sysops.refCounter.Flush() + return sysOps.refCounter.Flush() } func (s *ShutdownState) MarshalJSON() ([]byte, error) { diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index 8da138117..c0ca21d22 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -83,7 +83,7 @@ type SysOps struct { localSubnetsCacheTime time.Time } -func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { +func New(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { return &SysOps{ wgInterface: wgInterface, notifier: notifier, diff --git a/client/internal/routemanager/systemops/systemops_bsd_test.go b/client/internal/routemanager/systemops/systemops_bsd_test.go index 0d892c162..ec4fc406e 100644 --- a/client/internal/routemanager/systemops/systemops_bsd_test.go +++ b/client/internal/routemanager/systemops/systemops_bsd_test.go @@ -42,7 +42,7 @@ func TestConcurrentRoutes(t *testing.T) { _, intf = setupDummyInterface(t) nexthop = Nexthop{netip.Addr{}, intf} - r := NewSysOps(nil, nil) + r := New(nil, nil) var wg sync.WaitGroup for i := 0; i < 1024; i++ { @@ -146,7 +146,7 @@ func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR strin nexthop := Nexthop{netip.Addr{}, netIntf} - r := NewSysOps(nil, nil) + r := New(nil, nil) err = r.addToRouteTable(prefix, nexthop) require.NoError(t, err, "Failed to add route to table") diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 32ea38a7a..d9b109beb 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) @@ -342,7 +342,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) @@ -486,7 +486,7 @@ func setupTestEnv(t *testing.T) { assert.NoError(t, wgInterface.Close()) }) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err, "setupRouting should not return err") diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index d43c2d5bf..7089178fb 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -7,19 +7,39 @@ import ( "fmt" "net" "net/netip" + "os" "strconv" "syscall" "time" "unsafe" "github.com/cenkalti/backoff/v4" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "golang.org/x/net/route" "golang.org/x/sys/unix" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" ) +const ( + envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG" +) + +var routeProtoFlag int + +func init() { + switch os.Getenv(envRouteProtoFlag) { + case "2": + routeProtoFlag = unix.RTF_PROTO2 + case "3": + routeProtoFlag = unix.RTF_PROTO3 + default: + routeProtoFlag = unix.RTF_PROTO1 + } +} + func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { return r.setupRefCounter(initAddresses, stateManager) } @@ -28,6 +48,62 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRout return r.cleanupRefCounter(stateManager) } +// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag. +func (r *SysOps) FlushMarkedRoutes() error { + rib, err := retryFetchRIB() + if err != nil { + return fmt.Errorf("fetch routing table: %w", err) + } + + msgs, err := route.ParseRIB(route.RIBTypeRoute, rib) + if err != nil { + return fmt.Errorf("parse routing table: %w", err) + } + + var merr *multierror.Error + flushedCount := 0 + + for _, msg := range msgs { + rtMsg, ok := msg.(*route.RouteMessage) + if !ok { + continue + } + + if rtMsg.Flags&routeProtoFlag == 0 { + continue + } + + routeInfo, err := MsgToRoute(rtMsg) + if err != nil { + log.Debugf("Skipping route flush: %v", err) + continue + } + + if !routeInfo.Dst.IsValid() || !routeInfo.Dst.IsSingleIP() { + continue + } + + nexthop := Nexthop{ + IP: routeInfo.Gw, + Intf: routeInfo.Interface, + } + + if err := r.removeFromRouteTable(routeInfo.Dst, nexthop); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", routeInfo.Dst, err)) + continue + } + + flushedCount++ + log.Debugf("Flushed marked route: %s", routeInfo.Dst) + } + + if flushedCount > 0 { + log.Infof("Flushed %d residual NetBird routes from previous session", flushedCount) + } + + return nberrors.FormatErrorOrNil(merr) +} + func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { return r.routeSocket(unix.RTM_ADD, prefix, nexthop) } @@ -105,7 +181,7 @@ func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func( func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) { msg = &route.RouteMessage{ Type: action, - Flags: unix.RTF_UP, + Flags: unix.RTF_UP | routeProtoFlag, Version: unix.RTM_VERSION, Seq: r.getSeq(), } diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 29f962ad2..2c9e46290 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -295,7 +295,7 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage, data, err := os.ReadFile(m.filePath) if err != nil { if errors.Is(err, fs.ErrNotExist) { - log.Debug("state file does not exist") + log.Debugf("state file %s does not exist", m.filePath) return nil, nil // nolint:nilnil } return nil, fmt.Errorf("read state file: %w", err) diff --git a/client/internal/winregistry/volatile_windows.go b/client/internal/winregistry/volatile_windows.go new file mode 100644 index 000000000..a8e350fe7 --- /dev/null +++ b/client/internal/winregistry/volatile_windows.go @@ -0,0 +1,59 @@ +package winregistry + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows/registry" +) + +var ( + advapi = syscall.NewLazyDLL("advapi32.dll") + regCreateKeyExW = advapi.NewProc("RegCreateKeyExW") +) + +const ( + // Registry key options + regOptionNonVolatile = 0x0 // Key is preserved when system is rebooted + regOptionVolatile = 0x1 // Key is not preserved when system is rebooted + + // Registry disposition values + regCreatedNewKey = 0x1 + regOpenedExistingKey = 0x2 +) + +// CreateVolatileKey creates a volatile registry key named path under open key root. +// CreateVolatileKey returns the new key and a boolean flag that reports whether the key already existed. +// The access parameter specifies the access rights for the key to be created. +// +// Volatile keys are stored in memory and are automatically deleted when the system is shut down. +// This provides automatic cleanup without requiring manual registry maintenance. +func CreateVolatileKey(root registry.Key, path string, access uint32) (registry.Key, bool, error) { + pathPtr, err := syscall.UTF16PtrFromString(path) + if err != nil { + return 0, false, err + } + + var ( + handle syscall.Handle + disposition uint32 + ) + + ret, _, _ := regCreateKeyExW.Call( + uintptr(root), + uintptr(unsafe.Pointer(pathPtr)), + 0, // reserved + 0, // class + uintptr(regOptionVolatile), // options - volatile key + uintptr(access), // desired access + 0, // security attributes + uintptr(unsafe.Pointer(&handle)), + uintptr(unsafe.Pointer(&disposition)), + ) + + if ret != 0 { + return 0, false, syscall.Errno(ret) + } + + return registry.Key(handle), disposition == regOpenedExistingKey, nil +} diff --git a/client/net/conn.go b/client/net/conn.go index 918e7f628..bf54c792d 100644 --- a/client/net/conn.go +++ b/client/net/conn.go @@ -17,8 +17,7 @@ type Conn struct { ID hooks.ConnectionID } -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection. func (c *Conn) Close() error { return closeConn(c.ID, c.Conn) } @@ -29,7 +28,7 @@ type TCPConn struct { ID hooks.ConnectionID } -// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.TCPConn Close method to execute all registered hooks after closing the connection. func (c *TCPConn) Close() error { return closeConn(c.ID, c.TCPConn) } @@ -37,13 +36,16 @@ func (c *TCPConn) Close() error { // closeConn is a helper function to close connections and execute close hooks. func closeConn(id hooks.ConnectionID, conn io.Closer) error { err := conn.Close() + cleanupConnID(id) + return err +} +// cleanupConnID executes close hooks for a connection ID. +func cleanupConnID(id hooks.ConnectionID) { closeHooks := hooks.GetCloseHooks() for _, hook := range closeHooks { if err := hook(id); err != nil { log.Errorf("Error executing close hook: %v", err) } } - - return err } diff --git a/client/net/dial.go b/client/net/dial.go index 041a00e5d..17c9ff98a 100644 --- a/client/net/dial.go +++ b/client/net/dial.go @@ -74,7 +74,6 @@ func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, erro } return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil } - if err := conn.Close(); err != nil { log.Errorf("failed to close connection: %v", err) } diff --git a/client/net/dialer_dial.go b/client/net/dialer_dial.go index 2e1eb53d8..1e275013f 100644 --- a/client/net/dialer_dial.go +++ b/client/net/dialer_dial.go @@ -30,6 +30,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. conn, err := d.Dialer.DialContext(ctx, network, address) if err != nil { + cleanupConnID(connID) return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) } @@ -64,7 +65,7 @@ func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address str ips, err := resolver.LookupIPAddr(ctx, host) if err != nil { - return fmt.Errorf("failed to resolve address %s: %w", address, err) + return fmt.Errorf("resolve address %s: %w", address, err) } log.Debugf("Dialer resolved IPs for %s: %v", address, ips) diff --git a/client/net/listener_listen.go b/client/net/listener_listen.go index 0bb5ad67d..a150172b4 100644 --- a/client/net/listener_listen.go +++ b/client/net/listener_listen.go @@ -48,7 +48,7 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { return c.PacketConn.WriteTo(b, addr) } -// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.PacketConn Close method to execute all registered hooks after closing the connection. func (c *PacketConn) Close() error { defer c.seenAddrs.Clear() return closeConn(c.ID, c.PacketConn) @@ -69,7 +69,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { return c.UDPConn.WriteTo(b, addr) } -// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.UDPConn Close method to execute all registered hooks after closing the connection. func (c *UDPConn) Close() error { defer c.seenAddrs.Clear() return closeConn(c.ID, c.UDPConn) diff --git a/client/server/server.go b/client/server/server.go index e6de608c5..3641e6f92 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -353,6 +353,13 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques config.CustomDNSAddress = []byte{} } + config.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist + + if msg.DnsRouteInterval != nil { + interval := msg.DnsRouteInterval.AsDuration() + config.DNSRouteInterval = &interval + } + config.RosenpassEnabled = msg.RosenpassEnabled config.RosenpassPermissive = msg.RosenpassPermissive config.DisableAutoConnect = msg.DisableAutoConnect @@ -1050,10 +1057,7 @@ func (s *Server) Status( s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) if msg.GetFullPeerStatus { - if msg.ShouldRunProbes { - s.runProbes() - } - + s.runProbes(msg.ShouldRunProbes) fullStatus := s.statusRecorder.GetFullStatus() pbFullStatus := toProtoFullStatus(fullStatus) pbFullStatus.Events = s.statusRecorder.GetEventHistory() @@ -1063,7 +1067,7 @@ func (s *Server) Status( return &statusResponse, nil } -func (s *Server) runProbes() { +func (s *Server) runProbes(waitForProbeResult bool) { if s.connectClient == nil { return } @@ -1074,7 +1078,7 @@ func (s *Server) runProbes() { } if time.Since(s.lastProbe) > probeThreshold { - if engine.RunHealthProbes() { + if engine.RunHealthProbes(waitForProbeResult) { s.lastProbe = time.Now() } } diff --git a/client/server/setconfig_test.go b/client/server/setconfig_test.go new file mode 100644 index 000000000..1260bcc78 --- /dev/null +++ b/client/server/setconfig_test.go @@ -0,0 +1,298 @@ +package server + +import ( + "context" + "os/user" + "path/filepath" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" + + "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/proto" +) + +// TestSetConfig_AllFieldsSaved ensures that all fields in SetConfigRequest are properly saved to the config. +// This test uses reflection to detect when new fields are added but not handled in SetConfig. +func TestSetConfig_AllFieldsSaved(t *testing.T) { + tempDir := t.TempDir() + origDefaultProfileDir := profilemanager.DefaultConfigPathDir + origDefaultConfigPath := profilemanager.DefaultConfigPath + origActiveProfileStatePath := profilemanager.ActiveProfileStatePath + profilemanager.ConfigDirOverride = tempDir + profilemanager.DefaultConfigPathDir = tempDir + profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json" + profilemanager.DefaultConfigPath = filepath.Join(tempDir, "default.json") + t.Cleanup(func() { + profilemanager.DefaultConfigPathDir = origDefaultProfileDir + profilemanager.ActiveProfileStatePath = origActiveProfileStatePath + profilemanager.DefaultConfigPath = origDefaultConfigPath + profilemanager.ConfigDirOverride = "" + }) + + currUser, err := user.Current() + require.NoError(t, err) + + profName := "test-profile" + + ic := profilemanager.ConfigInput{ + ConfigPath: filepath.Join(tempDir, profName+".json"), + ManagementURL: "https://api.netbird.io:443", + } + _, err = profilemanager.UpdateOrCreateConfig(ic) + require.NoError(t, err) + + pm := profilemanager.ServiceManager{} + err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{ + Name: profName, + Username: currUser.Username, + }) + require.NoError(t, err) + + ctx := context.Background() + s := New(ctx, "console", "", false, false) + + rosenpassEnabled := true + rosenpassPermissive := true + serverSSHAllowed := true + interfaceName := "utun100" + wireguardPort := int64(51820) + preSharedKey := "test-psk" + disableAutoConnect := true + networkMonitor := true + disableClientRoutes := true + disableServerRoutes := true + disableDNS := true + disableFirewall := true + blockLANAccess := true + disableNotifications := true + lazyConnectionEnabled := true + blockInbound := true + mtu := int64(1280) + + req := &proto.SetConfigRequest{ + ProfileName: profName, + Username: currUser.Username, + ManagementUrl: "https://new-api.netbird.io:443", + AdminURL: "https://new-admin.netbird.io", + RosenpassEnabled: &rosenpassEnabled, + RosenpassPermissive: &rosenpassPermissive, + ServerSSHAllowed: &serverSSHAllowed, + InterfaceName: &interfaceName, + WireguardPort: &wireguardPort, + OptionalPreSharedKey: &preSharedKey, + DisableAutoConnect: &disableAutoConnect, + NetworkMonitor: &networkMonitor, + DisableClientRoutes: &disableClientRoutes, + DisableServerRoutes: &disableServerRoutes, + DisableDns: &disableDNS, + DisableFirewall: &disableFirewall, + BlockLanAccess: &blockLANAccess, + DisableNotifications: &disableNotifications, + LazyConnectionEnabled: &lazyConnectionEnabled, + BlockInbound: &blockInbound, + NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"}, + CleanNATExternalIPs: false, + CustomDNSAddress: []byte("1.1.1.1:53"), + ExtraIFaceBlacklist: []string{"eth1", "eth2"}, + DnsLabels: []string{"label1", "label2"}, + CleanDNSLabels: false, + DnsRouteInterval: durationpb.New(2 * time.Minute), + Mtu: &mtu, + } + + _, err = s.SetConfig(ctx, req) + require.NoError(t, err) + + profState := profilemanager.ActiveProfileState{ + Name: profName, + Username: currUser.Username, + } + cfgPath, err := profState.FilePath() + require.NoError(t, err) + + cfg, err := profilemanager.GetConfig(cfgPath) + require.NoError(t, err) + + require.Equal(t, "https://new-api.netbird.io:443", cfg.ManagementURL.String()) + require.Equal(t, "https://new-admin.netbird.io:443", cfg.AdminURL.String()) + require.Equal(t, rosenpassEnabled, cfg.RosenpassEnabled) + require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive) + require.NotNil(t, cfg.ServerSSHAllowed) + require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed) + require.Equal(t, interfaceName, cfg.WgIface) + require.Equal(t, int(wireguardPort), cfg.WgPort) + require.Equal(t, preSharedKey, cfg.PreSharedKey) + require.Equal(t, disableAutoConnect, cfg.DisableAutoConnect) + require.NotNil(t, cfg.NetworkMonitor) + require.Equal(t, networkMonitor, *cfg.NetworkMonitor) + require.Equal(t, disableClientRoutes, cfg.DisableClientRoutes) + require.Equal(t, disableServerRoutes, cfg.DisableServerRoutes) + require.Equal(t, disableDNS, cfg.DisableDNS) + require.Equal(t, disableFirewall, cfg.DisableFirewall) + require.Equal(t, blockLANAccess, cfg.BlockLANAccess) + require.NotNil(t, cfg.DisableNotifications) + require.Equal(t, disableNotifications, *cfg.DisableNotifications) + require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled) + require.Equal(t, blockInbound, cfg.BlockInbound) + require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs) + require.Equal(t, "1.1.1.1:53", cfg.CustomDNSAddress) + // IFaceBlackList contains defaults + extras + require.Contains(t, cfg.IFaceBlackList, "eth1") + require.Contains(t, cfg.IFaceBlackList, "eth2") + require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList()) + require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval) + require.Equal(t, uint16(mtu), cfg.MTU) + + verifyAllFieldsCovered(t, req) +} + +// verifyAllFieldsCovered uses reflection to ensure we're testing all fields in SetConfigRequest. +// If a new field is added to SetConfigRequest, this function will fail the test, +// forcing the developer to update both the SetConfig handler and this test. +func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) { + t.Helper() + + metadataFields := map[string]bool{ + "state": true, // protobuf internal + "sizeCache": true, // protobuf internal + "unknownFields": true, // protobuf internal + "Username": true, // metadata + "ProfileName": true, // metadata + "CleanNATExternalIPs": true, // control flag for clearing + "CleanDNSLabels": true, // control flag for clearing + } + + expectedFields := map[string]bool{ + "ManagementUrl": true, + "AdminURL": true, + "RosenpassEnabled": true, + "RosenpassPermissive": true, + "ServerSSHAllowed": true, + "InterfaceName": true, + "WireguardPort": true, + "OptionalPreSharedKey": true, + "DisableAutoConnect": true, + "NetworkMonitor": true, + "DisableClientRoutes": true, + "DisableServerRoutes": true, + "DisableDns": true, + "DisableFirewall": true, + "BlockLanAccess": true, + "DisableNotifications": true, + "LazyConnectionEnabled": true, + "BlockInbound": true, + "NatExternalIPs": true, + "CustomDNSAddress": true, + "ExtraIFaceBlacklist": true, + "DnsLabels": true, + "DnsRouteInterval": true, + "Mtu": true, + } + + val := reflect.ValueOf(req).Elem() + typ := val.Type() + + var unexpectedFields []string + for i := 0; i < val.NumField(); i++ { + field := typ.Field(i) + fieldName := field.Name + + if metadataFields[fieldName] { + continue + } + + if !expectedFields[fieldName] { + unexpectedFields = append(unexpectedFields, fieldName) + } + } + + if len(unexpectedFields) > 0 { + t.Fatalf("New field(s) detected in SetConfigRequest: %v", unexpectedFields) + } +} + +// TestCLIFlags_MappedToSetConfig ensures all CLI flags that modify config are properly mapped to SetConfigRequest. +// This test catches bugs where a new CLI flag is added but not wired to the SetConfigRequest in setupSetConfigReq. +func TestCLIFlags_MappedToSetConfig(t *testing.T) { + // Map of CLI flag names to their corresponding SetConfigRequest field names. + // This map must be updated when adding new config-related CLI flags. + flagToField := map[string]string{ + "management-url": "ManagementUrl", + "admin-url": "AdminURL", + "enable-rosenpass": "RosenpassEnabled", + "rosenpass-permissive": "RosenpassPermissive", + "allow-server-ssh": "ServerSSHAllowed", + "interface-name": "InterfaceName", + "wireguard-port": "WireguardPort", + "preshared-key": "OptionalPreSharedKey", + "disable-auto-connect": "DisableAutoConnect", + "network-monitor": "NetworkMonitor", + "disable-client-routes": "DisableClientRoutes", + "disable-server-routes": "DisableServerRoutes", + "disable-dns": "DisableDns", + "disable-firewall": "DisableFirewall", + "block-lan-access": "BlockLanAccess", + "block-inbound": "BlockInbound", + "enable-lazy-connection": "LazyConnectionEnabled", + "external-ip-map": "NatExternalIPs", + "dns-resolver-address": "CustomDNSAddress", + "extra-iface-blacklist": "ExtraIFaceBlacklist", + "extra-dns-labels": "DnsLabels", + "dns-router-interval": "DnsRouteInterval", + "mtu": "Mtu", + } + + // SetConfigRequest fields that don't have CLI flags (settable only via UI or other means). + fieldsWithoutCLIFlags := map[string]bool{ + "DisableNotifications": true, // Only settable via UI + } + + // Get all SetConfigRequest fields to verify our map is complete. + req := &proto.SetConfigRequest{} + val := reflect.ValueOf(req).Elem() + typ := val.Type() + + var unmappedFields []string + for i := 0; i < val.NumField(); i++ { + field := typ.Field(i) + fieldName := field.Name + + // Skip protobuf internal fields and metadata fields. + if fieldName == "state" || fieldName == "sizeCache" || fieldName == "unknownFields" { + continue + } + if fieldName == "Username" || fieldName == "ProfileName" { + continue + } + if fieldName == "CleanNATExternalIPs" || fieldName == "CleanDNSLabels" { + continue + } + + // Check if this field is either mapped to a CLI flag or explicitly documented as having no CLI flag. + mappedToCLI := false + for _, mappedField := range flagToField { + if mappedField == fieldName { + mappedToCLI = true + break + } + } + + hasNoCLIFlag := fieldsWithoutCLIFlags[fieldName] + + if !mappedToCLI && !hasNoCLIFlag { + unmappedFields = append(unmappedFields, fieldName) + } + } + + if len(unmappedFields) > 0 { + t.Fatalf("SetConfigRequest field(s) not documented: %v\n"+ + "Either add the CLI flag to flagToField map, or if there's no CLI flag for this field, "+ + "add it to fieldsWithoutCLIFlags map with a comment explaining why.", unmappedFields) + } + + t.Log("All SetConfigRequest fields are properly documented") +} diff --git a/client/server/state.go b/client/server/state.go index 107f55154..1cf85cd37 100644 --- a/client/server/state.go +++ b/client/server/state.go @@ -10,7 +10,9 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/client/proto" ) @@ -135,5 +137,12 @@ func restoreResidualState(ctx context.Context, statePath string) error { merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) } + // clean up any remaining routes independently of the state file + if !nbnet.AdvancedRouting() { + if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err)) + } + } + return nberrors.FormatErrorOrNil(merr) } diff --git a/client/status/status.go b/client/status/status.go index db5b7dc0b..8a0b7bae0 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal/peer" + probeRelay "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/version" @@ -205,15 +206,18 @@ func mapPeers( localICEEndpoint := "" remoteICEEndpoint := "" relayServerAddress := "" - connType := "P2P" + connType := "-" lastHandshake := time.Time{} transferReceived := int64(0) transferSent := int64(0) isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() - if pbPeerState.Relayed { - connType = "Relayed" + if isPeerConnected { + connType = "P2P" + if pbPeerState.Relayed { + connType = "Relayed" + } } if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) { @@ -337,10 +341,16 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, for _, relay := range overview.Relays.Details { available := "Available" reason := "" + if !relay.Available { - available = "Unavailable" - reason = fmt.Sprintf(", reason: %s", relay.Error) + if relay.Error == probeRelay.ErrCheckInProgress.Error() { + available = "Checking..." + } else { + available = "Unavailable" + reason = fmt.Sprintf(", reason: %s", relay.Error) + } } + relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason) } } else { diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 25d7380a9..f4350b251 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -31,7 +31,6 @@ import ( "fyne.io/systray" "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" - "github.com/skratchdot/open-golang/open" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -297,6 +296,8 @@ type serviceClient struct { mExitNodeDeselectAll *systray.MenuItem logFile string wLoginURL fyne.Window + + connectCancel context.CancelFunc } type menuHandler struct { @@ -593,17 +594,15 @@ func (s *serviceClient) getSettingsForm() *widget.Form { } } -func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { +func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { - log.Errorf("get client: %v", err) - return nil, err + return nil, fmt.Errorf("get daemon client: %w", err) } activeProf, err := s.profileManager.GetActiveProfile() if err != nil { - log.Errorf("get active profile: %v", err) - return nil, err + return nil, fmt.Errorf("get active profile: %w", err) } currUser, err := user.Current() @@ -611,84 +610,71 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { return nil, fmt.Errorf("get current user: %w", err) } - loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{ + loginResp, err := conn.Login(ctx, &proto.LoginRequest{ IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd", ProfileName: &activeProf.Name, Username: &currUser.Username, }) if err != nil { - log.Errorf("login to management URL with: %v", err) - return nil, err + return nil, fmt.Errorf("login to management: %w", err) } if loginResp.NeedsSSOLogin && openURL { - err = s.handleSSOLogin(loginResp, conn) - if err != nil { - log.Errorf("handle SSO login failed: %v", err) - return nil, err + if err = s.handleSSOLogin(ctx, loginResp, conn); err != nil { + return nil, fmt.Errorf("SSO login: %w", err) } } return loginResp, nil } -func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error { - err := open.Run(loginResp.VerificationURIComplete) - if err != nil { - log.Errorf("opening the verification uri in the browser failed: %v", err) - return err +func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error { + if err := openURL(loginResp.VerificationURIComplete); err != nil { + return fmt.Errorf("open browser: %w", err) } - resp, err := conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) + resp, err := conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) if err != nil { - log.Errorf("waiting sso login failed with: %v", err) - return err + return fmt.Errorf("wait for SSO login: %w", err) } if resp.Email != "" { - err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{ + if err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{ Email: resp.Email, - }) - if err != nil { - log.Warnf("failed to set profile state: %v", err) + }); err != nil { + log.Debugf("failed to set profile state: %v", err) } else { s.mProfile.refresh() } - } return nil } -func (s *serviceClient) menuUpClick() error { +func (s *serviceClient) menuUpClick(ctx context.Context) error { systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { systray.SetTemplateIcon(iconErrorMacOS, s.icError) - log.Errorf("get client: %v", err) - return err + return fmt.Errorf("get daemon client: %w", err) } - _, err = s.login(true) + _, err = s.login(ctx, true) if err != nil { - log.Errorf("login failed with: %v", err) - return err + return fmt.Errorf("login: %w", err) } - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + status, err := conn.Status(ctx, &proto.StatusRequest{}) if err != nil { - log.Errorf("get service status: %v", err) - return err + return fmt.Errorf("get status: %w", err) } if status.Status == string(internal.StatusConnected) { - log.Warnf("already connected") return nil } - if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil { - log.Errorf("up service: %v", err) - return err + if _, err := conn.Up(ctx, &proto.UpRequest{}); err != nil { + return fmt.Errorf("start connection: %w", err) } return nil @@ -698,24 +684,20 @@ func (s *serviceClient) menuDownClick() error { systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { - log.Errorf("get client: %v", err) - return err + return fmt.Errorf("get daemon client: %w", err) } status, err := conn.Status(s.ctx, &proto.StatusRequest{}) if err != nil { - log.Errorf("get service status: %v", err) - return err + return fmt.Errorf("get status: %w", err) } if status.Status != string(internal.StatusConnected) && status.Status != string(internal.StatusConnecting) { - log.Warnf("already down") return nil } - if _, err := s.conn.Down(s.ctx, &proto.DownRequest{}); err != nil { - log.Errorf("down service: %v", err) - return err + if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil { + return fmt.Errorf("stop connection: %w", err) } return nil @@ -851,6 +833,7 @@ func (s *serviceClient) onTrayReady() { newProfileMenuArgs := &newProfileMenuArgs{ ctx: s.ctx, + serviceClient: s, profileManager: s.profileManager, eventHandler: s.eventHandler, profileMenuItem: profileMenuItem, @@ -1354,7 +1337,13 @@ func (s *serviceClient) updateConfig() error { } // showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL. -func (s *serviceClient) showLoginURL() { +// It also starts a background goroutine that periodically checks if the client is already connected +// and closes the window if so. The goroutine can be cancelled by the returned CancelFunc, and it is +// also cancelled when the window is closed. +func (s *serviceClient) showLoginURL() context.CancelFunc { + + // create a cancellable context for the background check goroutine + ctx, cancel := context.WithCancel(s.ctx) resIcon := fyne.NewStaticResource("netbird.png", iconAbout) @@ -1363,6 +1352,8 @@ func (s *serviceClient) showLoginURL() { s.wLoginURL.Resize(fyne.NewSize(400, 200)) s.wLoginURL.SetIcon(resIcon) } + // ensure goroutine is cancelled when the window is closed + s.wLoginURL.SetOnClosed(func() { cancel() }) // add a description label label := widget.NewLabel("Your NetBird session has expired.\nPlease re-authenticate to continue using NetBird.") @@ -1374,7 +1365,7 @@ func (s *serviceClient) showLoginURL() { return } - resp, err := s.login(false) + resp, err := s.login(ctx, false) if err != nil { log.Errorf("failed to fetch login URL: %v", err) return @@ -1394,7 +1385,7 @@ func (s *serviceClient) showLoginURL() { return } - _, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode}) + _, err = conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode}) if err != nil { log.Errorf("Waiting sso login failed with: %v", err) label.SetText("Waiting login failed, please create \na debug bundle in the settings and contact support.") @@ -1402,7 +1393,7 @@ func (s *serviceClient) showLoginURL() { } label.SetText("Re-authentication successful.\nReconnecting") - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + status, err := conn.Status(ctx, &proto.StatusRequest{}) if err != nil { log.Errorf("get service status: %v", err) return @@ -1415,7 +1406,7 @@ func (s *serviceClient) showLoginURL() { return } - _, err = conn.Up(s.ctx, &proto.UpRequest{}) + _, err = conn.Up(ctx, &proto.UpRequest{}) if err != nil { label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.") log.Errorf("Reconnecting failed with: %v", err) @@ -1443,10 +1434,46 @@ func (s *serviceClient) showLoginURL() { ) s.wLoginURL.SetContent(container.NewCenter(content)) + // start a goroutine to check connection status and close the window if connected + go func() { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + conn, err := s.getSrvClient(failFastTimeout) + if err != nil { + return + } + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + if err != nil { + continue + } + if status.Status == string(internal.StatusConnected) { + if s.wLoginURL != nil { + s.wLoginURL.Close() + } + return + } + } + } + }() + s.wLoginURL.Show() + + // return cancel func so callers can stop the background goroutine if desired + return cancel } func openURL(url string) error { + if browser := os.Getenv("BROWSER"); browser != "" { + return exec.Command(browser, url).Start() + } + var err error switch runtime.GOOS { case "windows": diff --git a/client/ui/debug.go b/client/ui/debug.go index 76afc7753..bf9839dda 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -18,6 +18,7 @@ import ( "github.com/skratchdot/open-golang/open" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/proto" nbstatus "github.com/netbirdio/netbird/client/status" uptypes "github.com/netbirdio/netbird/upload-server/types" @@ -426,6 +427,12 @@ func (s *serviceClient) collectDebugData( return "", err } + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) if err != nil { log.Warnf("Failed to get post-up status: %v", err) @@ -433,7 +440,7 @@ func (s *serviceClient) collectDebugData( var postUpStatusOutput string if postUpStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName) postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) @@ -450,7 +457,7 @@ func (s *serviceClient) collectDebugData( var preDownStatusOutput string if preDownStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName) preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", @@ -574,6 +581,12 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa return nil, fmt.Errorf("get client: %v", err) } + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) if err != nil { log.Warnf("failed to get status for debug bundle: %v", err) @@ -581,7 +594,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa var statusOutput string if statusResp != nil { - overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName) statusOutput = nbstatus.ParseToFullDetailSummary(overview) } diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go index e9b7f4f30..e0b619411 100644 --- a/client/ui/event_handler.go +++ b/client/ui/event_handler.go @@ -12,6 +12,8 @@ import ( "fyne.io/fyne/v2" "fyne.io/systray" log "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/version" @@ -67,20 +69,55 @@ func (h *eventHandler) listen(ctx context.Context) { func (h *eventHandler) handleConnectClick() { h.client.mUp.Disable() + + if h.client.connectCancel != nil { + h.client.connectCancel() + } + + connectCtx, connectCancel := context.WithCancel(h.client.ctx) + h.client.connectCancel = connectCancel + go func() { - defer h.client.mUp.Enable() - if err := h.client.menuUpClick(); err != nil { - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service")) + defer connectCancel() + + if err := h.client.menuUpClick(connectCtx); err != nil { + st, ok := status.FromError(err) + if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) { + log.Debugf("connect operation cancelled by user") + } else { + h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect")) + log.Errorf("connect failed: %v", err) + } + } + + if err := h.client.updateStatus(); err != nil { + log.Debugf("failed to update status after connect: %v", err) } }() } func (h *eventHandler) handleDisconnectClick() { h.client.mDown.Disable() + + if h.client.connectCancel != nil { + log.Debugf("cancelling ongoing connect operation") + h.client.connectCancel() + h.client.connectCancel = nil + } + go func() { - defer h.client.mDown.Enable() if err := h.client.menuDownClick(); err != nil { - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird daemon")) + st, ok := status.FromError(err) + if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) { + h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to disconnect")) + log.Errorf("disconnect failed: %v", err) + } else { + log.Debugf("disconnect cancelled or already disconnecting") + } + } + + if err := h.client.updateStatus(); err != nil { + log.Debugf("failed to update status after disconnect: %v", err) } }() } @@ -245,6 +282,6 @@ func (h *eventHandler) logout(ctx context.Context) error { } h.client.getSrvConfig() - + return nil } diff --git a/client/ui/profile.go b/client/ui/profile.go index 075223795..74189c9a0 100644 --- a/client/ui/profile.go +++ b/client/ui/profile.go @@ -387,6 +387,7 @@ type subItem struct { type profileMenu struct { mu sync.Mutex ctx context.Context + serviceClient *serviceClient profileManager *profilemanager.ProfileManager eventHandler *eventHandler profileMenuItem *systray.MenuItem @@ -396,7 +397,7 @@ type profileMenu struct { logoutSubItem *subItem profilesState []Profile downClickCallback func() error - upClickCallback func() error + upClickCallback func(context.Context) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -404,12 +405,13 @@ type profileMenu struct { type newProfileMenuArgs struct { ctx context.Context + serviceClient *serviceClient profileManager *profilemanager.ProfileManager eventHandler *eventHandler profileMenuItem *systray.MenuItem emailMenuItem *systray.MenuItem downClickCallback func() error - upClickCallback func() error + upClickCallback func(context.Context) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -418,6 +420,7 @@ type newProfileMenuArgs struct { func newProfileMenu(args newProfileMenuArgs) *profileMenu { p := profileMenu{ ctx: args.ctx, + serviceClient: args.serviceClient, profileManager: args.profileManager, eventHandler: args.eventHandler, profileMenuItem: args.profileMenuItem, @@ -569,10 +572,19 @@ func (p *profileMenu) refresh() { } } - if err := p.upClickCallback(); err != nil { + if p.serviceClient.connectCancel != nil { + p.serviceClient.connectCancel() + } + + connectCtx, connectCancel := context.WithCancel(p.ctx) + p.serviceClient.connectCancel = connectCancel + + if err := p.upClickCallback(connectCtx); err != nil { log.Errorf("failed to handle up click after switching profile: %v", err) } + connectCancel() + p.refresh() p.loadSettingsCallback() } diff --git a/client/wasm/internal/rdp/cert_validation.go b/client/wasm/internal/rdp/cert_validation.go index 4a23a4bc8..1678c3996 100644 --- a/client/wasm/internal/rdp/cert_validation.go +++ b/client/wasm/internal/rdp/cert_validation.go @@ -73,8 +73,8 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert } } -func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config { - return &tls.Config{ +func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config { + config := &tls.Config{ InsecureSkipVerify: true, // We'll validate manually after handshake VerifyConnection: func(cs tls.ConnectionState) error { var certChain [][]byte @@ -93,4 +93,15 @@ func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tl return nil }, } + + // CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3 + if requiresCredSSP { + config.MinVersion = tls.VersionTLS12 + config.MaxVersion = tls.VersionTLS12 + } else { + config.MinVersion = tls.VersionTLS12 + config.MaxVersion = tls.VersionTLS13 + } + + return config } diff --git a/client/wasm/internal/rdp/rdcleanpath.go b/client/wasm/internal/rdp/rdcleanpath.go index 8062a05cc..16bf63bb9 100644 --- a/client/wasm/internal/rdp/rdcleanpath.go +++ b/client/wasm/internal/rdp/rdcleanpath.go @@ -6,11 +6,13 @@ import ( "context" "crypto/tls" "encoding/asn1" + "errors" "fmt" "io" "net" "sync" "syscall/js" + "time" log "github.com/sirupsen/logrus" ) @@ -19,18 +21,34 @@ const ( RDCleanPathVersion = 3390 RDCleanPathProxyHost = "rdcleanpath.proxy.local" RDCleanPathProxyScheme = "ws" + + rdpDialTimeout = 15 * time.Second + + GeneralErrorCode = 1 + WSAETimedOut = 10060 + WSAEConnRefused = 10061 + WSAEConnAborted = 10053 + WSAEConnReset = 10054 + WSAEGenericError = 10050 ) type RDCleanPathPDU struct { - Version int64 `asn1:"tag:0,explicit"` - Error []byte `asn1:"tag:1,explicit,optional"` - Destination string `asn1:"utf8,tag:2,explicit,optional"` - ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"` - ServerAuth string `asn1:"utf8,tag:4,explicit,optional"` - PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"` - X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"` - ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"` - ServerAddr string `asn1:"utf8,tag:9,explicit,optional"` + Version int64 `asn1:"tag:0,explicit"` + Error RDCleanPathErr `asn1:"tag:1,explicit,optional"` + Destination string `asn1:"utf8,tag:2,explicit,optional"` + ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"` + ServerAuth string `asn1:"utf8,tag:4,explicit,optional"` + PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"` + X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"` + ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"` + ServerAddr string `asn1:"utf8,tag:9,explicit,optional"` +} + +type RDCleanPathErr struct { + ErrorCode int16 `asn1:"tag:0,explicit"` + HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"` + WSALastError int16 `asn1:"tag:2,explicit,optional"` + TLSAlertCode int8 `asn1:"tag:3,explicit,optional"` } type RDCleanPathProxy struct { @@ -210,9 +228,13 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket [] destination := conn.destination log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination) - rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) + ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout) + defer cancel() + + rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination) if err != nil { log.Errorf("Failed to connect to %s: %v", destination, err) + p.sendRDCleanPathError(conn, newWSAError(err)) return } conn.rdpConn = rdpConn @@ -220,6 +242,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket [] _, err = rdpConn.Write(firstPacket) if err != nil { log.Errorf("Failed to write first packet: %v", err) + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -227,6 +250,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket [] n, err := rdpConn.Read(response) if err != nil { log.Errorf("Failed to read X.224 response: %v", err) + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -269,3 +293,52 @@ func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) { conn.wsHandlers.Call("send", uint8Array.Get("buffer")) } } + +func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) { + data, err := asn1.Marshal(pdu) + if err != nil { + log.Errorf("Failed to marshal error PDU: %v", err) + return + } + p.sendToWebSocket(conn, data) +} + +func errorToWSACode(err error) int16 { + if err == nil { + return WSAEGenericError + } + var netErr *net.OpError + if errors.As(err, &netErr) && netErr.Timeout() { + return WSAETimedOut + } + if errors.Is(err, context.DeadlineExceeded) { + return WSAETimedOut + } + if errors.Is(err, context.Canceled) { + return WSAEConnAborted + } + if errors.Is(err, io.EOF) { + return WSAEConnReset + } + return WSAEGenericError +} + +func newWSAError(err error) RDCleanPathPDU { + return RDCleanPathPDU{ + Version: RDCleanPathVersion, + Error: RDCleanPathErr{ + ErrorCode: GeneralErrorCode, + WSALastError: errorToWSACode(err), + }, + } +} + +func newHTTPError(statusCode int16) RDCleanPathPDU { + return RDCleanPathPDU{ + Version: RDCleanPathVersion, + Error: RDCleanPathErr{ + ErrorCode: GeneralErrorCode, + HTTPStatusCode: statusCode, + }, + } +} diff --git a/client/wasm/internal/rdp/rdcleanpath_handlers.go b/client/wasm/internal/rdp/rdcleanpath_handlers.go index 010efa5ea..97bb46338 100644 --- a/client/wasm/internal/rdp/rdcleanpath_handlers.go +++ b/client/wasm/internal/rdp/rdcleanpath_handlers.go @@ -3,6 +3,7 @@ package rdp import ( + "context" "crypto/tls" "encoding/asn1" "io" @@ -11,11 +12,17 @@ import ( log "github.com/sirupsen/logrus" ) +const ( + // MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP) + protocolSSL = 0x00000001 + protocolHybridEx = 0x00000008 +) + func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination) if pdu.Version != RDCleanPathVersion { - p.sendRDCleanPathError(conn, "Unsupported version") + p.sendRDCleanPathError(conn, newHTTPError(400)) return } @@ -24,10 +31,13 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl destination = pdu.Destination } - rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) + ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout) + defer cancel() + + rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination) if err != nil { log.Errorf("Failed to connect to %s: %v", destination, err) - p.sendRDCleanPathError(conn, "Connection failed") + p.sendRDCleanPathError(conn, newWSAError(err)) p.cleanupConnection(conn) return } @@ -40,6 +50,34 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl p.setupTLSConnection(conn, pdu) } +// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required. +// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags. +// Returns (requiresTLS12, selectedProtocol, detectionSuccessful). +func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) { + const minResponseLength = 19 + + if len(x224Response) < minResponseLength { + return false, 0, false + } + + // Per X.224 specification: + // x224Response[0] == 0x03: Length of X.224 header (3 bytes) + // x224Response[5] == 0xD0: X.224 Data TPDU code + if x224Response[0] != 0x03 || x224Response[5] != 0xD0 { + return false, 0, false + } + + if x224Response[11] == 0x02 { + flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 | + uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24 + + hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0 + return hasNLA, flags, true + } + + return false, 0, false +} + func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) { var x224Response []byte if len(pdu.X224ConnectionPDU) > 0 { @@ -47,7 +85,7 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU) if err != nil { log.Errorf("Failed to write X.224 PDU: %v", err) - p.sendRDCleanPathError(conn, "Failed to forward X.224") + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -55,21 +93,32 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean n, err := conn.rdpConn.Read(response) if err != nil { log.Errorf("Failed to read X.224 response: %v", err) - p.sendRDCleanPathError(conn, "Failed to read X.224 response") + p.sendRDCleanPathError(conn, newWSAError(err)) return } x224Response = response[:n] log.Debugf("Received X.224 Connection Confirm (%d bytes)", n) } - tlsConfig := p.getTLSConfigWithValidation(conn) + requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response) + if detected { + if requiresCredSSP { + log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol) + } else { + log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol) + } + } else { + log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3") + } + + tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP) tlsConn := tls.Client(conn.rdpConn, tlsConfig) conn.tlsConn = tlsConn if err := tlsConn.Handshake(); err != nil { log.Errorf("TLS handshake failed: %v", err) - p.sendRDCleanPathError(conn, "TLS handshake failed") + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -106,47 +155,6 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean p.cleanupConnection(conn) } -func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) { - if len(pdu.X224ConnectionPDU) > 0 { - log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU)) - _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU) - if err != nil { - log.Errorf("Failed to write X.224 PDU: %v", err) - p.sendRDCleanPathError(conn, "Failed to forward X.224") - return - } - - response := make([]byte, 1024) - n, err := conn.rdpConn.Read(response) - if err != nil { - log.Errorf("Failed to read X.224 response: %v", err) - p.sendRDCleanPathError(conn, "Failed to read X.224 response") - return - } - - responsePDU := RDCleanPathPDU{ - Version: RDCleanPathVersion, - X224ConnectionPDU: response[:n], - ServerAddr: conn.destination, - } - - p.sendRDCleanPathPDU(conn, responsePDU) - } else { - responsePDU := RDCleanPathPDU{ - Version: RDCleanPathVersion, - ServerAddr: conn.destination, - } - p.sendRDCleanPathPDU(conn, responsePDU) - } - - go p.forwardConnToWS(conn, conn.rdpConn, "TCP") - go p.forwardWSToConn(conn, conn.rdpConn, "TCP") - - <-conn.ctx.Done() - log.Debug("TCP connection context done, cleaning up") - p.cleanupConnection(conn) -} - func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { data, err := asn1.Marshal(pdu) if err != nil { @@ -158,21 +166,6 @@ func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDClean p.sendToWebSocket(conn, data) } -func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) { - pdu := RDCleanPathPDU{ - Version: RDCleanPathVersion, - Error: []byte(errorMsg), - } - - data, err := asn1.Marshal(pdu) - if err != nil { - log.Errorf("Failed to marshal error PDU: %v", err) - return - } - - p.sendToWebSocket(conn, data) -} - func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) { msgChan := make(chan []byte) errChan := make(chan error) diff --git a/dns/dns.go b/dns/dns.go index f889a32ec..cf089d4ed 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -19,6 +19,10 @@ const ( RootZone = "." // DefaultClass is the class supported by the system DefaultClass = "IN" + // ForwarderClientPort is the port clients connect to. DNAT rewrites packets from ForwarderClientPort to ForwarderServerPort. + ForwarderClientPort uint16 = 5353 + // ForwarderServerPort is the port the DNS forwarder actually listens on. Packets to ForwarderClientPort are DNATed here. + ForwarderServerPort uint16 = 22054 ) const invalidHostLabel = "[^a-zA-Z0-9-]+" @@ -31,6 +35,8 @@ type Config struct { NameServerGroups []*NameServerGroup // CustomZones contains a list of custom zone CustomZones []CustomZone + // ForwarderPort is the port clients should connect to on routing peers for DNS forwarding + ForwarderPort uint16 } // CustomZone represents a custom zone to be resolved by the dns server diff --git a/go.mod b/go.mod index a1560b409..06dc921f1 100644 --- a/go.mod +++ b/go.mod @@ -62,7 +62,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 + github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 13838b82d..ce68ed99e 100644 --- a/go.sum +++ b/go.sum @@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 h1:moJbL1uuaWR35yUgHZ6suijjqqW8/qGCuPPBXu5MeWQ= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48/go.mod h1:ifKa2jGPsOzZhJFo72v2AE5nMP3GYvlhoZ9JV6lHlJ8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh index e3fcbfdde..92252d0b3 100755 --- a/infrastructure_files/configure.sh +++ b/infrastructure_files/configure.sh @@ -185,12 +185,15 @@ if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME" echo "" - export NETBIRD_SIGNAL_PROTOCOL="https" unset NETBIRD_LETSENCRYPT_DOMAIN unset NETBIRD_MGMT_API_CERT_FILE unset NETBIRD_MGMT_API_CERT_KEY_FILE fi +if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then + export NETBIRD_SIGNAL_PROTOCOL="https" +fi + # Check if management identity provider is set if [ -n "$NETBIRD_MGMT_IDP" ]; then EXTRA_CONFIG={} diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index b24e853b4..2bc49d3e5 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -40,13 +40,21 @@ services: signal: <<: *default image: netbirdio/signal:$NETBIRD_SIGNAL_TAG + depends_on: + - dashboard volumes: - $SIGNAL_VOLUMENAME:/var/lib/netbird + - $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro ports: - $NETBIRD_SIGNAL_PORT:80 # # port and command for Let's Encrypt validation # - 443:443 # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] + command: [ + "--cert-file", "$NETBIRD_MGMT_API_CERT_FILE", + "--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE", + "--log-file", "console" + ] # Relay relay: diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik index 196e26a66..0010974c5 100644 --- a/infrastructure_files/docker-compose.yml.tmpl.traefik +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -49,6 +49,7 @@ services: - traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal - traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80 - traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`) + - traefik.http.routers.netbird-signal.service=netbird-signal - traefik.http.services.netbird-signal.loadbalancer.server.port=10000 - traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index bc326cd7e..09c5225ad 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -682,17 +682,6 @@ renderManagementJson() { "URI": "stun:$NETBIRD_DOMAIN:3478" } ], - "TURNConfig": { - "Turns": [ - { - "Proto": "udp", - "URI": "turn:$NETBIRD_DOMAIN:3478", - "Username": "$TURN_USER", - "Password": "$TURN_PASSWORD" - } - ], - "TimeBasedCredentials": false - }, "Relay": { "Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"], "CredentialsTTL": "24h", diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index daec4ef6f..209a20065 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -35,7 +35,13 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation { func (s *BaseServer) PermissionsManager() permissions.Manager { return Create(s, func() permissions.Manager { - return integrations.InitPermissionsManager(s.Store()) + manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter()) + + s.AfterInit(func(s *BaseServer) { + manager.SetAccountManager(s.AccountManager()) + }) + + return manager }) } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 0c493f07d..db377865a 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -109,7 +109,7 @@ type Manager interface { GetIdpManager() idp.Manager UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) - GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) + GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error diff --git a/management/server/dns.go b/management/server/dns.go index 5ba0d2620..decc5175d 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -21,8 +21,8 @@ import ( ) const ( - dnsForwarderPort = 22054 - oldForwarderPort = 5353 + dnsForwarderPort = nbdns.ForwarderServerPort + oldForwarderPort = nbdns.ForwarderClientPort ) const dnsForwarderPortMinVersion = "v0.59.0" @@ -199,7 +199,7 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID // If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0. func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { if len(peers) == 0 { - return oldForwarderPort + return int64(oldForwarderPort) } reqVer := semver.Canonical(requiredVersion) @@ -214,17 +214,17 @@ func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { peerVersion := semver.Canonical("v" + peer.Meta.WtVersion) if peerVersion == "" { // If any peer doesn't have version info, return 0 - return oldForwarderPort + return int64(oldForwarderPort) } // Compare versions if semver.Compare(peerVersion, reqVer) < 0 { - return oldForwarderPort + return int64(oldForwarderPort) } } // All peers have the required version or newer - return dnsForwarderPort + return int64(dnsForwarderPort) } // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 83caf74ef..96f73a390 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -394,7 +394,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - toProtocolDNSConfig(testData, cache, dnsForwarderPort) + toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort)) } }) @@ -402,7 +402,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { cache := &DNSConfigCache{} - toProtocolDNSConfig(testData, cache, dnsForwarderPort) + toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort)) } }) } @@ -455,13 +455,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { } // First run with config1 - result1 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) + result1 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort)) // Second run with config2 - result2 := toProtocolDNSConfig(config2, &cache, dnsForwarderPort) + result2 := toProtocolDNSConfig(config2, &cache, int64(dnsForwarderPort)) // Third run with config1 again - result3 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) + result3 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort)) // Verify that result1 and result3 are identical if !reflect.DeepEqual(result1, result3) { @@ -486,7 +486,7 @@ func TestComputeForwarderPort(t *testing.T) { // Test with empty peers list peers := []*nbpeer.Peer{} result := computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result) } @@ -504,7 +504,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result) } @@ -522,7 +522,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != dnsForwarderPort { + if result != int64(dnsForwarderPort) { t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result) } @@ -540,7 +540,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result) } @@ -553,7 +553,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result) } @@ -565,7 +565,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result == oldForwarderPort { + if result == int64(oldForwarderPort) { t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result) } @@ -578,7 +578,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result) } } diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 4b33495de..df89c616c 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -78,7 +78,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) - validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to list approved peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) @@ -86,7 +86,9 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, } _, valid := validPeers[peer.ID] - util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid)) + reason := invalidPeers[peer.ID] + + util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid, reason)) } func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { @@ -147,16 +149,17 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0) - validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { - log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(ctx).Errorf("failed to get validated peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) return } _, valid := validPeers[peer.ID] + reason := invalidPeers[peer.ID] - util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid)) + util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason)) } func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { @@ -240,22 +243,25 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0)) } - validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { - log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } - h.setApprovalRequiredFlag(respBody, validPeersMap) + h.setApprovalRequiredFlag(respBody, validPeersMap, invalidPeersMap) util.WriteJSONObject(r.Context(), w, respBody) } -func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { +func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersMap map[string]struct{}, invalidPeersMap map[string]string) { for _, peer := range respBody { - _, ok := approvedPeersMap[peer.Id] + _, ok := validPeersMap[peer.Id] if !ok { peer.ApprovalRequired = true + + reason := invalidPeersMap[peer.Id] + peer.DisapprovalReason = &reason } } } @@ -304,7 +310,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { } } - validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) @@ -430,13 +436,13 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer { +func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool, reason string) *api.Peer { osVersion := peer.Meta.OSVersion if osVersion == "" { osVersion = peer.Meta.Core } - return &api.Peer{ + apiPeer := &api.Peer{ CreatedAt: peer.CreatedAt, Id: peer.ID, Name: peer.Name, @@ -465,6 +471,12 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD InactivityExpirationEnabled: peer.InactivityExpirationEnabled, Ephemeral: peer.Ephemeral, } + + if !approved { + apiPeer.DisapprovalReason = &reason + } + + return apiPeer } func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch { diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 741f03f18..bdf56db6e 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -7,9 +7,10 @@ import ( "time" "github.com/golang-jwt/jwt/v5" - "github.com/netbirdio/management-integrations/integrations" "github.com/stretchr/testify/assert" + "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" diff --git a/management/server/idp/auth0_test.go b/management/server/idp/auth0_test.go index 66c16870b..bc352f117 100644 --- a/management/server/idp/auth0_test.go +++ b/management/server/idp/auth0_test.go @@ -26,9 +26,11 @@ type mockHTTPClient struct { } func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) { - body, err := io.ReadAll(req.Body) - if err == nil { - c.reqBody = string(body) + if req.Body != nil { + body, err := io.ReadAll(req.Body) + if err == nil { + c.reqBody = string(body) + } } return &http.Response{ StatusCode: c.code, diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index 51f99b3b7..f06e57196 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -201,6 +201,12 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr APIToken: config.ExtraConfig["ApiToken"], } return NewJumpCloudManager(jumpcloudConfig, appMetrics) + case "pocketid": + pocketidConfig := PocketIdClientConfig{ + APIToken: config.ExtraConfig["ApiToken"], + ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"], + } + return NewPocketIdManager(pocketidConfig, appMetrics) default: return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType) } diff --git a/management/server/idp/pocketid.go b/management/server/idp/pocketid.go new file mode 100644 index 000000000..38a5cc67f --- /dev/null +++ b/management/server/idp/pocketid.go @@ -0,0 +1,384 @@ +package idp + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "slices" + "strings" + "time" + + "github.com/netbirdio/netbird/management/server/telemetry" +) + +type PocketIdManager struct { + managementEndpoint string + apiToken string + httpClient ManagerHTTPClient + credentials ManagerCredentials + helper ManagerHelper + appMetrics telemetry.AppMetrics +} + +type pocketIdCustomClaimDto struct { + Key string `json:"key"` + Value string `json:"value"` +} + +type pocketIdUserDto struct { + CustomClaims []pocketIdCustomClaimDto `json:"customClaims"` + Disabled bool `json:"disabled"` + DisplayName string `json:"displayName"` + Email string `json:"email"` + FirstName string `json:"firstName"` + ID string `json:"id"` + IsAdmin bool `json:"isAdmin"` + LastName string `json:"lastName"` + LdapID string `json:"ldapId"` + Locale string `json:"locale"` + UserGroups []pocketIdUserGroupDto `json:"userGroups"` + Username string `json:"username"` +} + +type pocketIdUserCreateDto struct { + Disabled bool `json:"disabled,omitempty"` + DisplayName string `json:"displayName"` + Email string `json:"email"` + FirstName string `json:"firstName"` + IsAdmin bool `json:"isAdmin,omitempty"` + LastName string `json:"lastName,omitempty"` + Locale string `json:"locale,omitempty"` + Username string `json:"username"` +} + +type pocketIdPaginatedUserDto struct { + Data []pocketIdUserDto `json:"data"` + Pagination pocketIdPaginationDto `json:"pagination"` +} + +type pocketIdPaginationDto struct { + CurrentPage int `json:"currentPage"` + ItemsPerPage int `json:"itemsPerPage"` + TotalItems int `json:"totalItems"` + TotalPages int `json:"totalPages"` +} + +func (p *pocketIdUserDto) userData() *UserData { + return &UserData{ + Email: p.Email, + Name: p.DisplayName, + ID: p.ID, + AppMetadata: AppMetadata{}, + } +} + +type pocketIdUserGroupDto struct { + CreatedAt string `json:"createdAt"` + CustomClaims []pocketIdCustomClaimDto `json:"customClaims"` + FriendlyName string `json:"friendlyName"` + ID string `json:"id"` + LdapID string `json:"ldapId"` + Name string `json:"name"` +} + +func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMetrics) (*PocketIdManager, error) { + httpTransport := http.DefaultTransport.(*http.Transport).Clone() + httpTransport.MaxIdleConns = 5 + + httpClient := &http.Client{ + Timeout: 10 * time.Second, + Transport: httpTransport, + } + helper := JsonParser{} + + if config.ManagementEndpoint == "" { + return nil, fmt.Errorf("pocketId IdP configuration is incomplete, ManagementEndpoint is missing") + } + + if config.APIToken == "" { + return nil, fmt.Errorf("pocketId IdP configuration is incomplete, APIToken is missing") + } + + credentials := &PocketIdCredentials{ + clientConfig: config, + httpClient: httpClient, + helper: helper, + appMetrics: appMetrics, + } + + return &PocketIdManager{ + managementEndpoint: config.ManagementEndpoint, + apiToken: config.APIToken, + httpClient: httpClient, + credentials: credentials, + helper: helper, + appMetrics: appMetrics, + }, nil +} + +func (p *PocketIdManager) request(ctx context.Context, method, resource string, query *url.Values, body string) ([]byte, error) { + var MethodsWithBody = []string{http.MethodPost, http.MethodPut} + if !slices.Contains(MethodsWithBody, method) && body != "" { + return nil, fmt.Errorf("Body provided to unsupported method: %s", method) + } + + reqURL := fmt.Sprintf("%s/api/%s", p.managementEndpoint, resource) + if query != nil { + reqURL = fmt.Sprintf("%s?%s", reqURL, query.Encode()) + } + var req *http.Request + var err error + if body != "" { + req, err = http.NewRequestWithContext(ctx, method, reqURL, strings.NewReader(body)) + } else { + req, err = http.NewRequestWithContext(ctx, method, reqURL, nil) + } + if err != nil { + return nil, err + } + + req.Header.Add("X-API-KEY", p.apiToken) + + if body != "" { + req.Header.Add("content-type", "application/json") + req.Header.Add("content-length", fmt.Sprintf("%d", req.ContentLength)) + } + + resp, err := p.httpClient.Do(req) + if err != nil { + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountRequestError() + } + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountRequestStatusError() + } + + return nil, fmt.Errorf("received unexpected status code from PocketID API: %d", resp.StatusCode) + } + + return io.ReadAll(resp.Body) +} + +// getAllUsersPaginated fetches all users from PocketID API using pagination +func (p *PocketIdManager) getAllUsersPaginated(ctx context.Context, searchParams url.Values) ([]pocketIdUserDto, error) { + var allUsers []pocketIdUserDto + currentPage := 1 + + for { + params := url.Values{} + // Copy existing search parameters + for key, values := range searchParams { + params[key] = values + } + + params.Set("pagination[limit]", "100") + params.Set("pagination[page]", fmt.Sprintf("%d", currentPage)) + + body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "") + if err != nil { + return nil, err + } + + var profiles pocketIdPaginatedUserDto + err = p.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + allUsers = append(allUsers, profiles.Data...) + + // Check if we've reached the last page + if currentPage >= profiles.Pagination.TotalPages { + break + } + + currentPage++ + } + + return allUsers, nil +} + +func (p *PocketIdManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { + return nil +} + +func (p *PocketIdManager) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) { + body, err := p.request(ctx, http.MethodGet, "users/"+userId, nil, "") + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetUserDataByID() + } + + var user pocketIdUserDto + err = p.helper.Unmarshal(body, &user) + if err != nil { + return nil, err + } + + userData := user.userData() + userData.AppMetadata = appMetadata + + return userData, nil +} + +func (p *PocketIdManager) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) { + // Get all users using pagination + allUsers, err := p.getAllUsersPaginated(ctx, url.Values{}) + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetAccount() + } + + users := make([]*UserData, 0) + for _, profile := range allUsers { + userData := profile.userData() + userData.AppMetadata.WTAccountID = accountId + + users = append(users, userData) + } + return users, nil +} + +func (p *PocketIdManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { + // Get all users using pagination + allUsers, err := p.getAllUsersPaginated(ctx, url.Values{}) + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetAllAccounts() + } + + indexedUsers := make(map[string][]*UserData) + for _, profile := range allUsers { + userData := profile.userData() + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) + } + + return indexedUsers, nil +} + +func (p *PocketIdManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) { + firstLast := strings.Split(name, " ") + + createUser := pocketIdUserCreateDto{ + Disabled: false, + DisplayName: name, + Email: email, + FirstName: firstLast[0], + LastName: firstLast[1], + Username: firstLast[0] + "." + firstLast[1], + } + payload, err := p.helper.Marshal(createUser) + if err != nil { + return nil, err + } + + body, err := p.request(ctx, http.MethodPost, "users", nil, string(payload)) + if err != nil { + return nil, err + } + var newUser pocketIdUserDto + err = p.helper.Unmarshal(body, &newUser) + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountCreateUser() + } + var pending bool = true + ret := &UserData{ + Email: email, + Name: name, + ID: newUser.ID, + AppMetadata: AppMetadata{ + WTAccountID: accountID, + WTPendingInvite: &pending, + WTInvitedBy: invitedByEmail, + }, + } + return ret, nil +} + +func (p *PocketIdManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { + params := url.Values{ + // This value a + "search": []string{email}, + } + body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "") + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetUserByEmail() + } + + var profiles struct{ data []pocketIdUserDto } + err = p.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + users := make([]*UserData, 0) + for _, profile := range profiles.data { + users = append(users, profile.userData()) + } + return users, nil +} + +func (p *PocketIdManager) InviteUserByID(ctx context.Context, userID string) error { + _, err := p.request(ctx, http.MethodPut, "users/"+userID+"/one-time-access-email", nil, "") + if err != nil { + return err + } + return nil +} + +func (p *PocketIdManager) DeleteUser(ctx context.Context, userID string) error { + _, err := p.request(ctx, http.MethodDelete, "users/"+userID, nil, "") + if err != nil { + return err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountDeleteUser() + } + + return nil +} + +var _ Manager = (*PocketIdManager)(nil) + +type PocketIdClientConfig struct { + APIToken string + ManagementEndpoint string +} + +type PocketIdCredentials struct { + clientConfig PocketIdClientConfig + helper ManagerHelper + httpClient ManagerHTTPClient + appMetrics telemetry.AppMetrics +} + +var _ ManagerCredentials = (*PocketIdCredentials)(nil) + +func (p PocketIdCredentials) Authenticate(_ context.Context) (JWTToken, error) { + return JWTToken{}, nil +} diff --git a/management/server/idp/pocketid_test.go b/management/server/idp/pocketid_test.go new file mode 100644 index 000000000..49075a0d3 --- /dev/null +++ b/management/server/idp/pocketid_test.go @@ -0,0 +1,138 @@ +package idp + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/telemetry" +) + + +func TestNewPocketIdManager(t *testing.T) { + type test struct { + name string + inputConfig PocketIdClientConfig + assertErrFunc require.ErrorAssertionFunc + assertErrFuncMessage string + } + + defaultTestConfig := PocketIdClientConfig{ + APIToken: "api_token", + ManagementEndpoint: "http://localhost", + } + + tests := []test{ + { + name: "Good Configuration", + inputConfig: defaultTestConfig, + assertErrFunc: require.NoError, + assertErrFuncMessage: "shouldn't return error", + }, + { + name: "Missing ManagementEndpoint", + inputConfig: PocketIdClientConfig{ + APIToken: defaultTestConfig.APIToken, + ManagementEndpoint: "", + }, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + }, + { + name: "Missing APIToken", + inputConfig: PocketIdClientConfig{ + APIToken: "", + ManagementEndpoint: defaultTestConfig.ManagementEndpoint, + }, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{}) + tc.assertErrFunc(t, err, tc.assertErrFuncMessage) + }) + } +} + +func TestPocketID_GetUserDataByID(t *testing.T) { + client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + md := AppMetadata{WTAccountID: "acc1"} + got, err := mgr.GetUserDataByID(context.Background(), "u1", md) + require.NoError(t, err) + assert.Equal(t, "u1", got.ID) + assert.Equal(t, "user1@example.com", got.Email) + assert.Equal(t, "User One", got.Name) + assert.Equal(t, "acc1", got.AppMetadata.WTAccountID) +} + +func TestPocketID_GetAccount_WithPagination(t *testing.T) { + // Single page response with two users + client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + users, err := mgr.GetAccount(context.Background(), "accX") + require.NoError(t, err) + require.Len(t, users, 2) + assert.Equal(t, "u1", users[0].ID) + assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID) + assert.Equal(t, "u2", users[1].ID) +} + +func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) { + client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + accounts, err := mgr.GetAllAccounts(context.Background()) + require.NoError(t, err) + require.Len(t, accounts[UnsetAccountID], 2) +} + +func TestPocketID_CreateUser(t *testing.T) { + client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com") + require.NoError(t, err) + assert.Equal(t, "newid", ud.ID) + assert.Equal(t, "new@example.com", ud.Email) + assert.Equal(t, "New User", ud.Name) + assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID) + if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) { + assert.True(t, *ud.AppMetadata.WTPendingInvite) + } + assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy) +} + +func TestPocketID_InviteAndDeleteUser(t *testing.T) { + // Same mock for both calls; returns OK with empty JSON + client := &mockHTTPClient{code: 200, resBody: `{}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + err = mgr.InviteUserByID(context.Background(), "u1") + require.NoError(t, err) + + err = mgr.DeleteUser(context.Background(), "u1") + require.NoError(t, err) +} diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 21f11bfce..e9a1c8701 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -88,7 +88,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID return true, nil } -func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { +func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { var err error var groups []*types.Group var peers []*nbpeer.Peer @@ -96,20 +96,30 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, err + return nil, nil, err } peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") if err != nil { - return nil, err + return nil, nil, err } settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, err + return nil, nil, err } - return am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra) + validPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra) + if err != nil { + return nil, nil, err + } + + invalidPeers, err := am.integratedPeerValidator.GetInvalidPeers(ctx, accountID, settings.Extra) + if err != nil { + return nil, nil, err + } + + return validPeers, invalidPeers, nil } type MockIntegratedValidator struct { @@ -136,7 +146,11 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID return validatedPeers, nil } -func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer { +func (a MockIntegratedValidator) GetInvalidPeers(_ context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) { + return make(map[string]string), nil +} + +func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer { return peer } diff --git a/management/server/integrations/integrated_validator/interface.go b/management/server/integrations/integrated_validator/interface.go index ce632d567..26c338cb6 100644 --- a/management/server/integrations/integrated_validator/interface.go +++ b/management/server/integrations/integrated_validator/interface.go @@ -3,18 +3,19 @@ package integrated_validator import ( "context" - "github.com/netbirdio/netbird/shared/management/proto" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" ) // IntegratedValidator interface exists to avoid the circle dependencies type IntegratedValidator interface { ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) - PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer + PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) + GetInvalidPeers(ctx context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error SetPeerInvalidationListener(fn func(accountID string, peerIDs []string)) Stop(ctx context.Context) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 2d691ba03..8baffa58b 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -190,17 +190,17 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st panic("implement me") } -func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { +func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { account, err := am.GetAccountFunc(ctx, accountID) if err != nil { - return nil, err + return nil, nil, err } approvedPeers := make(map[string]struct{}) for id := range account.Peers { approvedPeers[id] = struct{}{} } - return approvedPeers, nil + return approvedPeers, nil, nil } // GetGroup mock implementation of GetGroup from server.AccountManager interface diff --git a/management/server/peer.go b/management/server/peer.go index ba9b71f26..e06d6c46d 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -357,7 +357,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer } var peer *nbpeer.Peer - var updateAccountPeers bool var eventsToStore []func() err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { @@ -370,11 +369,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, peerID) - if err != nil { - return err - } - eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}) if err != nil { return fmt.Errorf("failed to delete peer: %w", err) @@ -394,7 +388,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer storeEvent() } - if updateAccountPeers && am.experimentalNetworkMap(accountID) { + if am.experimentalNetworkMap(accountID) { account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { return err @@ -406,7 +400,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer } - if updateAccountPeers && userID != activity.SystemInitiator { + if userID != activity.SystemInitiator { am.BufferUpdateAccountPeers(ctx, accountID) } @@ -609,7 +603,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe } } - newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) + newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra, temporary) network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) if err != nil { @@ -709,11 +703,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err) } - updateAccountPeers, err := isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID) - if err != nil { - updateAccountPeers = true - } - if newPeer == nil { return nil, nil, nil, fmt.Errorf("new peer is nil") } @@ -737,9 +726,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe } } - if updateAccountPeers { - am.BufferUpdateAccountPeers(ctx, accountID) - } + am.BufferUpdateAccountPeers(ctx, accountID) return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) } @@ -1598,16 +1585,6 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID) } -// IsPeerInActiveGroup checks if the given peer is part of a group that is used -// in an active DNS, route, or ACL configuration. -func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) { - peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID) - if err != nil { - return false, err - } - return areGroupChangesAffectPeers(ctx, transaction, accountID, peerGroupIDs) // TODO: use transaction -} - // deletePeers deletes all specified peers and sends updates to the remote peers. // Returns a slice of functions to save events after successful peer deletion. func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index f16bdeda2..dbdfb88d2 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1179,7 +1179,7 @@ func TestToSyncResponse(t *testing.T) { } dnsCache := &DNSConfigCache{} accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true} - response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, dnsForwarderPort) + response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort)) assert.NotNil(t, response) // assert peer config @@ -1809,7 +1809,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { t.Run("adding peer to unlinked group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) // + peerShouldReceiveUpdate(t, updMsg) // close(done) }() @@ -1834,7 +1834,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { t.Run("deleting peer with unlinked group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg) close(done) }() diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go index 891fa59bb..e6bdd2025 100644 --- a/management/server/permissions/manager.go +++ b/management/server/permissions/manager.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -22,6 +23,7 @@ type Manager interface { ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error) + SetAccountManager(accountManager account.Manager) } type managerImpl struct { @@ -121,3 +123,7 @@ func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserR return permissions, nil } + +func (m *managerImpl) SetAccountManager(accountManager account.Manager) { + // no-op +} diff --git a/management/server/permissions/manager_mock.go b/management/server/permissions/manager_mock.go index fa115d628..ec9f263f9 100644 --- a/management/server/permissions/manager_mock.go +++ b/management/server/permissions/manager_mock.go @@ -9,6 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + account "github.com/netbirdio/netbird/management/server/account" modules "github.com/netbirdio/netbird/management/server/permissions/modules" operations "github.com/netbirdio/netbird/management/server/permissions/operations" roles "github.com/netbirdio/netbird/management/server/permissions/roles" @@ -53,6 +54,18 @@ func (mr *MockManagerMockRecorder) GetPermissionsByRole(ctx, role interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermissionsByRole", reflect.TypeOf((*MockManager)(nil).GetPermissionsByRole), ctx, role) } +// SetAccountManager mocks base method. +func (m *MockManager) SetAccountManager(accountManager account.Manager) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAccountManager", accountManager) +} + +// SetAccountManager indicates an expected call of SetAccountManager. +func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager) +} + // ValidateAccountAccess mocks base method. func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error { m.ctrl.T.Helper() diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index cd221b590..32538933a 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -845,6 +845,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { peer *nbpeer.Peer customZone nbdns.CustomZone peersToConnect []*nbpeer.Peer + expiredPeers []*nbpeer.Peer expectedRecords []nbdns.SimpleRecord }{ { @@ -857,6 +858,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { }, }, peersToConnect: []*nbpeer.Peer{}, + expiredPeers: []*nbpeer.Peer{}, peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: []nbdns.SimpleRecord{ {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, @@ -890,7 +892,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { } return peers }(), - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expiredPeers: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: func() []nbdns.SimpleRecord { var records []nbdns.SimpleRecord for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { @@ -924,7 +927,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { {ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}}, {ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}}, }, - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expiredPeers: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: []nbdns.SimpleRecord{ {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, @@ -934,11 +938,35 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, }, }, + { + name: "expired peers are included in DNS entries", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + peersToConnect: []*nbpeer.Peer{ + {ID: "peer1", IP: net.ParseIP("10.0.0.1")}, + }, + expiredPeers: []*nbpeer.Peer{ + {ID: "expired-peer", IP: net.ParseIP("10.0.0.99")}, + }, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect) + result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect, tt.expiredPeers) assert.Equal(t, len(tt.expectedRecords), len(result)) assert.ElementsMatch(t, tt.expectedRecords, result) }) diff --git a/management/server/types/networkmap.go b/management/server/types/networkmap.go index 1c0701401..c85dac973 100644 --- a/management/server/types/networkmap.go +++ b/management/server/types/networkmap.go @@ -109,9 +109,8 @@ func (a *Account) GetPeerNetworkMap( if dnsManagementStatus { var zones []nbdns.CustomZone - if peersCustomZone.Domain != "" { - records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect) + records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers) zones = append(zones, nbdns.CustomZone{ Domain: peersCustomZone.Domain, Records: records, @@ -894,7 +893,7 @@ func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) } // filterZoneRecordsForPeers filters DNS records to only include peers to connect. -func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord { +func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord { filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records)) peerIPs := make(map[string]struct{}) @@ -905,6 +904,10 @@ func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, p peerIPs[peerToConnect.IP.String()] = struct{}{} } + for _, expiredPeer := range expiredPeers { + peerIPs[expiredPeer.IP.String()] = struct{}{} + } + for _, record := range customZone.Records { if _, exists := peerIPs[record.RData]; exists { filteredRecords = append(filteredRecords, record) diff --git a/release_files/install.sh b/release_files/install.sh index 5d5349ec4..6a2c5f458 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -29,6 +29,8 @@ if [ -z ${NETBIRD_RELEASE+x} ]; then NETBIRD_RELEASE=latest fi +TAG_NAME="" + get_release() { local RELEASE=$1 if [ "$RELEASE" = "latest" ]; then @@ -38,17 +40,19 @@ get_release() { local TAG="tags/${RELEASE}" local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" fi + OUTPUT="" if [ -n "$GITHUB_TOKEN" ]; then - curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \ - | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' + OUTPUT=$(curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}") else - curl -s "${URL}" \ - | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' + OUTPUT=$(curl -s "${URL}") fi + TAG_NAME=$(echo ${OUTPUT} | grep -Eo '\"tag_name\":\s*\"v([0-9]+\.){2}[0-9]+"' | tail -n 1) + echo "${TAG_NAME}" | grep -oE 'v[0-9]+\.[0-9]+\.[0-9]+' } download_release_binary() { VERSION=$(get_release "$NETBIRD_RELEASE") + echo "Using the following tag name for binary installation: ${TAG_NAME}" BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download" BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz" diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 076f2532b..520a83e36 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -55,8 +55,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE var err error conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent) if err != nil { - log.Printf("createConnection error: %v", err) - return err + return fmt.Errorf("create connection: %w", err) } return nil } diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 93578b1ae..4a5454002 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -463,6 +463,9 @@ components: description: (Cloud only) Indicates whether peer needs approval type: boolean example: true + disapproval_reason: + description: (Cloud only) Reason why the peer requires approval + type: string country_code: $ref: '#/components/schemas/CountryCode' city_name: diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 3dbb32ef6..9611d26d6 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1037,6 +1037,9 @@ type Peer struct { // CreatedAt Peer creation date (UTC) CreatedAt time.Time `json:"created_at"` + // DisapprovalReason (Cloud only) Reason why the peer requires approval + DisapprovalReason *string `json:"disapproval_reason,omitempty"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` @@ -1124,6 +1127,9 @@ type PeerBatch struct { // CreatedAt Peer creation date (UTC) CreatedAt time.Time `json:"created_at"` + // DisapprovalReason (Cloud only) Reason why the peer requires approval + DisapprovalReason *string `json:"disapproval_reason,omitempty"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index ad82d37d9..3982ea2af 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -410,7 +410,7 @@ message DNSConfig { bool ServiceEnable = 1; repeated NameServerGroup NameServerGroups = 2; repeated CustomZone CustomZones = 3; - int64 ForwarderPort = 4; + int64 ForwarderPort = 4 [deprecated = true]; } // CustomZone represents a dns.CustomZone diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 31f3372c0..5368b57a2 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -60,8 +60,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo var err error conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent) if err != nil { - log.Printf("createConnection error: %v", err) - return err + return fmt.Errorf("create connection: %w", err) } return nil } diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 96873dee7..bf8f8e327 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -94,7 +94,7 @@ var ( startPprof() - opts, certManager, err := getTLSConfigurations() + opts, certManager, tlsConfig, err := getTLSConfigurations() if err != nil { return err } @@ -132,7 +132,7 @@ var ( // Start the main server - always serve HTTP with WebSocket proxy support // If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager - if certManager == nil { + if tlsConfig == nil { // Without TLS, serve plain HTTP httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort)) if err != nil { @@ -140,9 +140,10 @@ var ( } log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String()) serveHTTP(httpListener, grpcRootHandler) - } else if signalPort != 443 { - // With TLS but not on port 443, serve HTTPS - httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), certManager.TLSConfig()) + } else if certManager == nil || signalPort != 443 { + // Serve HTTPS if not already handled by startServerWithCertManager + // (custom certificates or Let's Encrypt with custom port) + httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), tlsConfig) if err != nil { return err } @@ -202,7 +203,7 @@ func startPprof() { }() } -func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { +func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, *tls.Config, error) { var ( err error certManager *autocert.Manager @@ -211,33 +212,33 @@ func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" { log.Infof("running without TLS") - return nil, nil, nil + return nil, nil, nil, nil } if signalLetsencryptDomain != "" { certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain) if err != nil { - return nil, certManager, err + return nil, certManager, nil, err } tlsConfig = certManager.TLSConfig() log.Infof("setting up TLS with LetsEncrypt.") } else { if signalCertFile == "" || signalCertKey == "" { log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt") - return nil, certManager, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") + return nil, certManager, nil, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") } tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey) if err != nil { log.Errorf("cannot load TLS credentials: %v", err) - return nil, certManager, err + return nil, certManager, nil, err } log.Infof("setting up TLS with custom certificates.") } transportCredentials := credentials.NewTLS(tlsConfig) - return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, err + return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, tlsConfig, err } func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) { diff --git a/version/url_windows.go b/version/url_windows.go index f2055b109..14fdb7ae6 100644 --- a/version/url_windows.go +++ b/version/url_windows.go @@ -1,9 +1,13 @@ package version -import "golang.org/x/sys/windows/registry" +import ( + "golang.org/x/sys/windows/registry" + "runtime" +) const ( urlWinExe = "https://pkgs.netbird.io/windows/x64" + urlWinExeArm = "https://pkgs.netbird.io/windows/arm64" ) var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Netbird" @@ -11,9 +15,14 @@ var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Ne // DownloadUrl return with the proper download link func DownloadUrl() string { _, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyAppPath, registry.QUERY_VALUE) - if err == nil { - return urlWinExe - } else { + if err != nil { return downloadURL } + + url := urlWinExe + if runtime.GOARCH == "arm64" { + url = urlWinExeArm + } + + return url }