From 34341d95a930a252588933eada70f92c296d756f Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 6 Oct 2025 15:22:02 -0300 Subject: [PATCH 01/29] Adjust signal port for websocket connections (#4594) --- infrastructure_files/docker-compose.yml.tmpl.traefik | 2 +- infrastructure_files/getting-started-with-zitadel.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik index fb01e6867..196e26a66 100644 --- a/infrastructure_files/docker-compose.yml.tmpl.traefik +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -47,7 +47,7 @@ services: - traefik.enable=true - traefik.http.routers.netbird-wsproxy-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/signal`) - traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal - - traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=10000 + - traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80 - traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`) - traefik.http.services.netbird-signal.loadbalancer.server.port=10000 - traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index be9662345..bc326cd7e 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -621,7 +621,7 @@ renderCaddyfile() { # relay reverse_proxy /relay* relay:80 # Signal - reverse_proxy /ws-proxy/signal* signal:10000 + reverse_proxy /ws-proxy/signal* signal:80 reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000 # Management reverse_proxy /api/* management:80 From 954f40991f3b6c399e6ce1cb20577aacca87d38e Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 6 Oct 2025 21:22:19 +0200 Subject: [PATCH 02/29] [client,management,signal] Handle grpc from ws proxy internally instead of via tcp (#4593) --- client/grpc/dialer.go | 7 +- management/internals/server/server.go | 3 +- signal/cmd/run.go | 9 +- util/wsproxy/server/proxy.go | 140 ++++++++++---------------- 4 files changed, 62 insertions(+), 97 deletions(-) diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 54fbb002c..6aff53b92 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -29,7 +29,8 @@ func Backoff(ctx context.Context) backoff.BackOff { // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) - if tlsEnabled { + // for js, the outer websocket layer takes care of tls + if tlsEnabled && runtime.GOOS != "js" { certPool, err := x509.SystemCertPool() if err != nil || certPool == nil { log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err) @@ -37,9 +38,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone } transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ - // for js, outer websocket layer takes care of tls verification via WithCustomDialer - InsecureSkipVerify: runtime.GOOS == "js", - RootCAs: certPool, + RootCAs: certPool, })) } diff --git a/management/internals/server/server.go b/management/internals/server/server.go index 94c633fc6..ab1c2ebe7 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -6,7 +6,6 @@ import ( "fmt" "net" "net/http" - "net/netip" "strings" "sync" "time" @@ -252,7 +251,7 @@ func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config) } func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler { - wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), ManagementLegacyPort), wsproxyserver.WithOTelMeter(meter)) + wsProxy := wsproxyserver.New(gRPCHandler, wsproxyserver.WithOTelMeter(meter)) return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { switch { diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 696c44723..96873dee7 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -10,7 +10,6 @@ import ( "net/http" // nolint:gosec _ "net/http/pprof" - "net/netip" "time" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" @@ -63,10 +62,10 @@ var ( Use: "run", Short: "start NetBird Signal Server daemon", SilenceUsage: true, - PreRun: func(cmd *cobra.Command, args []string) { + PreRunE: func(cmd *cobra.Command, args []string) error { err := util.InitLog(logLevel, logFile) if err != nil { - log.Fatalf("failed initializing log %v", err) + return fmt.Errorf("failed initializing log: %w", err) } flag.Parse() @@ -87,6 +86,8 @@ var ( signalPort = 80 } } + + return nil }, RunE: func(cmd *cobra.Command, args []string) error { flag.Parse() @@ -254,7 +255,7 @@ func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler h } func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler { - wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), legacyGRPCPort), wsproxyserver.WithOTelMeter(meter)) + wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter)) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { diff --git a/util/wsproxy/server/proxy.go b/util/wsproxy/server/proxy.go index 977440a60..8866924df 100644 --- a/util/wsproxy/server/proxy.go +++ b/util/wsproxy/server/proxy.go @@ -2,42 +2,41 @@ package server import ( "context" - "errors" "io" "net" "net/http" - "net/netip" "sync" "time" "github.com/coder/websocket" log "github.com/sirupsen/logrus" + "golang.org/x/net/http2" "github.com/netbirdio/netbird/util/wsproxy" ) const ( - dialTimeout = 10 * time.Second - bufferSize = 32 * 1024 + bufferSize = 32 * 1024 + ioTimeout = 5 * time.Second ) // Config contains the configuration for the WebSocket proxy. type Config struct { - LocalGRPCAddr netip.AddrPort + Handler http.Handler Path string MetricsRecorder MetricsRecorder } -// Proxy handles WebSocket to TCP proxying for gRPC connections. +// Proxy handles WebSocket to gRPC handler proxying. type Proxy struct { config Config metrics MetricsRecorder } // New creates a new WebSocket proxy instance with optional configuration -func New(localGRPCAddr netip.AddrPort, opts ...Option) *Proxy { +func New(handler http.Handler, opts ...Option) *Proxy { config := Config{ - LocalGRPCAddr: localGRPCAddr, + Handler: handler, Path: wsproxy.ProxyPath, MetricsRecorder: NoOpMetricsRecorder{}, // Default to no-op } @@ -63,7 +62,7 @@ func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) { p.metrics.RecordConnection(ctx) defer p.metrics.RecordDisconnection(ctx) - log.Debugf("WebSocket proxy handling connection from %s, forwarding to %s", r.RemoteAddr, p.config.LocalGRPCAddr) + log.Debugf("WebSocket proxy handling connection from %s, forwarding to internal gRPC handler", r.RemoteAddr) acceptOptions := &websocket.AcceptOptions{ OriginPatterns: []string{"*"}, } @@ -75,71 +74,41 @@ func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) { return } defer func() { - if err := wsConn.Close(websocket.StatusNormalClosure, ""); err != nil { - log.Debugf("Failed to close WebSocket: %v", err) - } + _ = wsConn.Close(websocket.StatusNormalClosure, "") }() - log.Debugf("WebSocket proxy attempting to connect to local gRPC at %s", p.config.LocalGRPCAddr) - tcpConn, err := net.DialTimeout("tcp", p.config.LocalGRPCAddr.String(), dialTimeout) - if err != nil { - p.metrics.RecordError(ctx, "tcp_dial_failed") - log.Warnf("Failed to connect to local gRPC server at %s: %v", p.config.LocalGRPCAddr, err) - if err := wsConn.Close(websocket.StatusInternalError, "Backend unavailable"); err != nil { - log.Debugf("Failed to close WebSocket after connection failure: %v", err) - } - return - } + clientConn, serverConn := net.Pipe() defer func() { - if err := tcpConn.Close(); err != nil { - log.Debugf("Failed to close TCP connection: %v", err) - } + _ = clientConn.Close() + _ = serverConn.Close() }() - log.Debugf("WebSocket proxy established: client %s -> local gRPC %s", r.RemoteAddr, p.config.LocalGRPCAddr) + log.Debugf("WebSocket proxy established: %s -> gRPC handler", r.RemoteAddr) - p.proxyData(ctx, wsConn, tcpConn) + go func() { + (&http2.Server{}).ServeConn(serverConn, &http2.ServeConnOpts{ + Context: ctx, + Handler: p.config.Handler, + }) + }() + + p.proxyData(ctx, wsConn, clientConn, r.RemoteAddr) } -func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, tcpConn net.Conn) { +func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) { proxyCtx, cancel := context.WithCancel(ctx) defer cancel() var wg sync.WaitGroup wg.Add(2) - go p.wsToTCP(proxyCtx, cancel, &wg, wsConn, tcpConn) - go p.tcpToWS(proxyCtx, cancel, &wg, wsConn, tcpConn) + go p.wsToPipe(proxyCtx, cancel, &wg, wsConn, pipeConn, clientAddr) + go p.pipeToWS(proxyCtx, cancel, &wg, wsConn, pipeConn, clientAddr) - done := make(chan struct{}) - go func() { - wg.Wait() - close(done) - }() - - select { - case <-done: - log.Tracef("Proxy data transfer completed, both goroutines terminated") - case <-proxyCtx.Done(): - log.Tracef("Proxy data transfer cancelled, forcing connection closure") - - if err := wsConn.Close(websocket.StatusGoingAway, "proxy cancelled"); err != nil { - log.Tracef("Error closing WebSocket during cancellation: %v", err) - } - if err := tcpConn.Close(); err != nil { - log.Tracef("Error closing TCP connection during cancellation: %v", err) - } - - select { - case <-done: - log.Tracef("Goroutines terminated after forced connection closure") - case <-time.After(2 * time.Second): - log.Tracef("Goroutines did not terminate within timeout after connection closure") - } - } + wg.Wait() } -func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) { +func (p *Proxy) wsToPipe(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) { defer wg.Done() defer cancel() @@ -148,80 +117,77 @@ func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync if err != nil { switch { case ctx.Err() != nil: - log.Debugf("wsToTCP goroutine terminating due to context cancellation") - case websocket.CloseStatus(err) == websocket.StatusNormalClosure: - log.Debugf("WebSocket closed normally") + log.Debugf("WebSocket from %s terminating due to context cancellation", clientAddr) + case websocket.CloseStatus(err) != -1: + log.Debugf("WebSocket from %s disconnected", clientAddr) default: p.metrics.RecordError(ctx, "websocket_read_error") - log.Errorf("WebSocket read error: %v", err) + log.Debugf("WebSocket read error from %s: %v", clientAddr, err) } return } if msgType != websocket.MessageBinary { - log.Warnf("Unexpected WebSocket message type: %v", msgType) + log.Warnf("Unexpected WebSocket message type from %s: %v", clientAddr, msgType) continue } if ctx.Err() != nil { - log.Tracef("wsToTCP goroutine terminating due to context cancellation before TCP write") + log.Tracef("wsToPipe goroutine terminating due to context cancellation before pipe write") return } - if err := tcpConn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil { - log.Debugf("Failed to set TCP write deadline: %v", err) + if err := pipeConn.SetWriteDeadline(time.Now().Add(ioTimeout)); err != nil { + log.Debugf("Failed to set pipe write deadline: %v", err) } - n, err := tcpConn.Write(data) + n, err := pipeConn.Write(data) if err != nil { - p.metrics.RecordError(ctx, "tcp_write_error") - log.Errorf("TCP write error: %v", err) + p.metrics.RecordError(ctx, "pipe_write_error") + log.Warnf("Pipe write error for %s: %v", clientAddr, err) return } - p.metrics.RecordBytesTransferred(ctx, "ws_to_tcp", int64(n)) + p.metrics.RecordBytesTransferred(ctx, "ws_to_grpc", int64(n)) } } -func (p *Proxy) tcpToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) { +func (p *Proxy) pipeToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) { defer wg.Done() defer cancel() buf := make([]byte, bufferSize) for { - if err := tcpConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { - log.Debugf("Failed to set TCP read deadline: %v", err) + if err := pipeConn.SetReadDeadline(time.Now().Add(ioTimeout)); err != nil { + log.Debugf("Failed to set pipe read deadline: %v", err) } - n, err := tcpConn.Read(buf) + n, err := pipeConn.Read(buf) if err != nil { if ctx.Err() != nil { - log.Tracef("tcpToWS goroutine terminating due to context cancellation") + log.Tracef("pipeToWS goroutine terminating due to context cancellation") return } - var netErr net.Error - if errors.As(err, &netErr) && netErr.Timeout() { - continue - } - if err != io.EOF { - log.Errorf("TCP read error: %v", err) + log.Debugf("Pipe read error for %s: %v", clientAddr, err) } return } if ctx.Err() != nil { - log.Tracef("tcpToWS goroutine terminating due to context cancellation before WebSocket write") + log.Tracef("pipeToWS goroutine terminating due to context cancellation before WebSocket write") return } - if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil { - p.metrics.RecordError(ctx, "websocket_write_error") - log.Errorf("WebSocket write error: %v", err) - return - } + if n > 0 { + if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil { + p.metrics.RecordError(ctx, "websocket_write_error") + log.Warnf("WebSocket write error for %s: %v", clientAddr, err) + return + } - p.metrics.RecordBytesTransferred(ctx, "tcp_to_ws", int64(n)) + p.metrics.RecordBytesTransferred(ctx, "grpc_to_ws", int64(n)) + } } } From 88467883fc4a14607393beaac412f573eaaa1d43 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 6 Oct 2025 22:05:48 +0200 Subject: [PATCH 03/29] [management,signal] Remove ws-proxy read deadline (#4598) --- util/wsproxy/server/proxy.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/util/wsproxy/server/proxy.go b/util/wsproxy/server/proxy.go index 8866924df..ffb622200 100644 --- a/util/wsproxy/server/proxy.go +++ b/util/wsproxy/server/proxy.go @@ -158,10 +158,6 @@ func (p *Proxy) pipeToWS(ctx context.Context, cancel context.CancelFunc, wg *syn buf := make([]byte, bufferSize) for { - if err := pipeConn.SetReadDeadline(time.Now().Add(ioTimeout)); err != nil { - log.Debugf("Failed to set pipe read deadline: %v", err) - } - n, err := pipeConn.Read(buf) if err != nil { if ctx.Err() != nil { From 4d33567888fe704fb75f0c0a74ed8a27fd131052 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 8 Oct 2025 03:12:16 +0200 Subject: [PATCH 04/29] [client] Remove endpoint address on peer disconnect, retain status for activity recording (#4228) * When a peer disconnects, remove the endpoint address to avoid sending traffic to a non-existent address, but retain the status for the activity recorder. --- client/iface/configurer/kernel_unix.go | 38 ++++++++++++++++ client/iface/configurer/usp.go | 61 ++++++++++++++++++++++++++ client/iface/device/interface.go | 1 + client/iface/iface.go | 11 +++++ client/internal/engine_test.go | 4 ++ client/internal/iface_common.go | 1 + client/internal/peer/conn.go | 6 +++ client/internal/peer/iface.go | 1 + 8 files changed, 123 insertions(+) diff --git a/client/iface/configurer/kernel_unix.go b/client/iface/configurer/kernel_unix.go index 84afc38f5..96b286175 100644 --- a/client/iface/configurer/kernel_unix.go +++ b/client/iface/configurer/kernel_unix.go @@ -73,6 +73,44 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, return nil } +func (c *KernelConfigurer) RemoveEndpointAddress(peerKey string) error { + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return err + } + + // Get the existing peer to preserve its allowed IPs + existingPeer, err := c.getPeer(c.deviceName, peerKey) + if err != nil { + return fmt.Errorf("get peer: %w", err) + } + + removePeerCfg := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + Remove: true, + } + + if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{removePeerCfg}}); err != nil { + return fmt.Errorf(`error removing peer %s from interface %s: %w`, peerKey, c.deviceName, err) + } + + //Re-add the peer without the endpoint but same AllowedIPs + reAddPeerCfg := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + AllowedIPs: existingPeer.AllowedIPs, + ReplaceAllowedIPs: true, + } + + if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{reAddPeerCfg}}); err != nil { + return fmt.Errorf( + `error re-adding peer %s to interface %s with allowed IPs %v: %w`, + peerKey, c.deviceName, existingPeer.AllowedIPs, err, + ) + } + + return nil +} + func (c *KernelConfigurer) RemovePeer(peerKey string) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index f744e0127..bc875b73c 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -106,6 +106,67 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, return nil } +func (c *WGUSPConfigurer) RemoveEndpointAddress(peerKey string) error { + peerKeyParsed, err := wgtypes.ParseKey(peerKey) + if err != nil { + return fmt.Errorf("parse peer key: %w", err) + } + + ipcStr, err := c.device.IpcGet() + if err != nil { + return fmt.Errorf("get IPC config: %w", err) + } + + // Parse current status to get allowed IPs for the peer + stats, err := parseStatus(c.deviceName, ipcStr) + if err != nil { + return fmt.Errorf("parse IPC config: %w", err) + } + + var allowedIPs []net.IPNet + found := false + for _, peer := range stats.Peers { + if peer.PublicKey == peerKey { + allowedIPs = peer.AllowedIPs + found = true + break + } + } + if !found { + return fmt.Errorf("peer %s not found", peerKey) + } + + // remove the peer from the WireGuard configuration + peer := wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + Remove: true, + } + + config := wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil { + return fmt.Errorf("failed to remove peer: %s", ipcErr) + } + + // Build the peer config + peer = wgtypes.PeerConfig{ + PublicKey: peerKeyParsed, + ReplaceAllowedIPs: true, + AllowedIPs: allowedIPs, + } + + config = wgtypes.Config{ + Peers: []wgtypes.PeerConfig{peer}, + } + + if err := c.device.IpcSet(toWgUserspaceString(config)); err != nil { + return fmt.Errorf("remove endpoint address: %w", err) + } + + return nil +} + func (c *WGUSPConfigurer) RemovePeer(peerKey string) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { diff --git a/client/iface/device/interface.go b/client/iface/device/interface.go index 1f40b0d46..db53d9c3a 100644 --- a/client/iface/device/interface.go +++ b/client/iface/device/interface.go @@ -21,4 +21,5 @@ type WGConfigurer interface { GetStats() (map[string]configurer.WGStats, error) FullStats() (*configurer.Stats, error) LastActivities() map[string]monotime.Time + RemoveEndpointAddress(peerKey string) error } diff --git a/client/iface/iface.go b/client/iface/iface.go index 609572561..158672160 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -148,6 +148,17 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) } +func (w *WGIface) RemoveEndpointAddress(peerKey string) error { + w.mu.Lock() + defer w.mu.Unlock() + if w.configurer == nil { + return ErrIfaceNotFound + } + + log.Debugf("Removing endpoint address: %s", peerKey) + return w.configurer.RemoveEndpointAddress(peerKey) +} + // RemovePeer removes a Wireguard Peer from the interface iface func (w *WGIface) RemovePeer(peerKey string) error { w.mu.Lock() diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 344104405..2f1098100 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -105,6 +105,10 @@ type MockWGIface struct { LastActivitiesFunc func() map[string]monotime.Time } +func (m *MockWGIface) RemoveEndpointAddress(_ string) error { + return nil +} + func (m *MockWGIface) FullStats() (*configurer.Stats, error) { return nil, fmt.Errorf("not implemented") } diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index 690fdb7cc..98fe01912 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -28,6 +28,7 @@ type wgIfaceBase interface { UpdateAddr(newAddr string) error GetProxy() wgproxy.Proxy UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + RemoveEndpointAddress(key string) error RemovePeer(peerKey string) error AddAllowedIP(peerKey string, allowedIP netip.Prefix) error RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 8db9e58f4..ded9aa479 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -430,6 +430,9 @@ func (conn *Conn) onICEStateDisconnected() { } else { conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String()) conn.currentConnPriority = conntype.None + if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil { + conn.Log.Errorf("failed to remove wg endpoint: %v", err) + } } changed := conn.statusICE.Get() != worker.StatusDisconnected @@ -523,6 +526,9 @@ func (conn *Conn) onRelayDisconnected() { if conn.currentConnPriority == conntype.Relay { conn.Log.Debugf("clean up WireGuard config") conn.currentConnPriority = conntype.None + if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil { + conn.Log.Errorf("failed to remove wg endpoint: %v", err) + } } if conn.wgProxyRelay != nil { diff --git a/client/internal/peer/iface.go b/client/internal/peer/iface.go index 0bcc7a68e..678396e61 100644 --- a/client/internal/peer/iface.go +++ b/client/internal/peer/iface.go @@ -18,4 +18,5 @@ type WGIface interface { GetStats() (map[string]configurer.WGStats, error) GetProxy() wgproxy.Proxy Address() wgaddr.Address + RemoveEndpointAddress(key string) error } From 229c65ffa1d783c0f1b56070c0718897f69072b0 Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Wed, 8 Oct 2025 17:42:15 +0700 Subject: [PATCH 05/29] Enhance showLoginURL to include connection status check and auto-close functionality (#4525) --- client/ui/client_ui.go | 42 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 25d7380a9..7c2000a9d 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -1354,7 +1354,13 @@ func (s *serviceClient) updateConfig() error { } // showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL. -func (s *serviceClient) showLoginURL() { +// It also starts a background goroutine that periodically checks if the client is already connected +// and closes the window if so. The goroutine can be cancelled by the returned CancelFunc, and it is +// also cancelled when the window is closed. +func (s *serviceClient) showLoginURL() context.CancelFunc { + + // create a cancellable context for the background check goroutine + ctx, cancel := context.WithCancel(s.ctx) resIcon := fyne.NewStaticResource("netbird.png", iconAbout) @@ -1363,6 +1369,8 @@ func (s *serviceClient) showLoginURL() { s.wLoginURL.Resize(fyne.NewSize(400, 200)) s.wLoginURL.SetIcon(resIcon) } + // ensure goroutine is cancelled when the window is closed + s.wLoginURL.SetOnClosed(func() { cancel() }) // add a description label label := widget.NewLabel("Your NetBird session has expired.\nPlease re-authenticate to continue using NetBird.") @@ -1443,7 +1451,39 @@ func (s *serviceClient) showLoginURL() { ) s.wLoginURL.SetContent(container.NewCenter(content)) + // start a goroutine to check connection status and close the window if connected + go func() { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + conn, err := s.getSrvClient(failFastTimeout) + if err != nil { + return + } + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + if err != nil { + continue + } + if status.Status == string(internal.StatusConnected) { + if s.wLoginURL != nil { + s.wLoginURL.Close() + } + return + } + } + } + }() + s.wLoginURL.Show() + + // return cancel func so callers can stop the background goroutine if desired + return cancel } func openURL(url string) error { From 768332820e1e6e211325332a1350b0b2a34429b3 Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Wed, 8 Oct 2025 21:54:27 +0700 Subject: [PATCH 06/29] [client] Implement DNS query caching in DNSForwarder (#4574) implements DNS query caching in the DNSForwarder to improve performance and provide fallback responses when upstream DNS servers fail. The cache stores successful DNS query results and serves them when upstream resolution fails. - Added a new cache component to store DNS query results by domain and query type - Integrated cache storage after successful DNS resolutions - Enhanced error handling to serve cached responses as fallback when upstream DNS fails --- client/internal/dnsfwd/cache.go | 78 +++++++++++++++++ client/internal/dnsfwd/cache_test.go | 86 +++++++++++++++++++ client/internal/dnsfwd/forwarder.go | 104 +++++++++++++++++++---- client/internal/dnsfwd/forwarder_test.go | 89 +++++++++++++++++++ 4 files changed, 341 insertions(+), 16 deletions(-) create mode 100644 client/internal/dnsfwd/cache.go create mode 100644 client/internal/dnsfwd/cache_test.go diff --git a/client/internal/dnsfwd/cache.go b/client/internal/dnsfwd/cache.go new file mode 100644 index 000000000..43fe2d020 --- /dev/null +++ b/client/internal/dnsfwd/cache.go @@ -0,0 +1,78 @@ +package dnsfwd + +import ( + "net/netip" + "slices" + "strings" + "sync" + + "github.com/miekg/dns" +) + +type cache struct { + mu sync.RWMutex + records map[string]*cacheEntry +} + +type cacheEntry struct { + ip4Addrs []netip.Addr + ip6Addrs []netip.Addr +} + +func newCache() *cache { + return &cache{ + records: make(map[string]*cacheEntry), + } +} + +func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + entry, exists := c.records[normalizeDomain(domain)] + if !exists { + return nil, false + } + + switch reqType { + case dns.TypeA: + return slices.Clone(entry.ip4Addrs), true + case dns.TypeAAAA: + return slices.Clone(entry.ip6Addrs), true + default: + return nil, false + } + +} + +func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) { + c.mu.Lock() + defer c.mu.Unlock() + norm := normalizeDomain(domain) + entry, exists := c.records[norm] + if !exists { + entry = &cacheEntry{} + c.records[norm] = entry + } + + switch reqType { + case dns.TypeA: + entry.ip4Addrs = slices.Clone(addrs) + case dns.TypeAAAA: + entry.ip6Addrs = slices.Clone(addrs) + } +} + +// unset removes cached entries for the given domain and request type. +func (c *cache) unset(domain string) { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.records, normalizeDomain(domain)) +} + +// normalizeDomain converts an input domain into a canonical form used as cache key: +// lowercase and fully-qualified (with trailing dot). +func normalizeDomain(domain string) string { + // dns.Fqdn ensures trailing dot; ToLower for consistent casing + return dns.Fqdn(strings.ToLower(domain)) +} diff --git a/client/internal/dnsfwd/cache_test.go b/client/internal/dnsfwd/cache_test.go new file mode 100644 index 000000000..c23f0f31d --- /dev/null +++ b/client/internal/dnsfwd/cache_test.go @@ -0,0 +1,86 @@ +package dnsfwd + +import ( + "net/netip" + "testing" +) + +func mustAddr(t *testing.T, s string) netip.Addr { + t.Helper() + a, err := netip.ParseAddr(s) + if err != nil { + t.Fatalf("parse addr %s: %v", s, err) + } + return a +} + +func TestCacheNormalization(t *testing.T) { + c := newCache() + + // Mixed case, without trailing dot + domainInput := "ExAmPlE.CoM" + ipv4 := []netip.Addr{mustAddr(t, "1.2.3.4")} + c.set(domainInput, 1 /* dns.TypeA */, ipv4) + + // Lookup with lower, with trailing dot + if got, ok := c.get("example.com.", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" { + t.Fatalf("expected cached IPv4 result via normalized key, got=%v ok=%v", got, ok) + } + + // Lookup with different casing again + if got, ok := c.get("EXAMPLE.COM", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" { + t.Fatalf("expected cached IPv4 result via different casing, got=%v ok=%v", got, ok) + } +} + +func TestCacheSeparateTypes(t *testing.T) { + c := newCache() + + domain := "test.local" + ipv4 := []netip.Addr{mustAddr(t, "10.0.0.1")} + ipv6 := []netip.Addr{mustAddr(t, "2001:db8::1")} + + c.set(domain, 1 /* A */, ipv4) + c.set(domain, 28 /* AAAA */, ipv6) + + got4, ok4 := c.get(domain, 1) + if !ok4 || len(got4) != 1 || got4[0] != ipv4[0] { + t.Fatalf("expected A record from cache, got=%v ok=%v", got4, ok4) + } + + got6, ok6 := c.get(domain, 28) + if !ok6 || len(got6) != 1 || got6[0] != ipv6[0] { + t.Fatalf("expected AAAA record from cache, got=%v ok=%v", got6, ok6) + } +} + +func TestCacheCloneOnGetAndSet(t *testing.T) { + c := newCache() + domain := "clone.test" + + src := []netip.Addr{mustAddr(t, "8.8.8.8")} + c.set(domain, 1, src) + + // Mutate source slice; cache should be unaffected + src[0] = mustAddr(t, "9.9.9.9") + + got, ok := c.get(domain, 1) + if !ok || len(got) != 1 || got[0].String() != "8.8.8.8" { + t.Fatalf("expected cached value to be independent of source slice, got=%v ok=%v", got, ok) + } + + // Mutate returned slice; internal cache should remain unchanged + got[0] = mustAddr(t, "4.4.4.4") + got2, ok2 := c.get(domain, 1) + if !ok2 || len(got2) != 1 || got2[0].String() != "8.8.8.8" { + t.Fatalf("expected returned slice to be a clone, got=%v ok=%v", got2, ok2) + } +} + +func TestCacheMiss(t *testing.T) { + c := newCache() + if got, ok := c.get("missing.example", 1); ok || got != nil { + t.Fatalf("expected cache miss, got=%v ok=%v", got, ok) + } +} + diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index d912919a1..7a262fa4c 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -46,6 +46,7 @@ type DNSForwarder struct { fwdEntries []*ForwarderEntry firewall firewaller resolver resolver + cache *cache } func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder { @@ -56,6 +57,7 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat firewall: firewall, statusRecorder: statusRecorder, resolver: net.DefaultResolver, + cache: newCache(), } } @@ -103,10 +105,39 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { f.mutex.Lock() defer f.mutex.Unlock() + // remove cache entries for domains that no longer appear + f.removeStaleCacheEntries(f.fwdEntries, entries) + f.fwdEntries = entries log.Debugf("Updated DNS forwarder with %d domains", len(entries)) } +// removeStaleCacheEntries unsets cache items for domains that were present +// in the old list but not present in the new list. +func (f *DNSForwarder) removeStaleCacheEntries(oldEntries, newEntries []*ForwarderEntry) { + if f.cache == nil { + return + } + + newSet := make(map[string]struct{}, len(newEntries)) + for _, e := range newEntries { + if e == nil { + continue + } + newSet[e.Domain.PunycodeString()] = struct{}{} + } + + for _, e := range oldEntries { + if e == nil { + continue + } + pattern := e.Domain.PunycodeString() + if _, ok := newSet[pattern]; !ok { + f.cache.unset(pattern) + } + } +} + func (f *DNSForwarder) Close(ctx context.Context) error { var result *multierror.Error @@ -171,6 +202,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns f.updateInternalState(ips, mostSpecificResId, matchingEntries) f.addIPsToResponse(resp, domain, ips) + f.cache.set(domain, question.Qtype, ips) return resp } @@ -282,29 +314,69 @@ func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns resp.Rcode = dns.RcodeSuccess } -// handleDNSError processes DNS lookup errors and sends an appropriate error response -func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) { +// handleDNSError processes DNS lookup errors and sends an appropriate error response. +func (f *DNSForwarder) handleDNSError( + ctx context.Context, + w dns.ResponseWriter, + question dns.Question, + resp *dns.Msg, + domain string, + err error, +) { + // Default to SERVFAIL; override below when appropriate. + resp.Rcode = dns.RcodeServerFailure + + qType := question.Qtype + qTypeName := dns.TypeToString[qType] + + // Prefer typed DNS errors; fall back to generic logging otherwise. var dnsErr *net.DNSError - - switch { - case errors.As(err, &dnsErr): - resp.Rcode = dns.RcodeServerFailure - if dnsErr.IsNotFound { - f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype) + if !errors.As(err, &dnsErr) { + log.Warnf(errResolveFailed, domain, err) + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write failure DNS response: %v", writeErr) } + return + } - if dnsErr.Server != "" { - log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err) - } else { - log.Warnf(errResolveFailed, domain, err) + // NotFound: set NXDOMAIN / appropriate code via helper. + if dnsErr.IsNotFound { + f.setResponseCodeForNotFound(ctx, resp, domain, qType) + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write failure DNS response: %v", writeErr) } - default: - resp.Rcode = dns.RcodeServerFailure + f.cache.set(domain, question.Qtype, nil) + return + } + + // Upstream failed but we might have a cached answer—serve it if present. + if ips, ok := f.cache.get(domain, qType); ok { + if len(ips) > 0 { + log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName) + f.addIPsToResponse(resp, domain, ips) + resp.Rcode = dns.RcodeSuccess + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write cached DNS response: %v", writeErr) + } + } else { // send NXDOMAIN / appropriate code if cache is empty + f.setResponseCodeForNotFound(ctx, resp, domain, qType) + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write failure DNS response: %v", writeErr) + } + } + return + } + + // No cache. Log with or without the server field for more context. + if dnsErr.Server != "" { + log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err) + } else { log.Warnf(errResolveFailed, domain, err) } - if err := w.WriteMsg(resp); err != nil { - log.Errorf("failed to write failure DNS response: %v", err) + // Write final failure response. + if writeErr := w.WriteMsg(resp); writeErr != nil { + log.Errorf("failed to write failure DNS response: %v", writeErr) } } diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index 57085e19a..c1c95a2c1 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -648,6 +648,95 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) { assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size") } +// Ensures that when the first query succeeds and populates the cache, +// a subsequent upstream failure still returns a successful response from cache. +func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { + mockResolver := &MockResolver{} + forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder.resolver = mockResolver + + d, err := domain.FromString("example.com") + require.NoError(t, err) + entries := []*ForwarderEntry{{Domain: d, ResID: "res-cache"}} + forwarder.UpdateDomains(entries) + + ip := netip.MustParseAddr("1.2.3.4") + + // First call resolves successfully and populates cache + mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")). + Return([]netip.Addr{ip}, nil).Once() + + // Second call fails upstream; forwarder should serve from cache + mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")). + Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once() + + // First query: populate cache + q1 := &dns.Msg{} + q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) + w1 := &test.MockResponseWriter{} + resp1 := forwarder.handleDNSQuery(w1, q1) + require.NotNil(t, resp1) + require.Equal(t, dns.RcodeSuccess, resp1.Rcode) + require.Len(t, resp1.Answer, 1) + + // Second query: serve from cache after upstream failure + q2 := &dns.Msg{} + q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) + var writtenResp *dns.Msg + w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} + _ = forwarder.handleDNSQuery(w2, q2) + + require.NotNil(t, writtenResp, "expected response to be written") + require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) + require.Len(t, writtenResp.Answer, 1) + + mockResolver.AssertExpectations(t) +} + +// Verifies that cache normalization works across casing and trailing dot variations. +func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { + mockResolver := &MockResolver{} + forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder.resolver = mockResolver + + d, err := domain.FromString("ExAmPlE.CoM") + require.NoError(t, err) + entries := []*ForwarderEntry{{Domain: d, ResID: "res-norm"}} + forwarder.UpdateDomains(entries) + + ip := netip.MustParseAddr("9.8.7.6") + + // Initial resolution with mixed case to populate cache + mixedQuery := "ExAmPlE.CoM" + mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(strings.ToLower(mixedQuery))). + Return([]netip.Addr{ip}, nil).Once() + + q1 := &dns.Msg{} + q1.SetQuestion(mixedQuery+".", dns.TypeA) + w1 := &test.MockResponseWriter{} + resp1 := forwarder.handleDNSQuery(w1, q1) + require.NotNil(t, resp1) + require.Equal(t, dns.RcodeSuccess, resp1.Rcode) + require.Len(t, resp1.Answer, 1) + + // Subsequent query without dot and upper case should hit cache even if upstream fails + // Forwarder lowercases and uses the question name as-is (no trailing dot here) + mockResolver.On("LookupNetIP", mock.Anything, "ip4", strings.ToLower("EXAMPLE.COM")). + Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once() + + q2 := &dns.Msg{} + q2.SetQuestion("EXAMPLE.COM", dns.TypeA) + var writtenResp *dns.Msg + w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} + _ = forwarder.handleDNSQuery(w2, q2) + + require.NotNil(t, writtenResp) + require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) + require.Len(t, writtenResp.Answer, 1) + + mockResolver.AssertExpectations(t) +} + func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { // Test complex overlapping pattern scenarios mockFirewall := &MockFirewall{} From 9021bb512bf7f88157a4c06aa1d9fba6ac80587f Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 8 Oct 2025 17:14:24 +0200 Subject: [PATCH 07/29] [client] Recreate agent when receive new session id (#4564) When an ICE agent connection was in progress, new offers were being ignored. This was incorrect logic because the remote agent could be restarted at any time. In this change, whenever a new session ID is received, the ongoing handshake is closed and a new one is started. --- client/internal/peer/conn.go | 4 +- client/internal/peer/conn_test.go | 16 +++--- client/internal/peer/guard/env.go | 20 ++++++++ client/internal/peer/guard/ice_monitor.go | 16 ++++-- client/internal/peer/guard/sr_watcher.go | 2 +- client/internal/peer/handshaker.go | 51 ++++++++++++------- client/internal/peer/handshaker_listener.go | 8 +-- .../internal/peer/handshaker_listener_test.go | 2 +- client/internal/peer/worker_ice.go | 47 +++++++++-------- 9 files changed, 107 insertions(+), 59 deletions(-) create mode 100644 client/internal/peer/guard/env.go diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index ded9aa479..68afe986a 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -171,9 +171,9 @@ func (conn *Conn) Open(engineCtx context.Context) error { conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay) - conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer) + conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer) if !isForceRelayed() { - conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) + conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer) } conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher) diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index c839ab147..6b47f95eb 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -79,10 +79,10 @@ func TestConn_OnRemoteOffer(t *testing.T) { return } - onNewOffeChan := make(chan struct{}) + onNewOfferChan := make(chan struct{}) - conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) { - onNewOffeChan <- struct{}{} + conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) { + onNewOfferChan <- struct{}{} }) conn.OnRemoteOffer(OfferAnswer{ @@ -98,7 +98,7 @@ func TestConn_OnRemoteOffer(t *testing.T) { defer cancel() select { - case <-onNewOffeChan: + case <-onNewOfferChan: // success case <-ctx.Done(): t.Error("expected to receive a new offer notification, but timed out") @@ -118,10 +118,10 @@ func TestConn_OnRemoteAnswer(t *testing.T) { return } - onNewOffeChan := make(chan struct{}) + onNewOfferChan := make(chan struct{}) - conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) { - onNewOffeChan <- struct{}{} + conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) { + onNewOfferChan <- struct{}{} }) conn.OnRemoteAnswer(OfferAnswer{ @@ -136,7 +136,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) { defer cancel() select { - case <-onNewOffeChan: + case <-onNewOfferChan: // success case <-ctx.Done(): t.Error("expected to receive a new offer notification, but timed out") diff --git a/client/internal/peer/guard/env.go b/client/internal/peer/guard/env.go new file mode 100644 index 000000000..1ea2d21be --- /dev/null +++ b/client/internal/peer/guard/env.go @@ -0,0 +1,20 @@ +package guard + +import ( + "os" + "strconv" + "time" +) + +const ( + envICEMonitorPeriod = "NB_ICE_MONITOR_PERIOD" +) + +func GetICEMonitorPeriod() time.Duration { + if envVal := os.Getenv(envICEMonitorPeriod); envVal != "" { + if seconds, err := strconv.Atoi(envVal); err == nil && seconds > 0 { + return time.Duration(seconds) * time.Second + } + } + return defaultCandidatesMonitorPeriod +} diff --git a/client/internal/peer/guard/ice_monitor.go b/client/internal/peer/guard/ice_monitor.go index 09cf9ae63..0f22ee7b0 100644 --- a/client/internal/peer/guard/ice_monitor.go +++ b/client/internal/peer/guard/ice_monitor.go @@ -16,8 +16,8 @@ import ( ) const ( - candidatesMonitorPeriod = 5 * time.Minute - candidateGatheringTimeout = 5 * time.Second + defaultCandidatesMonitorPeriod = 5 * time.Minute + candidateGatheringTimeout = 5 * time.Second ) type ICEMonitor struct { @@ -25,16 +25,19 @@ type ICEMonitor struct { iFaceDiscover stdnet.ExternalIFaceDiscover iceConfig icemaker.Config + tickerPeriod time.Duration currentCandidatesAddress []string candidatesMu sync.Mutex } -func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor { +func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config, period time.Duration) *ICEMonitor { + log.Debugf("prepare ICE monitor with period: %s", period) cm := &ICEMonitor{ ReconnectCh: make(chan struct{}, 1), iFaceDiscover: iFaceDiscover, iceConfig: config, + tickerPeriod: period, } return cm } @@ -46,7 +49,12 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) { return } - ticker := time.NewTicker(candidatesMonitorPeriod) + // Initial check to populate the candidates for later comparison + if _, err := cm.handleCandidateTick(ctx, ufrag, pwd); err != nil { + log.Warnf("Failed to check initial ICE candidates: %v", err) + } + + ticker := time.NewTicker(cm.tickerPeriod) defer ticker.Stop() for { diff --git a/client/internal/peer/guard/sr_watcher.go b/client/internal/peer/guard/sr_watcher.go index 90e45426f..686430752 100644 --- a/client/internal/peer/guard/sr_watcher.go +++ b/client/internal/peer/guard/sr_watcher.go @@ -51,7 +51,7 @@ func (w *SRWatcher) Start() { ctx, cancel := context.WithCancel(context.Background()) w.cancelIceMonitor = cancel - iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig) + iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod()) go iceMonitor.Start(ctx, w.onICEChanged) w.signalClient.SetOnReconnectedListener(w.onReconnected) w.relayManager.SetOnReconnectedListener(w.onReconnected) diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go index 42eaea683..aff26f847 100644 --- a/client/internal/peer/handshaker.go +++ b/client/internal/peer/handshaker.go @@ -44,13 +44,19 @@ type OfferAnswer struct { } type Handshaker struct { - mu sync.Mutex - log *log.Entry - config ConnConfig - signaler *Signaler - ice *WorkerICE - relay *WorkerRelay - onNewOfferListeners []*OfferListener + mu sync.Mutex + log *log.Entry + config ConnConfig + signaler *Signaler + ice *WorkerICE + relay *WorkerRelay + // relayListener is not blocking because the listener is using a goroutine to process the messages + // and it will only keep the latest message if multiple offers are received in a short time + // this is to avoid blocking the handshaker if the listener is doing some heavy processing + // and also to avoid processing old offers if multiple offers are received in a short time + // the listener will always process the latest offer + relayListener *AsyncOfferListener + iceListener func(remoteOfferAnswer *OfferAnswer) // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection remoteOffersCh chan OfferAnswer @@ -70,28 +76,39 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W } } -func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) { - l := NewOfferListener(offer) - h.onNewOfferListeners = append(h.onNewOfferListeners, l) +func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) { + h.relayListener = NewAsyncOfferListener(offer) +} + +func (h *Handshaker) AddICEListener(offer func(remoteOfferAnswer *OfferAnswer)) { + h.iceListener = offer } func (h *Handshaker) Listen(ctx context.Context) { for { select { case remoteOfferAnswer := <-h.remoteOffersCh: - // received confirmation from the remote peer -> ready to proceed + h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) + if h.relayListener != nil { + h.relayListener.Notify(&remoteOfferAnswer) + } + + if h.iceListener != nil { + h.iceListener(&remoteOfferAnswer) + } + if err := h.sendAnswer(); err != nil { h.log.Errorf("failed to send remote offer confirmation: %s", err) continue } - for _, listener := range h.onNewOfferListeners { - listener.Notify(&remoteOfferAnswer) - } - h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) case remoteOfferAnswer := <-h.remoteAnswerCh: h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) - for _, listener := range h.onNewOfferListeners { - listener.Notify(&remoteOfferAnswer) + if h.relayListener != nil { + h.relayListener.Notify(&remoteOfferAnswer) + } + + if h.iceListener != nil { + h.iceListener(&remoteOfferAnswer) } case <-ctx.Done(): h.log.Infof("stop listening for remote offers and answers") diff --git a/client/internal/peer/handshaker_listener.go b/client/internal/peer/handshaker_listener.go index e2d3f3f38..772e2777f 100644 --- a/client/internal/peer/handshaker_listener.go +++ b/client/internal/peer/handshaker_listener.go @@ -13,20 +13,20 @@ func (oa *OfferAnswer) SessionIDString() string { return oa.SessionID.String() } -type OfferListener struct { +type AsyncOfferListener struct { fn callbackFunc running bool latest *OfferAnswer mu sync.Mutex } -func NewOfferListener(fn callbackFunc) *OfferListener { - return &OfferListener{ +func NewAsyncOfferListener(fn callbackFunc) *AsyncOfferListener { + return &AsyncOfferListener{ fn: fn, } } -func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) { +func (o *AsyncOfferListener) Notify(remoteOfferAnswer *OfferAnswer) { o.mu.Lock() defer o.mu.Unlock() diff --git a/client/internal/peer/handshaker_listener_test.go b/client/internal/peer/handshaker_listener_test.go index 8363741a5..1a7156d10 100644 --- a/client/internal/peer/handshaker_listener_test.go +++ b/client/internal/peer/handshaker_listener_test.go @@ -14,7 +14,7 @@ func Test_newOfferListener(t *testing.T) { runChan <- struct{}{} } - hl := NewOfferListener(longRunningFn) + hl := NewAsyncOfferListener(longRunningFn) hl.Notify(dummyOfferAnswer) hl.Notify(dummyOfferAnswer) diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index eb886a4d3..3675f0157 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -92,23 +92,16 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn * func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString()) w.muxAgent.Lock() + defer w.muxAgent.Unlock() - if w.agentConnecting { - w.log.Debugf("agent connection is in progress, skipping the offer") - w.muxAgent.Unlock() - return - } - - if w.agent != nil { + if w.agent != nil || w.agentConnecting { // backward compatibility with old clients that do not send session ID if remoteOfferAnswer.SessionID == nil { w.log.Debugf("agent already exists, skipping the offer") - w.muxAgent.Unlock() return } if w.remoteSessionID == *remoteOfferAnswer.SessionID { w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString()) - w.muxAgent.Unlock() return } w.log.Debugf("agent already exists, recreate the connection") @@ -116,6 +109,12 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { if err := w.agent.Close(); err != nil { w.log.Warnf("failed to close ICE agent: %s", err) } + + sessionID, err := NewICESessionID() + if err != nil { + w.log.Errorf("failed to create new session ID: %s", err) + } + w.sessionID = sessionID w.agent = nil } @@ -126,18 +125,23 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { preferredCandidateTypes = icemaker.CandidateTypes() } - w.log.Debugf("recreate ICE agent") + if remoteOfferAnswer.SessionID != nil { + w.log.Debugf("recreate ICE agent: %s / %s", w.sessionID, *remoteOfferAnswer.SessionID) + } dialerCtx, dialerCancel := context.WithCancel(w.ctx) agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes) if err != nil { w.log.Errorf("failed to recreate ICE Agent: %s", err) - w.muxAgent.Unlock() return } w.agent = agent w.agentDialerCancel = dialerCancel w.agentConnecting = true - w.muxAgent.Unlock() + if remoteOfferAnswer.SessionID != nil { + w.remoteSessionID = *remoteOfferAnswer.SessionID + } else { + w.remoteSessionID = "" + } go w.connect(dialerCtx, agent, remoteOfferAnswer) } @@ -293,9 +297,6 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent w.muxAgent.Lock() w.agentConnecting = false w.lastSuccess = time.Now() - if remoteOfferAnswer.SessionID != nil { - w.remoteSessionID = *remoteOfferAnswer.SessionID - } w.muxAgent.Unlock() // todo: the potential problem is a race between the onConnectionStateChange @@ -309,16 +310,17 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C } w.muxAgent.Lock() - // todo review does it make sense to generate new session ID all the time when w.agent==agent - sessionID, err := NewICESessionID() - if err != nil { - w.log.Errorf("failed to create new session ID: %s", err) - } - w.sessionID = sessionID if w.agent == agent { + // consider to remove from here and move to the OnNewOffer + sessionID, err := NewICESessionID() + if err != nil { + w.log.Errorf("failed to create new session ID: %s", err) + } + w.sessionID = sessionID w.agent = nil w.agentConnecting = false + w.remoteSessionID = "" } w.muxAgent.Unlock() } @@ -395,11 +397,12 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia // ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to // notify the conn.onICEStateDisconnected changes to update the current used priority + w.closeAgent(agent, dialerCancel) + if w.lastKnownState == ice.ConnectionStateConnected { w.lastKnownState = ice.ConnectionStateDisconnected w.conn.onICEStateDisconnected() } - w.closeAgent(agent, dialerCancel) default: return } From 654aa9581d15a7e2b2332f7f6ee395eeac40449c Mon Sep 17 00:00:00 2001 From: Ashley Date: Wed, 8 Oct 2025 20:27:32 +0100 Subject: [PATCH 08/29] [client,gui] Update url_windows.go to offer arm64 executable download (#4586) --- version/url_windows.go | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/version/url_windows.go b/version/url_windows.go index f2055b109..14fdb7ae6 100644 --- a/version/url_windows.go +++ b/version/url_windows.go @@ -1,9 +1,13 @@ package version -import "golang.org/x/sys/windows/registry" +import ( + "golang.org/x/sys/windows/registry" + "runtime" +) const ( urlWinExe = "https://pkgs.netbird.io/windows/x64" + urlWinExeArm = "https://pkgs.netbird.io/windows/arm64" ) var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Netbird" @@ -11,9 +15,14 @@ var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Ne // DownloadUrl return with the proper download link func DownloadUrl() string { _, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyAppPath, registry.QUERY_VALUE) - if err == nil { - return urlWinExe - } else { + if err != nil { return downloadURL } + + url := urlWinExe + if runtime.GOARCH == "arm64" { + url = urlWinExeArm + } + + return url } From 4e03f708a4cf9461f149b0227cbefae9abbad29e Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Thu, 9 Oct 2025 17:39:02 +0300 Subject: [PATCH 09/29] fix dns forwarder port update (#4613) fix dns forwarder port update (#4613) --- client/internal/dnsfwd/manager.go | 16 +++++++--------- client/internal/engine.go | 8 ++++++-- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index 5c7a3fbdd..a3a4ba40f 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -40,7 +40,6 @@ type Manager struct { fwRules []firewall.Rule tcpRules []firewall.Rule dnsForwarder *DNSForwarder - port uint16 } func ListenPort() uint16 { @@ -49,11 +48,16 @@ func ListenPort() uint16 { return listenPort } -func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager { +func SetListenPort(port uint16) { + listenPortMu.Lock() + listenPort = port + listenPortMu.Unlock() +} + +func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager { return &Manager{ firewall: fw, statusRecorder: statusRecorder, - port: port, } } @@ -67,12 +71,6 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } - if m.port > 0 { - listenPortMu.Lock() - listenPort = m.port - listenPortMu.Unlock() - } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder) go func() { if err := m.dnsForwarder.Listen(fwdEntries); err != nil { diff --git a/client/internal/engine.go b/client/internal/engine.go index 646e059d4..bebf04f6c 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1849,6 +1849,10 @@ func (e *Engine) updateDNSForwarder( return } + if forwarderPort > 0 { + dnsfwd.SetListenPort(forwarderPort) + } + if !enabled { if e.dnsForwardMgr == nil { return @@ -1862,7 +1866,7 @@ func (e *Engine) updateDNSForwarder( if len(fwdEntries) > 0 { switch { case e.dnsForwardMgr == nil: - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort) + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil @@ -1892,7 +1896,7 @@ func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPor if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { log.Errorf("failed to stop DNS forward: %v", err) } - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort) + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil From d35a845dbd336206f2b306384f1faca15cd4bd03 Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Thu, 9 Oct 2025 22:18:00 +0300 Subject: [PATCH 10/29] [management] sync all other peers on peer add/remove (#4614) --- management/server/peer.go | 27 ++------------------------- management/server/peer_test.go | 4 ++-- 2 files changed, 4 insertions(+), 27 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index 4cf5d1e46..218edf299 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -350,7 +350,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer } var peer *nbpeer.Peer - var updateAccountPeers bool var eventsToStore []func() err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { @@ -363,11 +362,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, peerID) - if err != nil { - return err - } - eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}) if err != nil { return fmt.Errorf("failed to delete peer: %w", err) @@ -387,7 +381,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer storeEvent() } - if updateAccountPeers && userID != activity.SystemInitiator { + if userID != activity.SystemInitiator { am.BufferUpdateAccountPeers(ctx, accountID) } @@ -684,11 +678,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err) } - updateAccountPeers, err := isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID) - if err != nil { - updateAccountPeers = true - } - if newPeer == nil { return nil, nil, nil, fmt.Errorf("new peer is nil") } @@ -701,9 +690,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) - if updateAccountPeers { - am.BufferUpdateAccountPeers(ctx, accountID) - } + am.BufferUpdateAccountPeers(ctx, accountID) return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) } @@ -1527,16 +1514,6 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID) } -// IsPeerInActiveGroup checks if the given peer is part of a group that is used -// in an active DNS, route, or ACL configuration. -func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) { - peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID) - if err != nil { - return false, err - } - return areGroupChangesAffectPeers(ctx, transaction, accountID, peerGroupIDs) // TODO: use transaction -} - // deletePeers deletes all specified peers and sends updates to the remote peers. // Returns a slice of functions to save events after successful peer deletion. func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 42b3244ae..fd795b926 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1790,7 +1790,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { t.Run("adding peer to unlinked group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) // + peerShouldReceiveUpdate(t, updMsg) // close(done) }() @@ -1815,7 +1815,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { t.Run("deleting peer with unlinked group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg) close(done) }() From bedd3cabc994d6bb9903a95fc6753331c987c25c Mon Sep 17 00:00:00 2001 From: Kostya Leschenko Date: Fri, 10 Oct 2025 16:24:24 +0300 Subject: [PATCH 11/29] [client] Explicitly disable DNSOverTLS for systemd-resolved (#4579) --- client/internal/dns/systemd_linux.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index 0e8a53a63..d9854c033 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -31,6 +31,7 @@ const ( systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute" systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains" systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC" + systemdDbusSetDNSOverTLSMethodSuffix = systemdDbusLinkInterface + ".SetDNSOverTLS" systemdDbusResolvConfModeForeign = "foreign" dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject" @@ -102,6 +103,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana log.Warnf("failed to set DNSSEC to 'no': %v", err) } + // We don't support DNSOverTLS. On some machines this is default on so we explicitly set it to off + if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethodSuffix, dnsSecDisabled); err != nil { + log.Warnf("failed to set DNSOverTLS to 'no': %v", err) + } + var ( searchDomains []string matchDomains []string From 5151f19d29514436f65fe40faebf4c89da06133d Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 10 Oct 2025 16:15:51 +0200 Subject: [PATCH 12/29] [management] pass temporary flag to validator (#4599) --- go.mod | 2 +- go.sum | 4 ++-- management/server/integrated_validator.go | 2 +- .../server/integrations/integrated_validator/interface.go | 4 ++-- management/server/peer.go | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index a1560b409..31b45e881 100644 --- a/go.mod +++ b/go.mod @@ -62,7 +62,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 + github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 13838b82d..6b0b298a7 100644 --- a/go.sum +++ b/go.sum @@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 21f11bfce..251c04273 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -136,7 +136,7 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID return validatedPeers, nil } -func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer { +func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer { return peer } diff --git a/management/server/integrations/integrated_validator/interface.go b/management/server/integrations/integrated_validator/interface.go index ce632d567..be05c2527 100644 --- a/management/server/integrations/integrated_validator/interface.go +++ b/management/server/integrations/integrated_validator/interface.go @@ -3,16 +3,16 @@ package integrated_validator import ( "context" - "github.com/netbirdio/netbird/shared/management/proto" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" ) // IntegratedValidator interface exists to avoid the circle dependencies type IntegratedValidator interface { ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) - PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer + PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error diff --git a/management/server/peer.go b/management/server/peer.go index 218edf299..30b7073ef 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -578,7 +578,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe } } - newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) + newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra, temporary) network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) if err != nil { From 0d2e67983ad2de4fc301514deff3409879819520 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 10 Oct 2025 14:16:48 -0300 Subject: [PATCH 13/29] [misc] Add service definition for netbird-signal (#4620) --- infrastructure_files/docker-compose.yml.tmpl.traefik | 1 + 1 file changed, 1 insertion(+) diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik index 196e26a66..0010974c5 100644 --- a/infrastructure_files/docker-compose.yml.tmpl.traefik +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -49,6 +49,7 @@ services: - traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal - traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80 - traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`) + - traefik.http.routers.netbird-signal.service=netbird-signal - traefik.http.services.netbird-signal.loadbalancer.server.port=10000 - traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c From 000e99e7f3f5c18741ee3412ba113c32c6ce3aa7 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 13 Oct 2025 17:50:16 +0200 Subject: [PATCH 14/29] [client] Force TLS1.2 for RDP with Win11/Server2025 for CredSSP compatibility (#4617) --- client/wasm/internal/rdp/cert_validation.go | 15 ++- client/wasm/internal/rdp/rdcleanpath.go | 93 ++++++++++++-- .../wasm/internal/rdp/rdcleanpath_handlers.go | 119 +++++++++--------- 3 files changed, 152 insertions(+), 75 deletions(-) diff --git a/client/wasm/internal/rdp/cert_validation.go b/client/wasm/internal/rdp/cert_validation.go index 4a23a4bc8..1678c3996 100644 --- a/client/wasm/internal/rdp/cert_validation.go +++ b/client/wasm/internal/rdp/cert_validation.go @@ -73,8 +73,8 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert } } -func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config { - return &tls.Config{ +func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config { + config := &tls.Config{ InsecureSkipVerify: true, // We'll validate manually after handshake VerifyConnection: func(cs tls.ConnectionState) error { var certChain [][]byte @@ -93,4 +93,15 @@ func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tl return nil }, } + + // CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3 + if requiresCredSSP { + config.MinVersion = tls.VersionTLS12 + config.MaxVersion = tls.VersionTLS12 + } else { + config.MinVersion = tls.VersionTLS12 + config.MaxVersion = tls.VersionTLS13 + } + + return config } diff --git a/client/wasm/internal/rdp/rdcleanpath.go b/client/wasm/internal/rdp/rdcleanpath.go index 8062a05cc..16bf63bb9 100644 --- a/client/wasm/internal/rdp/rdcleanpath.go +++ b/client/wasm/internal/rdp/rdcleanpath.go @@ -6,11 +6,13 @@ import ( "context" "crypto/tls" "encoding/asn1" + "errors" "fmt" "io" "net" "sync" "syscall/js" + "time" log "github.com/sirupsen/logrus" ) @@ -19,18 +21,34 @@ const ( RDCleanPathVersion = 3390 RDCleanPathProxyHost = "rdcleanpath.proxy.local" RDCleanPathProxyScheme = "ws" + + rdpDialTimeout = 15 * time.Second + + GeneralErrorCode = 1 + WSAETimedOut = 10060 + WSAEConnRefused = 10061 + WSAEConnAborted = 10053 + WSAEConnReset = 10054 + WSAEGenericError = 10050 ) type RDCleanPathPDU struct { - Version int64 `asn1:"tag:0,explicit"` - Error []byte `asn1:"tag:1,explicit,optional"` - Destination string `asn1:"utf8,tag:2,explicit,optional"` - ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"` - ServerAuth string `asn1:"utf8,tag:4,explicit,optional"` - PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"` - X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"` - ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"` - ServerAddr string `asn1:"utf8,tag:9,explicit,optional"` + Version int64 `asn1:"tag:0,explicit"` + Error RDCleanPathErr `asn1:"tag:1,explicit,optional"` + Destination string `asn1:"utf8,tag:2,explicit,optional"` + ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"` + ServerAuth string `asn1:"utf8,tag:4,explicit,optional"` + PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"` + X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"` + ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"` + ServerAddr string `asn1:"utf8,tag:9,explicit,optional"` +} + +type RDCleanPathErr struct { + ErrorCode int16 `asn1:"tag:0,explicit"` + HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"` + WSALastError int16 `asn1:"tag:2,explicit,optional"` + TLSAlertCode int8 `asn1:"tag:3,explicit,optional"` } type RDCleanPathProxy struct { @@ -210,9 +228,13 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket [] destination := conn.destination log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination) - rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) + ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout) + defer cancel() + + rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination) if err != nil { log.Errorf("Failed to connect to %s: %v", destination, err) + p.sendRDCleanPathError(conn, newWSAError(err)) return } conn.rdpConn = rdpConn @@ -220,6 +242,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket [] _, err = rdpConn.Write(firstPacket) if err != nil { log.Errorf("Failed to write first packet: %v", err) + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -227,6 +250,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket [] n, err := rdpConn.Read(response) if err != nil { log.Errorf("Failed to read X.224 response: %v", err) + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -269,3 +293,52 @@ func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) { conn.wsHandlers.Call("send", uint8Array.Get("buffer")) } } + +func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) { + data, err := asn1.Marshal(pdu) + if err != nil { + log.Errorf("Failed to marshal error PDU: %v", err) + return + } + p.sendToWebSocket(conn, data) +} + +func errorToWSACode(err error) int16 { + if err == nil { + return WSAEGenericError + } + var netErr *net.OpError + if errors.As(err, &netErr) && netErr.Timeout() { + return WSAETimedOut + } + if errors.Is(err, context.DeadlineExceeded) { + return WSAETimedOut + } + if errors.Is(err, context.Canceled) { + return WSAEConnAborted + } + if errors.Is(err, io.EOF) { + return WSAEConnReset + } + return WSAEGenericError +} + +func newWSAError(err error) RDCleanPathPDU { + return RDCleanPathPDU{ + Version: RDCleanPathVersion, + Error: RDCleanPathErr{ + ErrorCode: GeneralErrorCode, + WSALastError: errorToWSACode(err), + }, + } +} + +func newHTTPError(statusCode int16) RDCleanPathPDU { + return RDCleanPathPDU{ + Version: RDCleanPathVersion, + Error: RDCleanPathErr{ + ErrorCode: GeneralErrorCode, + HTTPStatusCode: statusCode, + }, + } +} diff --git a/client/wasm/internal/rdp/rdcleanpath_handlers.go b/client/wasm/internal/rdp/rdcleanpath_handlers.go index 010efa5ea..97bb46338 100644 --- a/client/wasm/internal/rdp/rdcleanpath_handlers.go +++ b/client/wasm/internal/rdp/rdcleanpath_handlers.go @@ -3,6 +3,7 @@ package rdp import ( + "context" "crypto/tls" "encoding/asn1" "io" @@ -11,11 +12,17 @@ import ( log "github.com/sirupsen/logrus" ) +const ( + // MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP) + protocolSSL = 0x00000001 + protocolHybridEx = 0x00000008 +) + func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination) if pdu.Version != RDCleanPathVersion { - p.sendRDCleanPathError(conn, "Unsupported version") + p.sendRDCleanPathError(conn, newHTTPError(400)) return } @@ -24,10 +31,13 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl destination = pdu.Destination } - rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) + ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout) + defer cancel() + + rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination) if err != nil { log.Errorf("Failed to connect to %s: %v", destination, err) - p.sendRDCleanPathError(conn, "Connection failed") + p.sendRDCleanPathError(conn, newWSAError(err)) p.cleanupConnection(conn) return } @@ -40,6 +50,34 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl p.setupTLSConnection(conn, pdu) } +// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required. +// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags. +// Returns (requiresTLS12, selectedProtocol, detectionSuccessful). +func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) { + const minResponseLength = 19 + + if len(x224Response) < minResponseLength { + return false, 0, false + } + + // Per X.224 specification: + // x224Response[0] == 0x03: Length of X.224 header (3 bytes) + // x224Response[5] == 0xD0: X.224 Data TPDU code + if x224Response[0] != 0x03 || x224Response[5] != 0xD0 { + return false, 0, false + } + + if x224Response[11] == 0x02 { + flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 | + uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24 + + hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0 + return hasNLA, flags, true + } + + return false, 0, false +} + func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) { var x224Response []byte if len(pdu.X224ConnectionPDU) > 0 { @@ -47,7 +85,7 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU) if err != nil { log.Errorf("Failed to write X.224 PDU: %v", err) - p.sendRDCleanPathError(conn, "Failed to forward X.224") + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -55,21 +93,32 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean n, err := conn.rdpConn.Read(response) if err != nil { log.Errorf("Failed to read X.224 response: %v", err) - p.sendRDCleanPathError(conn, "Failed to read X.224 response") + p.sendRDCleanPathError(conn, newWSAError(err)) return } x224Response = response[:n] log.Debugf("Received X.224 Connection Confirm (%d bytes)", n) } - tlsConfig := p.getTLSConfigWithValidation(conn) + requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response) + if detected { + if requiresCredSSP { + log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol) + } else { + log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol) + } + } else { + log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3") + } + + tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP) tlsConn := tls.Client(conn.rdpConn, tlsConfig) conn.tlsConn = tlsConn if err := tlsConn.Handshake(); err != nil { log.Errorf("TLS handshake failed: %v", err) - p.sendRDCleanPathError(conn, "TLS handshake failed") + p.sendRDCleanPathError(conn, newWSAError(err)) return } @@ -106,47 +155,6 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean p.cleanupConnection(conn) } -func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) { - if len(pdu.X224ConnectionPDU) > 0 { - log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU)) - _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU) - if err != nil { - log.Errorf("Failed to write X.224 PDU: %v", err) - p.sendRDCleanPathError(conn, "Failed to forward X.224") - return - } - - response := make([]byte, 1024) - n, err := conn.rdpConn.Read(response) - if err != nil { - log.Errorf("Failed to read X.224 response: %v", err) - p.sendRDCleanPathError(conn, "Failed to read X.224 response") - return - } - - responsePDU := RDCleanPathPDU{ - Version: RDCleanPathVersion, - X224ConnectionPDU: response[:n], - ServerAddr: conn.destination, - } - - p.sendRDCleanPathPDU(conn, responsePDU) - } else { - responsePDU := RDCleanPathPDU{ - Version: RDCleanPathVersion, - ServerAddr: conn.destination, - } - p.sendRDCleanPathPDU(conn, responsePDU) - } - - go p.forwardConnToWS(conn, conn.rdpConn, "TCP") - go p.forwardWSToConn(conn, conn.rdpConn, "TCP") - - <-conn.ctx.Done() - log.Debug("TCP connection context done, cleaning up") - p.cleanupConnection(conn) -} - func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { data, err := asn1.Marshal(pdu) if err != nil { @@ -158,21 +166,6 @@ func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDClean p.sendToWebSocket(conn, data) } -func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) { - pdu := RDCleanPathPDU{ - Version: RDCleanPathVersion, - Error: []byte(errorMsg), - } - - data, err := asn1.Marshal(pdu) - if err != nil { - log.Errorf("Failed to marshal error PDU: %v", err) - return - } - - p.sendToWebSocket(conn, data) -} - func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) { msgChan := make(chan []byte) errChan := make(chan error) From bb37dc89cef3d8338277503ff5e0bbfecb76ffca Mon Sep 17 00:00:00 2001 From: John Conley <8932043+jfrconley@users.noreply.github.com> Date: Thu, 16 Oct 2025 01:46:29 -0700 Subject: [PATCH 15/29] [management] feat: Basic PocketID IDP integration (#4529) --- management/server/idp/auth0_test.go | 8 +- management/server/idp/idp.go | 6 + management/server/idp/pocketid.go | 384 +++++++++++++++++++++++++ management/server/idp/pocketid_test.go | 138 +++++++++ 4 files changed, 533 insertions(+), 3 deletions(-) create mode 100644 management/server/idp/pocketid.go create mode 100644 management/server/idp/pocketid_test.go diff --git a/management/server/idp/auth0_test.go b/management/server/idp/auth0_test.go index 66c16870b..bc352f117 100644 --- a/management/server/idp/auth0_test.go +++ b/management/server/idp/auth0_test.go @@ -26,9 +26,11 @@ type mockHTTPClient struct { } func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) { - body, err := io.ReadAll(req.Body) - if err == nil { - c.reqBody = string(body) + if req.Body != nil { + body, err := io.ReadAll(req.Body) + if err == nil { + c.reqBody = string(body) + } } return &http.Response{ StatusCode: c.code, diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index 51f99b3b7..f06e57196 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -201,6 +201,12 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr APIToken: config.ExtraConfig["ApiToken"], } return NewJumpCloudManager(jumpcloudConfig, appMetrics) + case "pocketid": + pocketidConfig := PocketIdClientConfig{ + APIToken: config.ExtraConfig["ApiToken"], + ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"], + } + return NewPocketIdManager(pocketidConfig, appMetrics) default: return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType) } diff --git a/management/server/idp/pocketid.go b/management/server/idp/pocketid.go new file mode 100644 index 000000000..38a5cc67f --- /dev/null +++ b/management/server/idp/pocketid.go @@ -0,0 +1,384 @@ +package idp + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "slices" + "strings" + "time" + + "github.com/netbirdio/netbird/management/server/telemetry" +) + +type PocketIdManager struct { + managementEndpoint string + apiToken string + httpClient ManagerHTTPClient + credentials ManagerCredentials + helper ManagerHelper + appMetrics telemetry.AppMetrics +} + +type pocketIdCustomClaimDto struct { + Key string `json:"key"` + Value string `json:"value"` +} + +type pocketIdUserDto struct { + CustomClaims []pocketIdCustomClaimDto `json:"customClaims"` + Disabled bool `json:"disabled"` + DisplayName string `json:"displayName"` + Email string `json:"email"` + FirstName string `json:"firstName"` + ID string `json:"id"` + IsAdmin bool `json:"isAdmin"` + LastName string `json:"lastName"` + LdapID string `json:"ldapId"` + Locale string `json:"locale"` + UserGroups []pocketIdUserGroupDto `json:"userGroups"` + Username string `json:"username"` +} + +type pocketIdUserCreateDto struct { + Disabled bool `json:"disabled,omitempty"` + DisplayName string `json:"displayName"` + Email string `json:"email"` + FirstName string `json:"firstName"` + IsAdmin bool `json:"isAdmin,omitempty"` + LastName string `json:"lastName,omitempty"` + Locale string `json:"locale,omitempty"` + Username string `json:"username"` +} + +type pocketIdPaginatedUserDto struct { + Data []pocketIdUserDto `json:"data"` + Pagination pocketIdPaginationDto `json:"pagination"` +} + +type pocketIdPaginationDto struct { + CurrentPage int `json:"currentPage"` + ItemsPerPage int `json:"itemsPerPage"` + TotalItems int `json:"totalItems"` + TotalPages int `json:"totalPages"` +} + +func (p *pocketIdUserDto) userData() *UserData { + return &UserData{ + Email: p.Email, + Name: p.DisplayName, + ID: p.ID, + AppMetadata: AppMetadata{}, + } +} + +type pocketIdUserGroupDto struct { + CreatedAt string `json:"createdAt"` + CustomClaims []pocketIdCustomClaimDto `json:"customClaims"` + FriendlyName string `json:"friendlyName"` + ID string `json:"id"` + LdapID string `json:"ldapId"` + Name string `json:"name"` +} + +func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMetrics) (*PocketIdManager, error) { + httpTransport := http.DefaultTransport.(*http.Transport).Clone() + httpTransport.MaxIdleConns = 5 + + httpClient := &http.Client{ + Timeout: 10 * time.Second, + Transport: httpTransport, + } + helper := JsonParser{} + + if config.ManagementEndpoint == "" { + return nil, fmt.Errorf("pocketId IdP configuration is incomplete, ManagementEndpoint is missing") + } + + if config.APIToken == "" { + return nil, fmt.Errorf("pocketId IdP configuration is incomplete, APIToken is missing") + } + + credentials := &PocketIdCredentials{ + clientConfig: config, + httpClient: httpClient, + helper: helper, + appMetrics: appMetrics, + } + + return &PocketIdManager{ + managementEndpoint: config.ManagementEndpoint, + apiToken: config.APIToken, + httpClient: httpClient, + credentials: credentials, + helper: helper, + appMetrics: appMetrics, + }, nil +} + +func (p *PocketIdManager) request(ctx context.Context, method, resource string, query *url.Values, body string) ([]byte, error) { + var MethodsWithBody = []string{http.MethodPost, http.MethodPut} + if !slices.Contains(MethodsWithBody, method) && body != "" { + return nil, fmt.Errorf("Body provided to unsupported method: %s", method) + } + + reqURL := fmt.Sprintf("%s/api/%s", p.managementEndpoint, resource) + if query != nil { + reqURL = fmt.Sprintf("%s?%s", reqURL, query.Encode()) + } + var req *http.Request + var err error + if body != "" { + req, err = http.NewRequestWithContext(ctx, method, reqURL, strings.NewReader(body)) + } else { + req, err = http.NewRequestWithContext(ctx, method, reqURL, nil) + } + if err != nil { + return nil, err + } + + req.Header.Add("X-API-KEY", p.apiToken) + + if body != "" { + req.Header.Add("content-type", "application/json") + req.Header.Add("content-length", fmt.Sprintf("%d", req.ContentLength)) + } + + resp, err := p.httpClient.Do(req) + if err != nil { + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountRequestError() + } + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountRequestStatusError() + } + + return nil, fmt.Errorf("received unexpected status code from PocketID API: %d", resp.StatusCode) + } + + return io.ReadAll(resp.Body) +} + +// getAllUsersPaginated fetches all users from PocketID API using pagination +func (p *PocketIdManager) getAllUsersPaginated(ctx context.Context, searchParams url.Values) ([]pocketIdUserDto, error) { + var allUsers []pocketIdUserDto + currentPage := 1 + + for { + params := url.Values{} + // Copy existing search parameters + for key, values := range searchParams { + params[key] = values + } + + params.Set("pagination[limit]", "100") + params.Set("pagination[page]", fmt.Sprintf("%d", currentPage)) + + body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "") + if err != nil { + return nil, err + } + + var profiles pocketIdPaginatedUserDto + err = p.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + allUsers = append(allUsers, profiles.Data...) + + // Check if we've reached the last page + if currentPage >= profiles.Pagination.TotalPages { + break + } + + currentPage++ + } + + return allUsers, nil +} + +func (p *PocketIdManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error { + return nil +} + +func (p *PocketIdManager) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) { + body, err := p.request(ctx, http.MethodGet, "users/"+userId, nil, "") + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetUserDataByID() + } + + var user pocketIdUserDto + err = p.helper.Unmarshal(body, &user) + if err != nil { + return nil, err + } + + userData := user.userData() + userData.AppMetadata = appMetadata + + return userData, nil +} + +func (p *PocketIdManager) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) { + // Get all users using pagination + allUsers, err := p.getAllUsersPaginated(ctx, url.Values{}) + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetAccount() + } + + users := make([]*UserData, 0) + for _, profile := range allUsers { + userData := profile.userData() + userData.AppMetadata.WTAccountID = accountId + + users = append(users, userData) + } + return users, nil +} + +func (p *PocketIdManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) { + // Get all users using pagination + allUsers, err := p.getAllUsersPaginated(ctx, url.Values{}) + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetAllAccounts() + } + + indexedUsers := make(map[string][]*UserData) + for _, profile := range allUsers { + userData := profile.userData() + indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData) + } + + return indexedUsers, nil +} + +func (p *PocketIdManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) { + firstLast := strings.Split(name, " ") + + createUser := pocketIdUserCreateDto{ + Disabled: false, + DisplayName: name, + Email: email, + FirstName: firstLast[0], + LastName: firstLast[1], + Username: firstLast[0] + "." + firstLast[1], + } + payload, err := p.helper.Marshal(createUser) + if err != nil { + return nil, err + } + + body, err := p.request(ctx, http.MethodPost, "users", nil, string(payload)) + if err != nil { + return nil, err + } + var newUser pocketIdUserDto + err = p.helper.Unmarshal(body, &newUser) + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountCreateUser() + } + var pending bool = true + ret := &UserData{ + Email: email, + Name: name, + ID: newUser.ID, + AppMetadata: AppMetadata{ + WTAccountID: accountID, + WTPendingInvite: &pending, + WTInvitedBy: invitedByEmail, + }, + } + return ret, nil +} + +func (p *PocketIdManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) { + params := url.Values{ + // This value a + "search": []string{email}, + } + body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "") + if err != nil { + return nil, err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountGetUserByEmail() + } + + var profiles struct{ data []pocketIdUserDto } + err = p.helper.Unmarshal(body, &profiles) + if err != nil { + return nil, err + } + + users := make([]*UserData, 0) + for _, profile := range profiles.data { + users = append(users, profile.userData()) + } + return users, nil +} + +func (p *PocketIdManager) InviteUserByID(ctx context.Context, userID string) error { + _, err := p.request(ctx, http.MethodPut, "users/"+userID+"/one-time-access-email", nil, "") + if err != nil { + return err + } + return nil +} + +func (p *PocketIdManager) DeleteUser(ctx context.Context, userID string) error { + _, err := p.request(ctx, http.MethodDelete, "users/"+userID, nil, "") + if err != nil { + return err + } + + if p.appMetrics != nil { + p.appMetrics.IDPMetrics().CountDeleteUser() + } + + return nil +} + +var _ Manager = (*PocketIdManager)(nil) + +type PocketIdClientConfig struct { + APIToken string + ManagementEndpoint string +} + +type PocketIdCredentials struct { + clientConfig PocketIdClientConfig + helper ManagerHelper + httpClient ManagerHTTPClient + appMetrics telemetry.AppMetrics +} + +var _ ManagerCredentials = (*PocketIdCredentials)(nil) + +func (p PocketIdCredentials) Authenticate(_ context.Context) (JWTToken, error) { + return JWTToken{}, nil +} diff --git a/management/server/idp/pocketid_test.go b/management/server/idp/pocketid_test.go new file mode 100644 index 000000000..49075a0d3 --- /dev/null +++ b/management/server/idp/pocketid_test.go @@ -0,0 +1,138 @@ +package idp + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/telemetry" +) + + +func TestNewPocketIdManager(t *testing.T) { + type test struct { + name string + inputConfig PocketIdClientConfig + assertErrFunc require.ErrorAssertionFunc + assertErrFuncMessage string + } + + defaultTestConfig := PocketIdClientConfig{ + APIToken: "api_token", + ManagementEndpoint: "http://localhost", + } + + tests := []test{ + { + name: "Good Configuration", + inputConfig: defaultTestConfig, + assertErrFunc: require.NoError, + assertErrFuncMessage: "shouldn't return error", + }, + { + name: "Missing ManagementEndpoint", + inputConfig: PocketIdClientConfig{ + APIToken: defaultTestConfig.APIToken, + ManagementEndpoint: "", + }, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + }, + { + name: "Missing APIToken", + inputConfig: PocketIdClientConfig{ + APIToken: "", + ManagementEndpoint: defaultTestConfig.ManagementEndpoint, + }, + assertErrFunc: require.Error, + assertErrFuncMessage: "should return error when field empty", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{}) + tc.assertErrFunc(t, err, tc.assertErrFuncMessage) + }) + } +} + +func TestPocketID_GetUserDataByID(t *testing.T) { + client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + md := AppMetadata{WTAccountID: "acc1"} + got, err := mgr.GetUserDataByID(context.Background(), "u1", md) + require.NoError(t, err) + assert.Equal(t, "u1", got.ID) + assert.Equal(t, "user1@example.com", got.Email) + assert.Equal(t, "User One", got.Name) + assert.Equal(t, "acc1", got.AppMetadata.WTAccountID) +} + +func TestPocketID_GetAccount_WithPagination(t *testing.T) { + // Single page response with two users + client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + users, err := mgr.GetAccount(context.Background(), "accX") + require.NoError(t, err) + require.Len(t, users, 2) + assert.Equal(t, "u1", users[0].ID) + assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID) + assert.Equal(t, "u2", users[1].ID) +} + +func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) { + client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + accounts, err := mgr.GetAllAccounts(context.Background()) + require.NoError(t, err) + require.Len(t, accounts[UnsetAccountID], 2) +} + +func TestPocketID_CreateUser(t *testing.T) { + client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com") + require.NoError(t, err) + assert.Equal(t, "newid", ud.ID) + assert.Equal(t, "new@example.com", ud.Email) + assert.Equal(t, "New User", ud.Name) + assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID) + if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) { + assert.True(t, *ud.AppMetadata.WTPendingInvite) + } + assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy) +} + +func TestPocketID_InviteAndDeleteUser(t *testing.T) { + // Same mock for both calls; returns OK with empty JSON + client := &mockHTTPClient{code: 200, resBody: `{}`} + + mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil) + require.NoError(t, err) + mgr.httpClient = client + + err = mgr.InviteUserByID(context.Background(), "u1") + require.NoError(t, err) + + err = mgr.DeleteUser(context.Background(), "u1") + require.NoError(t, err) +} From 277aa2b7cc82368d6ec0c6f73acfdfa676ad2c9b Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 16 Oct 2025 15:13:41 +0200 Subject: [PATCH 16/29] [client] Fix missing flag values in profiles (#4650) --- client/server/server.go | 7 + client/server/setconfig_test.go | 298 ++++++++++++++++++++++++++++++++ 2 files changed, 305 insertions(+) create mode 100644 client/server/setconfig_test.go diff --git a/client/server/server.go b/client/server/server.go index e6de608c5..89f50a1ef 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -353,6 +353,13 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques config.CustomDNSAddress = []byte{} } + config.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist + + if msg.DnsRouteInterval != nil { + interval := msg.DnsRouteInterval.AsDuration() + config.DNSRouteInterval = &interval + } + config.RosenpassEnabled = msg.RosenpassEnabled config.RosenpassPermissive = msg.RosenpassPermissive config.DisableAutoConnect = msg.DisableAutoConnect diff --git a/client/server/setconfig_test.go b/client/server/setconfig_test.go new file mode 100644 index 000000000..1260bcc78 --- /dev/null +++ b/client/server/setconfig_test.go @@ -0,0 +1,298 @@ +package server + +import ( + "context" + "os/user" + "path/filepath" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" + + "github.com/netbirdio/netbird/client/internal/profilemanager" + "github.com/netbirdio/netbird/client/proto" +) + +// TestSetConfig_AllFieldsSaved ensures that all fields in SetConfigRequest are properly saved to the config. +// This test uses reflection to detect when new fields are added but not handled in SetConfig. +func TestSetConfig_AllFieldsSaved(t *testing.T) { + tempDir := t.TempDir() + origDefaultProfileDir := profilemanager.DefaultConfigPathDir + origDefaultConfigPath := profilemanager.DefaultConfigPath + origActiveProfileStatePath := profilemanager.ActiveProfileStatePath + profilemanager.ConfigDirOverride = tempDir + profilemanager.DefaultConfigPathDir = tempDir + profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json" + profilemanager.DefaultConfigPath = filepath.Join(tempDir, "default.json") + t.Cleanup(func() { + profilemanager.DefaultConfigPathDir = origDefaultProfileDir + profilemanager.ActiveProfileStatePath = origActiveProfileStatePath + profilemanager.DefaultConfigPath = origDefaultConfigPath + profilemanager.ConfigDirOverride = "" + }) + + currUser, err := user.Current() + require.NoError(t, err) + + profName := "test-profile" + + ic := profilemanager.ConfigInput{ + ConfigPath: filepath.Join(tempDir, profName+".json"), + ManagementURL: "https://api.netbird.io:443", + } + _, err = profilemanager.UpdateOrCreateConfig(ic) + require.NoError(t, err) + + pm := profilemanager.ServiceManager{} + err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{ + Name: profName, + Username: currUser.Username, + }) + require.NoError(t, err) + + ctx := context.Background() + s := New(ctx, "console", "", false, false) + + rosenpassEnabled := true + rosenpassPermissive := true + serverSSHAllowed := true + interfaceName := "utun100" + wireguardPort := int64(51820) + preSharedKey := "test-psk" + disableAutoConnect := true + networkMonitor := true + disableClientRoutes := true + disableServerRoutes := true + disableDNS := true + disableFirewall := true + blockLANAccess := true + disableNotifications := true + lazyConnectionEnabled := true + blockInbound := true + mtu := int64(1280) + + req := &proto.SetConfigRequest{ + ProfileName: profName, + Username: currUser.Username, + ManagementUrl: "https://new-api.netbird.io:443", + AdminURL: "https://new-admin.netbird.io", + RosenpassEnabled: &rosenpassEnabled, + RosenpassPermissive: &rosenpassPermissive, + ServerSSHAllowed: &serverSSHAllowed, + InterfaceName: &interfaceName, + WireguardPort: &wireguardPort, + OptionalPreSharedKey: &preSharedKey, + DisableAutoConnect: &disableAutoConnect, + NetworkMonitor: &networkMonitor, + DisableClientRoutes: &disableClientRoutes, + DisableServerRoutes: &disableServerRoutes, + DisableDns: &disableDNS, + DisableFirewall: &disableFirewall, + BlockLanAccess: &blockLANAccess, + DisableNotifications: &disableNotifications, + LazyConnectionEnabled: &lazyConnectionEnabled, + BlockInbound: &blockInbound, + NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"}, + CleanNATExternalIPs: false, + CustomDNSAddress: []byte("1.1.1.1:53"), + ExtraIFaceBlacklist: []string{"eth1", "eth2"}, + DnsLabels: []string{"label1", "label2"}, + CleanDNSLabels: false, + DnsRouteInterval: durationpb.New(2 * time.Minute), + Mtu: &mtu, + } + + _, err = s.SetConfig(ctx, req) + require.NoError(t, err) + + profState := profilemanager.ActiveProfileState{ + Name: profName, + Username: currUser.Username, + } + cfgPath, err := profState.FilePath() + require.NoError(t, err) + + cfg, err := profilemanager.GetConfig(cfgPath) + require.NoError(t, err) + + require.Equal(t, "https://new-api.netbird.io:443", cfg.ManagementURL.String()) + require.Equal(t, "https://new-admin.netbird.io:443", cfg.AdminURL.String()) + require.Equal(t, rosenpassEnabled, cfg.RosenpassEnabled) + require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive) + require.NotNil(t, cfg.ServerSSHAllowed) + require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed) + require.Equal(t, interfaceName, cfg.WgIface) + require.Equal(t, int(wireguardPort), cfg.WgPort) + require.Equal(t, preSharedKey, cfg.PreSharedKey) + require.Equal(t, disableAutoConnect, cfg.DisableAutoConnect) + require.NotNil(t, cfg.NetworkMonitor) + require.Equal(t, networkMonitor, *cfg.NetworkMonitor) + require.Equal(t, disableClientRoutes, cfg.DisableClientRoutes) + require.Equal(t, disableServerRoutes, cfg.DisableServerRoutes) + require.Equal(t, disableDNS, cfg.DisableDNS) + require.Equal(t, disableFirewall, cfg.DisableFirewall) + require.Equal(t, blockLANAccess, cfg.BlockLANAccess) + require.NotNil(t, cfg.DisableNotifications) + require.Equal(t, disableNotifications, *cfg.DisableNotifications) + require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled) + require.Equal(t, blockInbound, cfg.BlockInbound) + require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs) + require.Equal(t, "1.1.1.1:53", cfg.CustomDNSAddress) + // IFaceBlackList contains defaults + extras + require.Contains(t, cfg.IFaceBlackList, "eth1") + require.Contains(t, cfg.IFaceBlackList, "eth2") + require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList()) + require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval) + require.Equal(t, uint16(mtu), cfg.MTU) + + verifyAllFieldsCovered(t, req) +} + +// verifyAllFieldsCovered uses reflection to ensure we're testing all fields in SetConfigRequest. +// If a new field is added to SetConfigRequest, this function will fail the test, +// forcing the developer to update both the SetConfig handler and this test. +func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) { + t.Helper() + + metadataFields := map[string]bool{ + "state": true, // protobuf internal + "sizeCache": true, // protobuf internal + "unknownFields": true, // protobuf internal + "Username": true, // metadata + "ProfileName": true, // metadata + "CleanNATExternalIPs": true, // control flag for clearing + "CleanDNSLabels": true, // control flag for clearing + } + + expectedFields := map[string]bool{ + "ManagementUrl": true, + "AdminURL": true, + "RosenpassEnabled": true, + "RosenpassPermissive": true, + "ServerSSHAllowed": true, + "InterfaceName": true, + "WireguardPort": true, + "OptionalPreSharedKey": true, + "DisableAutoConnect": true, + "NetworkMonitor": true, + "DisableClientRoutes": true, + "DisableServerRoutes": true, + "DisableDns": true, + "DisableFirewall": true, + "BlockLanAccess": true, + "DisableNotifications": true, + "LazyConnectionEnabled": true, + "BlockInbound": true, + "NatExternalIPs": true, + "CustomDNSAddress": true, + "ExtraIFaceBlacklist": true, + "DnsLabels": true, + "DnsRouteInterval": true, + "Mtu": true, + } + + val := reflect.ValueOf(req).Elem() + typ := val.Type() + + var unexpectedFields []string + for i := 0; i < val.NumField(); i++ { + field := typ.Field(i) + fieldName := field.Name + + if metadataFields[fieldName] { + continue + } + + if !expectedFields[fieldName] { + unexpectedFields = append(unexpectedFields, fieldName) + } + } + + if len(unexpectedFields) > 0 { + t.Fatalf("New field(s) detected in SetConfigRequest: %v", unexpectedFields) + } +} + +// TestCLIFlags_MappedToSetConfig ensures all CLI flags that modify config are properly mapped to SetConfigRequest. +// This test catches bugs where a new CLI flag is added but not wired to the SetConfigRequest in setupSetConfigReq. +func TestCLIFlags_MappedToSetConfig(t *testing.T) { + // Map of CLI flag names to their corresponding SetConfigRequest field names. + // This map must be updated when adding new config-related CLI flags. + flagToField := map[string]string{ + "management-url": "ManagementUrl", + "admin-url": "AdminURL", + "enable-rosenpass": "RosenpassEnabled", + "rosenpass-permissive": "RosenpassPermissive", + "allow-server-ssh": "ServerSSHAllowed", + "interface-name": "InterfaceName", + "wireguard-port": "WireguardPort", + "preshared-key": "OptionalPreSharedKey", + "disable-auto-connect": "DisableAutoConnect", + "network-monitor": "NetworkMonitor", + "disable-client-routes": "DisableClientRoutes", + "disable-server-routes": "DisableServerRoutes", + "disable-dns": "DisableDns", + "disable-firewall": "DisableFirewall", + "block-lan-access": "BlockLanAccess", + "block-inbound": "BlockInbound", + "enable-lazy-connection": "LazyConnectionEnabled", + "external-ip-map": "NatExternalIPs", + "dns-resolver-address": "CustomDNSAddress", + "extra-iface-blacklist": "ExtraIFaceBlacklist", + "extra-dns-labels": "DnsLabels", + "dns-router-interval": "DnsRouteInterval", + "mtu": "Mtu", + } + + // SetConfigRequest fields that don't have CLI flags (settable only via UI or other means). + fieldsWithoutCLIFlags := map[string]bool{ + "DisableNotifications": true, // Only settable via UI + } + + // Get all SetConfigRequest fields to verify our map is complete. + req := &proto.SetConfigRequest{} + val := reflect.ValueOf(req).Elem() + typ := val.Type() + + var unmappedFields []string + for i := 0; i < val.NumField(); i++ { + field := typ.Field(i) + fieldName := field.Name + + // Skip protobuf internal fields and metadata fields. + if fieldName == "state" || fieldName == "sizeCache" || fieldName == "unknownFields" { + continue + } + if fieldName == "Username" || fieldName == "ProfileName" { + continue + } + if fieldName == "CleanNATExternalIPs" || fieldName == "CleanDNSLabels" { + continue + } + + // Check if this field is either mapped to a CLI flag or explicitly documented as having no CLI flag. + mappedToCLI := false + for _, mappedField := range flagToField { + if mappedField == fieldName { + mappedToCLI = true + break + } + } + + hasNoCLIFlag := fieldsWithoutCLIFlags[fieldName] + + if !mappedToCLI && !hasNoCLIFlag { + unmappedFields = append(unmappedFields, fieldName) + } + } + + if len(unmappedFields) > 0 { + t.Fatalf("SetConfigRequest field(s) not documented: %v\n"+ + "Either add the CLI flag to flagToField map, or if there's no CLI flag for this field, "+ + "add it to fieldsWithoutCLIFlags map with a comment explaining why.", unmappedFields) + } + + t.Log("All SetConfigRequest fields are properly documented") +} From 8252ff41db5fed928a1666e84061ee30d19addf9 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 16 Oct 2025 15:58:29 +0200 Subject: [PATCH 17/29] [client] Add bind activity listener to bypass udp sockets (#4646) --- client/iface/device.go | 1 + client/iface/device/device_android.go | 5 + client/iface/device/device_darwin.go | 5 + client/iface/device/device_ios.go | 5 + client/iface/device/device_kernel_unix.go | 5 + client/iface/device/device_netstack.go | 6 + client/iface/device/device_usp_unix.go | 5 + client/iface/device/device_windows.go | 5 + client/iface/device/endpoint_manager.go | 13 + client/iface/device_android.go | 1 + client/iface/iface.go | 11 + .../internal/lazyconn/activity/lazy_conn.go | 82 +++++ .../lazyconn/activity/listener_bind.go | 127 ++++++++ .../lazyconn/activity/listener_bind_test.go | 291 ++++++++++++++++++ .../lazyconn/activity/listener_test.go | 41 --- .../activity/{listener.go => listener_udp.go} | 35 ++- .../lazyconn/activity/listener_udp_test.go | 110 +++++++ client/internal/lazyconn/activity/manager.go | 45 ++- .../lazyconn/activity/manager_test.go | 38 ++- client/internal/lazyconn/wgiface.go | 2 + 20 files changed, 760 insertions(+), 73 deletions(-) create mode 100644 client/iface/device/endpoint_manager.go create mode 100644 client/internal/lazyconn/activity/lazy_conn.go create mode 100644 client/internal/lazyconn/activity/listener_bind.go create mode 100644 client/internal/lazyconn/activity/listener_bind_test.go delete mode 100644 client/internal/lazyconn/activity/listener_test.go rename client/internal/lazyconn/activity/{listener.go => listener_udp.go} (64%) create mode 100644 client/internal/lazyconn/activity/listener_udp_test.go diff --git a/client/iface/device.go b/client/iface/device.go index 921f0ea98..c0c829825 100644 --- a/client/iface/device.go +++ b/client/iface/device.go @@ -23,4 +23,5 @@ type WGTunDevice interface { FilteredDevice() *device.FilteredDevice Device() *wgdevice.Device GetNet() *netstack.Net + GetICEBind() device.EndpointManager } diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index a731684cc..48346fc0f 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -150,6 +150,11 @@ func (t *WGTunDevice) GetNet() *netstack.Net { return nil } +// GetICEBind returns the ICEBind instance +func (t *WGTunDevice) GetICEBind() EndpointManager { + return t.iceBind +} + func routesToString(routes []string) string { return strings.Join(routes, ";") } diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index 390efe088..acd5f6f11 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -154,3 +154,8 @@ func (t *TunDevice) assignAddr() error { func (t *TunDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *TunDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index 96e4c8bcf..f96edf992 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -144,3 +144,8 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice { func (t *TunDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *TunDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index cdac43a53..2a836f846 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -179,3 +179,8 @@ func (t *TunKernelDevice) assignAddr() error { func (t *TunKernelDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns nil for kernel mode devices +func (t *TunKernelDevice) GetICEBind() EndpointManager { + return nil +} diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index e37321b68..40d8fdac8 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -21,6 +21,7 @@ type Bind interface { conn.Bind GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) ActivityRecorder() *bind.ActivityRecorder + EndpointManager } type TunNetstackDevice struct { @@ -155,3 +156,8 @@ func (t *TunNetstackDevice) Device() *device.Device { func (t *TunNetstackDevice) GetNet() *netstack.Net { return t.net } + +// GetICEBind returns the bind instance +func (t *TunNetstackDevice) GetICEBind() EndpointManager { + return t.bind +} diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 4cdd70a32..24654fc03 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -146,3 +146,8 @@ func (t *USPDevice) assignAddr() error { func (t *USPDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *USPDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index f1023bc0a..96350df8a 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -185,3 +185,8 @@ func (t *TunDevice) assignAddr() error { func (t *TunDevice) GetNet() *netstack.Net { return nil } + +// GetICEBind returns the ICEBind instance +func (t *TunDevice) GetICEBind() EndpointManager { + return t.iceBind +} diff --git a/client/iface/device/endpoint_manager.go b/client/iface/device/endpoint_manager.go new file mode 100644 index 000000000..b53888baa --- /dev/null +++ b/client/iface/device/endpoint_manager.go @@ -0,0 +1,13 @@ +package device + +import ( + "net" + "net/netip" +) + +// EndpointManager manages fake IP to connection mappings for userspace bind implementations. +// Implemented by bind.ICEBind and bind.RelayBindJS. +type EndpointManager interface { + SetEndpoint(fakeIP netip.Addr, conn net.Conn) + RemoveEndpoint(fakeIP netip.Addr) +} diff --git a/client/iface/device_android.go b/client/iface/device_android.go index 4649b8b97..cdfcea48d 100644 --- a/client/iface/device_android.go +++ b/client/iface/device_android.go @@ -21,4 +21,5 @@ type WGTunDevice interface { FilteredDevice() *device.FilteredDevice Device() *wgdevice.Device GetNet() *netstack.Net + GetICEBind() device.EndpointManager } diff --git a/client/iface/iface.go b/client/iface/iface.go index 158672160..07235a995 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -80,6 +80,17 @@ func (w *WGIface) GetProxy() wgproxy.Proxy { return w.wgProxyFactory.GetProxy() } +// GetBind returns the EndpointManager userspace bind mode. +func (w *WGIface) GetBind() device.EndpointManager { + w.mu.Lock() + defer w.mu.Unlock() + + if w.tun == nil { + return nil + } + return w.tun.GetICEBind() +} + // IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind func (w *WGIface) IsUserspaceBind() bool { return w.userspaceBind diff --git a/client/internal/lazyconn/activity/lazy_conn.go b/client/internal/lazyconn/activity/lazy_conn.go new file mode 100644 index 000000000..2564a9905 --- /dev/null +++ b/client/internal/lazyconn/activity/lazy_conn.go @@ -0,0 +1,82 @@ +package activity + +import ( + "context" + "io" + "net" + "time" +) + +// lazyConn detects activity when WireGuard attempts to send packets. +// It does not deliver packets, only signals that activity occurred. +type lazyConn struct { + activityCh chan struct{} + ctx context.Context + cancel context.CancelFunc +} + +// newLazyConn creates a new lazyConn for activity detection. +func newLazyConn() *lazyConn { + ctx, cancel := context.WithCancel(context.Background()) + return &lazyConn{ + activityCh: make(chan struct{}, 1), + ctx: ctx, + cancel: cancel, + } +} + +// Read blocks until the connection is closed. +func (c *lazyConn) Read(_ []byte) (n int, err error) { + <-c.ctx.Done() + return 0, io.EOF +} + +// Write signals activity detection when ICEBind routes packets to this endpoint. +func (c *lazyConn) Write(b []byte) (n int, err error) { + if c.ctx.Err() != nil { + return 0, io.EOF + } + + select { + case c.activityCh <- struct{}{}: + default: + } + + return len(b), nil +} + +// ActivityChan returns the channel that signals when activity is detected. +func (c *lazyConn) ActivityChan() <-chan struct{} { + return c.activityCh +} + +// Close closes the connection. +func (c *lazyConn) Close() error { + c.cancel() + return nil +} + +// LocalAddr returns the local address. +func (c *lazyConn) LocalAddr() net.Addr { + return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort} +} + +// RemoteAddr returns the remote address. +func (c *lazyConn) RemoteAddr() net.Addr { + return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort} +} + +// SetDeadline sets the read and write deadlines. +func (c *lazyConn) SetDeadline(_ time.Time) error { + return nil +} + +// SetReadDeadline sets the deadline for future Read calls. +func (c *lazyConn) SetReadDeadline(_ time.Time) error { + return nil +} + +// SetWriteDeadline sets the deadline for future Write calls. +func (c *lazyConn) SetWriteDeadline(_ time.Time) error { + return nil +} diff --git a/client/internal/lazyconn/activity/listener_bind.go b/client/internal/lazyconn/activity/listener_bind.go new file mode 100644 index 000000000..792d04215 --- /dev/null +++ b/client/internal/lazyconn/activity/listener_bind.go @@ -0,0 +1,127 @@ +package activity + +import ( + "fmt" + "net" + "net/netip" + "sync" + + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/internal/lazyconn" +) + +type bindProvider interface { + GetBind() device.EndpointManager +} + +const ( + // lazyBindPort is an obscure port used for lazy peer endpoints to avoid confusion with real peers. + // The actual routing is done via fakeIP in ICEBind, not by this port. + lazyBindPort = 17473 +) + +// BindListener uses lazyConn with bind implementations for direct data passing in userspace bind mode. +type BindListener struct { + wgIface WgInterface + peerCfg lazyconn.PeerConfig + done sync.WaitGroup + + lazyConn *lazyConn + bind device.EndpointManager + fakeIP netip.Addr +} + +// NewBindListener creates a listener that passes data directly through bind using LazyConn. +// It automatically derives a unique fake IP from the peer's NetBird IP in the 127.2.x.x range. +func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyconn.PeerConfig) (*BindListener, error) { + fakeIP, err := deriveFakeIP(wgIface, cfg.AllowedIPs) + if err != nil { + return nil, fmt.Errorf("derive fake IP: %w", err) + } + + d := &BindListener{ + wgIface: wgIface, + peerCfg: cfg, + bind: bind, + fakeIP: fakeIP, + } + + if err := d.setupLazyConn(); err != nil { + return nil, fmt.Errorf("setup lazy connection: %v", err) + } + + d.done.Add(1) + return d, nil +} + +// deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP. +// Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y). +// It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface. +func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) { + if len(allowedIPs) == 0 { + return netip.Addr{}, fmt.Errorf("no allowed IPs for peer") + } + + ourNetwork := wgIface.Address().Network + + var peerIP netip.Addr + for _, allowedIP := range allowedIPs { + ip := allowedIP.Addr() + if !ip.Is4() { + continue + } + if ourNetwork.Contains(ip) { + peerIP = ip + break + } + } + + if !peerIP.IsValid() { + return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs") + } + + octets := peerIP.As4() + fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]}) + return fakeIP, nil +} + +func (d *BindListener) setupLazyConn() error { + d.lazyConn = newLazyConn() + d.bind.SetEndpoint(d.fakeIP, d.lazyConn) + + endpoint := &net.UDPAddr{ + IP: d.fakeIP.AsSlice(), + Port: lazyBindPort, + } + return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, endpoint, nil) +} + +// ReadPackets blocks until activity is detected on the LazyConn or the listener is closed. +func (d *BindListener) ReadPackets() { + select { + case <-d.lazyConn.ActivityChan(): + d.peerCfg.Log.Infof("activity detected via LazyConn") + case <-d.lazyConn.ctx.Done(): + d.peerCfg.Log.Infof("exit from activity listener") + } + + d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey) + if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil { + d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err) + } + + _ = d.lazyConn.Close() + d.bind.RemoveEndpoint(d.fakeIP) + d.done.Done() +} + +// Close stops the listener and cleans up resources. +func (d *BindListener) Close() { + d.peerCfg.Log.Infof("closing activity listener (LazyConn)") + + if err := d.lazyConn.Close(); err != nil { + d.peerCfg.Log.Errorf("failed to close LazyConn: %s", err) + } + + d.done.Wait() +} diff --git a/client/internal/lazyconn/activity/listener_bind_test.go b/client/internal/lazyconn/activity/listener_bind_test.go new file mode 100644 index 000000000..f86dd3877 --- /dev/null +++ b/client/internal/lazyconn/activity/listener_bind_test.go @@ -0,0 +1,291 @@ +package activity + +import ( + "net" + "net/netip" + "runtime" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/internal/lazyconn" + peerid "github.com/netbirdio/netbird/client/internal/peer/id" +) + +func isBindListenerPlatform() bool { + return runtime.GOOS == "windows" || runtime.GOOS == "js" +} + +// mockEndpointManager implements device.EndpointManager for testing +type mockEndpointManager struct { + endpoints map[netip.Addr]net.Conn +} + +func newMockEndpointManager() *mockEndpointManager { + return &mockEndpointManager{ + endpoints: make(map[netip.Addr]net.Conn), + } +} + +func (m *mockEndpointManager) SetEndpoint(fakeIP netip.Addr, conn net.Conn) { + m.endpoints[fakeIP] = conn +} + +func (m *mockEndpointManager) RemoveEndpoint(fakeIP netip.Addr) { + delete(m.endpoints, fakeIP) +} + +func (m *mockEndpointManager) GetEndpoint(fakeIP netip.Addr) net.Conn { + return m.endpoints[fakeIP] +} + +// MockWGIfaceBind mocks WgInterface with bind support +type MockWGIfaceBind struct { + endpointMgr *mockEndpointManager +} + +func (m *MockWGIfaceBind) RemovePeer(string) error { + return nil +} + +func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error { + return nil +} + +func (m *MockWGIfaceBind) IsUserspaceBind() bool { + return true +} + +func (m *MockWGIfaceBind) Address() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + } +} + +func (m *MockWGIfaceBind) GetBind() device.EndpointManager { + return m.endpointMgr +} + +func TestBindListener_Creation(t *testing.T) { + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg) + require.NoError(t, err) + + expectedFakeIP := netip.MustParseAddr("127.2.0.2") + conn := mockEndpointMgr.GetEndpoint(expectedFakeIP) + require.NotNil(t, conn, "Endpoint should be registered in mock endpoint manager") + + _, ok := conn.(*lazyConn) + assert.True(t, ok, "Registered endpoint should be a lazyConn") + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } +} + +func TestBindListener_ActivityDetection(t *testing.T) { + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg) + require.NoError(t, err) + + activityDetected := make(chan struct{}) + go func() { + listener.ReadPackets() + close(activityDetected) + }() + + fakeIP := listener.fakeIP + conn := mockEndpointMgr.GetEndpoint(fakeIP) + require.NotNil(t, conn, "Endpoint should be registered") + + _, err = conn.Write([]byte{0x01, 0x02, 0x03}) + require.NoError(t, err) + + select { + case <-activityDetected: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity detection") + } + + assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity detection") +} + +func TestBindListener_Close(t *testing.T) { + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg) + require.NoError(t, err) + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + fakeIP := listener.fakeIP + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } + + assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after Close") +} + +func TestManager_BindMode(t *testing.T) { + if !isBindListenerPlatform() { + t.Skip("BindListener only used on Windows/JS platforms") + } + + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer := &MocPeer{PeerID: "testPeer1"} + mgr := NewManager(mockIface) + defer mgr.Close() + + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + err := mgr.MonitorPeerActivity(cfg) + require.NoError(t, err) + + listener, exists := mgr.GetPeerListener(cfg.PeerConnID) + require.True(t, exists, "Peer listener should be found") + + bindListener, ok := listener.(*BindListener) + require.True(t, ok, "Listener should be BindListener, got %T", listener) + + fakeIP := bindListener.fakeIP + conn := mockEndpointMgr.GetEndpoint(fakeIP) + require.NotNil(t, conn, "Endpoint should be registered") + + _, err = conn.Write([]byte{0x01, 0x02, 0x03}) + require.NoError(t, err) + + select { + case peerConnID := <-mgr.OnActivityChan: + assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match") + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity notification") + } + + assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity") +} + +func TestManager_BindMode_MultiplePeers(t *testing.T) { + if !isBindListenerPlatform() { + t.Skip("BindListener only used on Windows/JS platforms") + } + + mockEndpointMgr := newMockEndpointManager() + mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} + + peer1 := &MocPeer{PeerID: "testPeer1"} + peer2 := &MocPeer{PeerID: "testPeer2"} + mgr := NewManager(mockIface) + defer mgr.Close() + + cfg1 := lazyconn.PeerConfig{ + PublicKey: peer1.PeerID, + PeerConnID: peer1.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + cfg2 := lazyconn.PeerConfig{ + PublicKey: peer2.PeerID, + PeerConnID: peer2.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.3/32")}, + Log: log.WithField("peer", "testPeer2"), + } + + err := mgr.MonitorPeerActivity(cfg1) + require.NoError(t, err) + + err = mgr.MonitorPeerActivity(cfg2) + require.NoError(t, err) + + listener1, exists := mgr.GetPeerListener(cfg1.PeerConnID) + require.True(t, exists, "Peer1 listener should be found") + bindListener1 := listener1.(*BindListener) + + listener2, exists := mgr.GetPeerListener(cfg2.PeerConnID) + require.True(t, exists, "Peer2 listener should be found") + bindListener2 := listener2.(*BindListener) + + conn1 := mockEndpointMgr.GetEndpoint(bindListener1.fakeIP) + require.NotNil(t, conn1, "Peer1 endpoint should be registered") + _, err = conn1.Write([]byte{0x01}) + require.NoError(t, err) + + conn2 := mockEndpointMgr.GetEndpoint(bindListener2.fakeIP) + require.NotNil(t, conn2, "Peer2 endpoint should be registered") + _, err = conn2.Write([]byte{0x02}) + require.NoError(t, err) + + receivedPeers := make(map[peerid.ConnID]bool) + for i := 0; i < 2; i++ { + select { + case peerConnID := <-mgr.OnActivityChan: + receivedPeers[peerConnID] = true + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity notifications") + } + } + + assert.True(t, receivedPeers[cfg1.PeerConnID], "Peer1 activity should be received") + assert.True(t, receivedPeers[cfg2.PeerConnID], "Peer2 activity should be received") +} diff --git a/client/internal/lazyconn/activity/listener_test.go b/client/internal/lazyconn/activity/listener_test.go deleted file mode 100644 index 98d7838d2..000000000 --- a/client/internal/lazyconn/activity/listener_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package activity - -import ( - "testing" - "time" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/lazyconn" -) - -func TestNewListener(t *testing.T) { - peer := &MocPeer{ - PeerID: "examplePublicKey1", - } - - cfg := lazyconn.PeerConfig{ - PublicKey: peer.PeerID, - PeerConnID: peer.ConnID(), - Log: log.WithField("peer", "examplePublicKey1"), - } - - l, err := NewListener(MocWGIface{}, cfg) - if err != nil { - t.Fatalf("failed to create listener: %v", err) - } - - chanClosed := make(chan struct{}) - go func() { - defer close(chanClosed) - l.ReadPackets() - }() - - time.Sleep(1 * time.Second) - l.Close() - - select { - case <-chanClosed: - case <-time.After(time.Second): - } -} diff --git a/client/internal/lazyconn/activity/listener.go b/client/internal/lazyconn/activity/listener_udp.go similarity index 64% rename from client/internal/lazyconn/activity/listener.go rename to client/internal/lazyconn/activity/listener_udp.go index 817ff00c3..e0b09be6c 100644 --- a/client/internal/lazyconn/activity/listener.go +++ b/client/internal/lazyconn/activity/listener_udp.go @@ -11,26 +11,27 @@ import ( "github.com/netbirdio/netbird/client/internal/lazyconn" ) -// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking -type Listener struct { +// UDPListener uses UDP sockets for activity detection in kernel mode. +type UDPListener struct { wgIface WgInterface peerCfg lazyconn.PeerConfig conn *net.UDPConn endpoint *net.UDPAddr done sync.Mutex - isClosed atomic.Bool // use to avoid error log when closing the listener + isClosed atomic.Bool } -func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) { - d := &Listener{ +// NewUDPListener creates a listener that detects activity via UDP socket reads. +func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener, error) { + d := &UDPListener{ wgIface: wgIface, peerCfg: cfg, } conn, err := d.newConn() if err != nil { - return nil, fmt.Errorf("failed to creating activity listener: %v", err) + return nil, fmt.Errorf("create UDP connection: %v", err) } d.conn = conn d.endpoint = conn.LocalAddr().(*net.UDPAddr) @@ -38,12 +39,14 @@ func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error if err := d.createEndpoint(); err != nil { return nil, err } + d.done.Lock() - cfg.Log.Infof("created activity listener: %s", conn.LocalAddr().(*net.UDPAddr).String()) + cfg.Log.Infof("created activity listener: %s", d.conn.LocalAddr().(*net.UDPAddr).String()) return d, nil } -func (d *Listener) ReadPackets() { +// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed. +func (d *UDPListener) ReadPackets() { for { n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1)) if err != nil { @@ -64,15 +67,17 @@ func (d *Listener) ReadPackets() { } d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String()) - if err := d.removeEndpoint(); err != nil { + if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil { d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err) } - _ = d.conn.Close() // do not care err because some cases it will return "use of closed network connection" + // Ignore close error as it may return "use of closed network connection" if already closed. + _ = d.conn.Close() d.done.Unlock() } -func (d *Listener) Close() { +// Close stops the listener and cleans up resources. +func (d *UDPListener) Close() { d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String()) d.isClosed.Store(true) @@ -82,16 +87,12 @@ func (d *Listener) Close() { d.done.Lock() } -func (d *Listener) removeEndpoint() error { - return d.wgIface.RemovePeer(d.peerCfg.PublicKey) -} - -func (d *Listener) createEndpoint() error { +func (d *UDPListener) createEndpoint() error { d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String()) return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil) } -func (d *Listener) newConn() (*net.UDPConn, error) { +func (d *UDPListener) newConn() (*net.UDPConn, error) { addr := &net.UDPAddr{ Port: 0, IP: listenIP, diff --git a/client/internal/lazyconn/activity/listener_udp_test.go b/client/internal/lazyconn/activity/listener_udp_test.go new file mode 100644 index 000000000..d2adb9bf4 --- /dev/null +++ b/client/internal/lazyconn/activity/listener_udp_test.go @@ -0,0 +1,110 @@ +package activity + +import ( + "net" + "net/netip" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/lazyconn" +) + +func TestUDPListener_Creation(t *testing.T) { + mockIface := &MocWGIface{} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewUDPListener(mockIface, cfg) + require.NoError(t, err) + require.NotNil(t, listener.conn) + require.NotNil(t, listener.endpoint) + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } +} + +func TestUDPListener_ActivityDetection(t *testing.T) { + mockIface := &MocWGIface{} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewUDPListener(mockIface, cfg) + require.NoError(t, err) + + activityDetected := make(chan struct{}) + go func() { + listener.ReadPackets() + close(activityDetected) + }() + + conn, err := net.Dial("udp", listener.conn.LocalAddr().String()) + require.NoError(t, err) + defer conn.Close() + + _, err = conn.Write([]byte{0x01, 0x02, 0x03}) + require.NoError(t, err) + + select { + case <-activityDetected: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for activity detection") + } +} + +func TestUDPListener_Close(t *testing.T) { + mockIface := &MocWGIface{} + + peer := &MocPeer{PeerID: "testPeer1"} + cfg := lazyconn.PeerConfig{ + PublicKey: peer.PeerID, + PeerConnID: peer.ConnID(), + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, + Log: log.WithField("peer", "testPeer1"), + } + + listener, err := NewUDPListener(mockIface, cfg) + require.NoError(t, err) + + readPacketsDone := make(chan struct{}) + go func() { + listener.ReadPackets() + close(readPacketsDone) + }() + + listener.Close() + + select { + case <-readPacketsDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for ReadPackets to exit after Close") + } + + assert.True(t, listener.isClosed.Load(), "Listener should be marked as closed") +} diff --git a/client/internal/lazyconn/activity/manager.go b/client/internal/lazyconn/activity/manager.go index 915fb9cb8..db283ec9a 100644 --- a/client/internal/lazyconn/activity/manager.go +++ b/client/internal/lazyconn/activity/manager.go @@ -1,21 +1,32 @@ package activity import ( + "errors" "net" "net/netip" + "runtime" "sync" "time" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/lazyconn" peerid "github.com/netbirdio/netbird/client/internal/peer/id" ) +// listener defines the contract for activity detection listeners. +type listener interface { + ReadPackets() + Close() +} + type WgInterface interface { RemovePeer(peerKey string) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + IsUserspaceBind() bool + Address() wgaddr.Address } type Manager struct { @@ -23,7 +34,7 @@ type Manager struct { wgIface WgInterface - peers map[peerid.ConnID]*Listener + peers map[peerid.ConnID]listener done chan struct{} mu sync.Mutex @@ -33,7 +44,7 @@ func NewManager(wgIface WgInterface) *Manager { m := &Manager{ OnActivityChan: make(chan peerid.ConnID, 1), wgIface: wgIface, - peers: make(map[peerid.ConnID]*Listener), + peers: make(map[peerid.ConnID]listener), done: make(chan struct{}), } return m @@ -48,16 +59,38 @@ func (m *Manager) MonitorPeerActivity(peerCfg lazyconn.PeerConfig) error { return nil } - listener, err := NewListener(m.wgIface, peerCfg) + listener, err := m.createListener(peerCfg) if err != nil { return err } - m.peers[peerCfg.PeerConnID] = listener + m.peers[peerCfg.PeerConnID] = listener go m.waitForTraffic(listener, peerCfg.PeerConnID) return nil } +func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error) { + if !m.wgIface.IsUserspaceBind() { + return NewUDPListener(m.wgIface, peerCfg) + } + + // BindListener is only used on Windows and JS platforms: + // - JS: Cannot listen to UDP sockets + // - Windows: IP_UNICAST_IF socket option forces packets out the interface the default + // gateway points to, preventing them from reaching the loopback interface. + // BindListener bypasses this by passing data directly through the bind. + if runtime.GOOS != "windows" && runtime.GOOS != "js" { + return NewUDPListener(m.wgIface, peerCfg) + } + + provider, ok := m.wgIface.(bindProvider) + if !ok { + return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider") + } + + return NewBindListener(m.wgIface, provider.GetBind(), peerCfg) +} + func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) { m.mu.Lock() defer m.mu.Unlock() @@ -82,8 +115,8 @@ func (m *Manager) Close() { } } -func (m *Manager) waitForTraffic(listener *Listener, peerConnID peerid.ConnID) { - listener.ReadPackets() +func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) { + l.ReadPackets() m.mu.Lock() if _, ok := m.peers[peerConnID]; !ok { diff --git a/client/internal/lazyconn/activity/manager_test.go b/client/internal/lazyconn/activity/manager_test.go index ae6c31da4..0768d9219 100644 --- a/client/internal/lazyconn/activity/manager_test.go +++ b/client/internal/lazyconn/activity/manager_test.go @@ -9,6 +9,7 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/lazyconn" peerid "github.com/netbirdio/netbird/client/internal/peer/id" ) @@ -30,16 +31,26 @@ func (m MocWGIface) RemovePeer(string) error { func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error { return nil - } -// Add this method to the Manager struct -func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (*Listener, bool) { +func (m MocWGIface) IsUserspaceBind() bool { + return false +} + +func (m MocWGIface) Address() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/16"), + } +} + +// GetPeerListener is a test helper to access listeners +func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) { m.mu.Lock() defer m.mu.Unlock() - listener, exists := m.peers[peerConnID] - return listener, exists + l, exists := m.peers[peerConnID] + return l, exists } func TestManager_MonitorPeerActivity(t *testing.T) { @@ -65,7 +76,12 @@ func TestManager_MonitorPeerActivity(t *testing.T) { t.Fatalf("peer listener not found") } - if err := trigger(listener.conn.LocalAddr().String()); err != nil { + // Get the UDP listener's address for triggering + udpListener, ok := listener.(*UDPListener) + if !ok { + t.Fatalf("expected UDPListener") + } + if err := trigger(udpListener.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } @@ -97,7 +113,9 @@ func TestManager_RemovePeerActivity(t *testing.T) { t.Fatalf("failed to monitor peer activity: %v", err) } - addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String() + listener, _ := mgr.GetPeerListener(peerCfg1.PeerConnID) + udpListener, _ := listener.(*UDPListener) + addr := udpListener.conn.LocalAddr().String() mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID) @@ -147,7 +165,8 @@ func TestManager_MultiPeerActivity(t *testing.T) { t.Fatalf("peer listener for peer1 not found") } - if err := trigger(listener.conn.LocalAddr().String()); err != nil { + udpListener1, _ := listener.(*UDPListener) + if err := trigger(udpListener1.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } @@ -156,7 +175,8 @@ func TestManager_MultiPeerActivity(t *testing.T) { t.Fatalf("peer listener for peer2 not found") } - if err := trigger(listener.conn.LocalAddr().String()); err != nil { + udpListener2, _ := listener.(*UDPListener) + if err := trigger(udpListener2.conn.LocalAddr().String()); err != nil { t.Fatalf("failed to trigger activity: %v", err) } diff --git a/client/internal/lazyconn/wgiface.go b/client/internal/lazyconn/wgiface.go index 0351904f7..0626c1815 100644 --- a/client/internal/lazyconn/wgiface.go +++ b/client/internal/lazyconn/wgiface.go @@ -7,6 +7,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/monotime" ) @@ -14,5 +15,6 @@ type WGIface interface { RemovePeer(peerKey string) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error IsUserspaceBind() bool + Address() wgaddr.Address LastActivities() map[string]monotime.Time } From 3abae0bd170c95b79378ab1d59ded7fb83ca985e Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 16 Oct 2025 16:16:51 +0200 Subject: [PATCH 18/29] [client] Set default wg port for new profiles (#4651) --- client/internal/profilemanager/config.go | 1 + client/internal/profilemanager/config_test.go | 92 +++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index 4e6b422f6..f03822089 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -195,6 +195,7 @@ func createNewConfig(input ConfigInput) (*Config, error) { config := &Config{ // defaults to false only for new (post 0.26) configurations ServerSSHAllowed: util.False(), + WgPort: iface.DefaultWgPort, } if _, err := config.apply(input); err != nil { diff --git a/client/internal/profilemanager/config_test.go b/client/internal/profilemanager/config_test.go index 45e37bf0e..90bde7707 100644 --- a/client/internal/profilemanager/config_test.go +++ b/client/internal/profilemanager/config_test.go @@ -5,11 +5,14 @@ import ( "errors" "os" "path/filepath" + "runtime" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/util" ) @@ -141,6 +144,95 @@ func TestHiddenPreSharedKey(t *testing.T) { } } +func TestNewProfileDefaults(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + config, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: configPath, + }) + require.NoError(t, err, "should create new config") + + assert.Equal(t, DefaultManagementURL, config.ManagementURL.String(), "ManagementURL should have default") + assert.Equal(t, DefaultAdminURL, config.AdminURL.String(), "AdminURL should have default") + assert.NotEmpty(t, config.PrivateKey, "PrivateKey should be generated") + assert.NotEmpty(t, config.SSHKey, "SSHKey should be generated") + assert.Equal(t, iface.WgInterfaceDefault, config.WgIface, "WgIface should have default") + assert.Equal(t, iface.DefaultWgPort, config.WgPort, "WgPort should default to 51820") + assert.Equal(t, uint16(iface.DefaultMTU), config.MTU, "MTU should have default") + assert.Equal(t, dynamic.DefaultInterval, config.DNSRouteInterval, "DNSRouteInterval should have default") + assert.NotNil(t, config.ServerSSHAllowed, "ServerSSHAllowed should be set") + assert.NotNil(t, config.DisableNotifications, "DisableNotifications should be set") + assert.NotEmpty(t, config.IFaceBlackList, "IFaceBlackList should have defaults") + + if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { + assert.NotNil(t, config.NetworkMonitor, "NetworkMonitor should be set on Windows/macOS") + assert.True(t, *config.NetworkMonitor, "NetworkMonitor should be enabled by default on Windows/macOS") + } +} + +func TestWireguardPortZeroExplicit(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + // Create a new profile with explicit port 0 (random port) + explicitZero := 0 + config, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: configPath, + WireguardPort: &explicitZero, + }) + require.NoError(t, err, "should create config with explicit port 0") + + assert.Equal(t, 0, config.WgPort, "WgPort should be 0 when explicitly set by user") + + // Verify it persists + readConfig, err := GetConfig(configPath) + require.NoError(t, err) + assert.Equal(t, 0, readConfig.WgPort, "WgPort should remain 0 after reading from file") +} + +func TestWireguardPortDefaultVsExplicit(t *testing.T) { + tests := []struct { + name string + wireguardPort *int + expectedPort int + description string + }{ + { + name: "no port specified uses default", + wireguardPort: nil, + expectedPort: iface.DefaultWgPort, + description: "When user doesn't specify port, default to 51820", + }, + { + name: "explicit zero for random port", + wireguardPort: func() *int { v := 0; return &v }(), + expectedPort: 0, + description: "When user explicitly sets 0, use 0 for random port", + }, + { + name: "explicit custom port", + wireguardPort: func() *int { v := 52000; return &v }(), + expectedPort: 52000, + description: "When user sets custom port, use that port", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.json") + + config, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: configPath, + WireguardPort: tt.wireguardPort, + }) + require.NoError(t, err, tt.description) + assert.Equal(t, tt.expectedPort, config.WgPort, tt.description) + }) + } +} + func TestUpdateOldManagementURL(t *testing.T) { tests := []struct { name string From af95aabb0375f51872a4ff69dd12f412d716955e Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 16 Oct 2025 17:15:39 +0200 Subject: [PATCH 19/29] Handle the case when the service has already been down and the status recorder is not available (#4652) --- client/internal/debug/wgshow.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/client/internal/debug/wgshow.go b/client/internal/debug/wgshow.go index e4b4c2368..8233ca510 100644 --- a/client/internal/debug/wgshow.go +++ b/client/internal/debug/wgshow.go @@ -14,6 +14,9 @@ type WGIface interface { } func (g *BundleGenerator) addWgShow() error { + if g.statusRecorder == nil { + return fmt.Errorf("no status recorder available for wg show") + } result, err := g.statusRecorder.PeersStatus() if err != nil { return err From 3cdb10cde765621086d15e97da150fee884356f6 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 17 Oct 2025 11:09:39 +0200 Subject: [PATCH 20/29] [client] Remove rule squashing (#4653) --- client/firewall/iptables/acl_linux.go | 1 - client/internal/acl/manager.go | 157 +-------- client/internal/acl/manager_test.go | 486 -------------------------- 3 files changed, 3 insertions(+), 641 deletions(-) diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index ed8a7403b..d78372c9e 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -400,7 +400,6 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi return "" } - // Include action in the ipset name to prevent squashing rules with different actions actionSuffix := "" if action == firewall.ActionDrop { actionSuffix = "-drop" diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index 5ca950297..965decc73 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -29,11 +29,6 @@ type Manager interface { ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) } -type protoMatch struct { - ips map[string]int - policyID []byte -} - // DefaultManager uses firewall manager to handle type DefaultManager struct { firewall firewall.Manager @@ -86,21 +81,14 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout } func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { - rules, squashedProtocols := d.squashAcceptRules(networkMap) + rules := networkMap.FirewallRules enableSSH := networkMap.PeerConfig != nil && networkMap.PeerConfig.SshConfig != nil && networkMap.PeerConfig.SshConfig.SshEnabled - if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { - enableSSH = enableSSH && !ok - } - if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok { - enableSSH = enableSSH && !ok - } - // if TCP protocol rules not squashed and SSH enabled - // we add default firewall rule which accepts connection to any peer - // in the network by SSH (TCP 22 port). + // If SSH enabled, add default firewall rule which accepts connection to any peer + // in the network by SSH (TCP port defined by ssh.DefaultSSHPort). if enableSSH { rules = append(rules, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", @@ -368,145 +356,6 @@ func (d *DefaultManager) getPeerRuleID( return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr)))) } -// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type -// to all peers in the network map to one rule which just accepts that type of the traffic. -// -// NOTE: It will not squash two rules for same protocol if one covers all peers in the network, -// but other has port definitions or has drop policy. -func (d *DefaultManager) squashAcceptRules( - networkMap *mgmProto.NetworkMap, -) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) { - totalIPs := 0 - for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) { - for range p.AllowedIps { - totalIPs++ - } - } - - in := map[mgmProto.RuleProtocol]*protoMatch{} - out := map[mgmProto.RuleProtocol]*protoMatch{} - - // trace which type of protocols was squashed - squashedRules := []*mgmProto.FirewallRule{} - squashedProtocols := map[mgmProto.RuleProtocol]struct{}{} - - // this function we use to do calculation, can we squash the rules by protocol or not. - // We summ amount of Peers IP for given protocol we found in original rules list. - // But we zeroed the IP's for protocol if: - // 1. Any of the rule has DROP action type. - // 2. Any of rule contains Port. - // - // We zeroed this to notify squash function that this protocol can't be squashed. - addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) { - hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP || - r.Port != "" || !portInfoEmpty(r.PortInfo) - - if hasPortRestrictions { - // Don't squash rules with port restrictions - protocols[r.Protocol] = &protoMatch{ips: map[string]int{}} - return - } - - if _, ok := protocols[r.Protocol]; !ok { - protocols[r.Protocol] = &protoMatch{ - ips: map[string]int{}, - // store the first encountered PolicyID for this protocol - policyID: r.PolicyID, - } - } - - // special case, when we receive this all network IP address - // it means that rules for that protocol was already optimized on the - // management side - if r.PeerIP == "0.0.0.0" { - squashedRules = append(squashedRules, r) - squashedProtocols[r.Protocol] = struct{}{} - return - } - - ipset := protocols[r.Protocol].ips - - if _, ok := ipset[r.PeerIP]; ok { - return - } - ipset[r.PeerIP] = i - } - - for i, r := range networkMap.FirewallRules { - // calculate squash for different directions - if r.Direction == mgmProto.RuleDirection_IN { - addRuleToCalculationMap(i, r, in) - } else { - addRuleToCalculationMap(i, r, out) - } - } - - // order of squashing by protocol is important - // only for their first element ALL, it must be done first - protocolOrders := []mgmProto.RuleProtocol{ - mgmProto.RuleProtocol_ALL, - mgmProto.RuleProtocol_ICMP, - mgmProto.RuleProtocol_TCP, - mgmProto.RuleProtocol_UDP, - } - - squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) { - for _, protocol := range protocolOrders { - match, ok := matches[protocol] - if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 { - // don't squash if : - // 1. Rules not cover all peers in the network - // 2. Rules cover only one peer in the network. - continue - } - - // add special rule 0.0.0.0 which allows all IP's in our firewall implementations - squashedRules = append(squashedRules, &mgmProto.FirewallRule{ - PeerIP: "0.0.0.0", - Direction: direction, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: protocol, - PolicyID: match.policyID, - }) - squashedProtocols[protocol] = struct{}{} - - if protocol == mgmProto.RuleProtocol_ALL { - // if we have ALL traffic type squashed rule - // it allows all other type of traffic, so we can stop processing - break - } - } - } - - squash(in, mgmProto.RuleDirection_IN) - squash(out, mgmProto.RuleDirection_OUT) - - // if all protocol was squashed everything is allow and we can ignore all other rules - if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { - return squashedRules, squashedProtocols - } - - if len(squashedRules) == 0 { - return networkMap.FirewallRules, squashedProtocols - } - - var rules []*mgmProto.FirewallRule - // filter out rules which was squashed from final list - // if we also have other not squashed rules. - for i, r := range networkMap.FirewallRules { - if _, ok := squashedProtocols[r.Protocol]; ok { - if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i { - continue - } else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i { - continue - } - } - rules = append(rules, r) - } - - return append(rules, squashedRules...), squashedProtocols -} - // getRuleGroupingSelector takes all rule properties except IP address to build selector func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string { return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo) diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 664476ef4..daf4979ce 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -188,492 +188,6 @@ func TestDefaultManagerStateless(t *testing.T) { }) } -func TestDefaultManagerSquashRules(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - {AllowedIps: []string{"10.93.0.4"}}, - }, - FirewallRules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - }, - } - - manager := &DefaultManager{} - rules, _ := manager.squashAcceptRules(networkMap) - assert.Equal(t, 2, len(rules)) - - r := rules[0] - assert.Equal(t, "0.0.0.0", r.PeerIP) - assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction) - assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol) - assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action) - - r = rules[1] - assert.Equal(t, "0.0.0.0", r.PeerIP) - assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction) - assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol) - assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action) -} - -func TestDefaultManagerSquashRulesNoAffect(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - {AllowedIps: []string{"10.93.0.4"}}, - }, - FirewallRules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_ALL, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_OUT, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - }, - } - - manager := &DefaultManager{} - rules, _ := manager.squashAcceptRules(networkMap) - assert.Equal(t, len(networkMap.FirewallRules), len(rules)) -} - -func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) { - tests := []struct { - name string - rules []*mgmProto.FirewallRule - expectedCount int - description string - }{ - { - name: "should not squash rules with port ranges", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Range_{ - Range: &mgmProto.PortInfo_Range{ - Start: 8080, - End: 8090, - }, - }, - }, - }, - }, - expectedCount: 4, - description: "Rules with port ranges should not be squashed even if they cover all peers", - }, - { - name: "should not squash rules with specific ports", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - }, - expectedCount: 4, - description: "Rules with specific ports should not be squashed even if they cover all peers", - }, - { - name: "should not squash rules with legacy port field", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - }, - expectedCount: 4, - description: "Rules with legacy port field should not be squashed", - }, - { - name: "should not squash rules with DROP action", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_DROP, - Protocol: mgmProto.RuleProtocol_TCP, - }, - }, - expectedCount: 4, - description: "Rules with DROP action should not be squashed", - }, - { - name: "should squash rules without port restrictions", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - }, - expectedCount: 1, - description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule", - }, - { - name: "mixed rules should not squash protocol with port restrictions", - rules: []*mgmProto.FirewallRule{ - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - PortInfo: &mgmProto.PortInfo{ - PortSelection: &mgmProto.PortInfo_Port{ - Port: 80, - }, - }, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - }, - }, - expectedCount: 4, - description: "TCP should not be squashed because one rule has port restrictions", - }, - { - name: "should squash UDP but not TCP when TCP has port restrictions", - rules: []*mgmProto.FirewallRule{ - // TCP rules with port restrictions - should NOT be squashed - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_TCP, - Port: "443", - }, - // UDP rules without port restrictions - SHOULD be squashed - { - PeerIP: "10.93.0.1", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - { - PeerIP: "10.93.0.2", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - { - PeerIP: "10.93.0.3", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - { - PeerIP: "10.93.0.4", - Direction: mgmProto.RuleDirection_IN, - Action: mgmProto.RuleAction_ACCEPT, - Protocol: mgmProto.RuleProtocol_UDP, - }, - }, - expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0) - description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - networkMap := &mgmProto.NetworkMap{ - RemotePeers: []*mgmProto.RemotePeerConfig{ - {AllowedIps: []string{"10.93.0.1"}}, - {AllowedIps: []string{"10.93.0.2"}}, - {AllowedIps: []string{"10.93.0.3"}}, - {AllowedIps: []string{"10.93.0.4"}}, - }, - FirewallRules: tt.rules, - } - - manager := &DefaultManager{} - rules, _ := manager.squashAcceptRules(networkMap) - - assert.Equal(t, tt.expectedCount, len(rules), tt.description) - - // For squashed rules, verify we get the expected 0.0.0.0 rule - if tt.expectedCount == 1 { - assert.Equal(t, "0.0.0.0", rules[0].PeerIP) - assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction) - assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action) - } - }) - } -} - func TestPortInfoEmpty(t *testing.T) { tests := []struct { name string From 429d7d65855bec44481e572c83e1429ad0c50ebd Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 17 Oct 2025 11:10:16 +0200 Subject: [PATCH 21/29] [client] Support BROWSER env for login (#4654) --- client/cmd/login.go | 11 ++++++++++- client/ui/client_ui.go | 7 +++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/client/cmd/login.go b/client/cmd/login.go index 3ac211805..40b55f858 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "os/exec" "os/user" "runtime" "strings" @@ -356,13 +357,21 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro cmd.Println("") if !noBrowser { - if err := open.Run(verificationURIComplete); err != nil { + if err := openBrowser(verificationURIComplete); err != nil { cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" + "https://docs.netbird.io/how-to/register-machines-using-setup-keys") } } } +// openBrowser opens the URL in a browser, respecting the BROWSER environment variable. +func openBrowser(url string) error { + if browser := os.Getenv("BROWSER"); browser != "" { + return exec.Command(browser, url).Start() + } + return open.Run(url) +} + // isUnixRunningDesktop checks if a Linux OS is running desktop environment func isUnixRunningDesktop() bool { if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 7c2000a9d..0043f228e 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -31,7 +31,6 @@ import ( "fyne.io/systray" "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" - "github.com/skratchdot/open-golang/open" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -633,7 +632,7 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { } func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error { - err := open.Run(loginResp.VerificationURIComplete) + err := openURL(loginResp.VerificationURIComplete) if err != nil { log.Errorf("opening the verification uri in the browser failed: %v", err) return err @@ -1487,6 +1486,10 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { } func openURL(url string) error { + if browser := os.Getenv("BROWSER"); browser != "" { + return exec.Command(browser, url).Start() + } + var err error switch runtime.GOOS { case "windows": From f5301230bfef5d8b3abaf60634646793fa7b63ac Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 17 Oct 2025 13:31:15 +0200 Subject: [PATCH 22/29] [client] Fix status showing P2P without connection (#4661) --- client/status/status.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/client/status/status.go b/client/status/status.go index db5b7dc0b..5e4fcd8dc 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -205,15 +205,18 @@ func mapPeers( localICEEndpoint := "" remoteICEEndpoint := "" relayServerAddress := "" - connType := "P2P" + connType := "-" lastHandshake := time.Time{} transferReceived := int64(0) transferSent := int64(0) isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() - if pbPeerState.Relayed { - connType = "Relayed" + if isPeerConnected { + connType = "P2P" + if pbPeerState.Relayed { + connType = "Relayed" + } } if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) { From 0f9bfeff7cf36c81086c8dc91136c9194dbaf819 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 17 Oct 2025 14:47:11 -0300 Subject: [PATCH 23/29] [client] Security upgrade alpine from 3.22.0 to 3.22.2 #4618 --- client/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/Dockerfile b/client/Dockerfile index b2f627409..5cd459357 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -4,7 +4,7 @@ # sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . # sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest -FROM alpine:3.22.0 +FROM alpine:3.22.2 # iproute2: busybox doesn't display ip rules properly RUN apk add --no-cache \ bash \ From cd9a867ad02034570daa415fce6fe1dd47c1ccb8 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Fri, 17 Oct 2025 19:48:26 +0200 Subject: [PATCH 24/29] [client] Delete TURNConfig section from script (#4639) --- infrastructure_files/getting-started-with-zitadel.sh | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index bc326cd7e..09c5225ad 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -682,17 +682,6 @@ renderManagementJson() { "URI": "stun:$NETBIRD_DOMAIN:3478" } ], - "TURNConfig": { - "Turns": [ - { - "Proto": "udp", - "URI": "turn:$NETBIRD_DOMAIN:3478", - "Username": "$TURN_USER", - "Password": "$TURN_PASSWORD" - } - ], - "TimeBasedCredentials": false - }, "Relay": { "Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"], "CredentialsTTL": "24h", From 2fe2af38d29084f9bf226a83c8aecf7440c633bd Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 21 Oct 2025 18:14:39 +0200 Subject: [PATCH 25/29] [client] Clean up match domain reg entries between config changes (#4676) --- client/internal/dns/host_windows.go | 12 ++- client/internal/dns/host_windows_test.go | 102 ++++++++++++++++++ .../internal/winregistry/volatile_windows.go | 59 ++++++++++ 3 files changed, 168 insertions(+), 5 deletions(-) create mode 100644 client/internal/dns/host_windows_test.go create mode 100644 client/internal/winregistry/volatile_windows.go diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index a14a01f40..74111d335 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -17,6 +17,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/internal/winregistry" ) var ( @@ -197,6 +198,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, ".")) } + if err := r.removeDNSMatchPolicies(); err != nil { + log.Errorf("cleanup old dns match policies: %s", err) + } + if len(matchDomains) != 0 { count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP) if err != nil { @@ -204,9 +209,6 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager } r.nrptEntryCount = count } else { - if err := r.removeDNSMatchPolicies(); err != nil { - return fmt.Errorf("remove dns match policies: %w", err) - } r.nrptEntryCount = 0 } @@ -273,9 +275,9 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s return fmt.Errorf("remove existing dns policy: %w", err) } - regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE) + regKey, _, err := winregistry.CreateVolatileKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE) if err != nil { - return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err) + return fmt.Errorf("create volatile registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err) } defer closer(regKey) diff --git a/client/internal/dns/host_windows_test.go b/client/internal/dns/host_windows_test.go new file mode 100644 index 000000000..19496bf5a --- /dev/null +++ b/client/internal/dns/host_windows_test.go @@ -0,0 +1,102 @@ +package dns + +import ( + "fmt" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/windows/registry" +) + +// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up +// when the number of match domains decreases between configuration changes. +func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) { + if testing.Short() { + t.Skip("skipping registry integration test in short mode") + } + + defer cleanupRegistryKeys(t) + cleanupRegistryKeys(t) + + testIP := netip.MustParseAddr("100.64.0.1") + + // Create a test interface registry key so updateSearchDomains doesn't fail + testGUID := "{12345678-1234-1234-1234-123456789ABC}" + interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID + testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE) + require.NoError(t, err, "Should create test interface registry key") + testKey.Close() + defer func() { + _ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath) + }() + + cfg := ®istryConfigurator{ + guid: testGUID, + gpo: false, + } + + config5 := HostDNSConfig{ + ServerIP: testIP, + Domains: []DomainConfig{ + {Domain: "domain1.com", MatchOnly: true}, + {Domain: "domain2.com", MatchOnly: true}, + {Domain: "domain3.com", MatchOnly: true}, + {Domain: "domain4.com", MatchOnly: true}, + {Domain: "domain5.com", MatchOnly: true}, + }, + } + + err = cfg.applyDNSConfig(config5, nil) + require.NoError(t, err) + + // Verify all 5 entries exist + for i := 0; i < 5; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.True(t, exists, "Entry %d should exist after first config", i) + } + + config2 := HostDNSConfig{ + ServerIP: testIP, + Domains: []DomainConfig{ + {Domain: "domain1.com", MatchOnly: true}, + {Domain: "domain2.com", MatchOnly: true}, + }, + } + + err = cfg.applyDNSConfig(config2, nil) + require.NoError(t, err) + + // Verify first 2 entries exist + for i := 0; i < 2; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.True(t, exists, "Entry %d should exist after second config", i) + } + + // Verify entries 2-4 are cleaned up + for i := 2; i < 5; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i) + } +} + +func registryKeyExists(path string) (bool, error) { + k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE) + if err != nil { + if err == registry.ErrNotExist { + return false, nil + } + return false, err + } + k.Close() + return true, nil +} + +func cleanupRegistryKeys(*testing.T) { + cfg := ®istryConfigurator{nrptEntryCount: 10} + _ = cfg.removeDNSMatchPolicies() +} diff --git a/client/internal/winregistry/volatile_windows.go b/client/internal/winregistry/volatile_windows.go new file mode 100644 index 000000000..a8e350fe7 --- /dev/null +++ b/client/internal/winregistry/volatile_windows.go @@ -0,0 +1,59 @@ +package winregistry + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows/registry" +) + +var ( + advapi = syscall.NewLazyDLL("advapi32.dll") + regCreateKeyExW = advapi.NewProc("RegCreateKeyExW") +) + +const ( + // Registry key options + regOptionNonVolatile = 0x0 // Key is preserved when system is rebooted + regOptionVolatile = 0x1 // Key is not preserved when system is rebooted + + // Registry disposition values + regCreatedNewKey = 0x1 + regOpenedExistingKey = 0x2 +) + +// CreateVolatileKey creates a volatile registry key named path under open key root. +// CreateVolatileKey returns the new key and a boolean flag that reports whether the key already existed. +// The access parameter specifies the access rights for the key to be created. +// +// Volatile keys are stored in memory and are automatically deleted when the system is shut down. +// This provides automatic cleanup without requiring manual registry maintenance. +func CreateVolatileKey(root registry.Key, path string, access uint32) (registry.Key, bool, error) { + pathPtr, err := syscall.UTF16PtrFromString(path) + if err != nil { + return 0, false, err + } + + var ( + handle syscall.Handle + disposition uint32 + ) + + ret, _, _ := regCreateKeyExW.Call( + uintptr(root), + uintptr(unsafe.Pointer(pathPtr)), + 0, // reserved + 0, // class + uintptr(regOptionVolatile), // options - volatile key + uintptr(access), // desired access + 0, // security attributes + uintptr(unsafe.Pointer(&handle)), + uintptr(unsafe.Pointer(&disposition)), + ) + + if ret != 0 { + return 0, false, syscall.Errno(ret) + } + + return registry.Key(handle), disposition == regOpenedExistingKey, nil +} From 96f71ff1e15b50eff87a7052432537dad2718b20 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 21 Oct 2025 19:23:11 +0200 Subject: [PATCH 26/29] [misc] Update tag name extraction in install.sh (#4677) --- release_files/install.sh | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/release_files/install.sh b/release_files/install.sh index 5d5349ec4..6a2c5f458 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -29,6 +29,8 @@ if [ -z ${NETBIRD_RELEASE+x} ]; then NETBIRD_RELEASE=latest fi +TAG_NAME="" + get_release() { local RELEASE=$1 if [ "$RELEASE" = "latest" ]; then @@ -38,17 +40,19 @@ get_release() { local TAG="tags/${RELEASE}" local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" fi + OUTPUT="" if [ -n "$GITHUB_TOKEN" ]; then - curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \ - | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' + OUTPUT=$(curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}") else - curl -s "${URL}" \ - | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' + OUTPUT=$(curl -s "${URL}") fi + TAG_NAME=$(echo ${OUTPUT} | grep -Eo '\"tag_name\":\s*\"v([0-9]+\.){2}[0-9]+"' | tail -n 1) + echo "${TAG_NAME}" | grep -oE 'v[0-9]+\.[0-9]+\.[0-9]+' } download_release_binary() { VERSION=$(get_release "$NETBIRD_RELEASE") + echo "Using the following tag name for binary installation: ${TAG_NAME}" BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download" BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz" From d80d47a469ab7d1d740d28ec836d65dc7776beec Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 22 Oct 2025 12:46:22 +0300 Subject: [PATCH 27/29] [management] Add peer disapproval reason (#4468) --- go.mod | 2 +- go.sum | 4 +- management/server/account/manager.go | 2 +- .../http/handlers/peers/peers_handler.go | 38 ++++++++++++------- management/server/integrated_validator.go | 24 +++++++++--- .../integrated_validator/interface.go | 1 + management/server/mock_server/account_mock.go | 6 +-- shared/management/http/api/openapi.yml | 3 ++ shared/management/http/api/types.gen.go | 6 +++ 9 files changed, 61 insertions(+), 25 deletions(-) diff --git a/go.mod b/go.mod index 31b45e881..79dd92e6b 100644 --- a/go.mod +++ b/go.mod @@ -62,7 +62,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f + github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 6b0b298a7..f0065e081 100644 --- a/go.sum +++ b/go.sum @@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396 h1:aXHS63QWf0Z5fDN19Swl6npdJjGMyXthAvvgW7rbKJQ= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/management/server/account/manager.go b/management/server/account/manager.go index a1ed9498b..fe9fb25c6 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -109,7 +109,7 @@ type Manager interface { GetIdpManager() idp.Manager UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) - GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) + GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 4b33495de..df89c616c 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -78,7 +78,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) - validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to list approved peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) @@ -86,7 +86,9 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, } _, valid := validPeers[peer.ID] - util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid)) + reason := invalidPeers[peer.ID] + + util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid, reason)) } func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { @@ -147,16 +149,17 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0) - validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { - log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(ctx).Errorf("failed to get validated peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) return } _, valid := validPeers[peer.ID] + reason := invalidPeers[peer.ID] - util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid)) + util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason)) } func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { @@ -240,22 +243,25 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0)) } - validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { - log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } - h.setApprovalRequiredFlag(respBody, validPeersMap) + h.setApprovalRequiredFlag(respBody, validPeersMap, invalidPeersMap) util.WriteJSONObject(r.Context(), w, respBody) } -func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { +func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersMap map[string]struct{}, invalidPeersMap map[string]string) { for _, peer := range respBody { - _, ok := approvedPeersMap[peer.Id] + _, ok := validPeersMap[peer.Id] if !ok { peer.ApprovalRequired = true + + reason := invalidPeersMap[peer.Id] + peer.DisapprovalReason = &reason } } } @@ -304,7 +310,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { } } - validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) @@ -430,13 +436,13 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer { +func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool, reason string) *api.Peer { osVersion := peer.Meta.OSVersion if osVersion == "" { osVersion = peer.Meta.Core } - return &api.Peer{ + apiPeer := &api.Peer{ CreatedAt: peer.CreatedAt, Id: peer.ID, Name: peer.Name, @@ -465,6 +471,12 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD InactivityExpirationEnabled: peer.InactivityExpirationEnabled, Ephemeral: peer.Ephemeral, } + + if !approved { + apiPeer.DisapprovalReason = &reason + } + + return apiPeer } func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch { diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 251c04273..e9a1c8701 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -88,7 +88,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID return true, nil } -func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { +func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { var err error var groups []*types.Group var peers []*nbpeer.Peer @@ -96,20 +96,30 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, err + return nil, nil, err } peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") if err != nil { - return nil, err + return nil, nil, err } settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, err + return nil, nil, err } - return am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra) + validPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra) + if err != nil { + return nil, nil, err + } + + invalidPeers, err := am.integratedPeerValidator.GetInvalidPeers(ctx, accountID, settings.Extra) + if err != nil { + return nil, nil, err + } + + return validPeers, invalidPeers, nil } type MockIntegratedValidator struct { @@ -136,6 +146,10 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID return validatedPeers, nil } +func (a MockIntegratedValidator) GetInvalidPeers(_ context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) { + return make(map[string]string), nil +} + func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer { return peer } diff --git a/management/server/integrations/integrated_validator/interface.go b/management/server/integrations/integrated_validator/interface.go index be05c2527..26c338cb6 100644 --- a/management/server/integrations/integrated_validator/interface.go +++ b/management/server/integrations/integrated_validator/interface.go @@ -15,6 +15,7 @@ type IntegratedValidator interface { PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) + GetInvalidPeers(ctx context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error SetPeerInvalidationListener(fn func(accountID string, peerIDs []string)) Stop(ctx context.Context) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index d160e7269..e87043f26 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -189,17 +189,17 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st panic("implement me") } -func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { +func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { account, err := am.GetAccountFunc(ctx, accountID) if err != nil { - return nil, err + return nil, nil, err } approvedPeers := make(map[string]struct{}) for id := range account.Peers { approvedPeers[id] = struct{}{} } - return approvedPeers, nil + return approvedPeers, nil, nil } // GetGroup mock implementation of GetGroup from server.AccountManager interface diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 93578b1ae..4a5454002 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -463,6 +463,9 @@ components: description: (Cloud only) Indicates whether peer needs approval type: boolean example: true + disapproval_reason: + description: (Cloud only) Reason why the peer requires approval + type: string country_code: $ref: '#/components/schemas/CountryCode' city_name: diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 3dbb32ef6..9611d26d6 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1037,6 +1037,9 @@ type Peer struct { // CreatedAt Peer creation date (UTC) CreatedAt time.Time `json:"created_at"` + // DisapprovalReason (Cloud only) Reason why the peer requires approval + DisapprovalReason *string `json:"disapproval_reason,omitempty"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` @@ -1124,6 +1127,9 @@ type PeerBatch struct { // CreatedAt Peer creation date (UTC) CreatedAt time.Time `json:"created_at"` + // DisapprovalReason (Cloud only) Reason why the peer requires approval + DisapprovalReason *string `json:"disapproval_reason,omitempty"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` From 6654e2dbf7c3666a432d7ff457eb3cc5dfdc59aa Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 23 Oct 2025 17:07:52 +0200 Subject: [PATCH 28/29] [client] Fix active profile name in debug bundle (#4689) --- client/cmd/debug.go | 8 +++++++- client/internal/debug/debug.go | 4 +++- client/ui/debug.go | 19 ++++++++++++++++--- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/client/cmd/debug.go b/client/cmd/debug.go index 18f3547ca..d53c5f06b 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -307,8 +307,14 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string { if err != nil { cmd.PrintErrf("Failed to get status: %v\n", err) } else { + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + statusOutputString = nbstatus.ParseToFullDetailSummary( - nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""), + nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName), ) } return statusOutputString diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index ec920c5f3..442f54e71 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -47,7 +47,7 @@ nftables.txt: Anonymized nftables rules with packet counters, if --system-info f resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder. config.txt: Anonymized configuration information of the NetBird client. network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules. -state.json: Anonymized client state dump containing netbird states. +state.json: Anonymized client state dump containing netbird states for the active profile. mutex.prof: Mutex profiling information. goroutine.prof: Goroutine profiling information. block.prof: Block profiling information. @@ -564,6 +564,8 @@ func (g *BundleGenerator) addStateFile() error { return nil } + log.Debugf("Adding state file from: %s", path) + data, err := os.ReadFile(path) if err != nil { if errors.Is(err, fs.ErrNotExist) { diff --git a/client/ui/debug.go b/client/ui/debug.go index 76afc7753..bf9839dda 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -18,6 +18,7 @@ import ( "github.com/skratchdot/open-golang/open" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/proto" nbstatus "github.com/netbirdio/netbird/client/status" uptypes "github.com/netbirdio/netbird/upload-server/types" @@ -426,6 +427,12 @@ func (s *serviceClient) collectDebugData( return "", err } + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) if err != nil { log.Warnf("Failed to get post-up status: %v", err) @@ -433,7 +440,7 @@ func (s *serviceClient) collectDebugData( var postUpStatusOutput string if postUpStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName) postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) @@ -450,7 +457,7 @@ func (s *serviceClient) collectDebugData( var preDownStatusOutput string if preDownStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName) preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", @@ -574,6 +581,12 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa return nil, fmt.Errorf("get client: %v", err) } + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) if err != nil { log.Warnf("failed to get status for debug bundle: %v", err) @@ -581,7 +594,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa var statusOutput string if statusResp != nil { - overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName) statusOutput = nbstatus.ParseToFullDetailSummary(overview) } From 709e24eb6f7b4beff4da86617bbcf518434b8764 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Fri, 24 Oct 2025 15:40:20 +0300 Subject: [PATCH 29/29] [signal] Fix HTTP/WebSocket proxy not using custom certificates (#4644) This pull request fixes a bug where the HTTP/WebSocket proxy server was not using custom TLS certificates when provided via --cert-file and --cert-key flags. Previously, only the gRPC server had TLS enabled with custom certificates, while the HTTP/WebSocket proxy ran without TLS. --- infrastructure_files/configure.sh | 5 ++++- infrastructure_files/docker-compose.yml.tmpl | 8 +++++++ signal/cmd/run.go | 23 ++++++++++---------- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh index e3fcbfdde..92252d0b3 100755 --- a/infrastructure_files/configure.sh +++ b/infrastructure_files/configure.sh @@ -185,12 +185,15 @@ if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME" echo "" - export NETBIRD_SIGNAL_PROTOCOL="https" unset NETBIRD_LETSENCRYPT_DOMAIN unset NETBIRD_MGMT_API_CERT_FILE unset NETBIRD_MGMT_API_CERT_KEY_FILE fi +if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then + export NETBIRD_SIGNAL_PROTOCOL="https" +fi + # Check if management identity provider is set if [ -n "$NETBIRD_MGMT_IDP" ]; then EXTRA_CONFIG={} diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index b24e853b4..2bc49d3e5 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -40,13 +40,21 @@ services: signal: <<: *default image: netbirdio/signal:$NETBIRD_SIGNAL_TAG + depends_on: + - dashboard volumes: - $SIGNAL_VOLUMENAME:/var/lib/netbird + - $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro ports: - $NETBIRD_SIGNAL_PORT:80 # # port and command for Let's Encrypt validation # - 443:443 # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] + command: [ + "--cert-file", "$NETBIRD_MGMT_API_CERT_FILE", + "--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE", + "--log-file", "console" + ] # Relay relay: diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 96873dee7..bf8f8e327 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -94,7 +94,7 @@ var ( startPprof() - opts, certManager, err := getTLSConfigurations() + opts, certManager, tlsConfig, err := getTLSConfigurations() if err != nil { return err } @@ -132,7 +132,7 @@ var ( // Start the main server - always serve HTTP with WebSocket proxy support // If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager - if certManager == nil { + if tlsConfig == nil { // Without TLS, serve plain HTTP httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort)) if err != nil { @@ -140,9 +140,10 @@ var ( } log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String()) serveHTTP(httpListener, grpcRootHandler) - } else if signalPort != 443 { - // With TLS but not on port 443, serve HTTPS - httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), certManager.TLSConfig()) + } else if certManager == nil || signalPort != 443 { + // Serve HTTPS if not already handled by startServerWithCertManager + // (custom certificates or Let's Encrypt with custom port) + httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), tlsConfig) if err != nil { return err } @@ -202,7 +203,7 @@ func startPprof() { }() } -func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { +func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, *tls.Config, error) { var ( err error certManager *autocert.Manager @@ -211,33 +212,33 @@ func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" { log.Infof("running without TLS") - return nil, nil, nil + return nil, nil, nil, nil } if signalLetsencryptDomain != "" { certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain) if err != nil { - return nil, certManager, err + return nil, certManager, nil, err } tlsConfig = certManager.TLSConfig() log.Infof("setting up TLS with LetsEncrypt.") } else { if signalCertFile == "" || signalCertKey == "" { log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt") - return nil, certManager, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") + return nil, certManager, nil, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") } tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey) if err != nil { log.Errorf("cannot load TLS credentials: %v", err) - return nil, certManager, err + return nil, certManager, nil, err } log.Infof("setting up TLS with custom certificates.") } transportCredentials := credentials.NewTLS(tlsConfig) - return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, err + return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, tlsConfig, err } func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) {