From 2248ff392f659f4e941d8f337683f8e2344c60de Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 27 Jan 2026 20:10:59 +0100 Subject: [PATCH 01/27] Remove redundant square bracket trimming in USP endpoint parsing (#5197) --- client/iface/configurer/usp.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index c4ea349df..1298c609d 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -558,7 +558,7 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) { continue } - host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]")) + host, portStr, err := net.SplitHostPort(val) if err != nil { log.Errorf("failed to parse endpoint: %v", err) continue From b55262d4a21cee277614e2d0ddd156d99112a9fc Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 28 Jan 2026 15:06:59 +0100 Subject: [PATCH 02/27] [client] Refactor/optimise raw socket headers (#5174) Pre-create and reuse packet headers to eliminate per-packet allocations. --- client/iface/wgproxy/ebpf/proxy.go | 72 ------------- client/iface/wgproxy/ebpf/wrapper.go | 153 +++++++++++++++++++++++++-- 2 files changed, 147 insertions(+), 78 deletions(-) diff --git a/client/iface/wgproxy/ebpf/proxy.go b/client/iface/wgproxy/ebpf/proxy.go index 5458519fa..1b1a8ce1c 100644 --- a/client/iface/wgproxy/ebpf/proxy.go +++ b/client/iface/wgproxy/ebpf/proxy.go @@ -8,8 +8,6 @@ import ( "net" "sync" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" "github.com/hashicorp/go-multierror" "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" @@ -26,16 +24,6 @@ const ( loopbackAddr = "127.0.0.1" ) -var ( - localHostNetIPv4 = net.ParseIP("127.0.0.1") - localHostNetIPv6 = net.ParseIP("::1") - - serializeOpts = gopacket.SerializeOptions{ - ComputeChecksums: true, - FixLengths: true, - } -) - // WGEBPFProxy definition for proxy with EBPF support type WGEBPFProxy struct { localWGListenPort int @@ -253,63 +241,3 @@ generatePort: } return p.lastUsedPort, nil } - -func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error { - - var ipH gopacket.SerializableLayer - var networkLayer gopacket.NetworkLayer - var dstIP net.IP - var rawConn net.PacketConn - - if endpointAddr.IP.To4() != nil { - // IPv4 path - ipv4 := &layers.IPv4{ - DstIP: localHostNetIPv4, - SrcIP: endpointAddr.IP, - Version: 4, - TTL: 64, - Protocol: layers.IPProtocolUDP, - } - ipH = ipv4 - networkLayer = ipv4 - dstIP = localHostNetIPv4 - rawConn = p.rawConnIPv4 - } else { - // IPv6 path - if p.rawConnIPv6 == nil { - return fmt.Errorf("IPv6 raw socket not available") - } - ipv6 := &layers.IPv6{ - DstIP: localHostNetIPv6, - SrcIP: endpointAddr.IP, - Version: 6, - HopLimit: 64, - NextHeader: layers.IPProtocolUDP, - } - ipH = ipv6 - networkLayer = ipv6 - dstIP = localHostNetIPv6 - rawConn = p.rawConnIPv6 - } - - udpH := &layers.UDP{ - SrcPort: layers.UDPPort(endpointAddr.Port), - DstPort: layers.UDPPort(p.localWGListenPort), - } - - if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil { - return fmt.Errorf("set network layer for checksum: %w", err) - } - - layerBuffer := gopacket.NewSerializeBuffer() - payload := gopacket.Payload(data) - - if err := gopacket.SerializeLayers(layerBuffer, serializeOpts, ipH, udpH, payload); err != nil { - return fmt.Errorf("serialize layers: %w", err) - } - - if _, err := rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: dstIP}); err != nil { - return fmt.Errorf("write to raw conn: %w", err) - } - return nil -} diff --git a/client/iface/wgproxy/ebpf/wrapper.go b/client/iface/wgproxy/ebpf/wrapper.go index 5b98be7b4..6e80945c4 100644 --- a/client/iface/wgproxy/ebpf/wrapper.go +++ b/client/iface/wgproxy/ebpf/wrapper.go @@ -10,12 +10,89 @@ import ( "net" "sync" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) +var ( + errIPv6ConnNotAvailable = errors.New("IPv6 endpoint but rawConnIPv6 is not available") + errIPv4ConnNotAvailable = errors.New("IPv4 endpoint but rawConnIPv4 is not available") + + localHostNetIPv4 = net.ParseIP("127.0.0.1") + localHostNetIPv6 = net.ParseIP("::1") + + serializeOpts = gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } +) + +// PacketHeaders holds pre-created headers and buffers for efficient packet sending +type PacketHeaders struct { + ipH gopacket.SerializableLayer + udpH *layers.UDP + layerBuffer gopacket.SerializeBuffer + localHostAddr net.IP + isIPv4 bool +} + +func NewPacketHeaders(localWGListenPort int, endpoint *net.UDPAddr) (*PacketHeaders, error) { + var ipH gopacket.SerializableLayer + var networkLayer gopacket.NetworkLayer + var localHostAddr net.IP + var isIPv4 bool + + // Check if source address is IPv4 or IPv6 + if endpoint.IP.To4() != nil { + // IPv4 path + ipv4 := &layers.IPv4{ + DstIP: localHostNetIPv4, + SrcIP: endpoint.IP, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + } + ipH = ipv4 + networkLayer = ipv4 + localHostAddr = localHostNetIPv4 + isIPv4 = true + } else { + // IPv6 path + ipv6 := &layers.IPv6{ + DstIP: localHostNetIPv6, + SrcIP: endpoint.IP, + Version: 6, + HopLimit: 64, + NextHeader: layers.IPProtocolUDP, + } + ipH = ipv6 + networkLayer = ipv6 + localHostAddr = localHostNetIPv6 + isIPv4 = false + } + + udpH := &layers.UDP{ + SrcPort: layers.UDPPort(endpoint.Port), + DstPort: layers.UDPPort(localWGListenPort), + } + + if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil { + return nil, fmt.Errorf("set network layer for checksum: %w", err) + } + + return &PacketHeaders{ + ipH: ipH, + udpH: udpH, + layerBuffer: gopacket.NewSerializeBuffer(), + localHostAddr: localHostAddr, + isIPv4: isIPv4, + }, nil +} + // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call type ProxyWrapper struct { wgeBPFProxy *WGEBPFProxy @@ -24,8 +101,10 @@ type ProxyWrapper struct { ctx context.Context cancel context.CancelFunc - wgRelayedEndpointAddr *net.UDPAddr - wgEndpointCurrentUsedAddr *net.UDPAddr + wgRelayedEndpointAddr *net.UDPAddr + headers *PacketHeaders + headerCurrentUsed *PacketHeaders + rawConn net.PacketConn paused bool pausedCond *sync.Cond @@ -41,15 +120,32 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper { closeListener: listener.NewCloseListener(), } } + func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error { addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn) if err != nil { return fmt.Errorf("add turn conn: %w", err) } + + headers, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, addr) + if err != nil { + return fmt.Errorf("create packet sender: %w", err) + } + + // Check if required raw connection is available + if !headers.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil { + return errIPv6ConnNotAvailable + } + if headers.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil { + return errIPv4ConnNotAvailable + } + p.remoteConn = remoteConn p.ctx, p.cancel = context.WithCancel(ctx) p.wgRelayedEndpointAddr = addr - return err + p.headers = headers + p.rawConn = p.selectRawConn(headers) + return nil } func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { @@ -68,7 +164,8 @@ func (p *ProxyWrapper) Work() { p.pausedCond.L.Lock() p.paused = false - p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr + p.headerCurrentUsed = p.headers + p.rawConn = p.selectRawConn(p.headerCurrentUsed) if !p.isStarted { p.isStarted = true @@ -95,10 +192,28 @@ func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) { log.Errorf("failed to start package redirection, endpoint is nil") return } + + header, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, endpoint) + if err != nil { + log.Errorf("failed to create packet headers: %s", err) + return + } + + // Check if required raw connection is available + if !header.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil { + log.Error(errIPv6ConnNotAvailable) + return + } + if header.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil { + log.Error(errIPv4ConnNotAvailable) + return + } + p.pausedCond.L.Lock() p.paused = false - p.wgEndpointCurrentUsedAddr = endpoint + p.headerCurrentUsed = header + p.rawConn = p.selectRawConn(header) p.pausedCond.Signal() p.pausedCond.L.Unlock() @@ -140,7 +255,7 @@ func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { p.pausedCond.Wait() } - err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr) + err = p.sendPkg(buf[:n], p.headerCurrentUsed) p.pausedCond.L.Unlock() if err != nil { @@ -166,3 +281,29 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err } return n, nil } + +func (p *ProxyWrapper) sendPkg(data []byte, header *PacketHeaders) error { + defer func() { + if err := header.layerBuffer.Clear(); err != nil { + log.Errorf("failed to clear layer buffer: %s", err) + } + }() + + payload := gopacket.Payload(data) + + if err := gopacket.SerializeLayers(header.layerBuffer, serializeOpts, header.ipH, header.udpH, payload); err != nil { + return fmt.Errorf("serialize layers: %w", err) + } + + if _, err := p.rawConn.WriteTo(header.layerBuffer.Bytes(), &net.IPAddr{IP: header.localHostAddr}); err != nil { + return fmt.Errorf("write to raw conn: %w", err) + } + return nil +} + +func (p *ProxyWrapper) selectRawConn(header *PacketHeaders) net.PacketConn { + if header.isIPv4 { + return p.wgeBPFProxy.rawConnIPv4 + } + return p.wgeBPFProxy.rawConnIPv6 +} From cead3f38ee4912fe0d3962bf2a1d14d927a7eed3 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Wed, 28 Jan 2026 18:24:12 +0100 Subject: [PATCH 03/27] [management] fix ephemeral peers being not removed (#5203) --- management/internals/shared/grpc/server.go | 4 +++- management/server/peer.go | 10 +++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 32049d044..219baaf6d 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -17,13 +17,14 @@ import ( pb "github.com/golang/protobuf/proto" // nolint "github.com/golang/protobuf/ptypes/timestamp" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" - "github.com/netbirdio/netbird/shared/management/client/common" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc/codes" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/shared/management/client/common" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/idp" @@ -304,6 +305,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S if err != nil { log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) + s.cancelPeerRoutines(ctx, accountID, peer) return err } diff --git a/management/server/peer.go b/management/server/peer.go index 80c74e209..ab72d3051 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -728,11 +728,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe return fmt.Errorf("failed adding peer to All group: %w", err) } - if temporary { - // we should track ephemeral peers to be able to clean them if the peer don't sync and be marked as connected - am.networkMapController.TrackEphemeralPeer(ctx, newPeer) - } - if addedByUser { err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin()) if err != nil { @@ -760,6 +755,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe return fmt.Errorf("failed to increment network serial: %w", err) } + if ephemeral { + // we should track ephemeral peers to be able to clean them if the peer doesn't sync and isn't marked as connected + am.networkMapController.TrackEphemeralPeer(ctx, newPeer) + } + log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID) return nil }) From 0169e4540fbbe04c88c8600ae0cd0033211f893c Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Thu, 29 Jan 2026 10:58:45 +0100 Subject: [PATCH 04/27] [management] fix skip of ephemeral peers on deletion (#5206) --- .../modules/peers/ephemeral/manager/ephemeral.go | 4 ++-- management/internals/modules/peers/manager.go | 12 +++++++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/management/internals/modules/peers/ephemeral/manager/ephemeral.go b/management/internals/modules/peers/ephemeral/manager/ephemeral.go index 15119045b..758f643d0 100644 --- a/management/internals/modules/peers/ephemeral/manager/ephemeral.go +++ b/management/internals/modules/peers/ephemeral/manager/ephemeral.go @@ -187,10 +187,10 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { } for accountID, peerIDs := range peerIDsPerAccount { - log.WithContext(ctx).Debugf("delete ephemeral peers for account: %s", accountID) + log.WithContext(ctx).Tracef("cleanup: deleting %d ephemeral peers for account %s", len(peerIDs), accountID) err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true) if err != nil { - log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err) + log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err) } } } diff --git a/management/internals/modules/peers/manager.go b/management/internals/modules/peers/manager.go index 1551689b4..7ac2e379f 100644 --- a/management/internals/modules/peers/manager.go +++ b/management/internals/modules/peers/manager.go @@ -108,10 +108,19 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) if err != nil { + if e, ok := status.FromError(err); ok && e.Type() == status.NotFound { + log.WithContext(ctx).Tracef("DeletePeers: peer %s not found, skipping", peerID) + return nil + } return err } if checkConnected && (peer.Status.Connected || peer.Status.LastSeen.After(time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)))) { + log.WithContext(ctx).Tracef("DeletePeers: peer %s skipped (connected=%t, lastSeen=%s, threshold=%s, ephemeral=%t)", + peerID, peer.Status.Connected, + peer.Status.LastSeen.Format(time.RFC3339), + time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)).Format(time.RFC3339), + peer.Ephemeral) return nil } @@ -150,7 +159,8 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs return nil }) if err != nil { - return err + log.WithContext(ctx).Errorf("DeletePeers: failed to delete peer %s: %v", peerID, err) + continue } if m.integratedPeerValidator != nil { From f74bc48d16c1c7df28ef47e640507804c9e80b55 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 29 Jan 2026 18:05:06 +0800 Subject: [PATCH 05/27] [Client] Stop NetBird on firewall init failure (#5208) --- client/internal/engine.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index f0693e82c..63ba1c9f2 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -573,9 +573,11 @@ func (e *Engine) createFirewall() error { var err error e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU) - if err != nil || e.firewall == nil { - log.Errorf("failed creating firewall manager: %s", err) - return nil + if err != nil { + return fmt.Errorf("create firewall manager: %w", err) + } + if e.firewall == nil { + return fmt.Errorf("create firewall manager: received nil manager") } if err := e.initFirewall(); err != nil { From 81c11df1034956b82a036f209e6a1b0af2f037f2 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 29 Jan 2026 20:51:44 +0800 Subject: [PATCH 06/27] [management] Streamline domain validation (#5211) --- client/internal/dns/handler_chain_test.go | 48 +++++ management/cmd/management.go | 7 +- .../internals/modules/zones/records/record.go | 8 +- management/internals/modules/zones/zone.go | 4 +- management/server/account.go | 11 +- management/server/nameserver.go | 25 +-- management/server/nameserver_test.go | 85 +++----- .../server/networks/resources/manager_test.go | 6 +- .../networks/resources/types/resource.go | 4 +- .../networks/resources/types/resource_test.go | 6 +- management/server/util/util.go | 10 - shared/management/domain/validate.go | 30 ++- shared/management/domain/validate_test.go | 200 +++++++++++++++++- 13 files changed, 339 insertions(+), 105 deletions(-) diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index 72c0004d5..fa9525069 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -112,6 +112,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) { matchSubdomains: false, shouldMatch: false, }, + { + name: "single letter TLD exact match", + handlerDomain: "example.x.", + queryDomain: "example.x.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: true, + }, + { + name: "single letter TLD subdomain match", + handlerDomain: "example.x.", + queryDomain: "sub.example.x.", + isWildcard: false, + matchSubdomains: true, + shouldMatch: true, + }, + { + name: "single letter TLD wildcard match", + handlerDomain: "*.example.x.", + queryDomain: "sub.example.x.", + isWildcard: true, + matchSubdomains: false, + shouldMatch: true, + }, + { + name: "two letter domain labels", + handlerDomain: "a.b.", + queryDomain: "a.b.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: true, + }, + { + name: "single character domain", + handlerDomain: "x.", + queryDomain: "x.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: true, + }, + { + name: "single character domain with subdomain match", + handlerDomain: "x.", + queryDomain: "sub.x.", + isWildcard: false, + matchSubdomains: true, + shouldMatch: true, + }, } for _, tt := range tests { diff --git a/management/cmd/management.go b/management/cmd/management.go index 7da04074b..511168823 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -16,13 +16,13 @@ import ( "strings" "syscall" - "github.com/miekg/dns" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/management/internals/server" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" + nbdomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util/crypt" ) @@ -78,9 +78,8 @@ var ( } } - _, valid := dns.IsDomainName(dnsDomain) - if !valid || len(dnsDomain) > 192 { - return fmt.Errorf("failed parsing the provided dns-domain. Valid status: %t, Length: %d", valid, len(dnsDomain)) + if !nbdomain.IsValidDomainNoWildcard(dnsDomain) { + return fmt.Errorf("invalid dns-domain: %s", dnsDomain) } return nil diff --git a/management/internals/modules/zones/records/record.go b/management/internals/modules/zones/records/record.go index e44de08f4..1488febb9 100644 --- a/management/internals/modules/zones/records/record.go +++ b/management/internals/modules/zones/records/record.go @@ -6,7 +6,7 @@ import ( "github.com/rs/xid" - "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/http/api" ) @@ -63,7 +63,7 @@ func (r *Record) Validate() error { return errors.New("record name is required") } - if !util.IsValidDomain(r.Name) { + if !domain.IsValidDomain(r.Name) { return errors.New("invalid record name format") } @@ -81,8 +81,8 @@ func (r *Record) Validate() error { return err } case RecordTypeCNAME: - if !util.IsValidDomain(r.Content) { - return errors.New("invalid CNAME record format") + if !domain.IsValidDomainNoWildcard(r.Content) { + return errors.New("invalid CNAME target format") } default: return errors.New("invalid record type, must be A, AAAA, or CNAME") diff --git a/management/internals/modules/zones/zone.go b/management/internals/modules/zones/zone.go index 27adac1ac..f5ebed26c 100644 --- a/management/internals/modules/zones/zone.go +++ b/management/internals/modules/zones/zone.go @@ -6,7 +6,7 @@ import ( "github.com/rs/xid" "github.com/netbirdio/netbird/management/internals/modules/zones/records" - "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/http/api" ) @@ -73,7 +73,7 @@ func (z *Zone) Validate() error { return errors.New("zone name exceeds maximum length of 255 characters") } - if !util.IsValidDomain(z.Domain) { + if !domain.IsValidDomainNoWildcard(z.Domain) { return errors.New("invalid zone domain format") } diff --git a/management/server/account.go b/management/server/account.go index d453b87c3..ba5f0cffa 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -26,6 +26,7 @@ import ( "golang.org/x/exp/maps" nbdns "github.com/netbirdio/netbird/dns" + nbdomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" @@ -231,7 +232,7 @@ func BuildManager( // enable single account mode only if configured by user and number of existing accounts is not grater than 1 am.singleAccountMode = singleAccountModeDomain != "" && accountsCounter <= 1 if am.singleAccountMode { - if !isDomainValid(singleAccountModeDomain) { + if !nbdomain.IsValidDomainNoWildcard(singleAccountModeDomain) { return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain) } am.singleAccountModeDomain = singleAccountModeDomain @@ -402,7 +403,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra return status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") } - if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) { + if newSettings.DNSDomain != "" && !nbdomain.IsValidDomainNoWildcard(newSettings.DNSDomain) { return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) } @@ -1691,10 +1692,12 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st return nil } -var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`) +// isDomainValid validates public/IDP domains using stricter rules than internal DNS domains. +// Requires at least 2-char alphabetic TLD and no single-label domains. +var publicDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`) func isDomainValid(domain string) bool { - return invalidDomainRegexp.MatchString(domain) + return publicDomainRegexp.MatchString(domain) } func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string, peerIDs []string) { diff --git a/management/server/nameserver.go b/management/server/nameserver.go index a3eb4ae2e..3d8c78912 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -3,10 +3,10 @@ package server import ( "context" "errors" - "regexp" + "fmt" + "strings" "unicode/utf8" - "github.com/miekg/dns" "github.com/rs/xid" nbdns "github.com/netbirdio/netbird/dns" @@ -15,11 +15,10 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + nbdomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/status" ) -const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$` - var errInvalidDomainName = errors.New("invalid domain name") // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs @@ -305,16 +304,18 @@ func validateGroups(list []string, groups map[string]*types.Group) error { return nil } -var domainMatcher = regexp.MustCompile(domainPattern) - -func validateDomain(domain string) error { - if !domainMatcher.MatchString(domain) { - return errors.New("domain should consists of only letters, numbers, and hyphens with no leading, trailing hyphens, or spaces") +// validateDomain validates a nameserver match domain. +// Converts unicode to punycode. Wildcards are not allowed for nameservers. +func validateDomain(d string) error { + if strings.HasPrefix(d, "*.") { + return errors.New("wildcards not allowed") } - _, valid := dns.IsDomainName(domain) - if !valid { - return errInvalidDomainName + // Nameservers allow trailing dot (FQDN format) + toValidate := strings.TrimSuffix(d, ".") + + if _, err := nbdomain.ValidateDomains([]string{toValidate}); err != nil { + return fmt.Errorf("%w: %w", errInvalidDomainName, err) } return nil diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 0d781e0d4..90b4b9687 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -901,82 +901,53 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, return account, nil } +// TestValidateDomain tests nameserver-specific domain validation. +// Core domain validation is tested in shared/management/domain/validate_test.go. +// This test only covers nameserver-specific behavior: wildcard rejection and unicode support. func TestValidateDomain(t *testing.T) { testCases := []struct { name string domain string errFunc require.ErrorAssertionFunc }{ + // Nameserver-specific: wildcards not allowed { - name: "Valid domain name with multiple labels", - domain: "123.example.com", + name: "Wildcard prefix rejected", + domain: "*.example.com", + errFunc: require.Error, + }, + { + name: "Wildcard in middle rejected", + domain: "a.*.example.com", + errFunc: require.Error, + }, + // Nameserver-specific: unicode converted to punycode + { + name: "Unicode domain converted to punycode", + domain: "münchen.de", errFunc: require.NoError, }, { - name: "Valid domain name with hyphen", - domain: "test-example.com", + name: "Unicode domain all labels", + domain: "中国.中国", + errFunc: require.NoError, + }, + // Basic validation still works (delegates to shared validation) + { + name: "Valid multi-label domain", + domain: "example.com", errFunc: require.NoError, }, { - name: "Valid domain name with only one label", - domain: "example", + name: "Valid single label", + domain: "internal", errFunc: require.NoError, }, { - name: "Valid domain name with trailing dot", - domain: "example.", - errFunc: require.NoError, - }, - { - name: "Invalid wildcard domain name", - domain: "*.example", - errFunc: require.Error, - }, - { - name: "Invalid domain name with leading dot", - domain: ".com", - errFunc: require.Error, - }, - { - name: "Invalid domain name with dot only", - domain: ".", - errFunc: require.Error, - }, - { - name: "Invalid domain name with double hyphen", - domain: "test--example.com", - errFunc: require.Error, - }, - { - name: "Invalid domain name with a label exceeding 63 characters", - domain: "dnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdns.com", - errFunc: require.Error, - }, - { - name: "Invalid domain name starting with a hyphen", + name: "Invalid leading hyphen", domain: "-example.com", errFunc: require.Error, }, - { - name: "Invalid domain name ending with a hyphen", - domain: "example.com-", - errFunc: require.Error, - }, - { - name: "Invalid domain with unicode", - domain: "example?,.com", - errFunc: require.Error, - }, - { - name: "Invalid domain with space before top-level domain", - domain: "space .example.com", - errFunc: require.Error, - }, - { - name: "Invalid domain with trailing space", - domain: "example.com ", - errFunc: require.Error, - }, } for _, testCase := range testCases { diff --git a/management/server/networks/resources/manager_test.go b/management/server/networks/resources/manager_test.go index e2dea2c6b..29b0af2cc 100644 --- a/management/server/networks/resources/manager_test.go +++ b/management/server/networks/resources/manager_test.go @@ -203,7 +203,7 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) { NetworkID: "testNetworkId", Name: "testResourceId", Description: "description", - Address: "invalid-address", + Address: "-invalid", } store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) @@ -227,9 +227,9 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) { resource := &types.NetworkResource{ AccountID: "testAccountId", NetworkID: "testNetworkId", - Name: "testResourceId", + Name: "used-name", Description: "description", - Address: "invalid-address", + Address: "example.com", } store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go index 6b8cf9412..1fa908393 100644 --- a/management/server/networks/resources/types/resource.go +++ b/management/server/networks/resources/types/resource.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "net/netip" - "regexp" "github.com/rs/xid" @@ -166,8 +165,7 @@ func GetResourceType(address string) (NetworkResourceType, string, netip.Prefix, return Host, "", netip.PrefixFrom(ip, ip.BitLen()), nil } - domainRegex := regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`) - if domainRegex.MatchString(address) { + if _, err := nbDomain.ValidateDomains([]string{address}); err == nil { return Domain, address, netip.Prefix{}, nil } diff --git a/management/server/networks/resources/types/resource_test.go b/management/server/networks/resources/types/resource_test.go index 02e802300..a842b0a28 100644 --- a/management/server/networks/resources/types/resource_test.go +++ b/management/server/networks/resources/types/resource_test.go @@ -23,10 +23,12 @@ func TestGetResourceType(t *testing.T) { {"example.com", Domain, false, "example.com", netip.Prefix{}}, {"*.example.com", Domain, false, "*.example.com", netip.Prefix{}}, {"sub.example.com", Domain, false, "sub.example.com", netip.Prefix{}}, + {"example.x", Domain, false, "example.x", netip.Prefix{}}, + {"internal", Domain, false, "internal", netip.Prefix{}}, // Invalid inputs - {"invalid", "", true, "", netip.Prefix{}}, {"1.1.1.1/abc", "", true, "", netip.Prefix{}}, - {"1234", "", true, "", netip.Prefix{}}, + {"-invalid.com", "", true, "", netip.Prefix{}}, + {"", "", true, "", netip.Prefix{}}, } for _, tt := range tests { diff --git a/management/server/util/util.go b/management/server/util/util.go index eea6a72b0..ce9759864 100644 --- a/management/server/util/util.go +++ b/management/server/util/util.go @@ -1,9 +1,5 @@ package util -import "regexp" - -var domainRegex = regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`) - // Difference returns the elements in `a` that aren't in `b`. func Difference(a, b []string) []string { mb := make(map[string]struct{}, len(b)) @@ -55,9 +51,3 @@ func contains[T comparableObject[T]](slice []T, element T) bool { return false } -func IsValidDomain(domain string) bool { - if domain == "" { - return false - } - return domainRegex.MatchString(domain) -} diff --git a/shared/management/domain/validate.go b/shared/management/domain/validate.go index bf2af7116..1858b5d55 100644 --- a/shared/management/domain/validate.go +++ b/shared/management/domain/validate.go @@ -10,7 +10,30 @@ const maxDomains = 32 var domainRegex = regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`) -// ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList. +// IsValidDomain checks if a single domain string is valid. +// Does not convert unicode to punycode - domain must already be ASCII/punycode. +// Allows wildcard prefix (*.example.com). +func IsValidDomain(domain string) bool { + if domain == "" { + return false + } + return domainRegex.MatchString(strings.ToLower(domain)) +} + +// IsValidDomainNoWildcard checks if a single domain string is valid without wildcard prefix. +// Use for zone domains and CNAME targets where wildcards are not allowed. +func IsValidDomainNoWildcard(domain string) bool { + if domain == "" { + return false + } + if strings.HasPrefix(domain, "*.") { + return false + } + return domainRegex.MatchString(strings.ToLower(domain)) +} + +// ValidateDomains validates domains and converts unicode to punycode. +// Allows wildcard prefix (*.example.com). Maximum 32 domains. func ValidateDomains(domains []string) (List, error) { if len(domains) == 0 { return nil, fmt.Errorf("domains list is empty") @@ -37,7 +60,10 @@ func ValidateDomains(domains []string) (List, error) { return domainList, nil } -// ValidateDomainsList checks if each domain in the list is valid +// ValidateDomainsList validates domains without punycode conversion. +// Use this for domains that must already be in ASCII/punycode format (e.g., extra DNS labels). +// Unlike ValidateDomains, this does not convert unicode to punycode - unicode domains will fail. +// Allows wildcard prefix (*.example.com). Maximum 32 domains. func ValidateDomainsList(domains []string) error { if len(domains) == 0 { return nil diff --git a/shared/management/domain/validate_test.go b/shared/management/domain/validate_test.go index 30efcd9a9..9dbcd8ac8 100644 --- a/shared/management/domain/validate_test.go +++ b/shared/management/domain/validate_test.go @@ -2,12 +2,16 @@ package domain import ( "fmt" + "strings" "testing" "github.com/stretchr/testify/assert" ) func TestValidateDomains(t *testing.T) { + label63 := strings.Repeat("a", 63) + label64 := strings.Repeat("a", 64) + tests := []struct { name string domains []string @@ -26,6 +30,48 @@ func TestValidateDomains(t *testing.T) { expected: List{"sub.ex-ample.com"}, wantErr: false, }, + { + name: "Valid uppercase domain normalized to lowercase", + domains: []string{"EXAMPLE.COM"}, + expected: List{"example.com"}, + wantErr: false, + }, + { + name: "Valid mixed case domain", + domains: []string{"ExAmPlE.CoM"}, + expected: List{"example.com"}, + wantErr: false, + }, + { + name: "Single letter TLD", + domains: []string{"example.x"}, + expected: List{"example.x"}, + wantErr: false, + }, + { + name: "Two letter domain labels", + domains: []string{"a.b"}, + expected: List{"a.b"}, + wantErr: false, + }, + { + name: "Single character domain", + domains: []string{"x"}, + expected: List{"x"}, + wantErr: false, + }, + { + name: "Wildcard with single letter TLD", + domains: []string{"*.x"}, + expected: List{"*.x"}, + wantErr: false, + }, + { + name: "Multi-level with single letter labels", + domains: []string{"a.b.c"}, + expected: List{"a.b.c"}, + wantErr: false, + }, { name: "Valid Unicode domain", domains: []string{"münchen.de"}, @@ -45,17 +91,92 @@ func TestValidateDomains(t *testing.T) { wantErr: false, }, { - name: "Invalid domain format", + name: "Valid domain starting with digit", + domains: []string{"123.example.com"}, + expected: List{"123.example.com"}, + wantErr: false, + }, + // Numeric TLDs are allowed for internal/private DNS use cases. + // While ICANN doesn't issue all-numeric gTLDs, the DNS protocol permits them + // and resolvers like systemd-resolved handle them correctly. + { + name: "Numeric TLD allowed", + domains: []string{"example.123"}, + expected: List{"example.123"}, + wantErr: false, + }, + { + name: "Single digit TLD allowed", + domains: []string{"example.1"}, + expected: List{"example.1"}, + wantErr: false, + }, + { + name: "All numeric labels allowed", + domains: []string{"123.456"}, + expected: List{"123.456"}, + wantErr: false, + }, + { + name: "Single numeric label allowed", + domains: []string{"123"}, + expected: List{"123"}, + wantErr: false, + }, + { + name: "Valid domain with double hyphen", + domains: []string{"test--example.com"}, + expected: List{"test--example.com"}, + wantErr: false, + }, + { + name: "Invalid leading hyphen", domains: []string{"-example.com"}, expected: nil, wantErr: true, }, { - name: "Invalid domain format 2", + name: "Invalid trailing hyphen", domains: []string{"example.com-"}, expected: nil, wantErr: true, }, + { + name: "Invalid leading dot", + domains: []string{".com"}, + expected: nil, + wantErr: true, + }, + { + name: "Invalid dot only", + domains: []string{"."}, + expected: nil, + wantErr: true, + }, + { + name: "Invalid double dot", + domains: []string{"example..com"}, + expected: nil, + wantErr: true, + }, + { + name: "Invalid special characters", + domains: []string{"example?,.com"}, + expected: nil, + wantErr: true, + }, + { + name: "Invalid space in domain", + domains: []string{"space .example.com"}, + expected: nil, + wantErr: true, + }, + { + name: "Invalid trailing space", + domains: []string{"example.com "}, + expected: nil, + wantErr: true, + }, { name: "Multiple domains valid and invalid", domains: []string{"google.com", "invalid,nbdomain.com", "münchen.de"}, @@ -86,6 +207,30 @@ func TestValidateDomains(t *testing.T) { expected: nil, wantErr: true, }, + { + name: "Valid 63 char label (max)", + domains: []string{label63 + ".com"}, + expected: List{Domain(label63 + ".com")}, + wantErr: false, + }, + { + name: "Invalid 64 char label (exceeds max)", + domains: []string{label64 + ".com"}, + expected: nil, + wantErr: true, + }, + { + name: "Valid 253 char domain (max)", + domains: []string{strings.Repeat("a.", 126) + "a"}, + expected: List{Domain(strings.Repeat("a.", 126) + "a")}, + wantErr: false, + }, + { + name: "Invalid 254+ char domain (exceeds max)", + domains: []string{strings.Repeat("ab.", 85)}, + expected: nil, + wantErr: true, + }, } for _, tt := range tests { @@ -118,6 +263,57 @@ func TestValidateDomainsList(t *testing.T) { domains: []string{"sub.ex-ample.com"}, wantErr: false, }, + { + name: "Uppercase domain accepted", + domains: []string{"EXAMPLE.COM"}, + wantErr: false, + }, + { + name: "Single letter TLD", + domains: []string{"example.x"}, + wantErr: false, + }, + { + name: "Two letter domain labels", + domains: []string{"a.b"}, + wantErr: false, + }, + { + name: "Single character domain", + domains: []string{"x"}, + wantErr: false, + }, + { + name: "Wildcard with single letter TLD", + domains: []string{"*.x"}, + wantErr: false, + }, + { + name: "Multi-level with single letter labels", + domains: []string{"a.b.c"}, + wantErr: false, + }, + // Numeric TLDs are allowed for internal/private DNS use cases. + { + name: "Numeric TLD allowed", + domains: []string{"example.123"}, + wantErr: false, + }, + { + name: "Single digit TLD allowed", + domains: []string{"example.1"}, + wantErr: false, + }, + { + name: "All numeric labels allowed", + domains: []string{"123.456"}, + wantErr: false, + }, + { + name: "Single numeric label allowed", + domains: []string{"123"}, + wantErr: false, + }, { name: "Underscores in labels", domains: []string{"_jabber._tcp.gmail.com"}, From 5333e55a8134b7f0feb58777e92c32509abcd037 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 29 Jan 2026 16:58:10 +0100 Subject: [PATCH 07/27] Fix WG watcher missing initial handshake (#5213) Start the WireGuard watcher before configuring the WG endpoint to ensure it captures the initial handshake timestamp. Previously, the watcher was started after endpoint configuration, causing it to miss the handshake that occurred during setup. --- client/internal/peer/conn.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 39133a6d3..eb455431d 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -390,6 +390,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn } conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String()) + conn.enableWgWatcherIfNeeded() + presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey) if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil { conn.handleConfigurationFailure(err, wgProxy) @@ -402,8 +404,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn conn.wgProxyRelay.RedirectAs(ep) } - conn.enableWgWatcherIfNeeded() - conn.currentConnPriority = priority conn.statusICE.SetConnected() conn.updateIceState(iceConnInfo) @@ -501,6 +501,9 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { wgProxy.Work() presharedKey := conn.presharedKey(rci.rosenpassPubKey) + + conn.enableWgWatcherIfNeeded() + if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil { if err := wgProxy.CloseConn(); err != nil { conn.Log.Warnf("Failed to close relay connection: %v", err) @@ -509,8 +512,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { return } - conn.enableWgWatcherIfNeeded() - wgConfigWorkaround() conn.rosenpassRemoteKey = rci.rosenpassPubKey conn.currentConnPriority = conntype.Relay From 101c813e9846bc30ce6c00766e6befdaf7dc965c Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 30 Jan 2026 17:42:14 +0800 Subject: [PATCH 08/27] [client] Add macOS default resolvers as fallback (#5201) --- client/internal/dns/host_darwin.go | 25 +++- client/internal/dns/host_darwin_test.go | 166 ++++++++++++++++++++++++ client/internal/dns/server.go | 3 +- client/internal/dns/test/mock.go | 6 + 4 files changed, 196 insertions(+), 4 deletions(-) diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index 71badf0d4..af84c8a85 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -9,8 +9,10 @@ import ( "io" "net/netip" "os/exec" + "slices" "strconv" "strings" + "sync" log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" @@ -38,6 +40,9 @@ const ( type systemConfigurator struct { createdKeys map[string]struct{} systemDNSSettings SystemDNSSettings + + mu sync.RWMutex + origNameservers []netip.Addr } func newHostManager() (*systemConfigurator, error) { @@ -218,6 +223,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { } var dnsSettings SystemDNSSettings + var serverAddresses []netip.Addr inSearchDomainsArray := false inServerAddressesArray := false @@ -244,9 +250,12 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { dnsSettings.Domains = append(dnsSettings.Domains, searchDomain) } else if inServerAddressesArray { address := strings.Split(line, " : ")[1] - if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() { - dnsSettings.ServerIP = ip.Unmap() - inServerAddressesArray = false // Stop reading after finding the first IPv4 address + if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() { + ip = ip.Unmap() + serverAddresses = append(serverAddresses, ip) + if !dnsSettings.ServerIP.IsValid() && ip.Is4() { + dnsSettings.ServerIP = ip + } } } } @@ -258,9 +267,19 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { // default to 53 port dnsSettings.ServerPort = DefaultPort + s.mu.Lock() + s.origNameservers = serverAddresses + s.mu.Unlock() + return dnsSettings, nil } +func (s *systemConfigurator) getOriginalNameservers() []netip.Addr { + s.mu.RLock() + defer s.mu.RUnlock() + return slices.Clone(s.origNameservers) +} + func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error { err := s.addDNSState(key, domains, ip, port, true) if err != nil { diff --git a/client/internal/dns/host_darwin_test.go b/client/internal/dns/host_darwin_test.go index c4efd17b0..28915de65 100644 --- a/client/internal/dns/host_darwin_test.go +++ b/client/internal/dns/host_darwin_test.go @@ -109,3 +109,169 @@ func removeTestDNSKey(key string) error { _, err := cmd.CombinedOutput() return err } + +func TestGetOriginalNameservers(t *testing.T) { + configurator := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + origNameservers: []netip.Addr{ + netip.MustParseAddr("8.8.8.8"), + netip.MustParseAddr("1.1.1.1"), + }, + } + + servers := configurator.getOriginalNameservers() + assert.Len(t, servers, 2) + assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0]) + assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1]) +} + +func TestGetOriginalNameserversFromSystem(t *testing.T) { + configurator := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + } + + _, err := configurator.getSystemDNSSettings() + require.NoError(t, err) + + servers := configurator.getOriginalNameservers() + + require.NotEmpty(t, servers, "expected at least one DNS server from system configuration") + + for _, server := range servers { + assert.True(t, server.IsValid(), "server address should be valid") + assert.False(t, server.IsUnspecified(), "server address should not be unspecified") + } + + t.Logf("found %d original nameservers: %v", len(servers), servers) +} + +func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) { + t.Helper() + + tmpDir := t.TempDir() + stateFile := filepath.Join(tmpDir, "state.json") + sm := statemanager.New(stateFile) + sm.RegisterState(&ShutdownState{}) + sm.Start() + + configurator := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + } + + searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) + matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) + + cleanup := func() { + _ = sm.Stop(context.Background()) + for _, key := range []string{searchKey, matchKey, localKey} { + _ = removeTestDNSKey(key) + } + } + + return configurator, sm, cleanup +} + +func TestOriginalNameserversNoTransition(t *testing.T) { + netbirdIP := netip.MustParseAddr("100.64.0.1") + + testCases := []struct { + name string + routeAll bool + }{ + {"routeall_false", false}, + {"routeall_true", true}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + configurator, sm, cleanup := setupTestConfigurator(t) + defer cleanup() + + _, err := configurator.getSystemDNSSettings() + require.NoError(t, err) + initialServers := configurator.getOriginalNameservers() + t.Logf("Initial servers: %v", initialServers) + require.NotEmpty(t, initialServers) + + for _, srv := range initialServers { + require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP") + } + + config := HostDNSConfig{ + ServerIP: netbirdIP, + ServerPort: 53, + RouteAll: tc.routeAll, + Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}}, + } + + for i := 1; i <= 2; i++ { + err = configurator.applyDNSConfig(config, sm) + require.NoError(t, err) + + servers := configurator.getOriginalNameservers() + t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers) + assert.Equal(t, initialServers, servers) + } + }) + } +} + +func TestOriginalNameserversRouteAllTransition(t *testing.T) { + netbirdIP := netip.MustParseAddr("100.64.0.1") + + testCases := []struct { + name string + initialRoute bool + }{ + {"start_with_routeall_false", false}, + {"start_with_routeall_true", true}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + configurator, sm, cleanup := setupTestConfigurator(t) + defer cleanup() + + _, err := configurator.getSystemDNSSettings() + require.NoError(t, err) + initialServers := configurator.getOriginalNameservers() + t.Logf("Initial servers: %v", initialServers) + require.NotEmpty(t, initialServers) + + config := HostDNSConfig{ + ServerIP: netbirdIP, + ServerPort: 53, + RouteAll: tc.initialRoute, + Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}}, + } + + // First apply + err = configurator.applyDNSConfig(config, sm) + require.NoError(t, err) + servers := configurator.getOriginalNameservers() + t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers) + assert.Equal(t, initialServers, servers) + + // Toggle RouteAll + config.RouteAll = !tc.initialRoute + err = configurator.applyDNSConfig(config, sm) + require.NoError(t, err) + servers = configurator.getOriginalNameservers() + t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers) + assert.Equal(t, initialServers, servers) + + // Toggle back + config.RouteAll = tc.initialRoute + err = configurator.applyDNSConfig(config, sm) + require.NoError(t, err) + servers = configurator.getOriginalNameservers() + t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers) + assert.Equal(t, initialServers, servers) + + for _, srv := range servers { + assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP") + } + }) + } +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 1ce7bf1c6..4d4fcc06e 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -615,7 +615,7 @@ func (s *DefaultServer) applyHostConfig() { s.registerFallback(config) } -// registerFallback registers original nameservers as low-priority fallback handlers +// registerFallback registers original nameservers as low-priority fallback handlers. func (s *DefaultServer) registerFallback(config HostDNSConfig) { hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS) if !ok { @@ -624,6 +624,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) { originalNameservers := hostMgrWithNS.getOriginalNameservers() if len(originalNameservers) == 0 { + s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback) return } diff --git a/client/internal/dns/test/mock.go b/client/internal/dns/test/mock.go index 1db452805..8d16689bf 100644 --- a/client/internal/dns/test/mock.go +++ b/client/internal/dns/test/mock.go @@ -8,15 +8,21 @@ import ( type MockResponseWriter struct { WriteMsgFunc func(m *dns.Msg) error + lastResponse *dns.Msg } func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error { + rw.lastResponse = m if rw.WriteMsgFunc != nil { return rw.WriteMsgFunc(m) } return nil } +func (rw *MockResponseWriter) GetLastResponse() *dns.Msg { + return rw.lastResponse +} + func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil } func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil } func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil } From 0c990ab6623530b2ad6925a8dce04bdcc2455baa Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 30 Jan 2026 17:42:39 +0800 Subject: [PATCH 09/27] [client] Add block inbound option to the embed client (#5215) --- client/embed/embed.go | 3 +++ client/internal/networkmonitor/monitor.go | 6 ++++++ client/internal/wg_iface_monitor.go | 7 +++++++ 3 files changed, 16 insertions(+) diff --git a/client/embed/embed.go b/client/embed/embed.go index e266aae28..e73f37e35 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -69,6 +69,8 @@ type Options struct { StatePath string // DisableClientRoutes disables the client routes DisableClientRoutes bool + // BlockInbound blocks all inbound connections from peers + BlockInbound bool } // validateCredentials checks that exactly one credential type is provided @@ -137,6 +139,7 @@ func New(opts Options) (*Client, error) { PreSharedKey: &opts.PreSharedKey, DisableServerRoutes: &t, DisableClientRoutes: &opts.DisableClientRoutes, + BlockInbound: &opts.BlockInbound, } if opts.ConfigPath != "" { config, err = profilemanager.UpdateOrCreateConfig(input) diff --git a/client/internal/networkmonitor/monitor.go b/client/internal/networkmonitor/monitor.go index 6d019258d..6dd81f68c 100644 --- a/client/internal/networkmonitor/monitor.go +++ b/client/internal/networkmonitor/monitor.go @@ -14,6 +14,7 @@ import ( "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) @@ -37,6 +38,11 @@ func New() *NetworkMonitor { // Listen begins monitoring network changes. When a change is detected, this function will return without error. func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) { + if netstack.IsEnabled() { + log.Debugf("Network monitor: skipping in netstack mode") + return nil + } + nw.mu.Lock() if nw.cancel != nil { nw.mu.Unlock() diff --git a/client/internal/wg_iface_monitor.go b/client/internal/wg_iface_monitor.go index 78d70c15b..a870c1145 100644 --- a/client/internal/wg_iface_monitor.go +++ b/client/internal/wg_iface_monitor.go @@ -9,6 +9,8 @@ import ( "time" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/netstack" ) // WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine @@ -35,6 +37,11 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes return false, errors.New("not supported on mobile platforms") } + if netstack.IsEnabled() { + log.Debugf("Interface monitor: skipped in netstack mode") + return false, nil + } + if ifaceName == "" { log.Debugf("Interface monitor: empty interface name, skipping monitor") return false, errors.New("empty interface name") From 3a0cf230a179e767434018814a0dee5be2c526dd Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Sun, 1 Feb 2026 14:26:22 +0100 Subject: [PATCH 10/27] Disable local users for a smooth single-idp mode (#5226) Add LocalAuthDisabled option to embedded IdP configuration This adds the ability to disable local (email/password) authentication when using the embedded Dex identity provider. When disabled, users can only authenticate via external identity providers (Google, OIDC, etc.). This simplifies user login when there is only one external IdP configured. The login page will redirect directly to the IdP login page. Key changes: Added LocalAuthDisabled field to EmbeddedIdPConfig Added methods to check and toggle local auth: IsLocalAuthEnabled, HasNonLocalConnectors, DisableLocalAuth, EnableLocalAuth Validation prevents disabling local auth if no external connectors are configured Existing local users are preserved when disabled and can login again when re-enabled Operations are idempotent (disabling already disabled is a no-op) --- idp/dex/connector.go | 54 ++++ management/internals/server/modules.go | 9 +- management/server/account.go | 15 +- management/server/http/handler.go | 4 +- .../handlers/accounts/accounts_handler.go | 25 +- .../accounts/accounts_handler_test.go | 7 +- .../handlers/instance/instance_handler.go | 2 +- .../handlers/users/invites_handler_test.go | 17 ++ .../testing/testing_tools/channel/channel.go | 2 +- management/server/idp/embedded.go | 50 ++++ management/server/idp/embedded_test.go | 231 ++++++++++++++++++ management/server/instance/manager.go | 11 +- management/server/settings/manager.go | 15 +- management/server/types/settings.go | 10 + management/server/user.go | 12 + shared/management/http/api/openapi.yml | 5 + shared/management/http/api/types.gen.go | 3 + 17 files changed, 450 insertions(+), 22 deletions(-) diff --git a/idp/dex/connector.go b/idp/dex/connector.go index cad682141..ba2bb1f00 100644 --- a/idp/dex/connector.go +++ b/idp/dex/connector.go @@ -327,6 +327,60 @@ func ensureLocalConnector(ctx context.Context, stor storage.Storage) error { return nil } +// HasNonLocalConnectors checks if there are any connectors other than the local connector. +func (p *Provider) HasNonLocalConnectors(ctx context.Context) (bool, error) { + connectors, err := p.storage.ListConnectors(ctx) + if err != nil { + return false, fmt.Errorf("failed to list connectors: %w", err) + } + + p.logger.Info("checking for non-local connectors", "total_connectors", len(connectors)) + for _, conn := range connectors { + p.logger.Info("found connector in storage", "id", conn.ID, "type", conn.Type, "name", conn.Name) + if conn.ID != "local" || conn.Type != "local" { + p.logger.Info("found non-local connector", "id", conn.ID) + return true, nil + } + } + p.logger.Info("no non-local connectors found") + return false, nil +} + +// DisableLocalAuth removes the local (password) connector. +// Returns an error if no other connectors are configured. +func (p *Provider) DisableLocalAuth(ctx context.Context) error { + hasOthers, err := p.HasNonLocalConnectors(ctx) + if err != nil { + return err + } + if !hasOthers { + return fmt.Errorf("cannot disable local authentication: no other identity providers configured") + } + + // Check if local connector exists + _, err = p.storage.GetConnector(ctx, "local") + if errors.Is(err, storage.ErrNotFound) { + // Already disabled + return nil + } + if err != nil { + return fmt.Errorf("failed to check local connector: %w", err) + } + + // Delete the local connector + if err := p.storage.DeleteConnector(ctx, "local"); err != nil { + return fmt.Errorf("failed to delete local connector: %w", err) + } + + p.logger.Info("local authentication disabled") + return nil +} + +// EnableLocalAuth creates the local (password) connector if it doesn't exist. +func (p *Provider) EnableLocalAuth(ctx context.Context) error { + return ensureLocalConnector(ctx, p.storage) +} + // ensureStaticConnectors creates or updates static connectors in storage func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error { for _, conn := range connectors { diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index b51e2ebb2..31badf9d0 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -69,7 +69,14 @@ func (s *BaseServer) UsersManager() users.Manager { func (s *BaseServer) SettingsManager() settings.Manager { return Create(s, func() settings.Manager { extraSettingsManager := integrations.NewManager(s.EventStore()) - return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager()) + + idpConfig := settings.IdpConfig{} + if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled { + idpConfig.EmbeddedIdpEnabled = true + idpConfig.LocalAuthDisabled = s.Config.EmbeddedIdP.LocalAuthDisabled + } + + return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager(), idpConfig) }) } diff --git a/management/server/account.go b/management/server/account.go index ba5f0cffa..8f9dad031 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -26,7 +26,6 @@ import ( "golang.org/x/exp/maps" nbdns "github.com/netbirdio/netbird/dns" - nbdomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" @@ -49,6 +48,7 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + nbdomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/status" ) @@ -795,6 +795,19 @@ func IsEmbeddedIdp(i idp.Manager) bool { return ok } +// IsLocalAuthDisabled checks if local (email/password) authentication is disabled. +// Returns true only when using embedded IDP with local auth disabled in config. +func IsLocalAuthDisabled(ctx context.Context, i idp.Manager) bool { + if isNil(i) { + return false + } + embeddedIdp, ok := i.(*idp.EmbeddedIdPManager) + if !ok { + return false + } + return embeddedIdp.IsLocalAuthDisabled() +} + // addAccountIDToIDPAppMeta update user's app metadata in idp manager func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) { diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 32a97ff44..79431a0a3 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -129,14 +129,14 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks return nil, fmt.Errorf("register integrations endpoints: %w", err) } - // Check if embedded IdP is enabled + // Check if embedded IdP is enabled for instance manager embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager) instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP) if err != nil { return nil, fmt.Errorf("failed to create instance manager: %w", err) } - accounts.AddEndpoints(accountManager, settingsManager, embeddedIdpEnabled, router) + accounts.AddEndpoints(accountManager, settingsManager, router) peers.AddEndpoints(accountManager, router, networkMapController) users.AddEndpoints(accountManager, router) users.AddInvitesEndpoints(accountManager, router) diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index de778d59a..122c061ce 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -36,24 +36,22 @@ const ( // handler is a handler that handles the server.Account HTTP endpoints type handler struct { - accountManager account.Manager - settingsManager settings.Manager - embeddedIdpEnabled bool + accountManager account.Manager + settingsManager settings.Manager } -func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool, router *mux.Router) { - accountsHandler := newHandler(accountManager, settingsManager, embeddedIdpEnabled) +func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) { + accountsHandler := newHandler(accountManager, settingsManager) router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS") router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS") router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS") } // newHandler creates a new handler HTTP handler -func newHandler(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool) *handler { +func newHandler(accountManager account.Manager, settingsManager settings.Manager) *handler { return &handler{ - accountManager: accountManager, - settingsManager: settingsManager, - embeddedIdpEnabled: embeddedIdpEnabled, + accountManager: accountManager, + settingsManager: settingsManager, } } @@ -165,7 +163,7 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { return } - resp := toAccountResponse(accountID, settings, meta, onboarding, h.embeddedIdpEnabled) + resp := toAccountResponse(accountID, settings, meta, onboarding) util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } @@ -292,7 +290,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { return } - resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding, h.embeddedIdpEnabled) + resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding) util.WriteJSONObject(r.Context(), w, &resp) } @@ -321,7 +319,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding, embeddedIdpEnabled bool) *api.Account { +func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding) *api.Account { jwtAllowGroups := settings.JWTAllowGroups if jwtAllowGroups == nil { jwtAllowGroups = []string{} @@ -341,7 +339,8 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A LazyConnectionEnabled: &settings.LazyConnectionEnabled, DnsDomain: &settings.DNSDomain, AutoUpdateVersion: &settings.AutoUpdateVersion, - EmbeddedIdpEnabled: &embeddedIdpEnabled, + EmbeddedIdpEnabled: &settings.EmbeddedIdpEnabled, + LocalAuthDisabled: &settings.LocalAuthDisabled, } if settings.NetworkRange.IsValid() { diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index e455372c8..6cbd5908d 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -33,7 +33,6 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler { AnyTimes() return &handler{ - embeddedIdpEnabled: false, accountManager: &mock_server.MockAccountManager{ GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) { return account.Settings, nil @@ -124,6 +123,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), + LocalAuthDisabled: br(false), }, expectedArray: true, expectedID: accountID, @@ -148,6 +148,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), + LocalAuthDisabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -172,6 +173,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateVersion: sr("latest"), EmbeddedIdpEnabled: br(false), + LocalAuthDisabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -196,6 +198,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), + LocalAuthDisabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -220,6 +223,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), + LocalAuthDisabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -244,6 +248,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), + LocalAuthDisabled: br(false), }, expectedArray: false, expectedID: accountID, diff --git a/management/server/http/handlers/instance/instance_handler.go b/management/server/http/handlers/instance/instance_handler.go index 5d8baaf8d..cd9fae6b8 100644 --- a/management/server/http/handlers/instance/instance_handler.go +++ b/management/server/http/handlers/instance/instance_handler.go @@ -46,7 +46,7 @@ func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) { util.WriteErrorResponse("failed to check instance status", http.StatusInternalServerError, w) return } - + log.WithContext(r.Context()).Infof("instance setup status: %v", setupRequired) util.WriteJSONObject(r.Context(), w, api.InstanceStatus{ SetupRequired: setupRequired, }) diff --git a/management/server/http/handlers/users/invites_handler_test.go b/management/server/http/handlers/users/invites_handler_test.go index 80826b9d4..529ea24d6 100644 --- a/management/server/http/handlers/users/invites_handler_test.go +++ b/management/server/http/handlers/users/invites_handler_test.go @@ -205,6 +205,14 @@ func TestCreateInvite(t *testing.T) { return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") }, }, + { + name: "local auth disabled", + requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`, + expectedStatus: http.StatusPreconditionFailed, + mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) { + return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider") + }, + }, { name: "invalid JSON", requestBody: `{invalid json}`, @@ -376,6 +384,15 @@ func TestAcceptInvite(t *testing.T) { return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") }, }, + { + name: "local auth disabled", + token: testInviteToken, + requestBody: `{"password":"SecurePass123!"}`, + expectedStatus: http.StatusPreconditionFailed, + mockFunc: func(ctx context.Context, token, password string) error { + return status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider") + }, + }, { name: "missing token", token: "", diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 9339c3541..1fd4c9bad 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -73,7 +73,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee proxyController := integrations.NewController(store) userManager := users.NewManager(store) permissionsManager := permissions.NewManager(store) - settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager) + settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager, settings.IdpConfig{}) peersManager := peers.NewManager(store, permissionsManager) jobManager := job.NewJobManager(nil, store, peersManager) diff --git a/management/server/idp/embedded.go b/management/server/idp/embedded.go index db7a91fa3..a27050a26 100644 --- a/management/server/idp/embedded.go +++ b/management/server/idp/embedded.go @@ -43,6 +43,11 @@ type EmbeddedIdPConfig struct { Owner *OwnerConfig // SignKeyRefreshEnabled enables automatic key rotation for signing keys SignKeyRefreshEnabled bool + // LocalAuthDisabled disables the local (email/password) authentication connector. + // When true, users cannot authenticate via email/password, only via external identity providers. + // Existing local users are preserved and will be able to login again if re-enabled. + // Cannot be enabled if no external identity provider connectors are configured. + LocalAuthDisabled bool } // EmbeddedStorageConfig holds storage configuration for the embedded IdP. @@ -105,6 +110,8 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { Issuer: "NetBird", Theme: "light", }, + // Always enable password DB initially - we disable the local connector after startup if needed. + // This ensures Dex has at least one connector during initialization. EnablePasswordDB: true, StaticClients: []storage.Client{ { @@ -192,11 +199,32 @@ func NewEmbeddedIdPManager(ctx context.Context, config *EmbeddedIdPConfig, appMe return nil, err } + log.WithContext(ctx).Debugf("initializing embedded Dex IDP with config: %+v", config) + provider, err := dex.NewProviderFromYAML(ctx, yamlConfig) if err != nil { return nil, fmt.Errorf("failed to create embedded IdP provider: %w", err) } + // If local auth is disabled, validate that other connectors exist + if config.LocalAuthDisabled { + hasOthers, err := provider.HasNonLocalConnectors(ctx) + if err != nil { + _ = provider.Stop(ctx) + return nil, fmt.Errorf("failed to check connectors: %w", err) + } + if !hasOthers { + _ = provider.Stop(ctx) + return nil, fmt.Errorf("cannot disable local authentication: no other identity providers configured") + } + // Ensure local connector is removed (it might exist from a previous run) + if err := provider.DisableLocalAuth(ctx); err != nil { + _ = provider.Stop(ctx) + return nil, fmt.Errorf("failed to disable local auth: %w", err) + } + log.WithContext(ctx).Info("local authentication disabled - only external identity providers can be used") + } + log.WithContext(ctx).Infof("embedded Dex IDP initialized with issuer: %s", yamlConfig.Issuer) return &EmbeddedIdPManager{ @@ -281,6 +309,8 @@ func (m *EmbeddedIdPManager) GetAllAccounts(ctx context.Context) (map[string][]* return nil, fmt.Errorf("failed to list users: %w", err) } + log.WithContext(ctx).Debugf("retrieved %d users from embedded IdP", len(users)) + indexedUsers := make(map[string][]*UserData) for _, user := range users { indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], &UserData{ @@ -290,11 +320,17 @@ func (m *EmbeddedIdPManager) GetAllAccounts(ctx context.Context) (map[string][]* }) } + log.WithContext(ctx).Debugf("retrieved %d users from embedded IdP", len(indexedUsers[UnsetAccountID])) + return indexedUsers, nil } // CreateUser creates a new user in the embedded IdP. func (m *EmbeddedIdPManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) { + if m.config.LocalAuthDisabled { + return nil, fmt.Errorf("local user creation is disabled") + } + if m.appMetrics != nil { m.appMetrics.IDPMetrics().CountCreateUser() } @@ -364,6 +400,10 @@ func (m *EmbeddedIdPManager) GetUserByEmail(ctx context.Context, email string) ( // Unlike CreateUser which auto-generates a password, this method uses the provided password. // This is useful for instance setup where the user provides their own password. func (m *EmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email, password, name string) (*UserData, error) { + if m.config.LocalAuthDisabled { + return nil, fmt.Errorf("local user creation is disabled") + } + if m.appMetrics != nil { m.appMetrics.IDPMetrics().CountCreateUser() } @@ -553,3 +593,13 @@ func (m *EmbeddedIdPManager) GetClientIDs() []string { func (m *EmbeddedIdPManager) GetUserIDClaim() string { return defaultUserIDClaim } + +// IsLocalAuthDisabled returns whether local authentication is disabled based on configuration. +func (m *EmbeddedIdPManager) IsLocalAuthDisabled() bool { + return m.config.LocalAuthDisabled +} + +// HasNonLocalConnectors checks if there are any identity provider connectors other than local. +func (m *EmbeddedIdPManager) HasNonLocalConnectors(ctx context.Context) (bool, error) { + return m.provider.HasNonLocalConnectors(ctx) +} diff --git a/management/server/idp/embedded_test.go b/management/server/idp/embedded_test.go index d8d3009dd..4dda483fb 100644 --- a/management/server/idp/embedded_test.go +++ b/management/server/idp/embedded_test.go @@ -370,3 +370,234 @@ func TestEmbeddedIdPManager_GetLocalKeysLocation(t *testing.T) { }) } } + +func TestEmbeddedIdPManager_LocalAuthDisabled(t *testing.T) { + ctx := context.Background() + + t.Run("cannot start with local auth disabled without other connectors", func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + config := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + LocalAuthDisabled: true, + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: filepath.Join(tmpDir, "dex.db"), + }, + }, + } + + _, err = NewEmbeddedIdPManager(ctx, config, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "no other identity providers configured") + }) + + t.Run("local auth enabled by default", func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + config := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: filepath.Join(tmpDir, "dex.db"), + }, + }, + } + + manager, err := NewEmbeddedIdPManager(ctx, config, nil) + require.NoError(t, err) + defer func() { _ = manager.Stop(ctx) }() + + // Verify local auth is enabled by default + assert.False(t, manager.IsLocalAuthDisabled()) + }) + + t.Run("start with local auth disabled when connector exists", func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + dbFile := filepath.Join(tmpDir, "dex.db") + + // First, create a manager with local auth enabled and add a connector + config1 := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: dbFile, + }, + }, + } + + manager1, err := NewEmbeddedIdPManager(ctx, config1, nil) + require.NoError(t, err) + + // Create a user + userData, err := manager1.CreateUser(ctx, "preserved@example.com", "Preserved User", "account1", "admin@example.com") + require.NoError(t, err) + userID := userData.ID + + // Add an external connector (Google doesn't require OIDC discovery) + _, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{ + ID: "google-test", + Name: "Google Test", + Type: "google", + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + }) + require.NoError(t, err) + + // Stop the first manager + err = manager1.Stop(ctx) + require.NoError(t, err) + + // Now create a new manager with local auth disabled + config2 := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + LocalAuthDisabled: true, + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: dbFile, + }, + }, + } + + manager2, err := NewEmbeddedIdPManager(ctx, config2, nil) + require.NoError(t, err) + defer func() { _ = manager2.Stop(ctx) }() + + // Verify local auth is disabled via config + assert.True(t, manager2.IsLocalAuthDisabled()) + + // Verify the user still exists in storage (just can't login via local) + lookedUp, err := manager2.GetUserDataByID(ctx, userID, AppMetadata{}) + require.NoError(t, err) + assert.Equal(t, "preserved@example.com", lookedUp.Email) + }) + + t.Run("CreateUser fails when local auth is disabled", func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + dbFile := filepath.Join(tmpDir, "dex.db") + + // First, create a manager and add an external connector + config1 := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: dbFile, + }, + }, + } + + manager1, err := NewEmbeddedIdPManager(ctx, config1, nil) + require.NoError(t, err) + + _, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{ + ID: "google-test", + Name: "Google Test", + Type: "google", + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + }) + require.NoError(t, err) + + err = manager1.Stop(ctx) + require.NoError(t, err) + + // Create manager with local auth disabled + config2 := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + LocalAuthDisabled: true, + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: dbFile, + }, + }, + } + + manager2, err := NewEmbeddedIdPManager(ctx, config2, nil) + require.NoError(t, err) + defer func() { _ = manager2.Stop(ctx) }() + + // Try to create a user - should fail + _, err = manager2.CreateUser(ctx, "newuser@example.com", "New User", "account1", "admin@example.com") + require.Error(t, err) + assert.Contains(t, err.Error(), "local user creation is disabled") + }) + + t.Run("CreateUserWithPassword fails when local auth is disabled", func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + dbFile := filepath.Join(tmpDir, "dex.db") + + // First, create a manager and add an external connector + config1 := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: dbFile, + }, + }, + } + + manager1, err := NewEmbeddedIdPManager(ctx, config1, nil) + require.NoError(t, err) + + _, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{ + ID: "google-test", + Name: "Google Test", + Type: "google", + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + }) + require.NoError(t, err) + + err = manager1.Stop(ctx) + require.NoError(t, err) + + // Create manager with local auth disabled + config2 := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + LocalAuthDisabled: true, + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: dbFile, + }, + }, + } + + manager2, err := NewEmbeddedIdPManager(ctx, config2, nil) + require.NoError(t, err) + defer func() { _ = manager2.Stop(ctx) }() + + // Try to create a user with password - should fail + _, err = manager2.CreateUserWithPassword(ctx, "newuser@example.com", "SecurePass123!", "New User") + require.Error(t, err) + assert.Contains(t, err.Error(), "local user creation is disabled") + }) +} diff --git a/management/server/instance/manager.go b/management/server/instance/manager.go index 6a0509ebd..19e3abdc0 100644 --- a/management/server/instance/manager.go +++ b/management/server/instance/manager.go @@ -104,13 +104,22 @@ func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager) } func (m *DefaultManager) loadSetupRequired(ctx context.Context) error { + // Check if there are any accounts in the NetBird store + numAccounts, err := m.store.GetAccountsCounter(ctx) + if err != nil { + return err + } + hasAccounts := numAccounts > 0 + + // Check if there are any users in the embedded IdP (Dex) users, err := m.embeddedIdpManager.GetAllAccounts(ctx) if err != nil { return err } + hasLocalUsers := len(users) > 0 m.setupMu.Lock() - m.setupRequired = len(users) == 0 + m.setupRequired = !(hasAccounts || hasLocalUsers) m.setupMu.Unlock() return nil diff --git a/management/server/settings/manager.go b/management/server/settings/manager.go index 2b2896572..74af0a3ef 100644 --- a/management/server/settings/manager.go +++ b/management/server/settings/manager.go @@ -24,19 +24,28 @@ type Manager interface { UpdateExtraSettings(ctx context.Context, accountID, userID string, extraSettings *types.ExtraSettings) (bool, error) } +// IdpConfig holds IdP-related configuration that is set at runtime +// and not stored in the database. +type IdpConfig struct { + EmbeddedIdpEnabled bool + LocalAuthDisabled bool +} + type managerImpl struct { store store.Store extraSettingsManager extra_settings.Manager userManager users.Manager permissionsManager permissions.Manager + idpConfig IdpConfig } -func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager, permissionsManager permissions.Manager) Manager { +func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager, permissionsManager permissions.Manager, idpConfig IdpConfig) Manager { return &managerImpl{ store: store, extraSettingsManager: extraSettingsManager, userManager: userManager, permissionsManager: permissionsManager, + idpConfig: idpConfig, } } @@ -74,6 +83,10 @@ func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) settings.Extra.FlowDnsCollectionEnabled = extraSettings.FlowDnsCollectionEnabled } + // Fill in IdP-related runtime settings + settings.EmbeddedIdpEnabled = m.idpConfig.EmbeddedIdpEnabled + settings.LocalAuthDisabled = m.idpConfig.LocalAuthDisabled + return settings, nil } diff --git a/management/server/types/settings.go b/management/server/types/settings.go index 867e12bef..a94e01b78 100644 --- a/management/server/types/settings.go +++ b/management/server/types/settings.go @@ -55,6 +55,14 @@ type Settings struct { // AutoUpdateVersion client auto-update version AutoUpdateVersion string `gorm:"default:'disabled'"` + + // EmbeddedIdpEnabled indicates if the embedded identity provider is enabled. + // This is a runtime-only field, not stored in the database. + EmbeddedIdpEnabled bool `gorm:"-"` + + // LocalAuthDisabled indicates if local (email/password) authentication is disabled. + // This is a runtime-only field, not stored in the database. + LocalAuthDisabled bool `gorm:"-"` } // Copy copies the Settings struct @@ -76,6 +84,8 @@ func (s *Settings) Copy() *Settings { DNSDomain: s.DNSDomain, NetworkRange: s.NetworkRange, AutoUpdateVersion: s.AutoUpdateVersion, + EmbeddedIdpEnabled: s.EmbeddedIdpEnabled, + LocalAuthDisabled: s.LocalAuthDisabled, } if s.Extra != nil { settings.Extra = s.Extra.Copy() diff --git a/management/server/user.go b/management/server/user.go index 51da7a633..48005f325 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -191,6 +191,10 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID // Unlike createNewIdpUser, this method fetches user data directly from the database // since the embedded IdP usage ensures the username and email are stored locally in the User table. func (am *DefaultAccountManager) createEmbeddedIdpUser(ctx context.Context, accountID string, inviterID string, invite *types.UserInfo) (*idp.UserData, error) { + if IsLocalAuthDisabled(ctx, am.idpManager) { + return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider") + } + inviter, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, inviterID) if err != nil { return nil, fmt.Errorf("failed to get inviter user: %w", err) @@ -1462,6 +1466,10 @@ func (am *DefaultAccountManager) CreateUserInvite(ctx context.Context, accountID return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") } + if IsLocalAuthDisabled(ctx, am.idpManager) { + return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider") + } + if err := validateUserInvite(invite); err != nil { return nil, err } @@ -1621,6 +1629,10 @@ func (am *DefaultAccountManager) AcceptUserInvite(ctx context.Context, token, pa return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") } + if IsLocalAuthDisabled(ctx, am.idpManager) { + return status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider") + } + if password == "" { return status.Errorf(status.InvalidArgument, "password is required") } diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 26d2387d1..b9a8eae3a 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -294,6 +294,11 @@ components: type: boolean readOnly: true example: false + local_auth_disabled: + description: Indicates whether local (email/password) authentication is disabled. When true, users can only authenticate via external identity providers. This is a read-only field. + type: boolean + readOnly: true + example: false required: - peer_login_expiration_enabled - peer_login_expiration diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index e8c044b32..fd7c61917 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -415,6 +415,9 @@ type AccountSettings struct { // LazyConnectionEnabled Enables or disables experimental lazy connection LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"` + // LocalAuthDisabled Indicates whether local (email/password) authentication is disabled. When true, users can only authenticate via external identity providers. This is a read-only field. + LocalAuthDisabled *bool `json:"local_auth_disabled,omitempty"` + // NetworkRange Allows to define a custom network range for the account in CIDR format NetworkRange *string `json:"network_range,omitempty"` From 7b830d8f72b6fa997f03d99292de814faa33a562 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Sun, 1 Feb 2026 14:37:00 +0100 Subject: [PATCH 11/27] disable sync lim (#5233) --- management/internals/shared/grpc/server.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 219baaf6d..6757cca13 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -77,8 +77,9 @@ type Server struct { oAuthConfigProvider idp.OAuthConfigProvider - syncSem atomic.Int32 - syncLim int32 + syncSem atomic.Int32 + syncLimEnabled bool + syncLim int32 } // NewServer creates a new Management server @@ -108,6 +109,7 @@ func NewServer( blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true" syncLim := int32(defaultSyncLim) + syncLimEnabled := true if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" { syncLimParsed, err := strconv.Atoi(syncLimStr) if err != nil { @@ -115,6 +117,9 @@ func NewServer( } else { //nolint:gosec syncLim = int32(syncLimParsed) + if syncLim < 0 { + syncLimEnabled = false + } } } @@ -134,7 +139,8 @@ func NewServer( loginFilter: newLoginFilter(), - syncLim: syncLim, + syncLim: syncLim, + syncLimEnabled: syncLimEnabled, }, nil } @@ -212,7 +218,7 @@ func (s *Server) Job(srv proto.ManagementService_JobServer) error { // Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and // notifies the connected peer of any updates (e.g. new peers under the same account) func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { - if s.syncSem.Load() >= s.syncLim { + if s.syncLimEnabled && s.syncSem.Load() >= s.syncLim { return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later") } s.syncSem.Add(1) From 893129334376ef0cf65198dfbf5f428290ef6244 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Sun, 1 Feb 2026 15:44:27 +0100 Subject: [PATCH 12/27] [management] run cancelPeerRoutinesWithoutLock in sync (#5234) --- management/internals/shared/grpc/server.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 6757cca13..3704b3188 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -311,7 +311,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S if err != nil { log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer) return err } @@ -319,7 +319,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S if err != nil { log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer) return err } @@ -490,6 +490,10 @@ func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer unlock := s.acquirePeerLockByUID(ctx, peer.Key) defer unlock() + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer) +} + +func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer) { err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) if err != nil { log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err) From b20d4849720754928ce08a88d8970877a5b0daa5 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Sun, 1 Feb 2026 16:06:36 +0100 Subject: [PATCH 13/27] [docs] Add selfhosting video (#5235) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8f4c04641..bca81c20b 100644 --- a/README.md +++ b/README.md @@ -60,8 +60,8 @@ https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2 -### NetBird on Lawrence Systems (Video) -[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw) +### Self-Host NetBird (Video) +[![Watch the video](https://img.youtube.com/vi/bZAgpT6nzaQ/0.jpg)](https://youtu.be/bZAgpT6nzaQ) ### Key features From 6fdc00ff4185849cf9d3e65d38971684a0ffe65e Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 3 Feb 2026 17:30:02 +0100 Subject: [PATCH 14/27] [management] adding account id validation to accessible peers handler (#5246) --- management/server/http/handler.go | 5 +++-- .../http/handlers/peers/peers_handler.go | 19 ++++++++++++++----- .../http/handlers/peers/peers_handler_test.go | 11 +++++++++-- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 79431a0a3..17355d1d9 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -9,10 +9,11 @@ import ( "time" "github.com/gorilla/mux" - idpmanager "github.com/netbirdio/netbird/management/server/idp" "github.com/rs/cors" log "github.com/sirupsen/logrus" + idpmanager "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/modules/zones" @@ -137,7 +138,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks } accounts.AddEndpoints(accountManager, settingsManager, router) - peers.AddEndpoints(accountManager, router, networkMapController) + peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager) users.AddEndpoints(accountManager, router) users.AddInvitesEndpoints(accountManager, router) users.AddPublicInvitesEndpoints(accountManager, router) diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 53d8ab055..783cfe11b 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -17,6 +17,7 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" @@ -26,11 +27,12 @@ import ( // Handler is a handler that returns peers of the account type Handler struct { accountManager account.Manager + permissionsManager permissions.Manager networkMapController network_map.Controller } -func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller) { - peersHandler := NewHandler(accountManager, networkMapController) +func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller, permissionsManager permissions.Manager) { + peersHandler := NewHandler(accountManager, networkMapController, permissionsManager) router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). Methods("GET", "PUT", "DELETE", "OPTIONS") @@ -42,10 +44,11 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMap } // NewHandler creates a new peers Handler -func NewHandler(accountManager account.Manager, networkMapController network_map.Controller) *Handler { +func NewHandler(accountManager account.Manager, networkMapController network_map.Controller, permissionsManager permissions.Manager) *Handler { return &Handler{ accountManager: accountManager, networkMapController: networkMapController, + permissionsManager: permissionsManager, } } @@ -359,13 +362,19 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { return } - account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator) + user, err := h.accountManager.GetUserByID(r.Context(), userID) if err != nil { util.WriteError(r.Context(), err, w) return } - user, err := h.accountManager.GetUserByID(r.Context(), userID) + err = h.permissionsManager.ValidateAccountAccess(r.Context(), accountID, user, false) + if err != nil { + util.WriteError(r.Context(), status.NewPermissionDeniedError(), w) + return + } + + account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index 869a39b5e..786c144fc 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -13,13 +13,15 @@ import ( "testing" "time" + "github.com/golang/mock/gomock" "github.com/gorilla/mux" - "go.uber.org/mock/gomock" + ugomock "go.uber.org/mock/gomock" "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbcontext "github.com/netbirdio/netbird/management/server/context" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" @@ -102,7 +104,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler { }, } - ctrl := gomock.NewController(t) + ctrl := ugomock.NewController(t) networkMapController := network_map.NewMockController(ctrl) networkMapController.EXPECT(). @@ -110,6 +112,10 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler { Return("domain"). AnyTimes() + ctrl2 := gomock.NewController(t) + permissionsManager := permissions.NewMockManager(ctrl2) + permissionsManager.EXPECT().ValidateAccountAccess(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + return &Handler{ accountManager: &mock_server.MockAccountManager{ UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { @@ -199,6 +205,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler { }, }, networkMapController: networkMapController, + permissionsManager: permissionsManager, } } From d488f583115402c62d9974e787a5b23f1b73fc32 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Wed, 4 Feb 2026 11:44:46 +0100 Subject: [PATCH 15/27] [management] fix set disconnected status for connected peer (#5247) --- management/internals/shared/grpc/server.go | 32 ++++++----- management/server/account.go | 16 +++++- management/server/account/manager.go | 2 +- management/server/account_test.go | 55 +++++++++++++++++++ management/server/mock_server/account_mock.go | 5 +- 5 files changed, 89 insertions(+), 21 deletions(-) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 3704b3188..befcd2adf 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -307,11 +307,13 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S return mapError(ctx, err) } + streamStartTime := time.Now().UTC() + err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv, dnsFwdPort) if err != nil { log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) - s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime) return err } @@ -319,7 +321,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S if err != nil { log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) - s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime) return err } @@ -336,7 +338,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S s.syncSem.Add(-1) - return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) + return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, streamStartTime) } func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) { @@ -404,7 +406,7 @@ func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgt } // handleUpdates sends updates to the connected peer until the updates channel is closed. -func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error { log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String()) for { select { @@ -416,11 +418,11 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg if !open { log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String()) - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return nil } log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) - if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil { + if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil { log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) return err } @@ -429,7 +431,7 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg case <-srv.Context().Done(): // happens when connection drops, e.g. client disconnects log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String()) - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return srv.Context().Err() } } @@ -437,16 +439,16 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg // sendUpdate encrypts the update message using the peer key and the server's wireguard key, // then sends the encrypted message to the connected peer via the sync server. -func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error { key, err := s.secretsManager.GetWGKey() if err != nil { - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return status.Errorf(codes.Internal, "failed processing update message") } encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update) if err != nil { - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return status.Errorf(codes.Internal, "failed processing update message") } err = srv.Send(&proto.EncryptedMessage{ @@ -454,7 +456,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp Body: encryptedResp, }) if err != nil { - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return status.Errorf(codes.Internal, "failed sending update message") } log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String()) @@ -486,15 +488,15 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even return nil } -func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { +func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) { unlock := s.acquirePeerLockByUID(ctx, peer.Key) defer unlock() - s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime) } -func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer) { - err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) +func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) { + err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key, streamStartTime) if err != nil { log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err) } diff --git a/management/server/account.go b/management/server/account.go index 8f9dad031..4f53415f5 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1684,8 +1684,20 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID return peer, netMap, postureChecks, dnsfwdPort, nil } -func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error { - err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID) +func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error { + peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey) + if err != nil { + log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err) + return nil + } + + if peer.Status.LastSeen.After(streamStartTime) { + log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect", + peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339)) + return nil + } + + err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err) } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 5e9bb42a2..eed7739da 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -115,7 +115,7 @@ type Manager interface { GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) - OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error + OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) diff --git a/management/server/account_test.go b/management/server/account_test.go index 86cc69e8b..f3d98916c 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1961,6 +1961,61 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. } } +func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) { + manager, _, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + + accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) + require.NoError(t, err, "unable to create an account") + + key, err := wgtypes.GenerateKey() + require.NoError(t, err, "unable to generate WireGuard key") + peerPubKey := key.PublicKey().String() + + _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ + Key: peerPubKey, + Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, + }, false) + require.NoError(t, err, "unable to add peer") + + t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) { + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID) + require.NoError(t, err, "unable to mark peer connected") + + peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err, "unable to get peer") + require.True(t, peer.Status.Connected, "peer should be connected") + + streamStartTime := time.Now().UTC() + + err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime) + require.NoError(t, err) + + peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.False(t, peer.Status.Connected, "peer should be disconnected") + }) + + t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) { + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID) + require.NoError(t, err, "unable to mark peer connected") + + peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.True(t, peer.Status.Connected, "peer should be connected") + + streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour) + + err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime) + require.NoError(t, err) + + peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.True(t, peer.Status.Connected, + "peer should remain connected because LastSeen > streamStartTime (zombie stream protection)") + }) +} + func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) { manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 026989898..a4754d180 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -221,9 +221,8 @@ func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID str return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } -func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error { - // TODO implement me - panic("implement me") +func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error { + return nil } func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { From f7732557fa42622c85c96aa894a648ac7f5cb58f Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 5 Feb 2026 01:07:27 +0800 Subject: [PATCH 16/27] [client] Add missing bsd flags in debug bundle (#5254) --- .../routemanager/systemops/routeflags_bsd.go | 40 ++++++++++++------- .../systemops/routeflags_freebsd.go | 38 +++++++++++------- 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/client/internal/routemanager/systemops/routeflags_bsd.go b/client/internal/routemanager/systemops/routeflags_bsd.go index ad32e5029..33280bfb3 100644 --- a/client/internal/routemanager/systemops/routeflags_bsd.go +++ b/client/internal/routemanager/systemops/routeflags_bsd.go @@ -4,16 +4,17 @@ package systemops import ( "strings" - "syscall" + + "golang.org/x/sys/unix" ) // filterRoutesByFlags returns true if the route message should be ignored based on its flags. func filterRoutesByFlags(routeMessageFlags int) bool { - if routeMessageFlags&syscall.RTF_UP == 0 { + if routeMessageFlags&unix.RTF_UP == 0 { return true } - if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 { + if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE|unix.RTF_WASCLONED) != 0 { return true } @@ -24,42 +25,51 @@ func filterRoutesByFlags(routeMessageFlags int) bool { func formatBSDFlags(flags int) string { var flagStrs []string - if flags&syscall.RTF_UP != 0 { + if flags&unix.RTF_UP != 0 { flagStrs = append(flagStrs, "U") } - if flags&syscall.RTF_GATEWAY != 0 { + if flags&unix.RTF_GATEWAY != 0 { flagStrs = append(flagStrs, "G") } - if flags&syscall.RTF_HOST != 0 { + if flags&unix.RTF_HOST != 0 { flagStrs = append(flagStrs, "H") } - if flags&syscall.RTF_REJECT != 0 { + if flags&unix.RTF_REJECT != 0 { flagStrs = append(flagStrs, "R") } - if flags&syscall.RTF_DYNAMIC != 0 { + if flags&unix.RTF_DYNAMIC != 0 { flagStrs = append(flagStrs, "D") } - if flags&syscall.RTF_MODIFIED != 0 { + if flags&unix.RTF_MODIFIED != 0 { flagStrs = append(flagStrs, "M") } - if flags&syscall.RTF_STATIC != 0 { + if flags&unix.RTF_STATIC != 0 { flagStrs = append(flagStrs, "S") } - if flags&syscall.RTF_LLINFO != 0 { + if flags&unix.RTF_LLINFO != 0 { flagStrs = append(flagStrs, "L") } - if flags&syscall.RTF_LOCAL != 0 { + if flags&unix.RTF_LOCAL != 0 { flagStrs = append(flagStrs, "l") } - if flags&syscall.RTF_BLACKHOLE != 0 { + if flags&unix.RTF_BLACKHOLE != 0 { flagStrs = append(flagStrs, "B") } - if flags&syscall.RTF_CLONING != 0 { + if flags&unix.RTF_CLONING != 0 { flagStrs = append(flagStrs, "C") } - if flags&syscall.RTF_WASCLONED != 0 { + if flags&unix.RTF_WASCLONED != 0 { flagStrs = append(flagStrs, "W") } + if flags&unix.RTF_PROTO1 != 0 { + flagStrs = append(flagStrs, "1") + } + if flags&unix.RTF_PROTO2 != 0 { + flagStrs = append(flagStrs, "2") + } + if flags&unix.RTF_PROTO3 != 0 { + flagStrs = append(flagStrs, "3") + } if len(flagStrs) == 0 { return "-" diff --git a/client/internal/routemanager/systemops/routeflags_freebsd.go b/client/internal/routemanager/systemops/routeflags_freebsd.go index 2338fe5d8..a8c82b3ed 100644 --- a/client/internal/routemanager/systemops/routeflags_freebsd.go +++ b/client/internal/routemanager/systemops/routeflags_freebsd.go @@ -4,17 +4,18 @@ package systemops import ( "strings" - "syscall" + + "golang.org/x/sys/unix" ) // filterRoutesByFlags returns true if the route message should be ignored based on its flags. func filterRoutesByFlags(routeMessageFlags int) bool { - if routeMessageFlags&syscall.RTF_UP == 0 { + if routeMessageFlags&unix.RTF_UP == 0 { return true } - // NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0 - if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 { + // NOTE: RTF_WASCLONED deprecated in FreeBSD 8.0 + if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE) != 0 { return true } @@ -25,37 +26,46 @@ func filterRoutesByFlags(routeMessageFlags int) bool { func formatBSDFlags(flags int) string { var flagStrs []string - if flags&syscall.RTF_UP != 0 { + if flags&unix.RTF_UP != 0 { flagStrs = append(flagStrs, "U") } - if flags&syscall.RTF_GATEWAY != 0 { + if flags&unix.RTF_GATEWAY != 0 { flagStrs = append(flagStrs, "G") } - if flags&syscall.RTF_HOST != 0 { + if flags&unix.RTF_HOST != 0 { flagStrs = append(flagStrs, "H") } - if flags&syscall.RTF_REJECT != 0 { + if flags&unix.RTF_REJECT != 0 { flagStrs = append(flagStrs, "R") } - if flags&syscall.RTF_DYNAMIC != 0 { + if flags&unix.RTF_DYNAMIC != 0 { flagStrs = append(flagStrs, "D") } - if flags&syscall.RTF_MODIFIED != 0 { + if flags&unix.RTF_MODIFIED != 0 { flagStrs = append(flagStrs, "M") } - if flags&syscall.RTF_STATIC != 0 { + if flags&unix.RTF_STATIC != 0 { flagStrs = append(flagStrs, "S") } - if flags&syscall.RTF_LLINFO != 0 { + if flags&unix.RTF_LLINFO != 0 { flagStrs = append(flagStrs, "L") } - if flags&syscall.RTF_LOCAL != 0 { + if flags&unix.RTF_LOCAL != 0 { flagStrs = append(flagStrs, "l") } - if flags&syscall.RTF_BLACKHOLE != 0 { + if flags&unix.RTF_BLACKHOLE != 0 { flagStrs = append(flagStrs, "B") } // Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0 + if flags&unix.RTF_PROTO1 != 0 { + flagStrs = append(flagStrs, "1") + } + if flags&unix.RTF_PROTO2 != 0 { + flagStrs = append(flagStrs, "2") + } + if flags&unix.RTF_PROTO3 != 0 { + flagStrs = append(flagStrs, "3") + } if len(flagStrs) == 0 { return "-" From 194a986926cfcce284e77ee6eca84732d8976ace Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 4 Feb 2026 22:22:37 +0100 Subject: [PATCH 17/27] Cache the result of wgInterface.ToInterface() using sync.Once (#5256) Avoid repeated conversions during route setup. The toInterface helper ensures the conversion happens only once regardless of how many routes are added or removed. --- client/internal/routemanager/manager.go | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 2baa0e668..077b9521b 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -173,12 +173,21 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) { } func (m *DefaultManager) setupRefCounters(useNoop bool) { + var once sync.Once + var wgIface *net.Interface + toInterface := func() *net.Interface { + once.Do(func() { + wgIface = m.wgInterface.ToInterface() + }) + return wgIface + } + m.routeRefCounter = refcounter.New( func(prefix netip.Prefix, _ struct{}) (struct{}, error) { - return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface()) + return struct{}{}, m.sysOps.AddVPNRoute(prefix, toInterface()) }, func(prefix netip.Prefix, _ struct{}) error { - return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface()) + return m.sysOps.RemoveVPNRoute(prefix, toInterface()) }, ) From d2f9653cea8c3b7d119c657366e967eb5869dd07 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 5 Feb 2026 12:06:28 +0100 Subject: [PATCH 18/27] Fix nil pointer panic in ICE agent during sleep/wake cycles (#5261) Add defensive nil checks in ThreadSafeAgent.Close() to prevent panic when agent field is nil. This can occur during Windows suspend/resume when network interfaces are disrupted or the pion/ice library returns nil without error. Also capture agent pointer in local variable before goroutine execution to prevent race conditions. Fixes service crashes on laptop wake-up. --- client/internal/peer/ice/agent.go | 51 +++++++++++++++++++----------- client/internal/peer/worker_ice.go | 6 ++-- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go index 79f68d279..c74b46d10 100644 --- a/client/internal/peer/ice/agent.go +++ b/client/internal/peer/ice/agent.go @@ -2,6 +2,7 @@ package ice import ( "context" + "fmt" "sync" "time" @@ -32,24 +33,6 @@ type ThreadSafeAgent struct { once sync.Once } -func (a *ThreadSafeAgent) Close() error { - var err error - a.once.Do(func() { - done := make(chan error, 1) - go func() { - done <- a.Agent.Close() - }() - - select { - case err = <-done: - case <-time.After(iceAgentCloseTimeout): - log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout) - err = nil - } - }) - return err -} - func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) { iceKeepAlive := iceKeepAlive() iceDisconnectedTimeout := iceDisconnectedTimeout() @@ -93,9 +76,41 @@ func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, c return nil, err } + if agent == nil { + return nil, fmt.Errorf("ice.NewAgent returned nil agent without error") + } + return &ThreadSafeAgent{Agent: agent}, nil } +func (a *ThreadSafeAgent) Close() error { + var err error + a.once.Do(func() { + // Defensive check to prevent nil pointer dereference + // This can happen during sleep/wake transitions or memory corruption scenarios + // github.com/netbirdio/netbird/client/internal/peer/ice.(*ThreadSafeAgent).Close(0x40006883f0?) + // [signal 0xc0000005 code=0x0 addr=0x0 pc=0x7ff7e73af83c] + agent := a.Agent + if agent == nil { + log.Warnf("ICE agent is nil during close, skipping") + return + } + + done := make(chan error, 1) + go func() { + done <- agent.Close() + }() + + select { + case err = <-done: + case <-time.After(iceAgentCloseTimeout): + log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout) + err = nil + } + }) + return err +} + func GenerateICECredentials() (string, string, error) { ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha) if err != nil { diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index b6b9d2cf4..464f57bff 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -107,8 +107,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { } w.log.Debugf("agent already exists, recreate the connection") w.agentDialerCancel() - if err := w.agent.Close(); err != nil { - w.log.Warnf("failed to close ICE agent: %s", err) + if w.agent != nil { + if err := w.agent.Close(); err != nil { + w.log.Warnf("failed to close ICE agent: %s", err) + } } sessionID, err := NewICESessionID() From 1b96648d4d190ec2d57b3a195004445dd5e10b16 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 5 Feb 2026 21:34:35 +0800 Subject: [PATCH 19/27] [client] Always log dns forwader responses (#5262) --- client/internal/dnsfwd/forwarder.go | 125 ++++++++++------------- client/internal/dnsfwd/forwarder_test.go | 90 +++++++--------- 2 files changed, 93 insertions(+), 122 deletions(-) diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 1230a4e46..5c7cb31fc 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -190,50 +190,75 @@ func (f *DNSForwarder) Close(ctx context.Context) error { return nberrors.FormatErrorOrNil(result) } -func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg { +func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg, startTime time.Time) { if len(query.Question) == 0 { - return nil + return } question := query.Question[0] - logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s", - question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass]) + qname := strings.ToLower(question.Name) - domain := strings.ToLower(question.Name) + logger.Tracef("question: domain=%s type=%s class=%s", + qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass]) resp := query.SetReply(query) network := resutil.NetworkForQtype(question.Qtype) if network == "" { resp.Rcode = dns.RcodeNotImplemented - if err := w.WriteMsg(resp); err != nil { - logger.Errorf("failed to write DNS response: %v", err) - } - return nil + f.writeResponse(logger, w, resp, qname, startTime) + return } - mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, ".")) - // query doesn't match any configured domain + mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, ".")) if mostSpecificResId == "" { resp.Rcode = dns.RcodeRefused - if err := w.WriteMsg(resp); err != nil { - logger.Errorf("failed to write DNS response: %v", err) - } - return nil + f.writeResponse(logger, w, resp, qname, startTime) + return } ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) defer cancel() - result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype) + result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype) if result.Err != nil { - f.handleDNSError(ctx, logger, w, question, resp, domain, result) - return nil + f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime) + return } f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries) - resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...) - f.cache.set(domain, question.Qtype, result.IPs) + resp.Answer = append(resp.Answer, resutil.IPsToRRs(qname, result.IPs, f.ttl)...) + f.cache.set(qname, question.Qtype, result.IPs) - return resp + f.writeResponse(logger, w, resp, qname, startTime) +} + +func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) { + if err := w.WriteMsg(resp); err != nil { + logger.Errorf("failed to write DNS response: %v", err) + return + } + + logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s", + qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime)) +} + +// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation. +type udpResponseWriter struct { + dns.ResponseWriter + query *dns.Msg +} + +func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error { + opt := u.query.IsEdns0() + maxSize := dns.MinMsgSize + if opt != nil { + maxSize = int(opt.UDPSize()) + } + + if resp.Len() > maxSize { + resp.Truncate(maxSize) + } + + return u.ResponseWriter.WriteMsg(resp) } func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { @@ -243,30 +268,7 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { "dns_id": fmt.Sprintf("%04x", query.Id), }) - resp := f.handleDNSQuery(logger, w, query) - if resp == nil { - return - } - - opt := query.IsEdns0() - maxSize := dns.MinMsgSize - if opt != nil { - // client advertised a larger EDNS0 buffer - maxSize = int(opt.UDPSize()) - } - - // if our response is too big, truncate and set the TC bit - if resp.Len() > maxSize { - resp.Truncate(maxSize) - } - - if err := w.WriteMsg(resp); err != nil { - logger.Errorf("failed to write DNS response: %v", err) - return - } - - logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s", - query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime)) + f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime) } func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) { @@ -276,18 +278,7 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) { "dns_id": fmt.Sprintf("%04x", query.Id), }) - resp := f.handleDNSQuery(logger, w, query) - if resp == nil { - return - } - - if err := w.WriteMsg(resp); err != nil { - logger.Errorf("failed to write DNS response: %v", err) - return - } - - logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s", - query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime)) + f.handleDNSQuery(logger, w, query, startTime) } func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) { @@ -334,6 +325,7 @@ func (f *DNSForwarder) handleDNSError( resp *dns.Msg, domain string, result resutil.LookupResult, + startTime time.Time, ) { qType := question.Qtype qTypeName := dns.TypeToString[qType] @@ -343,9 +335,7 @@ func (f *DNSForwarder) handleDNSError( // NotFound: cache negative result and respond if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess { f.cache.set(domain, question.Qtype, nil) - if writeErr := w.WriteMsg(resp); writeErr != nil { - logger.Errorf("failed to write failure DNS response: %v", writeErr) - } + f.writeResponse(logger, w, resp, domain, startTime) return } @@ -355,9 +345,7 @@ func (f *DNSForwarder) handleDNSError( logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName) resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...) resp.Rcode = dns.RcodeSuccess - if writeErr := w.WriteMsg(resp); writeErr != nil { - logger.Errorf("failed to write cached DNS response: %v", writeErr) - } + f.writeResponse(logger, w, resp, domain, startTime) return } @@ -365,9 +353,7 @@ func (f *DNSForwarder) handleDNSError( verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType) if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess { resp.Rcode = verifyResult.Rcode - if writeErr := w.WriteMsg(resp); writeErr != nil { - logger.Errorf("failed to write failure DNS response: %v", writeErr) - } + f.writeResponse(logger, w, resp, domain, startTime) return } } @@ -375,15 +361,12 @@ func (f *DNSForwarder) handleDNSError( // No cache or verification failed. Log with or without the server field for more context. var dnsErr *net.DNSError if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" { - logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err) + logger.Warnf("upstream failure: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err) } else { logger.Warnf(errResolveFailed, domain, result.Err) } - // Write final failure response. - if writeErr := w.WriteMsg(resp); writeErr != nil { - logger.Errorf("failed to write failure DNS response: %v", writeErr) - } + f.writeResponse(logger, w, resp, domain, startTime) } // getMatchingEntries retrieves the resource IDs for a given domain. diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index 6416c2f21..7325ef8a7 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -318,8 +318,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) { query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA) mockWriter := &test.MockResponseWriter{} - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) + resp := mockWriter.GetLastResponse() if tt.shouldResolve { require.NotNil(t, resp, "Expected response for authorized domain") require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response") @@ -329,10 +330,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) { mockFirewall.AssertExpectations(t) mockResolver.AssertExpectations(t) } else { - if resp != nil { - assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess, - "Unauthorized domain should not return successful answers") - } + require.NotNil(t, resp, "Expected response") + assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess, + "Unauthorized domain should not return successful answers") mockFirewall.AssertNotCalled(t, "UpdateSet") mockResolver.AssertNotCalled(t, "LookupNetIP") } @@ -466,14 +466,16 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) { dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA) mockWriter := &test.MockResponseWriter{} - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery, time.Now()) // Verify response + resp := mockWriter.GetLastResponse() if tt.shouldResolve { require.NotNil(t, resp, "Expected response for authorized domain") require.Equal(t, dns.RcodeSuccess, resp.Rcode) require.NotEmpty(t, resp.Answer) - } else if resp != nil { + } else { + require.NotNil(t, resp, "Expected response") assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0, "Unauthorized domain should be refused or have no answers") } @@ -528,9 +530,10 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) { query.SetQuestion("example.com.", dns.TypeA) mockWriter := &test.MockResponseWriter{} - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) // Verify response contains all IPs + resp := mockWriter.GetLastResponse() require.NotNil(t, resp) require.Equal(t, dns.RcodeSuccess, resp.Rcode) require.Len(t, resp.Answer, 3, "Should have 3 answer records") @@ -605,7 +608,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) { }, } - _ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) // Check the response written to the writer require.NotNil(t, writtenResp, "Expected response to be written") @@ -675,7 +678,8 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { q1 := &dns.Msg{} q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) w1 := &test.MockResponseWriter{} - resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now()) + resp1 := w1.GetLastResponse() require.NotNil(t, resp1) require.Equal(t, dns.RcodeSuccess, resp1.Rcode) require.Len(t, resp1.Answer, 1) @@ -683,13 +687,13 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { // Second query: serve from cache after upstream failure q2 := &dns.Msg{} q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) - var writtenResp *dns.Msg - w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} - _ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2) + w2 := &test.MockResponseWriter{} + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now()) - require.NotNil(t, writtenResp, "expected response to be written") - require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) - require.Len(t, writtenResp.Answer, 1) + resp2 := w2.GetLastResponse() + require.NotNil(t, resp2, "expected response to be written") + require.Equal(t, dns.RcodeSuccess, resp2.Rcode) + require.Len(t, resp2.Answer, 1) mockResolver.AssertExpectations(t) } @@ -715,7 +719,8 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { q1 := &dns.Msg{} q1.SetQuestion(mixedQuery+".", dns.TypeA) w1 := &test.MockResponseWriter{} - resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now()) + resp1 := w1.GetLastResponse() require.NotNil(t, resp1) require.Equal(t, dns.RcodeSuccess, resp1.Rcode) require.Len(t, resp1.Answer, 1) @@ -727,13 +732,13 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { q2 := &dns.Msg{} q2.SetQuestion("EXAMPLE.COM", dns.TypeA) - var writtenResp *dns.Msg - w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} - _ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2) + w2 := &test.MockResponseWriter{} + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now()) - require.NotNil(t, writtenResp) - require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) - require.Len(t, writtenResp.Answer, 1) + resp2 := w2.GetLastResponse() + require.NotNil(t, resp2) + require.Equal(t, dns.RcodeSuccess, resp2.Rcode) + require.Len(t, resp2.Answer, 1) mockResolver.AssertExpectations(t) } @@ -784,8 +789,9 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { query.SetQuestion("smtp.mail.example.com.", dns.TypeA) mockWriter := &test.MockResponseWriter{} - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) + resp := mockWriter.GetLastResponse() require.NotNil(t, resp) assert.Equal(t, dns.RcodeSuccess, resp.Rcode) @@ -897,26 +903,15 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) { query := &dns.Msg{} query.SetQuestion(dns.Fqdn("example.com"), tt.queryType) - var writtenResp *dns.Msg - mockWriter := &test.MockResponseWriter{ - WriteMsgFunc: func(m *dns.Msg) error { - writtenResp = m - return nil - }, - } + mockWriter := &test.MockResponseWriter{} + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) - - // If a response was returned, it means it should be written (happens in wrapper functions) - if resp != nil && writtenResp == nil { - writtenResp = resp - } - - require.NotNil(t, writtenResp, "Expected response to be written") - assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description) + resp := mockWriter.GetLastResponse() + require.NotNil(t, resp, "Expected response to be written") + assert.Equal(t, tt.expectedCode, resp.Rcode, tt.description) if tt.expectNoAnswer { - assert.Empty(t, writtenResp.Answer, "Response should have no answer records") + assert.Empty(t, resp.Answer, "Response should have no answer records") } mockResolver.AssertExpectations(t) @@ -931,15 +926,8 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) { query := &dns.Msg{} // Don't set any question - writeCalled := false - mockWriter := &test.MockResponseWriter{ - WriteMsgFunc: func(m *dns.Msg) error { - writeCalled = true - return nil - }, - } - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) + mockWriter := &test.MockResponseWriter{} + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) - assert.Nil(t, resp, "Should return nil for empty query") - assert.False(t, writeCalled, "Should not write response for empty query") + assert.Nil(t, mockWriter.GetLastResponse(), "Should not write response for empty query") } From 0119f3e9f4f87e49af92c3ee5a4a79c7eebe9a1d Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 6 Feb 2026 17:03:01 +0800 Subject: [PATCH 20/27] [client] Fix netstack detection and add wireguard port option (#5251) - Add WireguardPort option to embed.Options for custom port configuration - Fix KernelInterface detection to account for netstack mode - Skip SSH config updates when running in netstack mode - Skip interface removal wait when running in netstack mode - Use BindListener for netstack to avoid port conflicts on same host --- client/embed/embed.go | 3 +++ client/iface/iface.go | 5 +++++ client/internal/connect.go | 3 ++- client/internal/engine.go | 2 +- client/internal/engine_ssh.go | 9 +++++++++ client/internal/lazyconn/activity/manager.go | 8 +++++--- 6 files changed, 25 insertions(+), 5 deletions(-) diff --git a/client/embed/embed.go b/client/embed/embed.go index e73f37e35..2ad025ff0 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -71,6 +71,8 @@ type Options struct { DisableClientRoutes bool // BlockInbound blocks all inbound connections from peers BlockInbound bool + // WireguardPort is the port for the WireGuard interface. Use 0 for a random port. + WireguardPort *int } // validateCredentials checks that exactly one credential type is provided @@ -140,6 +142,7 @@ func New(opts Options) (*Client, error) { DisableServerRoutes: &t, DisableClientRoutes: &opts.DisableClientRoutes, BlockInbound: &opts.BlockInbound, + WireguardPort: opts.WireguardPort, } if opts.ConfigPath != "" { config, err = profilemanager.UpdateOrCreateConfig(input) diff --git a/client/iface/iface.go b/client/iface/iface.go index e5623c979..9b331d68c 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -18,6 +18,7 @@ import ( "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" @@ -228,6 +229,10 @@ func (w *WGIface) Close() error { result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)) } + if nbnetstack.IsEnabled() { + return errors.FormatErrorOrNil(result) + } + if err := w.waitUntilRemoved(); err != nil { log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err) if err := w.Destroy(); err != nil { diff --git a/client/internal/connect.go b/client/internal/connect.go index 7fc3c9a96..17fc20c42 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" @@ -244,7 +245,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan localPeerState := peer.LocalPeerState{ IP: loginResp.GetPeerConfig().GetAddress(), PubKey: myPrivateKey.PublicKey().String(), - KernelInterface: device.WireGuardModuleIsLoaded(), + KernelInterface: device.WireGuardModuleIsLoaded() && !netstack.IsEnabled(), FQDN: loginResp.GetPeerConfig().GetFqdn(), } c.statusRecorder.UpdateLocalPeerState(localPeerState) diff --git a/client/internal/engine.go b/client/internal/engine.go index 63ba1c9f2..597ac7c2d 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1017,7 +1017,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { state := e.statusRecorder.GetLocalPeerState() state.IP = e.wgInterface.Address().String() state.PubKey = e.config.WgPrivateKey.PublicKey().String() - state.KernelInterface = device.WireGuardModuleIsLoaded() + state.KernelInterface = !e.wgInterface.IsUserspaceBind() state.FQDN = conf.GetFqdn() e.statusRecorder.UpdateLocalPeerState(state) diff --git a/client/internal/engine_ssh.go b/client/internal/engine_ssh.go index a8c05fe0a..1419bc262 100644 --- a/client/internal/engine_ssh.go +++ b/client/internal/engine_ssh.go @@ -10,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" firewallManager "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface/netstack" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" sshauth "github.com/netbirdio/netbird/client/ssh/auth" sshconfig "github.com/netbirdio/netbird/client/ssh/config" @@ -94,6 +95,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { // updateSSHClientConfig updates the SSH client configuration with peer information func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error { + if netstack.IsEnabled() { + return nil + } + peerInfo := e.extractPeerSSHInfo(remotePeers) if len(peerInfo) == 0 { log.Debug("no SSH-enabled peers found, skipping SSH config update") @@ -216,6 +221,10 @@ func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) { // cleanupSSHConfig removes NetBird SSH client configuration on shutdown func (e *Engine) cleanupSSHConfig() { + if netstack.IsEnabled() { + return + } + configMgr := sshconfig.New() if err := configMgr.RemoveSSHClientConfig(); err != nil { diff --git a/client/internal/lazyconn/activity/manager.go b/client/internal/lazyconn/activity/manager.go index db283ec9a..1c11378c8 100644 --- a/client/internal/lazyconn/activity/manager.go +++ b/client/internal/lazyconn/activity/manager.go @@ -11,6 +11,7 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/lazyconn" peerid "github.com/netbirdio/netbird/client/internal/peer/id" @@ -74,12 +75,13 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error) return NewUDPListener(m.wgIface, peerCfg) } - // BindListener is only used on Windows and JS platforms: + // BindListener is used on Windows, JS, and netstack platforms: // - JS: Cannot listen to UDP sockets // - Windows: IP_UNICAST_IF socket option forces packets out the interface the default // gateway points to, preventing them from reaching the loopback interface. - // BindListener bypasses this by passing data directly through the bind. - if runtime.GOOS != "windows" && runtime.GOOS != "js" { + // - Netstack: Allows multiple instances on the same host without port conflicts. + // BindListener bypasses these issues by passing data directly through the bind. + if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() { return NewUDPListener(m.wgIface, peerCfg) } From c3f176f34835151da3a3bf493960b3e8a7857421 Mon Sep 17 00:00:00 2001 From: eyJhb Date: Fri, 6 Feb 2026 11:23:36 +0100 Subject: [PATCH 21/27] [client] Fix wrong URL being logged for DefaultAdminURL (#5252) - DefaultManagementURL was being logged instead of DefaultAdminURL --- client/internal/profilemanager/config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index f2fda84e0..8f3ff8b11 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -252,7 +252,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { } if config.AdminURL == nil { - log.Infof("using default Admin URL %s", DefaultManagementURL) + log.Infof("using default Admin URL %s", DefaultAdminURL) config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL) if err != nil { return false, err From af8f730bdac9109b02f3432efd6b6d24a251cea9 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Fri, 6 Feb 2026 18:00:43 +0100 Subject: [PATCH 22/27] [management] check stream start time for connecting peer (#5267) --- management/internals/shared/grpc/server.go | 10 +++--- management/server/account.go | 6 ++-- management/server/account/manager.go | 4 +-- management/server/account_test.go | 33 +++++++++++++++---- management/server/mock_server/account_mock.go | 12 +++---- management/server/peer.go | 20 ++++++++--- 6 files changed, 58 insertions(+), 27 deletions(-) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index befcd2adf..98c68ebda 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -300,20 +300,18 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S metahash := metaHash(peerMeta, realIP.String()) s.loginFilter.addLogin(peerKey.String(), metahash) - peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) + peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, reqStart) if err != nil { log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) s.syncSem.Add(-1) return mapError(ctx, err) } - streamStartTime := time.Now().UTC() - err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv, dnsFwdPort) if err != nil { log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) - s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart) return err } @@ -321,7 +319,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S if err != nil { log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) - s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart) return err } @@ -338,7 +336,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S s.syncSem.Add(-1) - return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, streamStartTime) + return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, reqStart) } func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) { diff --git a/management/server/account.go b/management/server/account.go index 4f53415f5..a9f59773a 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1670,13 +1670,13 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain } -func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { +func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) if err != nil { return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err) } - err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID) + err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } @@ -1697,7 +1697,7 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account return nil } - err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID) + err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC()) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err) } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index eed7739da..1d25b0af7 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -58,7 +58,7 @@ type Manager interface { GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) - MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error + MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error @@ -114,7 +114,7 @@ type Manager interface { UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) - SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) + SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) diff --git a/management/server/account_test.go b/management/server/account_test.go index f3d98916c..443e6344e 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1881,7 +1881,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) require.NoError(t, err, "unable to get the account") - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC()) require.NoError(t, err, "unable to mark peer connected") _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ @@ -1952,7 +1952,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. require.NoError(t, err, "unable to get the account") // when we mark peer as connected, the peer login expiration routine should trigger - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC()) require.NoError(t, err, "unable to mark peer connected") failed := waitTimeout(wg, time.Second) @@ -1979,7 +1979,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) { require.NoError(t, err, "unable to add peer") t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) { - err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID) + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC()) require.NoError(t, err, "unable to mark peer connected") peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) @@ -1997,7 +1997,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) { }) t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) { - err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID) + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC()) require.NoError(t, err, "unable to mark peer connected") peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) @@ -2014,6 +2014,27 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) { require.True(t, peer.Status.Connected, "peer should remain connected because LastSeen > streamStartTime (zombie stream protection)") }) + + t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) { + node2SyncTime := time.Now().UTC() + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime) + require.NoError(t, err, "node 2 should connect peer") + + peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.True(t, peer.Status.Connected, "peer should be connected") + require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime") + + node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute) + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime) + require.NoError(t, err, "stale connect should not return error") + + peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.True(t, peer.Status.Connected, "peer should still be connected") + require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), + "LastSeen should NOT be overwritten by stale syncTime from blocked goroutine") + }) } func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) { @@ -2038,7 +2059,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test account, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "unable to get the account") - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC()) require.NoError(t, err, "unable to mark peer connected") wg := &sync.WaitGroup{} @@ -3231,7 +3252,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { b.ResetTimer() start := time.Now() for i := 0; i < b.N; i++ { - _, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}) + _, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}, time.Now().UTC()) assert.NoError(b, err) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index a4754d180..8471d0a94 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -37,8 +37,8 @@ type MockAccountManager struct { GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) - MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error - SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) + MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP, syncTime time.Time) error + SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) @@ -214,9 +214,9 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented") } -func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { +func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { if am.SyncAndMarkPeerFunc != nil { - return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP) + return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP, syncTime) } return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } @@ -322,9 +322,9 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth } // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface -func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error { +func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error { if am.MarkPeerConnectedFunc != nil { - return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP) + return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP, syncTime) } return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } diff --git a/management/server/peer.go b/management/server/peer.go index ab72d3051..a4bdc784d 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -103,11 +103,13 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc } // MarkPeerConnected marks peer as connected (true) or disconnected (false) -func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error { +// syncTime is used as the LastSeen timestamp and for stale request detection +func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error { var peer *nbpeer.Peer var settings *types.Settings var expired bool var err error + var skipped bool err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey) @@ -115,9 +117,19 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK return err } - expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID) + if connected && !syncTime.After(peer.Status.LastSeen) { + log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s >= syncTime=%s), skipping connect", + peer.ID, peer.Status.LastSeen.Format(time.RFC3339), syncTime.Format(time.RFC3339)) + skipped = true + return nil + } + + expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID, syncTime) return err }) + if skipped { + return nil + } if err != nil { return err } @@ -147,10 +159,10 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK return nil } -func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) { +func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string, syncTime time.Time) (bool, error) { oldStatus := peer.Status.Copy() newStatus := oldStatus - newStatus.LastSeen = time.Now().UTC() + newStatus.LastSeen = syncTime newStatus.Connected = connected // whenever peer got connected that means that it logged in successfully if newStatus.Connected { From 3be16d19a0d5167d2ceeaac321b987b17f3120f4 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 6 Feb 2026 19:47:38 +0100 Subject: [PATCH 23/27] [management] Feature/grpc debounce msgtype (#5239) * Add gRPC update debouncing mechanism Implements backpressure handling for peer network map updates to efficiently handle rapid changes. First update is sent immediately, subsequent rapid updates are coalesced, ensuring only the latest update is sent after a 1-second quiet period. * Enhance unit test to verify peer count synchronization with debouncing and timeout handling * Debounce based on type * Refactor test to validate timer restart after pending update dispatch * Simplify timer reset for Go 1.23+ automatic channel draining Remove manual channel drain in resetTimer() since Go 1.23+ automatically drains the timer channel when Stop() returns false, making the select-case pattern unnecessary. --- .../network_map/controller/controller.go | 11 +- .../update_channel/updatechannel_test.go | 22 +- .../controllers/network_map/update_message.go | 15 +- management/internals/shared/grpc/server.go | 33 +- management/internals/shared/grpc/token_mgr.go | 10 +- .../internals/shared/grpc/update_debouncer.go | 103 +++ .../shared/grpc/update_debouncer_test.go | 587 ++++++++++++++++++ management/server/management_test.go | 59 +- 8 files changed, 818 insertions(+), 22 deletions(-) create mode 100644 management/internals/shared/grpc/update_debouncer.go create mode 100644 management/internals/shared/grpc/update_debouncer_test.go diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index 5ae64e9f1..3e28e1380 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -247,7 +247,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) c.metrics.CountToSyncResponseDuration(time.Since(start)) - c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update}) + c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{ + Update: update, + MessageType: network_map.MessageTypeNetworkMap, + }) }(peer) } @@ -370,7 +373,10 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) - c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update}) + c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{ + Update: update, + MessageType: network_map.MessageTypeNetworkMap, + }) return nil } @@ -778,6 +784,7 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI }, }, }, + MessageType: network_map.MessageTypeNetworkMap, }) c.peersUpdateManager.CloseChannel(ctx, peerID) diff --git a/management/internals/controllers/network_map/update_channel/updatechannel_test.go b/management/internals/controllers/network_map/update_channel/updatechannel_test.go index afc1e2c32..c73baf81f 100644 --- a/management/internals/controllers/network_map/update_channel/updatechannel_test.go +++ b/management/internals/controllers/network_map/update_channel/updatechannel_test.go @@ -25,11 +25,14 @@ func TestCreateChannel(t *testing.T) { func TestSendUpdate(t *testing.T) { peer := "test-sendupdate" peersUpdater := NewPeersUpdateManager(nil) - update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{ - Serial: 0, + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{ + Serial: 0, + }, }, - }} + MessageType: network_map.MessageTypeNetworkMap, + } _ = peersUpdater.CreateChannel(context.Background(), peer) if _, ok := peersUpdater.peerChannels[peer]; !ok { t.Error("Error creating the channel") @@ -45,11 +48,14 @@ func TestSendUpdate(t *testing.T) { peersUpdater.SendUpdate(context.Background(), peer, update1) } - update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{ - Serial: 10, + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{ + Serial: 10, + }, }, - }} + MessageType: network_map.MessageTypeNetworkMap, + } peersUpdater.SendUpdate(context.Background(), peer, update2) timeout := time.After(5 * time.Second) diff --git a/management/internals/controllers/network_map/update_message.go b/management/internals/controllers/network_map/update_message.go index 33643bcbd..0ffddf8b2 100644 --- a/management/internals/controllers/network_map/update_message.go +++ b/management/internals/controllers/network_map/update_message.go @@ -4,6 +4,19 @@ import ( "github.com/netbirdio/netbird/shared/management/proto" ) +// MessageType indicates the type of update message for debouncing strategy +type MessageType int + +const ( + // MessageTypeNetworkMap represents network map updates (peers, routes, DNS, firewall) + // These updates can be safely debounced - only the latest state matters + MessageTypeNetworkMap MessageType = iota + // MessageTypeControlConfig represents control/config updates (tokens, peer expiration) + // These updates should not be dropped as they contain time-sensitive information + MessageTypeControlConfig +) + type UpdateMessage struct { - Update *proto.SyncResponse + Update *proto.SyncResponse + MessageType MessageType } diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 98c68ebda..ff9d7ea05 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -404,11 +404,20 @@ func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgt } // handleUpdates sends updates to the connected peer until the updates channel is closed. +// It implements a backpressure mechanism that sends the first update immediately, +// then debounces subsequent rapid updates, ensuring only the latest update is sent +// after a quiet period. func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error { log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String()) + + // Create a debouncer for this peer connection + debouncer := NewUpdateDebouncer(1000 * time.Millisecond) + defer debouncer.Stop() + for { select { // condition when there are some updates + // todo set the updates channel size to 1 case update, open := <-updates: if s.appMetrics != nil { s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1) @@ -419,10 +428,28 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return nil } + log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) - if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil { - log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) - return err + if debouncer.ProcessUpdate(update) { + // Send immediately (first update or after quiet period) + if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil { + log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) + return err + } + } + + // Timer expired - quiet period reached, send pending updates if any + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) == 0 { + continue + } + log.WithContext(ctx).Debugf("sending %d debounced update(s) for peer %s", len(pendingUpdates), peerKey.String()) + for _, pendingUpdate := range pendingUpdates { + if err := s.sendUpdate(ctx, accountID, peerKey, peer, pendingUpdate, srv, streamStartTime); err != nil { + log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) + return err + } } // condition when client <-> server connection has been terminated diff --git a/management/internals/shared/grpc/token_mgr.go b/management/internals/shared/grpc/token_mgr.go index ccb32202f..65e58ad41 100644 --- a/management/internals/shared/grpc/token_mgr.go +++ b/management/internals/shared/grpc/token_mgr.go @@ -242,7 +242,10 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont m.extendNetbirdConfig(ctx, peerID, accountID, update) log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID) - m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update}) + m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{ + Update: update, + MessageType: network_map.MessageTypeControlConfig, + }) } func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) { @@ -266,7 +269,10 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac m.extendNetbirdConfig(ctx, peerID, accountID, update) log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID) - m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update}) + m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{ + Update: update, + MessageType: network_map.MessageTypeControlConfig, + }) } func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) { diff --git a/management/internals/shared/grpc/update_debouncer.go b/management/internals/shared/grpc/update_debouncer.go new file mode 100644 index 000000000..8af9c2656 --- /dev/null +++ b/management/internals/shared/grpc/update_debouncer.go @@ -0,0 +1,103 @@ +package grpc + +import ( + "time" + + "github.com/netbirdio/netbird/management/internals/controllers/network_map" +) + +// UpdateDebouncer implements a backpressure mechanism that: +// - Sends the first update immediately +// - Coalesces rapid subsequent network map updates (only latest matters) +// - Queues control/config updates (all must be delivered) +// - Preserves the order of messages (important for control configs between network maps) +// - Ensures pending updates are sent after a quiet period +type UpdateDebouncer struct { + debounceInterval time.Duration + timer *time.Timer + pendingUpdates []*network_map.UpdateMessage // Queue that preserves order + timerC <-chan time.Time +} + +// NewUpdateDebouncer creates a new debouncer with the specified interval +func NewUpdateDebouncer(interval time.Duration) *UpdateDebouncer { + return &UpdateDebouncer{ + debounceInterval: interval, + } +} + +// ProcessUpdate handles an incoming update and returns whether it should be sent immediately +func (d *UpdateDebouncer) ProcessUpdate(update *network_map.UpdateMessage) bool { + if d.timer == nil { + // No active debounce timer, signal to send immediately + // and start the debounce period + d.startTimer() + return true + } + + // Already in debounce period, accumulate this update preserving order + // Check if we should coalesce with the last pending update + if len(d.pendingUpdates) > 0 && + update.MessageType == network_map.MessageTypeNetworkMap && + d.pendingUpdates[len(d.pendingUpdates)-1].MessageType == network_map.MessageTypeNetworkMap { + // Replace the last network map with this one (coalesce consecutive network maps) + d.pendingUpdates[len(d.pendingUpdates)-1] = update + } else { + // Append to the queue (preserves order for control configs and non-consecutive network maps) + d.pendingUpdates = append(d.pendingUpdates, update) + } + d.resetTimer() + return false +} + +// TimerChannel returns the timer channel for select statements +func (d *UpdateDebouncer) TimerChannel() <-chan time.Time { + if d.timer == nil { + return nil + } + return d.timerC +} + +// GetPendingUpdates returns and clears all pending updates after timer expiration. +// Updates are returned in the order they were received, with consecutive network maps +// already coalesced to only the latest one. +// If there were pending updates, it restarts the timer to continue debouncing. +// If there were no pending updates, it clears the timer (true quiet period). +func (d *UpdateDebouncer) GetPendingUpdates() []*network_map.UpdateMessage { + updates := d.pendingUpdates + d.pendingUpdates = nil + + if len(updates) > 0 { + // There were pending updates, so updates are still coming rapidly + // Restart the timer to continue debouncing mode + if d.timer != nil { + d.timer.Reset(d.debounceInterval) + } + } else { + // No pending updates means true quiet period - return to immediate mode + d.timer = nil + d.timerC = nil + } + + return updates +} + +// Stop stops the debouncer and cleans up resources +func (d *UpdateDebouncer) Stop() { + if d.timer != nil { + d.timer.Stop() + d.timer = nil + d.timerC = nil + } + d.pendingUpdates = nil +} + +func (d *UpdateDebouncer) startTimer() { + d.timer = time.NewTimer(d.debounceInterval) + d.timerC = d.timer.C +} + +func (d *UpdateDebouncer) resetTimer() { + d.timer.Stop() + d.timer.Reset(d.debounceInterval) +} diff --git a/management/internals/shared/grpc/update_debouncer_test.go b/management/internals/shared/grpc/update_debouncer_test.go new file mode 100644 index 000000000..075994a2d --- /dev/null +++ b/management/internals/shared/grpc/update_debouncer_test.go @@ -0,0 +1,587 @@ +package grpc + +import ( + "testing" + "time" + + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/shared/management/proto" +) + +func TestUpdateDebouncer_FirstUpdateSentImmediately(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + update := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + shouldSend := debouncer.ProcessUpdate(update) + + if !shouldSend { + t.Error("First update should be sent immediately") + } + + if debouncer.TimerChannel() == nil { + t.Error("Timer should be started after first update") + } +} + +func TestUpdateDebouncer_RapidUpdatesCoalesced(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update3 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // First update should be sent immediately + if !debouncer.ProcessUpdate(update1) { + t.Error("First update should be sent immediately") + } + + // Rapid subsequent updates should be coalesced + if debouncer.ProcessUpdate(update2) { + t.Error("Second rapid update should not be sent immediately") + } + + if debouncer.ProcessUpdate(update3) { + t.Error("Third rapid update should not be sent immediately") + } + + // Wait for debounce period + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) != 1 { + t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates)) + } + if pendingUpdates[0] != update3 { + t.Error("Should get the last update (update3)") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_LastUpdateAlwaysSent(t *testing.T) { + debouncer := NewUpdateDebouncer(30 * time.Millisecond) + defer debouncer.Stop() + + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // Send first update + debouncer.ProcessUpdate(update1) + + // Send second update within debounce period + debouncer.ProcessUpdate(update2) + + // Wait for timer + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) != 1 { + t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates)) + } + if pendingUpdates[0] != update2 { + t.Error("Should get the last update") + } + if pendingUpdates[0] == update1 { + t.Error("Should not get the first update") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_TimerResetOnNewUpdate(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update3 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // Send first update + debouncer.ProcessUpdate(update1) + + // Wait a bit, but not the full debounce period + time.Sleep(30 * time.Millisecond) + + // Send second update - should reset timer + debouncer.ProcessUpdate(update2) + + // Wait a bit more + time.Sleep(30 * time.Millisecond) + + // Send third update - should reset timer again + debouncer.ProcessUpdate(update3) + + // Now wait for the timer (should fire after last update's reset) + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) != 1 { + t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates)) + } + if pendingUpdates[0] != update3 { + t.Error("Should get the last update (update3)") + } + // Timer should be restarted since there was a pending update + if debouncer.TimerChannel() == nil { + t.Error("Timer should be restarted after sending pending update") + } + case <-time.After(150 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_TimerRestartsAfterPendingUpdateSent(t *testing.T) { + debouncer := NewUpdateDebouncer(30 * time.Millisecond) + defer debouncer.Stop() + + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update3 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // First update sent immediately + debouncer.ProcessUpdate(update1) + + // Second update coalesced + debouncer.ProcessUpdate(update2) + + // Wait for timer to expire + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + + if len(pendingUpdates) == 0 { + t.Fatal("Should have pending update") + } + + // After sending pending update, timer is restarted, so next update is NOT immediate + if debouncer.ProcessUpdate(update3) { + t.Error("Update after debounced send should not be sent immediately (timer restarted)") + } + + // Wait for the restarted timer and verify update3 is pending + select { + case <-debouncer.TimerChannel(): + finalUpdates := debouncer.GetPendingUpdates() + if len(finalUpdates) != 1 || finalUpdates[0] != update3 { + t.Error("Should get update3 as pending") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timer should have fired for restarted timer") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_StopCleansUp(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + + update := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // Send update to start timer + debouncer.ProcessUpdate(update) + + // Stop should clean up + debouncer.Stop() + + // Multiple stops should be safe + debouncer.Stop() +} + +func TestUpdateDebouncer_HighFrequencyUpdates(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + // Simulate high-frequency updates + var lastUpdate *network_map.UpdateMessage + sentImmediately := 0 + for i := 0; i < 100; i++ { + update := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{ + Serial: uint64(i), + }, + }, + MessageType: network_map.MessageTypeNetworkMap, + } + lastUpdate = update + if debouncer.ProcessUpdate(update) { + sentImmediately++ + } + time.Sleep(1 * time.Millisecond) // Very rapid updates + } + + // Only first update should be sent immediately + if sentImmediately != 1 { + t.Errorf("Expected only 1 update sent immediately, got %d", sentImmediately) + } + + // Wait for debounce period + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) != 1 { + t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates)) + } + if pendingUpdates[0] != lastUpdate { + t.Error("Should get the very last update") + } + if pendingUpdates[0].Update.NetworkMap.Serial != 99 { + t.Errorf("Expected serial 99, got %d", pendingUpdates[0].Update.NetworkMap.Serial) + } + case <-time.After(200 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_NoUpdatesAfterFirst(t *testing.T) { + debouncer := NewUpdateDebouncer(30 * time.Millisecond) + defer debouncer.Stop() + + update := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // Send first update + if !debouncer.ProcessUpdate(update) { + t.Error("First update should be sent immediately") + } + + // Wait for timer to expire with no additional updates (true quiet period) + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) != 0 { + t.Error("Should have no pending updates") + } + // After true quiet period, timer should be cleared + if debouncer.TimerChannel() != nil { + t.Error("Timer should be cleared after quiet period") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_IntermediateUpdatesDropped(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + updates := make([]*network_map.UpdateMessage, 5) + for i := range updates { + updates[i] = &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{ + Serial: uint64(i), + }, + }, + MessageType: network_map.MessageTypeNetworkMap, + } + } + + // First update sent immediately + debouncer.ProcessUpdate(updates[0]) + + // Send updates 1, 2, 3, 4 rapidly - only last one should remain pending + debouncer.ProcessUpdate(updates[1]) + debouncer.ProcessUpdate(updates[2]) + debouncer.ProcessUpdate(updates[3]) + debouncer.ProcessUpdate(updates[4]) + + // Wait for debounce + <-debouncer.TimerChannel() + pendingUpdates := debouncer.GetPendingUpdates() + + if len(pendingUpdates) != 1 { + t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates)) + } + if pendingUpdates[0].Update.NetworkMap.Serial != 4 { + t.Errorf("Expected only the last update (serial 4), got serial %d", pendingUpdates[0].Update.NetworkMap.Serial) + } +} + +func TestUpdateDebouncer_TrueQuietPeriodResetsToImmediateMode(t *testing.T) { + debouncer := NewUpdateDebouncer(30 * time.Millisecond) + defer debouncer.Stop() + + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // First update sent immediately + if !debouncer.ProcessUpdate(update1) { + t.Error("First update should be sent immediately") + } + + // Wait for timer without sending any more updates (true quiet period) + <-debouncer.TimerChannel() + pendingUpdates := debouncer.GetPendingUpdates() + + if len(pendingUpdates) != 0 { + t.Error("Should have no pending updates during quiet period") + } + + // After true quiet period, next update should be sent immediately + if !debouncer.ProcessUpdate(update2) { + t.Error("Update after true quiet period should be sent immediately") + } +} + +func TestUpdateDebouncer_ContinuousHighFrequencyStaysInDebounceMode(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + // Simulate continuous high-frequency updates + for i := 0; i < 10; i++ { + update := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{ + Serial: uint64(i), + }, + }, + MessageType: network_map.MessageTypeNetworkMap, + } + + if i == 0 { + // First one sent immediately + if !debouncer.ProcessUpdate(update) { + t.Error("First update should be sent immediately") + } + } else { + // All others should be coalesced (not sent immediately) + if debouncer.ProcessUpdate(update) { + t.Errorf("Update %d should not be sent immediately", i) + } + } + + // Wait a bit but send next update before debounce expires + time.Sleep(20 * time.Millisecond) + } + + // Now wait for final debounce + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) == 0 { + t.Fatal("Should have the last update pending") + } + if pendingUpdates[0].Update.NetworkMap.Serial != 9 { + t.Errorf("Expected serial 9, got %d", pendingUpdates[0].Update.NetworkMap.Serial) + } + case <-time.After(200 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_ControlConfigMessagesQueued(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + netmapUpdate := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}}, + MessageType: network_map.MessageTypeNetworkMap, + } + tokenUpdate1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}}, + MessageType: network_map.MessageTypeControlConfig, + } + tokenUpdate2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}}, + MessageType: network_map.MessageTypeControlConfig, + } + + // First update sent immediately + debouncer.ProcessUpdate(netmapUpdate) + + // Send multiple control config updates - they should all be queued + debouncer.ProcessUpdate(tokenUpdate1) + debouncer.ProcessUpdate(tokenUpdate2) + + // Wait for debounce period + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + // Should get both control config updates + if len(pendingUpdates) != 2 { + t.Errorf("Expected 2 control config updates, got %d", len(pendingUpdates)) + } + // Control configs should come first + if pendingUpdates[0] != tokenUpdate1 { + t.Error("First pending update should be tokenUpdate1") + } + if pendingUpdates[1] != tokenUpdate2 { + t.Error("Second pending update should be tokenUpdate2") + } + case <-time.After(200 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_MixedMessageTypes(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + netmapUpdate1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}}, + MessageType: network_map.MessageTypeNetworkMap, + } + netmapUpdate2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 2}}, + MessageType: network_map.MessageTypeNetworkMap, + } + tokenUpdate := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}}, + MessageType: network_map.MessageTypeControlConfig, + } + + // First update sent immediately + debouncer.ProcessUpdate(netmapUpdate1) + + // Send token update and network map update + debouncer.ProcessUpdate(tokenUpdate) + debouncer.ProcessUpdate(netmapUpdate2) + + // Wait for debounce period + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + // Should get 2 updates in order: token, then network map + if len(pendingUpdates) != 2 { + t.Errorf("Expected 2 pending updates, got %d", len(pendingUpdates)) + } + // Token update should come first (preserves order) + if pendingUpdates[0] != tokenUpdate { + t.Error("First pending update should be tokenUpdate") + } + // Network map update should come second + if pendingUpdates[1] != netmapUpdate2 { + t.Error("Second pending update should be netmapUpdate2") + } + case <-time.After(200 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_OrderPreservation(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + // Simulate: 50 network maps -> 1 control config -> 50 network maps + // Expected result: 3 messages (netmap, controlConfig, netmap) + + // Send first network map immediately + firstNetmap := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 0}}, + MessageType: network_map.MessageTypeNetworkMap, + } + if !debouncer.ProcessUpdate(firstNetmap) { + t.Error("First update should be sent immediately") + } + + // Send 49 more network maps (will be coalesced to last one) + var lastNetmapBatch1 *network_map.UpdateMessage + for i := 1; i < 50; i++ { + lastNetmapBatch1 = &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}}, + MessageType: network_map.MessageTypeNetworkMap, + } + debouncer.ProcessUpdate(lastNetmapBatch1) + } + + // Send 1 control config + controlConfig := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}}, + MessageType: network_map.MessageTypeControlConfig, + } + debouncer.ProcessUpdate(controlConfig) + + // Send 50 more network maps (will be coalesced to last one) + var lastNetmapBatch2 *network_map.UpdateMessage + for i := 50; i < 100; i++ { + lastNetmapBatch2 = &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}}, + MessageType: network_map.MessageTypeNetworkMap, + } + debouncer.ProcessUpdate(lastNetmapBatch2) + } + + // Wait for debounce period + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + // Should get exactly 3 updates: netmap, controlConfig, netmap + if len(pendingUpdates) != 3 { + t.Errorf("Expected 3 pending updates, got %d", len(pendingUpdates)) + } + // First should be the last netmap from batch 1 + if pendingUpdates[0] != lastNetmapBatch1 { + t.Error("First pending update should be last netmap from batch 1") + } + if pendingUpdates[0].Update.NetworkMap.Serial != 49 { + t.Errorf("Expected serial 49, got %d", pendingUpdates[0].Update.NetworkMap.Serial) + } + // Second should be the control config + if pendingUpdates[1] != controlConfig { + t.Error("Second pending update should be control config") + } + // Third should be the last netmap from batch 2 + if pendingUpdates[2] != lastNetmapBatch2 { + t.Error("Third pending update should be last netmap from batch 2") + } + if pendingUpdates[2].Update.NetworkMap.Serial != 99 { + t.Errorf("Expected serial 99, got %d", pendingUpdates[2].Update.NetworkMap.Serial) + } + case <-time.After(200 * time.Millisecond): + t.Error("Timer should have fired") + } +} diff --git a/management/server/management_test.go b/management/server/management_test.go index 0864baadf..de02855bf 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -610,6 +610,7 @@ func TestSync10PeersGetUpdates(t *testing.T) { initialPeers := 10 additionalPeers := 10 + expectedPeerCount := initialPeers + additionalPeers - 1 // -1 because peer doesn't see itself var peers []wgtypes.Key for i := 0; i < initialPeers; i++ { @@ -618,8 +619,19 @@ func TestSync10PeersGetUpdates(t *testing.T) { peers = append(peers, key) } + // Track the maximum peer count each peer has seen + type peerState struct { + mu sync.Mutex + maxPeerCount int + done bool + } + peerStates := make(map[string]*peerState) + for _, pk := range peers { + peerStates[pk.PublicKey().String()] = &peerState{} + } + var wg sync.WaitGroup - wg.Add(initialPeers + initialPeers*additionalPeers) + wg.Add(initialPeers) // One completion per initial peer var syncClients []mgmtProto.ManagementService_SyncClient for _, pk := range peers { @@ -643,6 +655,9 @@ func TestSync10PeersGetUpdates(t *testing.T) { syncClients = append(syncClients, s) go func(pk wgtypes.Key, syncStream mgmtProto.ManagementService_SyncClient) { + pubKey := pk.PublicKey().String() + state := peerStates[pubKey] + for { encMsg := &mgmtProto.EncryptedMessage{} err := syncStream.RecvMsg(encMsg) @@ -651,19 +666,28 @@ func TestSync10PeersGetUpdates(t *testing.T) { } decryptedBytes, decErr := encryption.Decrypt(encMsg.Body, ts.serverPubKey, pk) if decErr != nil { - t.Errorf("failed to decrypt SyncResponse for peer %s: %v", pk.PublicKey().String(), decErr) + t.Errorf("failed to decrypt SyncResponse for peer %s: %v", pubKey, decErr) return } resp := &mgmtProto.SyncResponse{} umErr := pb.Unmarshal(decryptedBytes, resp) if umErr != nil { - t.Errorf("failed to unmarshal SyncResponse for peer %s: %v", pk.PublicKey().String(), umErr) + t.Errorf("failed to unmarshal SyncResponse for peer %s: %v", pubKey, umErr) return } - // We only count if there's a new peer update - if len(resp.GetRemotePeers()) > 0 { + + // Track the maximum peer count seen (due to debouncing, updates are coalesced) + peerCount := len(resp.GetRemotePeers()) + state.mu.Lock() + if peerCount > state.maxPeerCount { + state.maxPeerCount = peerCount + } + // Signal completion when this peer has seen all expected peers + if !state.done && state.maxPeerCount >= expectedPeerCount { + state.done = true wg.Done() } + state.mu.Unlock() } }(pk, s) } @@ -677,7 +701,30 @@ func TestSync10PeersGetUpdates(t *testing.T) { time.Sleep(time.Duration(n) * time.Millisecond) } - wg.Wait() + // Wait for debouncer to flush final updates (debounce interval is 1000ms) + time.Sleep(1500 * time.Millisecond) + + // Wait with timeout + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Success - all peers received expected peer count + case <-time.After(5 * time.Second): + // Timeout - report which peers didn't receive all updates + t.Error("Timeout waiting for all peers to receive updates") + for pubKey, state := range peerStates { + state.mu.Lock() + if state.maxPeerCount < expectedPeerCount { + t.Errorf("Peer %s only saw %d peers, expected %d", pubKey, state.maxPeerCount, expectedPeerCount) + } + state.mu.Unlock() + } + } for _, sc := range syncClients { err := sc.CloseSend() From 7bc85107eb6f76e250b874419092c3934ee8deff Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 6 Feb 2026 19:50:48 +0100 Subject: [PATCH 24/27] Adds timing measurement to handleSync to help diagnose sync performance issues (#5228) --- client/internal/engine.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/client/internal/engine.go b/client/internal/engine.go index 597ac7c2d..4dbd5f45e 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -828,6 +828,10 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate } func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { + started := time.Now() + defer func() { + log.Infof("sync finished in %s", time.Since(started)) + }() e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() From 391221a986813463345eed059d2593802fdcf823 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 9 Feb 2026 17:14:02 +0800 Subject: [PATCH 25/27] [client] Fix uspfilter duplicate firewall rules (#5269) --- client/firewall/uspfilter/allow_netbird.go | 34 +- .../uspfilter/allow_netbird_windows.go | 31 +- client/firewall/uspfilter/filter.go | 60 ++- .../uspfilter/filter_routeacl_test.go | 376 ++++++++++++++++++ client/firewall/uspfilter/filter_test.go | 152 +++++++ client/internal/acl/manager_test.go | 206 ++++++++++ 6 files changed, 791 insertions(+), 68 deletions(-) create mode 100644 client/firewall/uspfilter/filter_routeacl_test.go diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index 22e6fca1f..6a6533344 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -3,12 +3,6 @@ package uspfilter import ( - "context" - "net/netip" - "time" - - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -17,33 +11,7 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() - m.outgoingRules = make(map[netip.Addr]RuleSet) - m.incomingDenyRules = make(map[netip.Addr]RuleSet) - m.incomingRules = make(map[netip.Addr]RuleSet) - - if m.udpTracker != nil { - m.udpTracker.Close() - } - - if m.icmpTracker != nil { - m.icmpTracker.Close() - } - - if m.tcpTracker != nil { - m.tcpTracker.Close() - } - - if fwder := m.forwarder.Load(); fwder != nil { - fwder.Stop() - } - - if m.logger != nil { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - if err := m.logger.Stop(ctx); err != nil { - log.Errorf("failed to shutdown logger: %v", err) - } - } + m.resetState() if m.nativeFirewall != nil { return m.nativeFirewall.Close(stateManager) diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index 8a56b0862..6aef2ecfd 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -1,12 +1,9 @@ package uspfilter import ( - "context" "fmt" - "net/netip" "os/exec" "syscall" - "time" log "github.com/sirupsen/logrus" @@ -26,33 +23,7 @@ func (m *Manager) Close(*statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() - m.outgoingRules = make(map[netip.Addr]RuleSet) - m.incomingDenyRules = make(map[netip.Addr]RuleSet) - m.incomingRules = make(map[netip.Addr]RuleSet) - - if m.udpTracker != nil { - m.udpTracker.Close() - } - - if m.icmpTracker != nil { - m.icmpTracker.Close() - } - - if m.tcpTracker != nil { - m.tcpTracker.Close() - } - - if fwder := m.forwarder.Load(); fwder != nil { - fwder.Stop() - } - - if m.logger != nil { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - if err := m.logger.Stop(ctx); err != nil { - log.Errorf("failed to shutdown logger: %v", err) - } - } + m.resetState() if !isWindowsFirewallReachable() { return nil diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index aacc4ca1c..df2e274eb 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -1,6 +1,7 @@ package uspfilter import ( + "context" "encoding/binary" "errors" "fmt" @@ -12,11 +13,13 @@ import ( "strings" "sync" "sync/atomic" + "time" "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/google/uuid" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/common" @@ -24,6 +27,7 @@ import ( "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" "github.com/netbirdio/netbird/client/iface/netstack" + nbid "github.com/netbirdio/netbird/client/internal/acl/id" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -89,6 +93,7 @@ type Manager struct { incomingDenyRules map[netip.Addr]RuleSet incomingRules map[netip.Addr]RuleSet routeRules RouteRules + routeRulesMap map[nbid.RuleID]*RouteRule decoders sync.Pool wgIface common.IFaceMapper nativeFirewall firewall.Manager @@ -229,6 +234,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe flowLogger: flowLogger, netstack: netstack.IsEnabled(), localForwarding: enableLocalForwarding, + routeRulesMap: make(map[nbid.RuleID]*RouteRule), dnatMappings: make(map[netip.Addr]netip.Addr), portDNATRules: []portDNATRule{}, netstackServices: make(map[serviceKey]struct{}), @@ -480,11 +486,15 @@ func (m *Manager) addRouteFiltering( return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) } - ruleID := uuid.New().String() + ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) + + if existingRule, ok := m.routeRulesMap[ruleKey]; ok { + return existingRule, nil + } rule := RouteRule{ // TODO: consolidate these IDs - id: ruleID, + id: string(ruleKey), mgmtId: id, sources: sources, dstSet: destination.Set, @@ -499,6 +509,7 @@ func (m *Manager) addRouteFiltering( m.routeRules = append(m.routeRules, &rule) m.routeRules.Sort() + m.routeRulesMap[ruleKey] = &rule return &rule, nil } @@ -515,15 +526,20 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error { return m.nativeFirewall.DeleteRouteRule(rule) } - ruleID := rule.ID() + ruleKey := nbid.RuleID(rule.ID()) + if _, ok := m.routeRulesMap[ruleKey]; !ok { + return fmt.Errorf("route rule not found: %s", ruleKey) + } + idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool { - return r.id == ruleID + return r.id == string(ruleKey) }) if idx < 0 { - return fmt.Errorf("route rule not found: %s", ruleID) + return fmt.Errorf("route rule not found in slice: %s", ruleKey) } m.routeRules = slices.Delete(m.routeRules, idx, idx+1) + delete(m.routeRulesMap, ruleKey) return nil } @@ -570,6 +586,40 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error { // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } +// resetState clears all firewall rules and closes connection trackers. +// Must be called with m.mutex held. +func (m *Manager) resetState() { + maps.Clear(m.outgoingRules) + maps.Clear(m.incomingDenyRules) + maps.Clear(m.incomingRules) + maps.Clear(m.routeRulesMap) + m.routeRules = m.routeRules[:0] + + if m.udpTracker != nil { + m.udpTracker.Close() + } + + if m.icmpTracker != nil { + m.icmpTracker.Close() + } + + if m.tcpTracker != nil { + m.tcpTracker.Close() + } + + if fwder := m.forwarder.Load(); fwder != nil { + fwder.Stop() + } + + if m.logger != nil { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := m.logger.Stop(ctx); err != nil { + log.Errorf("failed to shutdown logger: %v", err) + } + } +} + // SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic. func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error { if m.nativeFirewall == nil { diff --git a/client/firewall/uspfilter/filter_routeacl_test.go b/client/firewall/uspfilter/filter_routeacl_test.go new file mode 100644 index 000000000..68572a01c --- /dev/null +++ b/client/firewall/uspfilter/filter_routeacl_test.go @@ -0,0 +1,376 @@ +package uspfilter + +import ( + "net/netip" + "testing" + + "github.com/golang/mock/gomock" + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + wgdevice "golang.zx2c4.com/wireguard/device" + + fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/mocks" + "github.com/netbirdio/netbird/client/iface/wgaddr" +) + +// TestAddRouteFilteringReturnsExistingRule verifies that adding the same route +// filtering rule twice returns the same rule ID (idempotent behavior). +func TestAddRouteFilteringReturnsExistingRule(t *testing.T) { + manager := setupTestManager(t) + + sources := []netip.Prefix{ + netip.MustParsePrefix("100.64.1.0/24"), + netip.MustParsePrefix("100.64.2.0/24"), + } + destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")} + + // Add rule first time + rule1, err := manager.AddRouteFiltering( + []byte("policy-1"), + sources, + destination, + fw.ProtocolTCP, + nil, + &fw.Port{Values: []uint16{443}}, + fw.ActionAccept, + ) + require.NoError(t, err) + require.NotNil(t, rule1) + + // Add the same rule again + rule2, err := manager.AddRouteFiltering( + []byte("policy-1"), + sources, + destination, + fw.ProtocolTCP, + nil, + &fw.Port{Values: []uint16{443}}, + fw.ActionAccept, + ) + require.NoError(t, err) + require.NotNil(t, rule2) + + // These should be the same (idempotent) like nftables/iptables implementations + assert.Equal(t, rule1.ID(), rule2.ID(), + "Adding the same rule twice should return the same rule ID (idempotent)") + + manager.mutex.RLock() + ruleCount := len(manager.routeRules) + manager.mutex.RUnlock() + + assert.Equal(t, 2, ruleCount, + "Should have exactly 2 rules (1 user rule + 1 block rule)") +} + +// TestAddRouteFilteringDifferentRulesGetDifferentIDs verifies that rules with +// different parameters get distinct IDs. +func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) { + manager := setupTestManager(t) + + sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")} + + // Add first rule + rule1, err := manager.AddRouteFiltering( + []byte("policy-1"), + sources, + fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}, + fw.ProtocolTCP, + nil, + &fw.Port{Values: []uint16{443}}, + fw.ActionAccept, + ) + require.NoError(t, err) + + // Add different rule (different destination) + rule2, err := manager.AddRouteFiltering( + []byte("policy-2"), + sources, + fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different! + fw.ProtocolTCP, + nil, + &fw.Port{Values: []uint16{443}}, + fw.ActionAccept, + ) + require.NoError(t, err) + + assert.NotEqual(t, rule1.ID(), rule2.ID(), + "Different rules should have different IDs") + + manager.mutex.RLock() + ruleCount := len(manager.routeRules) + manager.mutex.RUnlock() + + assert.Equal(t, 3, ruleCount, "Should have 3 rules (2 user rules + 1 block rule)") +} + +// TestRouteRuleUpdateDoesNotCauseGap verifies that re-adding the same route +// rule during a network map update does not disrupt existing traffic. +func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) { + manager := setupTestManager(t) + + sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")} + destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")} + + rule1, err := manager.AddRouteFiltering( + []byte("policy-1"), + sources, + destination, + fw.ProtocolTCP, + nil, + nil, + fw.ActionAccept, + ) + require.NoError(t, err) + + srcIP := netip.MustParseAddr("100.64.1.5") + dstIP := netip.MustParseAddr("192.168.1.10") + _, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443) + require.True(t, pass, "Traffic should pass with rule in place") + + // Re-add same rule (simulates network map update) + rule2, err := manager.AddRouteFiltering( + []byte("policy-1"), + sources, + destination, + fw.ProtocolTCP, + nil, + nil, + fw.ActionAccept, + ) + require.NoError(t, err) + + // Idempotent IDs mean rule1.ID() == rule2.ID(), so the ACL manager + // won't delete rule1 during cleanup. If IDs differed, deleting rule1 + // would remove the only matching rule and cause a traffic gap. + if rule1.ID() != rule2.ID() { + err = manager.DeleteRouteRule(rule1) + require.NoError(t, err) + } + + _, passAfter := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443) + assert.True(t, passAfter, + "Traffic should still pass after rule update - no gap should occur") +} + +// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates +// exactly one drop rule for the WireGuard network prefix, and calling it again +// returns the same rule without duplicating. +func TestBlockInvalidRoutedIdempotent(t *testing.T) { + ctrl := gomock.NewController(t) + dev := mocks.NewMockDevice(ctrl) + dev.EXPECT().MTU().Return(1500, nil).AnyTimes() + + wgNet := netip.MustParsePrefix("100.64.0.1/16") + + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: wgNet.Addr(), + Network: wgNet, + } + }, + GetDeviceFunc: func() *device.FilteredDevice { + return &device.FilteredDevice{Device: dev} + }, + GetWGDeviceFunc: func() *wgdevice.Device { + return &wgdevice.Device{} + }, + } + + manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, manager.Close(nil)) + }) + + // Call blockInvalidRouted directly multiple times + rule1, err := manager.blockInvalidRouted(ifaceMock) + require.NoError(t, err) + require.NotNil(t, rule1) + + rule2, err := manager.blockInvalidRouted(ifaceMock) + require.NoError(t, err) + require.NotNil(t, rule2) + + rule3, err := manager.blockInvalidRouted(ifaceMock) + require.NoError(t, err) + require.NotNil(t, rule3) + + // All should return the same rule + assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule") + assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule") + + // Should have exactly 1 route rule + manager.mutex.RLock() + ruleCount := len(manager.routeRules) + manager.mutex.RUnlock() + + assert.Equal(t, 1, ruleCount, "Should have exactly 1 block rule after 3 calls") + + // Verify the rule blocks traffic to the WG network + srcIP := netip.MustParseAddr("10.0.0.1") + dstIP := netip.MustParseAddr("100.64.0.50") + _, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80) + assert.False(t, pass, "Block rule should deny traffic to WG prefix") +} + +// TestBlockRuleNotAccumulatedOnRepeatedEnableRouting verifies that calling +// EnableRouting multiple times (as happens on each route update) does not +// accumulate duplicate block rules in the routeRules slice. +func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) { + ctrl := gomock.NewController(t) + dev := mocks.NewMockDevice(ctrl) + dev.EXPECT().MTU().Return(1500, nil).AnyTimes() + + wgNet := netip.MustParsePrefix("100.64.0.1/16") + + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: wgNet.Addr(), + Network: wgNet, + } + }, + GetDeviceFunc: func() *device.FilteredDevice { + return &device.FilteredDevice{Device: dev} + }, + GetWGDeviceFunc: func() *wgdevice.Device { + return &wgdevice.Device{} + }, + } + + manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, manager.Close(nil)) + }) + + // Call EnableRouting multiple times (simulating repeated route updates) + for i := 0; i < 5; i++ { + require.NoError(t, manager.EnableRouting()) + } + + manager.mutex.RLock() + ruleCount := len(manager.routeRules) + manager.mutex.RUnlock() + + assert.Equal(t, 1, ruleCount, + "Repeated EnableRouting should not accumulate block rules") +} + +// TestRouteRuleCountStableAcrossUpdates verifies that adding the same route +// rule multiple times does not create duplicate entries. +func TestRouteRuleCountStableAcrossUpdates(t *testing.T) { + manager := setupTestManager(t) + + sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")} + destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")} + + // Simulate 5 network map updates with the same route rule + for i := 0; i < 5; i++ { + rule, err := manager.AddRouteFiltering( + []byte("policy-1"), + sources, + destination, + fw.ProtocolTCP, + nil, + &fw.Port{Values: []uint16{443}}, + fw.ActionAccept, + ) + require.NoError(t, err) + require.NotNil(t, rule) + } + + manager.mutex.RLock() + ruleCount := len(manager.routeRules) + manager.mutex.RUnlock() + + assert.Equal(t, 2, ruleCount, + "Should have exactly 2 rules (1 user rule + 1 block rule) after 5 updates") +} + +// TestDeleteRouteRuleAfterIdempotentAdd verifies that deleting a route rule +// after adding it multiple times works correctly. +func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) { + manager := setupTestManager(t) + + sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")} + destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")} + + // Add same rule twice + rule1, err := manager.AddRouteFiltering( + []byte("policy-1"), + sources, + destination, + fw.ProtocolTCP, + nil, + nil, + fw.ActionAccept, + ) + require.NoError(t, err) + + rule2, err := manager.AddRouteFiltering( + []byte("policy-1"), + sources, + destination, + fw.ProtocolTCP, + nil, + nil, + fw.ActionAccept, + ) + require.NoError(t, err) + + require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID") + + // Delete using first reference + err = manager.DeleteRouteRule(rule1) + require.NoError(t, err) + + // Verify traffic no longer passes + srcIP := netip.MustParseAddr("100.64.1.5") + dstIP := netip.MustParseAddr("192.168.1.10") + _, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443) + assert.False(t, pass, "Traffic should not pass after rule deletion") +} + +func setupTestManager(t *testing.T) *Manager { + t.Helper() + + ctrl := gomock.NewController(t) + dev := mocks.NewMockDevice(ctrl) + dev.EXPECT().MTU().Return(1500, nil).AnyTimes() + + wgNet := netip.MustParsePrefix("100.64.0.1/16") + + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: wgNet.Addr(), + Network: wgNet, + } + }, + GetDeviceFunc: func() *device.FilteredDevice { + return &device.FilteredDevice{Device: dev} + }, + GetWGDeviceFunc: func() *wgdevice.Device { + return &wgdevice.Device{} + }, + } + + manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) + require.NoError(t, err) + require.NoError(t, manager.EnableRouting()) + + t.Cleanup(func() { + require.NoError(t, manager.Close(nil)) + }) + + return manager +} diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index c6a4ebeb8..55a8e723c 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -263,6 +263,158 @@ func TestAddUDPPacketHook(t *testing.T) { } } +// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added +// to the deny map and can be cleanly deleted without leaving orphans. +func TestPeerRuleLifecycleDenyRules(t *testing.T) { + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + } + + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) + require.NoError(t, err) + defer func() { + require.NoError(t, m.Close(nil)) + }() + + ip := net.ParseIP("192.168.1.1") + addr := netip.MustParseAddr("192.168.1.1") + + // Add multiple deny rules for different ports + rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, + &fw.Port{Values: []uint16{22}}, fw.ActionDrop, "") + require.NoError(t, err) + + rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, + &fw.Port{Values: []uint16{80}}, fw.ActionDrop, "") + require.NoError(t, err) + + m.mutex.RLock() + denyCount := len(m.incomingDenyRules[addr]) + m.mutex.RUnlock() + require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules") + + // Delete the first deny rule + err = m.DeletePeerRule(rule1[0]) + require.NoError(t, err) + + m.mutex.RLock() + denyCount = len(m.incomingDenyRules[addr]) + m.mutex.RUnlock() + require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first") + + // Delete the second deny rule + err = m.DeletePeerRule(rule2[0]) + require.NoError(t, err) + + m.mutex.RLock() + _, exists := m.incomingDenyRules[addr] + m.mutex.RUnlock() + require.False(t, exists, "Deny rules IP entry should be cleaned up when empty") +} + +// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting +// peer rules (simulating network map updates) does not leak rules in the maps. +func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) { + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + } + + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) + require.NoError(t, err) + defer func() { + require.NoError(t, m.Close(nil)) + }() + + ip := net.ParseIP("192.168.1.1") + addr := netip.MustParseAddr("192.168.1.1") + + // Simulate 10 network map updates: add rule, delete old, add new + for i := 0; i < 10; i++ { + // Add a deny rule + rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, + &fw.Port{Values: []uint16{22}}, fw.ActionDrop, "") + require.NoError(t, err) + + // Add an allow rule + allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, + &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "") + require.NoError(t, err) + + // Delete them (simulating ACL manager cleanup) + for _, r := range rules { + require.NoError(t, m.DeletePeerRule(r)) + } + for _, r := range allowRules { + require.NoError(t, m.DeletePeerRule(r)) + } + } + + m.mutex.RLock() + denyCount := len(m.incomingDenyRules[addr]) + allowCount := len(m.incomingRules[addr]) + m.mutex.RUnlock() + + require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup") + require.Equal(t, 0, allowCount, "No allow rules should remain after cleanup") +} + +// TestMixedAllowDenyRulesSameIP verifies that allow and deny rules for the same +// IP are stored in separate maps and don't interfere with each other. +func TestMixedAllowDenyRulesSameIP(t *testing.T) { + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + } + + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) + require.NoError(t, err) + defer func() { + require.NoError(t, m.Close(nil)) + }() + + ip := net.ParseIP("192.168.1.1") + + // Add allow rule for port 80 + allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, + &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "") + require.NoError(t, err) + + // Add deny rule for port 22 + denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, + &fw.Port{Values: []uint16{22}}, fw.ActionDrop, "") + require.NoError(t, err) + + addr := netip.MustParseAddr("192.168.1.1") + m.mutex.RLock() + allowCount := len(m.incomingRules[addr]) + denyCount := len(m.incomingDenyRules[addr]) + m.mutex.RUnlock() + + require.Equal(t, 1, allowCount, "Should have 1 allow rule") + require.Equal(t, 1, denyCount, "Should have 1 deny rule") + + // Delete allow rule should not affect deny rule + err = m.DeletePeerRule(allowRule[0]) + require.NoError(t, err) + + m.mutex.RLock() + denyCountAfter := len(m.incomingDenyRules[addr]) + m.mutex.RUnlock() + + require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule") + + // Delete deny rule + err = m.DeletePeerRule(denyRule[0]) + require.NoError(t, err) + + m.mutex.RLock() + _, denyExists := m.incomingDenyRules[addr] + _, allowExists := m.incomingRules[addr] + m.mutex.RUnlock() + + require.False(t, denyExists, "Deny rules should be empty") + require.False(t, allowExists, "Allow rules should be empty") +} + func TestManagerReset(t *testing.T) { ifaceMock := &IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 4bc0fd800..bd7adfaef 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -189,6 +189,212 @@ func TestDefaultManagerStateless(t *testing.T) { }) } +// TestDenyRulesNotAccumulatedOnRepeatedApply verifies that applying the same +// deny rules repeatedly does not accumulate duplicate rules in the uspfilter. +// This tests the full ACL manager -> uspfilter integration. +func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) { + t.Setenv("NB_WG_KERNEL_DISABLED", "true") + + networkMap := &mgmProto.NetworkMap{ + FirewallRules: []*mgmProto.FirewallRule{ + { + PeerIP: "10.93.0.1", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_DROP, + Protocol: mgmProto.RuleProtocol_TCP, + Port: "22", + }, + { + PeerIP: "10.93.0.2", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_DROP, + Protocol: mgmProto.RuleProtocol_TCP, + Port: "80", + }, + { + PeerIP: "10.93.0.3", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + Port: "443", + }, + }, + FirewallRulesIsEmpty: false, + } + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ifaceMock := mocks.NewMockIFaceMapper(ctrl) + ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() + ifaceMock.EXPECT().SetFilter(gomock.Any()) + network := netip.MustParsePrefix("172.0.0.1/32") + ifaceMock.EXPECT().Name().Return("lo").AnyTimes() + ifaceMock.EXPECT().Address().Return(wgaddr.Address{ + IP: network.Addr(), + Network: network, + }).AnyTimes() + ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() + + fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU) + require.NoError(t, err) + defer func() { + require.NoError(t, fw.Close(nil)) + }() + + acl := NewDefaultManager(fw) + + // Apply the same rules 5 times (simulating repeated network map updates) + for i := 0; i < 5; i++ { + acl.ApplyFiltering(networkMap, false) + } + + // The ACL manager should track exactly 3 rule pairs (2 deny + 1 accept inbound) + assert.Equal(t, 3, len(acl.peerRulesPairs), + "Should have exactly 3 rule pairs after 5 identical updates") +} + +// TestDenyRulesCleanedUpOnRemoval verifies that deny rules are properly cleaned +// up when they're removed from the network map in a subsequent update. +func TestDenyRulesCleanedUpOnRemoval(t *testing.T) { + t.Setenv("NB_WG_KERNEL_DISABLED", "true") + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ifaceMock := mocks.NewMockIFaceMapper(ctrl) + ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() + ifaceMock.EXPECT().SetFilter(gomock.Any()) + network := netip.MustParsePrefix("172.0.0.1/32") + ifaceMock.EXPECT().Name().Return("lo").AnyTimes() + ifaceMock.EXPECT().Address().Return(wgaddr.Address{ + IP: network.Addr(), + Network: network, + }).AnyTimes() + ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() + + fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU) + require.NoError(t, err) + defer func() { + require.NoError(t, fw.Close(nil)) + }() + + acl := NewDefaultManager(fw) + + // First update: add deny and accept rules + networkMap1 := &mgmProto.NetworkMap{ + FirewallRules: []*mgmProto.FirewallRule{ + { + PeerIP: "10.93.0.1", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_DROP, + Protocol: mgmProto.RuleProtocol_TCP, + Port: "22", + }, + { + PeerIP: "10.93.0.2", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + Port: "443", + }, + }, + FirewallRulesIsEmpty: false, + } + + acl.ApplyFiltering(networkMap1, false) + assert.Equal(t, 2, len(acl.peerRulesPairs), "Should have 2 rules after first update") + + // Second update: remove the deny rule, keep only accept + networkMap2 := &mgmProto.NetworkMap{ + FirewallRules: []*mgmProto.FirewallRule{ + { + PeerIP: "10.93.0.2", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + Port: "443", + }, + }, + FirewallRulesIsEmpty: false, + } + + acl.ApplyFiltering(networkMap2, false) + assert.Equal(t, 1, len(acl.peerRulesPairs), + "Should have 1 rule after removing deny rule") + + // Third update: remove all rules + networkMap3 := &mgmProto.NetworkMap{ + FirewallRules: []*mgmProto.FirewallRule{}, + FirewallRulesIsEmpty: true, + } + + acl.ApplyFiltering(networkMap3, false) + assert.Equal(t, 0, len(acl.peerRulesPairs), + "Should have 0 rules after removing all rules") +} + +// TestRuleUpdateChangingAction verifies that when a rule's action changes from +// accept to deny (or vice versa), the old rule is properly removed and the new +// one added without leaking. +func TestRuleUpdateChangingAction(t *testing.T) { + t.Setenv("NB_WG_KERNEL_DISABLED", "true") + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ifaceMock := mocks.NewMockIFaceMapper(ctrl) + ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() + ifaceMock.EXPECT().SetFilter(gomock.Any()) + network := netip.MustParsePrefix("172.0.0.1/32") + ifaceMock.EXPECT().Name().Return("lo").AnyTimes() + ifaceMock.EXPECT().Address().Return(wgaddr.Address{ + IP: network.Addr(), + Network: network, + }).AnyTimes() + ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() + + fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU) + require.NoError(t, err) + defer func() { + require.NoError(t, fw.Close(nil)) + }() + + acl := NewDefaultManager(fw) + + // First update: accept rule + networkMap := &mgmProto.NetworkMap{ + FirewallRules: []*mgmProto.FirewallRule{ + { + PeerIP: "10.93.0.1", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, + Port: "22", + }, + }, + FirewallRulesIsEmpty: false, + } + acl.ApplyFiltering(networkMap, false) + assert.Equal(t, 1, len(acl.peerRulesPairs)) + + // Second update: change to deny (same IP/port/proto, different action) + networkMap.FirewallRules = []*mgmProto.FirewallRule{ + { + PeerIP: "10.93.0.1", + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_DROP, + Protocol: mgmProto.RuleProtocol_TCP, + Port: "22", + }, + } + acl.ApplyFiltering(networkMap, false) + + // Should still have exactly 1 rule (the old accept removed, new deny added) + assert.Equal(t, 1, len(acl.peerRulesPairs), + "Changing action should result in exactly 1 rule, not 2") +} + func TestPortInfoEmpty(t *testing.T) { tests := []struct { name string From 08403f64aa07c230ea288ed8764596eff3a33d26 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 9 Feb 2026 18:09:11 +0800 Subject: [PATCH 26/27] [client] Add env var to skip DNS probing (#5270) --- client/internal/dns/server.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 4d4fcc06e..c2b01de62 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -6,7 +6,9 @@ import ( "fmt" "net/netip" "net/url" + "os" "runtime" + "strconv" "strings" "sync" @@ -27,6 +29,8 @@ import ( "github.com/netbirdio/netbird/shared/management/domain" ) +const envSkipDNSProbe = "NB_SKIP_DNS_PROBE" + // ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes type ReadyListener interface { OnReady() @@ -439,6 +443,17 @@ func (s *DefaultServer) SearchDomains() []string { // ProbeAvailability tests each upstream group's servers for availability // and deactivates the group if no server responds func (s *DefaultServer) ProbeAvailability() { + if val := os.Getenv(envSkipDNSProbe); val != "" { + skipProbe, err := strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", envSkipDNSProbe, err) + } + if skipProbe { + log.Infof("skipping DNS probe due to %s", envSkipDNSProbe) + return + } + } + var wg sync.WaitGroup for _, mux := range s.dnsMuxMap { wg.Add(1) From 6981fdce7e86640b9ff1c222a807390296fc684c Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 9 Feb 2026 11:34:24 +0100 Subject: [PATCH 27/27] [client] Fix race condition and ensure correct message ordering in Relay (#5265) * Fix race condition and ensure correct message ordering in connection establishment Reorder operations in OpenConn to register the connection before waiting for peer availability. This ensures: - Connection is ready to receive messages before peer subscription completes - Transport messages and onconnected events maintain proper ordering - No messages are lost during the connection establishment window - Concurrent OpenConn calls cannot create duplicate connections If peer availability check fails, the pre-registered connection is properly cleaned up. * Handle service shutdown during relay connection initialization Ensure relay connections are properly cleaned up when the service is not running by verifying `serviceIsRunning` and removing stale entries from `c.conns` to prevent unintended behaviors. --- shared/relay/client/client.go | 51 ++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/shared/relay/client/client.go b/shared/relay/client/client.go index 57a98614d..0acadaa4b 100644 --- a/shared/relay/client/client.go +++ b/shared/relay/client/client.go @@ -225,35 +225,42 @@ func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, erro c.mu.Unlock() return nil, ErrConnAlreadyExists } - c.mu.Unlock() - if err := c.stateSubscription.WaitToBeOnlineAndSubscribe(ctx, peerID); err != nil { - c.log.Errorf("peer not available: %s, %s", peerID, err) - return nil, err - } - - c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID) - msgChannel := make(chan Msg, 100) - - c.mu.Lock() - if !c.serviceIsRunning { - c.mu.Unlock() - return nil, fmt.Errorf("relay connection is not established") - } + c.log.Infof("prepare the relayed connection, waiting for remote peer: %s", peerID) c.muInstanceURL.Lock() instanceURL := c.instanceURL c.muInstanceURL.Unlock() - conn := NewConn(c, peerID, msgChannel, instanceURL) - _, ok = c.conns[peerID] - if ok { - c.mu.Unlock() - _ = conn.Close() - return nil, ErrConnAlreadyExists - } - c.conns[peerID] = newConnContainer(c.log, conn, msgChannel) + msgChannel := make(chan Msg, 100) + conn := NewConn(c, peerID, msgChannel, instanceURL) + container := newConnContainer(c.log, conn, msgChannel) + c.conns[peerID] = container c.mu.Unlock() + + if err := c.stateSubscription.WaitToBeOnlineAndSubscribe(ctx, peerID); err != nil { + c.log.Errorf("peer not available: %s, %s", peerID, err) + c.mu.Lock() + if savedContainer, ok := c.conns[peerID]; ok && savedContainer == container { + delete(c.conns, peerID) + } + c.mu.Unlock() + container.close() + return nil, err + } + + c.mu.Lock() + if !c.serviceIsRunning { + if savedContainer, ok := c.conns[peerID]; ok && savedContainer == container { + delete(c.conns, peerID) + } + c.mu.Unlock() + container.close() + return nil, fmt.Errorf("relay connection is not established") + } + c.mu.Unlock() + + c.log.Infof("remote peer is available: %s", peerID) return conn, nil }