From 69d87343d2bbde0d30deaed7b5f2970827e93c5e Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 8 Sep 2025 14:51:34 +0200 Subject: [PATCH 01/30] [client] Debug information for connection (#4439) Improve logging Print the exact time when the first WireGuard handshake occurs Print the steps for gathering system information --- client/iface/configurer/usp.go | 7 +++++++ client/internal/engine.go | 3 +-- client/internal/peer/wg_watcher.go | 13 ++++++++++--- client/system/info.go | 4 ++++ client/system/info_windows.go | 1 - 5 files changed, 22 insertions(+), 6 deletions(-) diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index 171458e38..945f1a162 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -394,6 +394,13 @@ func toLastHandshake(stringVar string) (time.Time, error) { if err != nil { return time.Time{}, fmt.Errorf("parse handshake sec: %w", err) } + + // If sec is 0 (Unix epoch), return zero time instead + // This indicates no handshake has occurred + if sec == 0 { + return time.Time{}, nil + } + return time.Unix(sec, 0), nil } diff --git a/client/internal/engine.go b/client/internal/engine.go index ca01bfd14..61ff41600 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -949,7 +949,6 @@ func (e *Engine) receiveManagementEvents() { e.config.LazyConnectionEnabled, ) - // err = e.mgmClient.Sync(info, e.handleSync) err = e.mgmClient.Sync(e.ctx, info, e.handleSync) if err != nil { // happens if management is unavailable for a long time. @@ -960,7 +959,7 @@ func (e *Engine) receiveManagementEvents() { } log.Debugf("stopped receiving updates from Management Service") }() - log.Debugf("connecting to Management Service updates stream") + log.Infof("connecting to Management Service updates stream") } func (e *Engine) updateSTUNs(stuns []*mgmProto.HostConfig) error { diff --git a/client/internal/peer/wg_watcher.go b/client/internal/peer/wg_watcher.go index 218872c15..0ed200fda 100644 --- a/client/internal/peer/wg_watcher.go +++ b/client/internal/peer/wg_watcher.go @@ -30,9 +30,10 @@ type WGWatcher struct { peerKey string stateDump *stateDump - ctx context.Context - ctxCancel context.CancelFunc - ctxLock sync.Mutex + ctx context.Context + ctxCancel context.CancelFunc + ctxLock sync.Mutex + enabledTime time.Time } func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher { @@ -48,6 +49,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) { w.log.Debugf("enable WireGuard watcher") w.ctxLock.Lock() + w.enabledTime = time.Now() if w.ctx != nil && w.ctx.Err() == nil { w.log.Errorf("WireGuard watcher already enabled") @@ -101,6 +103,11 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex onDisconnectedFn() return } + if lastHandshake.IsZero() { + elapsed := handshake.Sub(w.enabledTime).Seconds() + w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake) + } + lastHandshake = *handshake resetTime := time.Until(handshake.Add(checkPeriod)) diff --git a/client/system/info.go b/client/system/info.go index ceb1682f3..a180be4c0 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -6,6 +6,7 @@ import ( "net/netip" "strings" + log "github.com/sirupsen/logrus" "google.golang.org/grpc/metadata" "github.com/netbirdio/netbird/shared/management/proto" @@ -172,6 +173,7 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool { // GetInfoWithChecks retrieves and parses the system information with applied checks. func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) { + log.Debugf("gathering system information with checks: %d", len(checks)) processCheckPaths := make([]string, 0) for _, check := range checks { processCheckPaths = append(processCheckPaths, check.GetFiles()...) @@ -181,9 +183,11 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro if err != nil { return nil, err } + log.Debugf("gathering process check information completed") info := GetInfo(ctx) info.Files = files + log.Debugf("all system information gathered successfully") return info, nil } diff --git a/client/system/info_windows.go b/client/system/info_windows.go index e67356f57..d7f8f30aa 100644 --- a/client/system/info_windows.go +++ b/client/system/info_windows.go @@ -48,6 +48,5 @@ func GetInfo(ctx context.Context) *Info { gio.Hostname = extractDeviceName(ctx, systemHostname) gio.NetbirdVersion = version.NetbirdVersion() gio.UIVersion = extractUserAgent(ctx) - return gio } From dba7ef667d6921e4d7b0381df28a3b5ae8d3db16 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 8 Sep 2025 10:03:56 -0300 Subject: [PATCH 02/30] [misc] Remove aur support and start service on ostree (#4461) * Remove aur support and start service on ostree The aur installation was adding many packages and installing more than just the client. For now is best to remove it and rely on binary install Some users complained about ostree installation not starting the client, we add two explicit commands to it * use ${SUDO} * fix if closure --- release_files/install.sh | 43 +++------------------------------------- 1 file changed, 3 insertions(+), 40 deletions(-) diff --git a/release_files/install.sh b/release_files/install.sh index 856d332cb..5d5349ec4 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -130,36 +130,6 @@ repo_gpgcheck=1 EOF } -install_aur_package() { - INSTALL_PKGS="git base-devel go" - REMOVE_PKGS="" - - # Check if dependencies are installed - for PKG in $INSTALL_PKGS; do - if ! pacman -Q "$PKG" > /dev/null 2>&1; then - # Install missing package(s) - ${SUDO} pacman -S "$PKG" --noconfirm - - # Add installed package for clean up later - REMOVE_PKGS="$REMOVE_PKGS $PKG" - fi - done - - # Build package from AUR - cd /tmp && git clone https://aur.archlinux.org/netbird.git - cd netbird && makepkg -sri --noconfirm - - if ! $SKIP_UI_APP; then - cd /tmp && git clone https://aur.archlinux.org/netbird-ui.git - cd netbird-ui && makepkg -sri --noconfirm - fi - - if [ -n "$REMOVE_PKGS" ]; then - # Clean up the installed packages - ${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm - fi -} - prepare_tun_module() { # Create the necessary file structure for /dev/net/tun if [ ! -c /dev/net/tun ]; then @@ -276,12 +246,9 @@ install_netbird() { if ! $SKIP_UI_APP; then ${SUDO} rpm-ostree -y install netbird-ui fi - ;; - pacman) - ${SUDO} pacman -Syy - install_aur_package - # in-line with the docs at https://wiki.archlinux.org/title/Netbird - ${SUDO} systemctl enable --now netbird@main.service + # ensure the service is started after install + ${SUDO} netbird service install || true + ${SUDO} netbird service start || true ;; pkg) # Check if the package is already installed @@ -458,11 +425,7 @@ if type uname >/dev/null 2>&1; then elif [ -x "$(command -v yum)" ]; then PACKAGE_MANAGER="yum" echo "The installation will be performed using yum package manager" - elif [ -x "$(command -v pacman)" ]; then - PACKAGE_MANAGER="pacman" - echo "The installation will be performed using pacman package manager" fi - else echo "Unable to determine OS type from /etc/os-release" exit 1 From 7aef0f67df4b5974fc08b2808b8de3998e8a9de9 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 8 Sep 2025 18:42:42 +0200 Subject: [PATCH 03/30] [client] Implement environment variable handling for Android (#4440) Some features can only be manipulated via environment variables. With this PR, environment variables can be managed from Android. --- client/android/client.go | 18 ++++++++++++++++-- client/android/env_list.go | 32 ++++++++++++++++++++++++++++++++ client/internal/peer/conn.go | 3 +-- client/internal/peer/env.go | 14 ++++++++++++++ 4 files changed, 63 insertions(+), 4 deletions(-) create mode 100644 client/android/env_list.go create mode 100644 client/internal/peer/env.go diff --git a/client/android/client.go b/client/android/client.go index c05246569..4b4fcc9be 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -4,6 +4,7 @@ package android import ( "context" + "os" "slices" "sync" @@ -83,7 +84,8 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi } // Run start the internal client. It is a blocker function -func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error { +func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { + exportEnvList(envList) cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ ConfigPath: c.cfgFile, }) @@ -118,7 +120,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // In this case make no sense handle registration steps. -func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error { +func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error { + exportEnvList(envList) cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ ConfigPath: c.cfgFile, }) @@ -249,3 +252,14 @@ func (c *Client) SetConnectionListener(listener ConnectionListener) { func (c *Client) RemoveConnectionListener() { c.recorder.RemoveConnectionListener() } + +func exportEnvList(list *EnvList) { + if list == nil { + return + } + for k, v := range list.AllItems() { + if err := os.Setenv(k, v); err != nil { + log.Errorf("could not set env variable %s: %v", k, err) + } + } +} diff --git a/client/android/env_list.go b/client/android/env_list.go new file mode 100644 index 000000000..04122300a --- /dev/null +++ b/client/android/env_list.go @@ -0,0 +1,32 @@ +package android + +import "github.com/netbirdio/netbird/client/internal/peer" + +var ( + // EnvKeyNBForceRelay Exported for Android java client + EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay +) + +// EnvList wraps a Go map for export to Java +type EnvList struct { + data map[string]string +} + +// NewEnvList creates a new EnvList +func NewEnvList() *EnvList { + return &EnvList{data: make(map[string]string)} +} + +// Put adds a key-value pair +func (el *EnvList) Put(key, value string) { + el.data[key] = value +} + +// Get retrieves a value by key +func (el *EnvList) Get(key string) string { + return el.data[key] +} + +func (el *EnvList) AllItems() map[string]string { + return el.data +} diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 224a8144c..86e4596d4 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -6,7 +6,6 @@ import ( "math/rand" "net" "net/netip" - "os" "runtime" "sync" "time" @@ -174,7 +173,7 @@ func (conn *Conn) Open(engineCtx context.Context) error { conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay) conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer) - if os.Getenv("NB_FORCE_RELAY") != "true" { + if !isForceRelayed() { conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) } diff --git a/client/internal/peer/env.go b/client/internal/peer/env.go new file mode 100644 index 000000000..32a458d00 --- /dev/null +++ b/client/internal/peer/env.go @@ -0,0 +1,14 @@ +package peer + +import ( + "os" + "strings" +) + +const ( + EnvKeyNBForceRelay = "NB_FORCE_RELAY" +) + +func isForceRelayed() bool { + return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true") +} From 9e81e782e50b734f68b107ccb0e2ca37c74d6711 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 11 Sep 2025 10:08:54 +0200 Subject: [PATCH 04/30] [client] Fix/v4 stun routing (#4430) Deduplicate STUN package sending. Originally, because every peer shared the same UDP address, the library could not distinguish which STUN message was associated with which candidate. As a result, the Pion library responded from all candidates for every STUN message. --- client/iface/bind/ice_bind.go | 13 +- client/iface/bind/udp_mux_ios.go | 7 - client/iface/device.go | 4 +- client/iface/device/device_android.go | 5 +- client/iface/device/device_darwin.go | 5 +- client/iface/device/device_ios.go | 5 +- client/iface/device/device_kernel_unix.go | 12 +- client/iface/device/device_netstack.go | 5 +- client/iface/device/device_usp_unix.go | 5 +- client/iface/device/device_windows.go | 5 +- client/iface/device_android.go | 4 +- client/iface/iface.go | 6 +- .../udp_muxed_conn.go => udpmux/conn.go} | 17 +- client/iface/udpmux/doc.go | 64 ++++++ .../iface/{bind/udp_mux.go => udpmux/mux.go} | 193 +++++++++++------- .../mux_generic.go} | 4 +- client/iface/udpmux/mux_ios.go | 7 + .../universal.go} | 14 +- client/internal/engine.go | 8 +- client/internal/engine_test.go | 9 +- client/internal/iface_common.go | 4 +- client/internal/peer/worker_ice.go | 78 +++---- go.mod | 2 +- go.sum | 4 +- 24 files changed, 301 insertions(+), 179 deletions(-) delete mode 100644 client/iface/bind/udp_mux_ios.go rename client/iface/{bind/udp_muxed_conn.go => udpmux/conn.go} (95%) create mode 100644 client/iface/udpmux/doc.go rename client/iface/{bind/udp_mux.go => udpmux/mux.go} (65%) rename client/iface/{bind/udp_mux_generic.go => udpmux/mux_generic.go} (85%) create mode 100644 client/iface/udpmux/mux_ios.go rename client/iface/{bind/udp_mux_universal.go => udpmux/universal.go} (97%) diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index 359d2129b..b74f90d6c 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -15,6 +15,7 @@ import ( "golang.org/x/net/ipv6" wgConn "golang.zx2c4.com/wireguard/conn" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -44,7 +45,7 @@ type ICEBind struct { RecvChan chan RecvMessage transportNet transport.Net - filterFn FilterFn + filterFn udpmux.FilterFn endpoints map[netip.Addr]net.Conn endpointsMu sync.Mutex // every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a @@ -54,13 +55,13 @@ type ICEBind struct { closed bool muUDPMux sync.Mutex - udpMux *UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault address wgaddr.Address mtu uint16 activityRecorder *ActivityRecorder } -func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { +func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) ib := &ICEBind{ StdNetBind: b, @@ -115,7 +116,7 @@ func (s *ICEBind) ActivityRecorder() *ActivityRecorder { } // GetICEMux returns the ICE UDPMux that was created and used by ICEBind -func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { +func (s *ICEBind) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) { s.muUDPMux.Lock() defer s.muUDPMux.Unlock() if s.udpMux == nil { @@ -158,8 +159,8 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r s.muUDPMux.Lock() defer s.muUDPMux.Unlock() - s.udpMux = NewUniversalUDPMuxDefault( - UniversalUDPMuxParams{ + s.udpMux = udpmux.NewUniversalUDPMuxDefault( + udpmux.UniversalUDPMuxParams{ UDPConn: nbnet.WrapPacketConn(conn), Net: s.transportNet, FilterFn: s.filterFn, diff --git a/client/iface/bind/udp_mux_ios.go b/client/iface/bind/udp_mux_ios.go deleted file mode 100644 index db0249d11..000000000 --- a/client/iface/bind/udp_mux_ios.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build ios - -package bind - -func (m *UDPMuxDefault) notifyAddressRemoval(addr string) { - // iOS doesn't support nbnet hooks, so this is a no-op -} diff --git a/client/iface/device.go b/client/iface/device.go index ca6dda2c2..921f0ea98 100644 --- a/client/iface/device.go +++ b/client/iface/device.go @@ -7,14 +7,14 @@ import ( wgdevice "golang.zx2c4.com/wireguard/device" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) type WGTunDevice interface { Create() (device.WGConfigurer, error) - Up() (*bind.UniversalUDPMuxDefault, error) + Up() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddr(address wgaddr.Address) error WgAddress() wgaddr.Address MTU() uint16 diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index fe3b9f82e..a731684cc 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -29,7 +30,7 @@ type WGTunDevice struct { name string device *device.Device filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } @@ -88,7 +89,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string } return t.configurer, nil } -func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index cce9d42df..390efe088 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -26,7 +27,7 @@ type TunDevice struct { device *device.Device filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } @@ -71,7 +72,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index 168985b5e..96e4c8bcf 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -28,7 +29,7 @@ type TunDevice struct { device *device.Device filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } @@ -83,7 +84,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index 00a72bcc6..2ef6f6b22 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -12,8 +12,8 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/sharedsock" nbnet "github.com/netbirdio/netbird/util/net" @@ -31,9 +31,9 @@ type TunKernelDevice struct { link *wgLink udpMuxConn net.PacketConn - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault - filterFn bind.FilterFn + filterFn udpmux.FilterFn } func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice { @@ -79,7 +79,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) { return configurer, nil } -func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { if t.udpMux != nil { return t.udpMux, nil } @@ -106,14 +106,14 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { udpConn = nbnet.WrapPacketConn(rawSock) } - bindParams := bind.UniversalUDPMuxParams{ + bindParams := udpmux.UniversalUDPMuxParams{ UDPConn: udpConn, Net: t.transportNet, FilterFn: t.filterFn, WGAddress: t.address, MTU: t.mtu, } - mux := bind.NewUniversalUDPMuxDefault(bindParams) + mux := udpmux.NewUniversalUDPMuxDefault(bindParams) go mux.ReadFromConn(t.ctx) t.udpMuxConn = rawSock t.udpMux = mux diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index f41331ff7..2fcc74809 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -10,6 +10,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -26,7 +27,7 @@ type TunNetstackDevice struct { device *device.Device filteredDevice *FilteredDevice nsTun *nbnetstack.NetStackTun - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer net *netstack.Net @@ -80,7 +81,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 8d30112ae..4cdd70a32 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -12,6 +12,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -25,7 +26,7 @@ type USPDevice struct { device *device.Device filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } @@ -74,7 +75,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index de258868f..f1023bc0a 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -29,7 +30,7 @@ type TunDevice struct { device *device.Device nativeTunDevice *tun.NativeTun filteredDevice *FilteredDevice - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault configurer WGConfigurer } @@ -104,7 +105,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return t.configurer, nil } -func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err diff --git a/client/iface/device_android.go b/client/iface/device_android.go index 39b5c28ae..4649b8b97 100644 --- a/client/iface/device_android.go +++ b/client/iface/device_android.go @@ -5,14 +5,14 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" ) type WGTunDevice interface { Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) - Up() (*bind.UniversalUDPMuxDefault, error) + Up() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddr(address wgaddr.Address) error WgAddress() wgaddr.Address MTU() uint16 diff --git a/client/iface/iface.go b/client/iface/iface.go index 9a42223a1..609572561 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -16,9 +16,9 @@ import ( wgdevice "golang.zx2c4.com/wireguard/device" "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/monotime" @@ -61,7 +61,7 @@ type WGIFaceOpts struct { MTU uint16 MobileArgs *device.MobileIFaceArguments TransportNet transport.Net - FilterFn bind.FilterFn + FilterFn udpmux.FilterFn DisableDNS bool } @@ -114,7 +114,7 @@ func (r *WGIface) ToInterface() *net.Interface { // Up configures a Wireguard interface // The interface must exist before calling this method (e.g. call interface.Create() before) -func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) { +func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) { w.mu.Lock() defer w.mu.Unlock() diff --git a/client/iface/bind/udp_muxed_conn.go b/client/iface/udpmux/conn.go similarity index 95% rename from client/iface/bind/udp_muxed_conn.go rename to client/iface/udpmux/conn.go index 7cacf1c31..3aa40caeb 100644 --- a/client/iface/bind/udp_muxed_conn.go +++ b/client/iface/udpmux/conn.go @@ -1,4 +1,4 @@ -package bind +package udpmux /* Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements @@ -16,11 +16,12 @@ import ( ) type udpMuxedConnParams struct { - Mux *UDPMuxDefault - AddrPool *sync.Pool - Key string - LocalAddr net.Addr - Logger logging.LeveledLogger + Mux *SingleSocketUDPMux + AddrPool *sync.Pool + Key string + LocalAddr net.Addr + Logger logging.LeveledLogger + CandidateID string } // udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag @@ -119,6 +120,10 @@ func (c *udpMuxedConn) Close() error { return err } +func (c *udpMuxedConn) GetCandidateID() string { + return c.params.CandidateID +} + func (c *udpMuxedConn) isClosed() bool { select { case <-c.closedChan: diff --git a/client/iface/udpmux/doc.go b/client/iface/udpmux/doc.go new file mode 100644 index 000000000..27e5e43bc --- /dev/null +++ b/client/iface/udpmux/doc.go @@ -0,0 +1,64 @@ +// Package udpmux provides a custom implementation of a UDP multiplexer +// that allows multiple logical ICE connections to share a single underlying +// UDP socket. This is based on Pion's ICE library, with modifications for +// NetBird's requirements. +// +// # Background +// +// In WebRTC and NAT traversal scenarios, ICE (Interactive Connectivity +// Establishment) is responsible for discovering candidate network paths +// and maintaining connectivity between peers. Each ICE connection +// normally requires a dedicated UDP socket. However, using one socket +// per candidate can be inefficient and difficult to manage. +// +// This package introduces SingleSocketUDPMux, which allows multiple ICE +// candidate connections (muxed connections) to share a single UDP socket. +// It handles demultiplexing of packets based on ICE ufrag values, STUN +// attributes, and candidate IDs. +// +// # Usage +// +// The typical flow is: +// +// 1. Create a UDP socket (net.PacketConn). +// 2. Construct Params with the socket and optional logger/net stack. +// 3. Call NewSingleSocketUDPMux(params). +// 4. For each ICE candidate ufrag, call GetConn(ufrag, addr, candidateID) +// to obtain a logical PacketConn. +// 5. Use the returned PacketConn just like a normal UDP connection. +// +// # STUN Message Routing Logic +// +// When a STUN packet arrives, the mux decides which connection should +// receive it using this routing logic: +// +// Primary Routing: Candidate Pair ID +// - Extract the candidate pair ID from the STUN message using +// ice.CandidatePairIDFromSTUN(msg) +// - The target candidate is the locally generated candidate that +// corresponds to the connection that should handle this STUN message +// - If found, use the target candidate ID to lookup the specific +// connection in candidateConnMap +// - Route the message directly to that connection +// +// Fallback Routing: Broadcasting +// When candidate pair ID is not available or lookup fails: +// - Collect connections from addressMap based on source address +// - Find connection using username attribute (ufrag) from STUN message +// - Remove duplicate connections from the list +// - Send the STUN message to all collected connections +// +// # Peer Reflexive Candidate Discovery +// +// When a remote peer sends a STUN message from an unknown source address +// (from a candidate that has not been exchanged via signal), the ICE +// library will: +// - Generate a new peer reflexive candidate for this source address +// - Extract or assign a candidate ID based on the STUN message attributes +// - Create a mapping between the new peer reflexive candidate ID and +// the appropriate local connection +// +// This discovery mechanism ensures that STUN messages from newly discovered +// peer reflexive candidates can be properly routed to the correct local +// connection without requiring fallback broadcasting. +package udpmux diff --git a/client/iface/bind/udp_mux.go b/client/iface/udpmux/mux.go similarity index 65% rename from client/iface/bind/udp_mux.go rename to client/iface/udpmux/mux.go index db7494405..319724926 100644 --- a/client/iface/bind/udp_mux.go +++ b/client/iface/udpmux/mux.go @@ -1,4 +1,4 @@ -package bind +package udpmux import ( "fmt" @@ -22,9 +22,9 @@ import ( const receiveMTU = 8192 -// UDPMuxDefault is an implementation of the interface -type UDPMuxDefault struct { - params UDPMuxParams +// SingleSocketUDPMux is an implementation of the interface +type SingleSocketUDPMux struct { + params Params closedChan chan struct{} closeOnce sync.Once @@ -32,6 +32,9 @@ type UDPMuxDefault struct { // connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType connsIPv4, connsIPv6 map[string]*udpMuxedConn + // candidateConnMap maps local candidate IDs to their corresponding connection. + candidateConnMap map[string]*udpMuxedConn + addressMapMu sync.RWMutex addressMap map[string][]*udpMuxedConn @@ -46,8 +49,8 @@ type UDPMuxDefault struct { const maxAddrSize = 512 -// UDPMuxParams are parameters for UDPMux. -type UDPMuxParams struct { +// Params are parameters for UDPMux. +type Params struct { Logger logging.LeveledLogger UDPConn net.PacketConn @@ -147,18 +150,19 @@ func isZeros(ip net.IP) bool { return true } -// NewUDPMuxDefault creates an implementation of UDPMux -func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { +// NewSingleSocketUDPMux creates an implementation of UDPMux +func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux { if params.Logger == nil { params.Logger = getLogger() } - mux := &UDPMuxDefault{ - addressMap: map[string][]*udpMuxedConn{}, - params: params, - connsIPv4: make(map[string]*udpMuxedConn), - connsIPv6: make(map[string]*udpMuxedConn), - closedChan: make(chan struct{}, 1), + mux := &SingleSocketUDPMux{ + addressMap: map[string][]*udpMuxedConn{}, + params: params, + connsIPv4: make(map[string]*udpMuxedConn), + connsIPv6: make(map[string]*udpMuxedConn), + candidateConnMap: make(map[string]*udpMuxedConn), + closedChan: make(chan struct{}, 1), pool: &sync.Pool{ New: func() interface{} { // big enough buffer to fit both packet and address @@ -171,15 +175,15 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { return mux } -func (m *UDPMuxDefault) updateLocalAddresses() { +func (m *SingleSocketUDPMux) updateLocalAddresses() { var localAddrsForUnspecified []net.Addr if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr()) } else if ok && addr.IP.IsUnspecified() { // For unspecified addresses, the correct behavior is to return errListenUnspecified, but // it will break the applications that are already using unspecified UDP connection - // with UDPMuxDefault, so print a warn log and create a local address list for mux. - m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") + // with SingleSocketUDPMux, so print a warn log and create a local address list for mux. + m.params.Logger.Warn("SingleSocketUDPMux should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") var networks []ice.NetworkType switch { @@ -216,13 +220,13 @@ func (m *UDPMuxDefault) updateLocalAddresses() { m.mu.Unlock() } -// LocalAddr returns the listening address of this UDPMuxDefault -func (m *UDPMuxDefault) LocalAddr() net.Addr { +// LocalAddr returns the listening address of this SingleSocketUDPMux +func (m *SingleSocketUDPMux) LocalAddr() net.Addr { return m.params.UDPConn.LocalAddr() } // GetListenAddresses returns the list of addresses that this mux is listening on -func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { +func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr { m.updateLocalAddresses() m.mu.Lock() @@ -236,7 +240,7 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { // GetConn returns a PacketConn given the connection's ufrag and network address // creates the connection if an existing one can't be found -func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) { +func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID string) (net.PacketConn, error) { // don't check addr for mux using unspecified address m.mu.Lock() lenLocalAddrs := len(m.localAddrsForUnspecified) @@ -260,12 +264,14 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er return conn, nil } - c := m.createMuxedConn(ufrag) + c := m.createMuxedConn(ufrag, candidateID) go func() { <-c.CloseChannel() m.RemoveConnByUfrag(ufrag) }() + m.candidateConnMap[candidateID] = c + if isIPv6 { m.connsIPv6[ufrag] = c } else { @@ -276,7 +282,7 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er } // RemoveConnByUfrag stops and removes the muxed packet connection -func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { +func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) { removedConns := make([]*udpMuxedConn, 0, 2) // Keep lock section small to avoid deadlock with conn lock @@ -284,10 +290,12 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { if c, ok := m.connsIPv4[ufrag]; ok { delete(m.connsIPv4, ufrag) removedConns = append(removedConns, c) + delete(m.candidateConnMap, c.GetCandidateID()) } if c, ok := m.connsIPv6[ufrag]; ok { delete(m.connsIPv6, ufrag) removedConns = append(removedConns, c) + delete(m.candidateConnMap, c.GetCandidateID()) } m.mu.Unlock() @@ -314,7 +322,7 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { } // IsClosed returns true if the mux had been closed -func (m *UDPMuxDefault) IsClosed() bool { +func (m *SingleSocketUDPMux) IsClosed() bool { select { case <-m.closedChan: return true @@ -324,7 +332,7 @@ func (m *UDPMuxDefault) IsClosed() bool { } // Close the mux, no further connections could be created -func (m *UDPMuxDefault) Close() error { +func (m *SingleSocketUDPMux) Close() error { var err error m.closeOnce.Do(func() { m.mu.Lock() @@ -347,11 +355,11 @@ func (m *UDPMuxDefault) Close() error { return err } -func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) { +func (m *SingleSocketUDPMux) writeTo(buf []byte, rAddr net.Addr) (n int, err error) { return m.params.UDPConn.WriteTo(buf, rAddr) } -func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) { +func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr string) { if m.IsClosed() { return } @@ -368,81 +376,109 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) log.Debugf("ICE: registered %s for %s", addr, conn.params.Key) } -func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn { +func (m *SingleSocketUDPMux) createMuxedConn(key string, candidateID string) *udpMuxedConn { c := newUDPMuxedConn(&udpMuxedConnParams{ - Mux: m, - Key: key, - AddrPool: m.pool, - LocalAddr: m.LocalAddr(), - Logger: m.params.Logger, + Mux: m, + Key: key, + AddrPool: m.pool, + LocalAddr: m.LocalAddr(), + Logger: m.params.Logger, + CandidateID: candidateID, }) return c } // HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library -func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error { - +func (m *SingleSocketUDPMux) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error { remoteAddr, ok := addr.(*net.UDPAddr) if !ok { return fmt.Errorf("underlying PacketConn did not return a UDPAddr") } - // If we have already seen this address dispatch to the appropriate destination - // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one - // muxed connection - one for the SRFLX candidate and the other one for the HOST one. - // We will then forward STUN packets to each of these connections. - m.addressMapMu.RLock() + // Try to route to specific candidate connection first + if conn := m.findCandidateConnection(msg); conn != nil { + return conn.writePacket(msg.Raw, remoteAddr) + } + + // Fallback: route to all possible connections + return m.forwardToAllConnections(msg, addr, remoteAddr) +} + +// findCandidateConnection attempts to find the specific connection for a STUN message +func (m *SingleSocketUDPMux) findCandidateConnection(msg *stun.Message) *udpMuxedConn { + candidatePairID, ok, err := ice.CandidatePairIDFromSTUN(msg) + if err != nil { + return nil + } else if !ok { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + conn, exists := m.candidateConnMap[candidatePairID.TargetCandidateID()] + if !exists { + return nil + } + return conn +} + +// forwardToAllConnections forwards STUN message to all relevant connections +func (m *SingleSocketUDPMux) forwardToAllConnections(msg *stun.Message, addr net.Addr, remoteAddr *net.UDPAddr) error { var destinationConnList []*udpMuxedConn + + // Add connections from address map + m.addressMapMu.RLock() if storedConns, ok := m.addressMap[addr.String()]; ok { destinationConnList = append(destinationConnList, storedConns...) } m.addressMapMu.RUnlock() - var isIPv6 bool - if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil { - isIPv6 = true + if conn, ok := m.findConnectionByUsername(msg, addr); ok { + // If we have already seen this address dispatch to the appropriate destination + // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one + // muxed connection - one for the SRFLX candidate and the other one for the HOST one. + // We will then forward STUN packets to each of these connections. + if !m.connectionExists(conn, destinationConnList) { + destinationConnList = append(destinationConnList, conn) + } } - // This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront. - // However, we can take a username attribute from the STUN message which contains ufrag. - // We can use ufrag to identify the destination conn to route packet to. - attr, stunAttrErr := msg.Get(stun.AttrUsername) - if stunAttrErr == nil { - ufrag := strings.Split(string(attr), ":")[0] - - m.mu.Lock() - destinationConn := m.connsIPv4[ufrag] - if isIPv6 { - destinationConn = m.connsIPv6[ufrag] - } - - if destinationConn != nil { - exists := false - for _, conn := range destinationConnList { - if conn.params.Key == destinationConn.params.Key { - exists = true - break - } - } - if !exists { - destinationConnList = append(destinationConnList, destinationConn) - } - } - m.mu.Unlock() - } - - // Forward STUN packets to each destination connections even thought the STUN packet might not belong there. - // It will be discarded by the further ICE candidate logic if so. + // Forward to all found connections for _, conn := range destinationConnList { if err := conn.writePacket(msg.Raw, remoteAddr); err != nil { log.Errorf("could not write packet: %v", err) } } - return nil } -func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) { +// findConnectionByUsername finds connection using username attribute from STUN message +func (m *SingleSocketUDPMux) findConnectionByUsername(msg *stun.Message, addr net.Addr) (*udpMuxedConn, bool) { + attr, err := msg.Get(stun.AttrUsername) + if err != nil { + return nil, false + } + + ufrag := strings.Split(string(attr), ":")[0] + isIPv6 := isIPv6Address(addr) + + m.mu.Lock() + defer m.mu.Unlock() + + return m.getConn(ufrag, isIPv6) +} + +// connectionExists checks if a connection already exists in the list +func (m *SingleSocketUDPMux) connectionExists(target *udpMuxedConn, conns []*udpMuxedConn) bool { + for _, conn := range conns { + if conn.params.Key == target.params.Key { + return true + } + } + return false +} + +func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) { if isIPv6 { val, ok = m.connsIPv6[ufrag] } else { @@ -451,6 +487,13 @@ func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, o return } +func isIPv6Address(addr net.Addr) bool { + if udpAddr, ok := addr.(*net.UDPAddr); ok { + return udpAddr.IP.To4() == nil + } + return false +} + type bufferHolder struct { buf []byte } diff --git a/client/iface/bind/udp_mux_generic.go b/client/iface/udpmux/mux_generic.go similarity index 85% rename from client/iface/bind/udp_mux_generic.go rename to client/iface/udpmux/mux_generic.go index 63f786d2b..cf3043be0 100644 --- a/client/iface/bind/udp_mux_generic.go +++ b/client/iface/udpmux/mux_generic.go @@ -1,12 +1,12 @@ //go:build !ios -package bind +package udpmux import ( nbnet "github.com/netbirdio/netbird/util/net" ) -func (m *UDPMuxDefault) notifyAddressRemoval(addr string) { +func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) { // Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet) if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok { conn.RemoveAddress(addr) diff --git a/client/iface/udpmux/mux_ios.go b/client/iface/udpmux/mux_ios.go new file mode 100644 index 000000000..4cf211d8f --- /dev/null +++ b/client/iface/udpmux/mux_ios.go @@ -0,0 +1,7 @@ +//go:build ios + +package udpmux + +func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) { + // iOS doesn't support nbnet hooks, so this is a no-op +} diff --git a/client/iface/bind/udp_mux_universal.go b/client/iface/udpmux/universal.go similarity index 97% rename from client/iface/bind/udp_mux_universal.go rename to client/iface/udpmux/universal.go index a1f517dcd..43bfedaaa 100644 --- a/client/iface/bind/udp_mux_universal.go +++ b/client/iface/udpmux/universal.go @@ -1,4 +1,4 @@ -package bind +package udpmux /* Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements. @@ -29,7 +29,7 @@ type FilterFn func(address netip.Addr) (bool, netip.Prefix, error) // UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn // It then passes packets to the UDPMux that does the actual connection muxing. type UniversalUDPMuxDefault struct { - *UDPMuxDefault + *SingleSocketUDPMux params UniversalUDPMuxParams // since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents @@ -72,12 +72,12 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef address: params.WGAddress, } - udpMuxParams := UDPMuxParams{ + udpMuxParams := Params{ Logger: params.Logger, UDPConn: m.params.UDPConn, Net: m.params.Net, } - m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams) + m.SingleSocketUDPMux = NewSingleSocketUDPMux(udpMuxParams) return m } @@ -211,8 +211,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time // GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers // and return a unique connection per server. -func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) { - return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr) +func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr, candidateID string) (net.PacketConn, error) { + return m.SingleSocketUDPMux.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr, candidateID) } // HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server. @@ -233,7 +233,7 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A } return nil } - return m.UDPMuxDefault.HandleSTUNMessage(msg, addr) + return m.SingleSocketUDPMux.HandleSTUNMessage(msg, addr) } // isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server. diff --git a/client/internal/engine.go b/client/internal/engine.go index 61ff41600..9dc744434 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -29,9 +29,9 @@ import ( "github.com/netbirdio/netbird/client/firewall" firewallManager "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" @@ -166,7 +166,7 @@ type Engine struct { wgInterface WGIface - udpMux *bind.UniversalUDPMuxDefault + udpMux *udpmux.UniversalUDPMuxDefault // networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service networkSerial uint64 @@ -461,7 +461,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) StunTurn: &e.stunTurn, InterfaceBlackList: e.config.IFaceBlackList, DisableIPv6Discovery: e.config.DisableIPv6Discovery, - UDPMux: e.udpMux.UDPMuxDefault, + UDPMux: e.udpMux.SingleSocketUDPMux, UDPMuxSrflx: e.udpMux, NATExternalIPs: e.parseNATExternalIPMappings(), } @@ -1326,7 +1326,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV StunTurn: &e.stunTurn, InterfaceBlackList: e.config.IFaceBlackList, DisableIPv6Discovery: e.config.DisableIPv6Discovery, - UDPMux: e.udpMux.UDPMuxDefault, + UDPMux: e.udpMux.SingleSocketUDPMux, UDPMuxSrflx: e.udpMux, NATExternalIPs: e.parseNATExternalIPMappings(), }, diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 90c8cbc60..4d2e81f43 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -26,10 +26,11 @@ import ( "google.golang.org/grpc/keepalive" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/internal/dns" @@ -84,7 +85,7 @@ type MockWGIface struct { NameFunc func() string AddressFunc func() wgaddr.Address ToInterfaceFunc func() *net.Interface - UpFunc func() (*bind.UniversalUDPMuxDefault, error) + UpFunc func() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddrFunc func(newAddr string) error UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemovePeerFunc func(peerKey string) error @@ -134,7 +135,7 @@ func (m *MockWGIface) ToInterface() *net.Interface { return m.ToInterfaceFunc() } -func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) { +func (m *MockWGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) { return m.UpFunc() } @@ -413,7 +414,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { if err != nil { t.Fatal(err) } - engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280}) + engine.udpMux = udpmux.NewUniversalUDPMuxDefault(udpmux.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280}) engine.ctx = ctx engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{}) engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface) diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index bf96153ea..690fdb7cc 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -9,9 +9,9 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/monotime" @@ -24,7 +24,7 @@ type wgIfaceBase interface { Name() string Address() wgaddr.Address ToInterface() *net.Interface - Up() (*bind.UniversalUDPMuxDefault, error) + Up() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error GetProxy() wgproxy.Proxy UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 896c55b6c..4e85ba0fa 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -9,11 +9,10 @@ import ( "time" "github.com/pion/ice/v4" - "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/internal/peer/conntype" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" @@ -55,10 +54,6 @@ type WorkerICE struct { sessionID ICESessionID muxAgent sync.Mutex - StunTurn []*stun.URI - - sentExtraSrflx bool - localUfrag string localPwd string @@ -139,7 +134,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { w.muxAgent.Unlock() return } - w.sentExtraSrflx = false w.agent = agent w.agentDialerCancel = dialerCancel w.agentConnecting = true @@ -166,6 +160,21 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA w.log.Errorf("error while handling remote candidate") return } + + if shouldAddExtraCandidate(candidate) { + // sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port) + // this is useful when network has an existing port forwarding rule for the wireguard port and this peer + extraSrflx, err := extraSrflxCandidate(candidate) + if err != nil { + w.log.Errorf("failed creating extra server reflexive candidate %s", err) + return + } + + if err := w.agent.AddRemoteCandidate(extraSrflx); err != nil { + w.log.Errorf("error while handling remote candidate") + return + } + } } func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) { @@ -327,7 +336,7 @@ func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) return } - mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault) + mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*udpmux.UniversalUDPMuxDefault) if !ok { w.log.Warn("invalid udp mux conversion") return @@ -354,26 +363,6 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) { w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err) } }() - - if !w.shouldSendExtraSrflxCandidate(candidate) { - return - } - - // sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port) - // this is useful when network has an existing port forwarding rule for the wireguard port and this peer - extraSrflx, err := extraSrflxCandidate(candidate) - if err != nil { - w.log.Errorf("failed creating extra server reflexive candidate %s", err) - return - } - w.sentExtraSrflx = true - - go func() { - err = w.signaler.SignalICECandidate(extraSrflx, w.config.Key) - if err != nil { - w.log.Errorf("failed signaling the extra server reflexive candidate: %s", err) - } - }() } func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) { @@ -424,22 +413,31 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia } } -func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool { - if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port { - return true - } - return false -} - func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) { - isControlling := w.config.LocalKey > w.config.Key - if isControlling { - return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) + if isController(w.config) { + return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) } else { return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) } } +func shouldAddExtraCandidate(candidate ice.Candidate) bool { + if candidate.Type() != ice.CandidateTypeServerReflexive { + return false + } + + if candidate.Port() == candidate.RelatedAddress().Port { + return false + } + + // in the older version when we didn't set candidate ID extension the remote peer sent the extra candidates + // in newer version we generate locally the extra candidate + if _, ok := candidate.GetExtension(ice.ExtensionKeyCandidateID); !ok { + return false + } + return true +} + func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { relatedAdd := candidate.RelatedAddress() ec, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ @@ -455,6 +453,10 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive } for _, e := range candidate.Extensions() { + // overwrite the original candidate ID with the new one to avoid candidate duplication + if e.Key == ice.ExtensionKeyCandidateID { + e.Value = candidate.ID() + } if err := ec.AddExtension(e); err != nil { return nil, err } diff --git a/go.mod b/go.mod index 70e52875f..23aa45277 100644 --- a/go.mod +++ b/go.mod @@ -261,6 +261,6 @@ replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-2 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 -replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107 +replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 diff --git a/go.sum b/go.sum index 3fdef5d08..7096be3fe 100644 --- a/go.sum +++ b/go.sum @@ -501,8 +501,8 @@ github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJE github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= -github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107 h1:ZJwhKexMlK15B/Ld+1T8VYE2Mt1lk1kf2DlXr46EHcw= -github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= +github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= +github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ= github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= From 47e64d72dbe69ac3245e7c7140e2db868c8c8782 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 11 Sep 2025 16:21:09 +0200 Subject: [PATCH 05/30] [client] Fix client status check (#4474) The client status is not enough to protect the RPC calls from concurrency issues, because it is handled internally in the client in an asynchronous way. --- client/internal/connect.go | 7 +- client/server/server.go | 408 ++++++++++++++++++----------------- client/server/server_test.go | 1 + 3 files changed, 213 insertions(+), 203 deletions(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index f20b8d361..33cd4b4a1 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -280,15 +280,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan return wrapErr(err) } - log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) state.Set(StatusConnected) if runningChan != nil { - select { - case runningChan <- struct{}{}: - default: - } + close(runningChan) + runningChan = nil } <-engineCtx.Done() diff --git a/client/server/server.go b/client/server/server.go index d89c7ce91..fae342f78 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -65,6 +65,8 @@ type Server struct { mutex sync.Mutex config *profilemanager.Config proto.UnimplementedDaemonServiceServer + clientRunning bool // protected by mutex + clientRunningChan chan struct{} connectClient *internal.ConnectClient @@ -103,6 +105,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable func (s *Server) Start() error { s.mutex.Lock() defer s.mutex.Unlock() + state := internal.CtxGetState(s.rootCtx) if err := handlePanicLog(); err != nil { @@ -172,8 +175,12 @@ func (s *Server) Start() error { return nil } - go s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil) - + if s.clientRunning { + return nil + } + s.clientRunning = true + s.clientRunningChan = make(chan struct{}, 1) + go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan) return nil } @@ -204,12 +211,22 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error { // connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional // mechanism to keep the client connected even when the connection is lost. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. -func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, - runningChan chan struct{}, -) { - backOff := getConnectWithBackoff(ctx) - retryStarted := false +func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) { + defer func() { + s.mutex.Lock() + s.clientRunning = false + s.mutex.Unlock() + }() + if s.config.DisableAutoConnect { + if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil { + log.Debugf("run client connection exited with error: %v", err) + } + log.Tracef("client connection exited") + return + } + + backOff := getConnectWithBackoff(ctx) go func() { t := time.NewTicker(24 * time.Hour) for { @@ -218,91 +235,34 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanage t.Stop() return case <-t.C: - if retryStarted { - - mgmtState := statusRecorder.GetManagementState() - signalState := statusRecorder.GetSignalState() - if mgmtState.Connected && signalState.Connected { - log.Tracef("resetting status") - retryStarted = false - } else { - log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected) - } + mgmtState := statusRecorder.GetManagementState() + signalState := statusRecorder.GetSignalState() + if mgmtState.Connected && signalState.Connected { + log.Tracef("resetting status") + backOff.Reset() + } else { + log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected) } } } }() runOperation := func() error { - log.Tracef("running client connection") - s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder) - s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse) - - err := s.connectClient.Run(runningChan) + err := s.connect(ctx, profileConfig, statusRecorder, runningChan) if err != nil { log.Debugf("run client connection exited with error: %v. Will retry in the background", err) + return err } - if config.DisableAutoConnect { - return backoff.Permanent(err) - } - - if !retryStarted { - retryStarted = true - backOff.Reset() - } - - log.Tracef("client connection exited") - return fmt.Errorf("client connection exited") + log.Tracef("client connection exited gracefully, do not need to retry") + return nil } - err := backoff.Retry(runOperation, backOff) - if s, ok := gstatus.FromError(err); ok && s.Code() != codes.Canceled { - log.Errorf("received an error when trying to connect: %v", err) - } else { - log.Tracef("retry canceled") + if err := backoff.Retry(runOperation, backOff); err != nil { + log.Errorf("operation failed: %v", err) } } -// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries -func getConnectWithBackoff(ctx context.Context) backoff.BackOff { - initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime) - maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval) - maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime) - multiplier := defaultRetryMultiplier - - if envValue := os.Getenv(retryMultiplierVar); envValue != "" { - // parse the multiplier from the environment variable string value to float64 - value, err := strconv.ParseFloat(envValue, 64) - if err != nil { - log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier) - } else { - multiplier = value - } - } - - return backoff.WithContext(&backoff.ExponentialBackOff{ - InitialInterval: initialInterval, - RandomizationFactor: 1, - Multiplier: multiplier, - MaxInterval: maxInterval, - MaxElapsedTime: maxElapsedTime, // 14 days - Stop: backoff.Stop, - Clock: backoff.SystemClock, - }, ctx) -} - -// parseEnvDuration parses the environment variable and returns the duration -func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration { - if envValue := os.Getenv(envVar); envValue != "" { - if duration, err := time.ParseDuration(envValue); err == nil { - return duration - } - log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration) - } - return defaultDuration -} - // loginAttempt attempts to login using the provided information. it returns a status in case something fails func (s *Server) loginAttempt(ctx context.Context, setupKey, jwtToken string) (internal.StatusType, error) { var status internal.StatusType @@ -716,11 +676,14 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second) defer cancel() - runningChan := make(chan struct{}, 1) // buffered channel to do not lose the signal - go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan) + if !s.clientRunning { + s.clientRunning = true + s.clientRunningChan = make(chan struct{}, 1) + go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan) + } for { select { - case <-runningChan: + case <-s.clientRunningChan: s.isSessionActive.Store(true) return &proto.UpResponse{}, nil case <-callerCtx.Done(): @@ -1127,6 +1090,134 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p }, nil } +// AddProfile adds a new profile to the daemon. +func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.checkProfilesDisabled() { + return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled) + } + + if msg.ProfileName == "" || msg.Username == "" { + return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided") + } + + if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil { + log.Errorf("failed to create profile: %v", err) + return nil, fmt.Errorf("failed to create profile: %w", err) + } + + return &proto.AddProfileResponse{}, nil +} + +// RemoveProfile removes a profile from the daemon. +func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if err := s.validateProfileOperation(msg.ProfileName, false); err != nil { + return nil, err + } + + if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil { + log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err) + } + + if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil { + log.Errorf("failed to remove profile: %v", err) + return nil, fmt.Errorf("failed to remove profile: %w", err) + } + + return &proto.RemoveProfileResponse{}, nil +} + +// ListProfiles lists all profiles in the daemon. +func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesRequest) (*proto.ListProfilesResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if msg.Username == "" { + return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided") + } + + profiles, err := s.profileManager.ListProfiles(msg.Username) + if err != nil { + log.Errorf("failed to list profiles: %v", err) + return nil, fmt.Errorf("failed to list profiles: %w", err) + } + + response := &proto.ListProfilesResponse{ + Profiles: make([]*proto.Profile, len(profiles)), + } + for i, profile := range profiles { + response.Profiles[i] = &proto.Profile{ + Name: profile.Name, + IsActive: profile.IsActive, + } + } + + return response, nil +} + +// GetActiveProfile returns the active profile in the daemon. +func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + activeProfile, err := s.profileManager.GetActiveProfileState() + if err != nil { + log.Errorf("failed to get active profile state: %v", err) + return nil, fmt.Errorf("failed to get active profile state: %w", err) + } + + return &proto.GetActiveProfileResponse{ + ProfileName: activeProfile.Name, + Username: activeProfile.Username, + }, nil +} + +// GetFeatures returns the features supported by the daemon. +func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + features := &proto.GetFeaturesResponse{ + DisableProfiles: s.checkProfilesDisabled(), + DisableUpdateSettings: s.checkUpdateSettingsDisabled(), + } + + return features, nil +} + +func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error { + log.Tracef("running client connection") + s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder) + s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse) + if err := s.connectClient.Run(runningChan); err != nil { + return err + } + return nil +} + +func (s *Server) checkProfilesDisabled() bool { + // Check if the environment variable is set to disable profiles + if s.profilesDisabled { + return true + } + + return false +} + +func (s *Server) checkUpdateSettingsDisabled() bool { + // Check if the environment variable is set to disable profiles + if s.updateSettingsDisabled { + return true + } + + return false +} + func (s *Server) onSessionExpire() { if runtime.GOOS != "windows" { isUIActive := internal.CheckUIApp() @@ -1138,6 +1229,45 @@ func (s *Server) onSessionExpire() { } } +// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries +func getConnectWithBackoff(ctx context.Context) backoff.BackOff { + initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime) + maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval) + maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime) + multiplier := defaultRetryMultiplier + + if envValue := os.Getenv(retryMultiplierVar); envValue != "" { + // parse the multiplier from the environment variable string value to float64 + value, err := strconv.ParseFloat(envValue, 64) + if err != nil { + log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier) + } else { + multiplier = value + } + } + + return backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: initialInterval, + RandomizationFactor: 1, + Multiplier: multiplier, + MaxInterval: maxInterval, + MaxElapsedTime: maxElapsedTime, // 14 days + Stop: backoff.Stop, + Clock: backoff.SystemClock, + }, ctx) +} + +// parseEnvDuration parses the environment variable and returns the duration +func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration { + if envValue := os.Getenv(envVar); envValue != "" { + if duration, err := time.ParseDuration(envValue); err == nil { + return duration + } + log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration) + } + return defaultDuration +} + func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { pbFullStatus := proto.FullStatus{ ManagementState: &proto.ManagementState{}, @@ -1252,121 +1382,3 @@ func sendTerminalNotification() error { return wallCmd.Wait() } - -// AddProfile adds a new profile to the daemon. -func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if s.checkProfilesDisabled() { - return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled) - } - - if msg.ProfileName == "" || msg.Username == "" { - return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided") - } - - if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil { - log.Errorf("failed to create profile: %v", err) - return nil, fmt.Errorf("failed to create profile: %w", err) - } - - return &proto.AddProfileResponse{}, nil -} - -// RemoveProfile removes a profile from the daemon. -func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if err := s.validateProfileOperation(msg.ProfileName, false); err != nil { - return nil, err - } - - if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil { - log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err) - } - - if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil { - log.Errorf("failed to remove profile: %v", err) - return nil, fmt.Errorf("failed to remove profile: %w", err) - } - - return &proto.RemoveProfileResponse{}, nil -} - -// ListProfiles lists all profiles in the daemon. -func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesRequest) (*proto.ListProfilesResponse, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - if msg.Username == "" { - return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided") - } - - profiles, err := s.profileManager.ListProfiles(msg.Username) - if err != nil { - log.Errorf("failed to list profiles: %v", err) - return nil, fmt.Errorf("failed to list profiles: %w", err) - } - - response := &proto.ListProfilesResponse{ - Profiles: make([]*proto.Profile, len(profiles)), - } - for i, profile := range profiles { - response.Profiles[i] = &proto.Profile{ - Name: profile.Name, - IsActive: profile.IsActive, - } - } - - return response, nil -} - -// GetActiveProfile returns the active profile in the daemon. -func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - activeProfile, err := s.profileManager.GetActiveProfileState() - if err != nil { - log.Errorf("failed to get active profile state: %v", err) - return nil, fmt.Errorf("failed to get active profile state: %w", err) - } - - return &proto.GetActiveProfileResponse{ - ProfileName: activeProfile.Name, - Username: activeProfile.Username, - }, nil -} - -// GetFeatures returns the features supported by the daemon. -func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) { - s.mutex.Lock() - defer s.mutex.Unlock() - - features := &proto.GetFeaturesResponse{ - DisableProfiles: s.checkProfilesDisabled(), - DisableUpdateSettings: s.checkUpdateSettingsDisabled(), - } - - return features, nil -} - -func (s *Server) checkProfilesDisabled() bool { - // Check if the environment variable is set to disable profiles - if s.profilesDisabled { - return true - } - - return false -} - -func (s *Server) checkUpdateSettingsDisabled() bool { - // Check if the environment variable is set to disable profiles - if s.updateSettingsDisabled { - return true - } - - return false -} diff --git a/client/server/server_test.go b/client/server/server_test.go index 87889cbce..45a1aa5c7 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc/keepalive" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" From cf7f6c355f713e83cf171b79e08dac60b316e4fd Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Thu, 11 Sep 2025 22:20:10 +0300 Subject: [PATCH 06/30] [misc] Remove default zitadel admin user in deployment script (#4482) * Delete default zitadel-admin user during initialization Signed-off-by: bcmmbaga * Refactor Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- .../getting-started-with-zitadel.sh | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 2d7c65cbe..cfec1000e 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -328,6 +328,45 @@ delete_auto_service_user() { echo "$PARSED_RESPONSE" } +delete_default_zitadel_admin() { + INSTANCE_URL=$1 + PAT=$2 + + # Search for the default zitadel-admin user + RESPONSE=$( + curl -sS -X POST "$INSTANCE_URL/management/v1/users/_search" \ + -H "Authorization: Bearer $PAT" \ + -H "Content-Type: application/json" \ + -d '{ + "queries": [ + { + "userNameQuery": { + "userName": "zitadel-admin@", + "method": "TEXT_QUERY_METHOD_STARTS_WITH" + } + } + ] + }' + ) + + DEFAULT_ADMIN_ID=$(echo "$RESPONSE" | jq -r '.result[0].id // empty') + + if [ -n "$DEFAULT_ADMIN_ID" ] && [ "$DEFAULT_ADMIN_ID" != "null" ]; then + echo "Found default zitadel-admin user with ID: $DEFAULT_ADMIN_ID" + + RESPONSE=$( + curl -sS -X DELETE "$INSTANCE_URL/management/v1/users/$DEFAULT_ADMIN_ID" \ + -H "Authorization: Bearer $PAT" \ + -H "Content-Type: application/json" \ + ) + PARSED_RESPONSE=$(echo "$RESPONSE" | jq -r '.details.changeDate // "deleted"') + handle_zitadel_request_response "$PARSED_RESPONSE" "delete_default_zitadel_admin" "$RESPONSE" + + else + echo "Default zitadel-admin user not found: $RESPONSE" + fi +} + init_zitadel() { echo -e "\nInitializing Zitadel with NetBird's applications\n" INSTANCE_URL="$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN" @@ -346,6 +385,9 @@ init_zitadel() { echo -n "Waiting for Zitadel to become ready " wait_api "$INSTANCE_URL" "$PAT" + echo "Deleting default zitadel-admin user..." + delete_default_zitadel_admin "$INSTANCE_URL" "$PAT" + # create the zitadel project echo "Creating new zitadel project" PROJECT_ID=$(create_new_project "$INSTANCE_URL" "$PAT") From 0c6f671a7c8ab9befb9580c9297a1b0de05286cb Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 12 Sep 2025 09:31:03 +0200 Subject: [PATCH 07/30] Refactor healthcheck sender and receiver to use configurable options (#4433) --- .github/workflows/golang-test-linux.yml | 2 +- shared/relay/client/manager.go | 7 +- shared/relay/client/picker.go | 18 +++--- shared/relay/client/picker_test.go | 10 +-- shared/relay/healthcheck/env.go | 24 +++++++ shared/relay/healthcheck/env_test.go | 36 +++++++++++ shared/relay/healthcheck/receiver.go | 32 +++++++-- shared/relay/healthcheck/receiver_test.go | 79 ++++++----------------- shared/relay/healthcheck/sender.go | 76 ++++++++++++---------- shared/relay/healthcheck/sender_test.go | 78 ++++++---------------- 10 files changed, 186 insertions(+), 176 deletions(-) create mode 100644 shared/relay/healthcheck/env.go create mode 100644 shared/relay/healthcheck/env_test.go diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index f7b4e238f..ba36c013b 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -217,7 +217,7 @@ jobs: - arch: "386" raceFlag: "" - arch: "amd64" - raceFlag: "" + raceFlag: "-race" runs-on: ubuntu-22.04 steps: - name: Install Go diff --git a/shared/relay/client/manager.go b/shared/relay/client/manager.go index a40343fb1..6220e7f6b 100644 --- a/shared/relay/client/manager.go +++ b/shared/relay/client/manager.go @@ -78,9 +78,10 @@ func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uin tokenStore: tokenStore, mtu: mtu, serverPicker: &ServerPicker{ - TokenStore: tokenStore, - PeerID: peerID, - MTU: mtu, + TokenStore: tokenStore, + PeerID: peerID, + MTU: mtu, + ConnectionTimeout: defaultConnectionTimeout, }, relayClients: make(map[string]*RelayTrack), onDisconnectedListeners: make(map[string]*list.List), diff --git a/shared/relay/client/picker.go b/shared/relay/client/picker.go index b6c7b5e8a..39d0ba072 100644 --- a/shared/relay/client/picker.go +++ b/shared/relay/client/picker.go @@ -13,11 +13,8 @@ import ( ) const ( - maxConcurrentServers = 7 -) - -var ( - connectionTimeout = 30 * time.Second + maxConcurrentServers = 7 + defaultConnectionTimeout = 30 * time.Second ) type connResult struct { @@ -27,14 +24,15 @@ type connResult struct { } type ServerPicker struct { - TokenStore *auth.TokenStore - ServerURLs atomic.Value - PeerID string - MTU uint16 + TokenStore *auth.TokenStore + ServerURLs atomic.Value + PeerID string + MTU uint16 + ConnectionTimeout time.Duration } func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) { - ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout) + ctx, cancel := context.WithTimeout(parentCtx, sp.ConnectionTimeout) defer cancel() totalServers := len(sp.ServerURLs.Load().([]string)) diff --git a/shared/relay/client/picker_test.go b/shared/relay/client/picker_test.go index 28167c5ce..fb3fa7375 100644 --- a/shared/relay/client/picker_test.go +++ b/shared/relay/client/picker_test.go @@ -8,15 +8,15 @@ import ( ) func TestServerPicker_UnavailableServers(t *testing.T) { - connectionTimeout = 5 * time.Second - + timeout := 5 * time.Second sp := ServerPicker{ - TokenStore: nil, - PeerID: "test", + TokenStore: nil, + PeerID: "test", + ConnectionTimeout: timeout, } sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"}) - ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1) + ctx, cancel := context.WithTimeout(context.Background(), timeout+1) defer cancel() go func() { diff --git a/shared/relay/healthcheck/env.go b/shared/relay/healthcheck/env.go new file mode 100644 index 000000000..2b584c195 --- /dev/null +++ b/shared/relay/healthcheck/env.go @@ -0,0 +1,24 @@ +package healthcheck + +import ( + "os" + "strconv" + + log "github.com/sirupsen/logrus" +) + +const ( + defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD" +) + +func getAttemptThresholdFromEnv() int { + if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" { + threshold, err := strconv.ParseInt(attemptThreshold, 10, 64) + if err != nil { + log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold) + return defaultAttemptThreshold + } + return int(threshold) + } + return defaultAttemptThreshold +} diff --git a/shared/relay/healthcheck/env_test.go b/shared/relay/healthcheck/env_test.go new file mode 100644 index 000000000..2e14bb8bf --- /dev/null +++ b/shared/relay/healthcheck/env_test.go @@ -0,0 +1,36 @@ +package healthcheck + +import ( + "os" + "testing" +) + +//nolint:tenv +func TestGetAttemptThresholdFromEnv(t *testing.T) { + tests := []struct { + name string + envValue string + expected int + }{ + {"Default attempt threshold when env is not set", "", defaultAttemptThreshold}, + {"Custom attempt threshold when env is set to a valid integer", "3", 3}, + {"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.envValue == "" { + os.Unsetenv(defaultAttemptThresholdEnv) + } else { + os.Setenv(defaultAttemptThresholdEnv, tt.envValue) + } + + result := getAttemptThresholdFromEnv() + if result != tt.expected { + t.Fatalf("Expected %d, got %d", tt.expected, result) + } + + os.Unsetenv(defaultAttemptThresholdEnv) + }) + } +} diff --git a/shared/relay/healthcheck/receiver.go b/shared/relay/healthcheck/receiver.go index b3503d5db..90f795bbe 100644 --- a/shared/relay/healthcheck/receiver.go +++ b/shared/relay/healthcheck/receiver.go @@ -7,10 +7,15 @@ import ( log "github.com/sirupsen/logrus" ) -var ( - heartbeatTimeout = healthCheckInterval + 10*time.Second +const ( + defaultHeartbeatTimeout = defaultHealthCheckInterval + 10*time.Second ) +type ReceiverOptions struct { + HeartbeatTimeout time.Duration + AttemptThreshold int +} + // Receiver is a healthcheck receiver // It will listen for heartbeat and check if the heartbeat is not received in a certain time // If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work @@ -27,6 +32,23 @@ type Receiver struct { // NewReceiver creates a new healthcheck receiver and start the timer in the background func NewReceiver(log *log.Entry) *Receiver { + opts := ReceiverOptions{ + HeartbeatTimeout: defaultHeartbeatTimeout, + AttemptThreshold: getAttemptThresholdFromEnv(), + } + return NewReceiverWithOpts(log, opts) +} + +func NewReceiverWithOpts(log *log.Entry, opts ReceiverOptions) *Receiver { + heartbeatTimeout := opts.HeartbeatTimeout + if heartbeatTimeout <= 0 { + heartbeatTimeout = defaultHeartbeatTimeout + } + attemptThreshold := opts.AttemptThreshold + if attemptThreshold <= 0 { + attemptThreshold = defaultAttemptThreshold + } + ctx, ctxCancel := context.WithCancel(context.Background()) r := &Receiver{ @@ -35,10 +57,10 @@ func NewReceiver(log *log.Entry) *Receiver { ctx: ctx, ctxCancel: ctxCancel, heartbeat: make(chan struct{}, 1), - attemptThreshold: getAttemptThresholdFromEnv(), + attemptThreshold: attemptThreshold, } - go r.waitForHealthcheck() + go r.waitForHealthcheck(heartbeatTimeout) return r } @@ -55,7 +77,7 @@ func (r *Receiver) Stop() { r.ctxCancel() } -func (r *Receiver) waitForHealthcheck() { +func (r *Receiver) waitForHealthcheck(heartbeatTimeout time.Duration) { ticker := time.NewTicker(heartbeatTimeout) defer ticker.Stop() defer r.ctxCancel() diff --git a/shared/relay/healthcheck/receiver_test.go b/shared/relay/healthcheck/receiver_test.go index 2794159f6..b20cc5124 100644 --- a/shared/relay/healthcheck/receiver_test.go +++ b/shared/relay/healthcheck/receiver_test.go @@ -2,31 +2,18 @@ package healthcheck import ( "context" - "fmt" - "os" - "sync" "testing" "time" log "github.com/sirupsen/logrus" ) -// Mutex to protect global variable access in tests -var testMutex sync.Mutex - func TestNewReceiver(t *testing.T) { - testMutex.Lock() - originalTimeout := heartbeatTimeout - heartbeatTimeout = 5 * time.Second - testMutex.Unlock() - defer func() { - testMutex.Lock() - heartbeatTimeout = originalTimeout - testMutex.Unlock() - }() - - r := NewReceiver(log.WithContext(context.Background())) + opts := ReceiverOptions{ + HeartbeatTimeout: 5 * time.Second, + } + r := NewReceiverWithOpts(log.WithContext(context.Background()), opts) defer r.Stop() select { @@ -38,18 +25,10 @@ func TestNewReceiver(t *testing.T) { } func TestNewReceiverNotReceive(t *testing.T) { - testMutex.Lock() - originalTimeout := heartbeatTimeout - heartbeatTimeout = 1 * time.Second - testMutex.Unlock() - - defer func() { - testMutex.Lock() - heartbeatTimeout = originalTimeout - testMutex.Unlock() - }() - - r := NewReceiver(log.WithContext(context.Background())) + opts := ReceiverOptions{ + HeartbeatTimeout: 1 * time.Second, + } + r := NewReceiverWithOpts(log.WithContext(context.Background()), opts) defer r.Stop() select { @@ -61,18 +40,10 @@ func TestNewReceiverNotReceive(t *testing.T) { } func TestNewReceiverAck(t *testing.T) { - testMutex.Lock() - originalTimeout := heartbeatTimeout - heartbeatTimeout = 2 * time.Second - testMutex.Unlock() - - defer func() { - testMutex.Lock() - heartbeatTimeout = originalTimeout - testMutex.Unlock() - }() - - r := NewReceiver(log.WithContext(context.Background())) + opts := ReceiverOptions{ + HeartbeatTimeout: 2 * time.Second, + } + r := NewReceiverWithOpts(log.WithContext(context.Background()), opts) defer r.Stop() r.Heartbeat() @@ -97,30 +68,19 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) { for _, tc := range testsCases { t.Run(tc.name, func(t *testing.T) { - testMutex.Lock() - originalInterval := healthCheckInterval - originalTimeout := heartbeatTimeout - healthCheckInterval = 1 * time.Second - heartbeatTimeout = healthCheckInterval + 500*time.Millisecond - testMutex.Unlock() + healthCheckInterval := 1 * time.Second - defer func() { - testMutex.Lock() - healthCheckInterval = originalInterval - heartbeatTimeout = originalTimeout - testMutex.Unlock() - }() - //nolint:tenv - os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) - defer os.Unsetenv(defaultAttemptThresholdEnv) + opts := ReceiverOptions{ + HeartbeatTimeout: healthCheckInterval + 500*time.Millisecond, + AttemptThreshold: tc.threshold, + } - receiver := NewReceiver(log.WithField("test_name", tc.name)) + receiver := NewReceiverWithOpts(log.WithField("test_name", tc.name), opts) - testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval + testTimeout := opts.HeartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval if tc.resetCounterOnce { receiver.Heartbeat() - t.Logf("reset counter once") } select { @@ -134,7 +94,6 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) { } t.Fatalf("should have timed out before %s", testTimeout) } - }) } } diff --git a/shared/relay/healthcheck/sender.go b/shared/relay/healthcheck/sender.go index 57b3015ec..771e94206 100644 --- a/shared/relay/healthcheck/sender.go +++ b/shared/relay/healthcheck/sender.go @@ -2,52 +2,76 @@ package healthcheck import ( "context" - "os" - "strconv" "time" log "github.com/sirupsen/logrus" ) const ( - defaultAttemptThreshold = 1 - defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD" + defaultAttemptThreshold = 1 + + defaultHealthCheckInterval = 25 * time.Second + defaultHealthCheckTimeout = 20 * time.Second ) -var ( - healthCheckInterval = 25 * time.Second - healthCheckTimeout = 20 * time.Second -) +type SenderOptions struct { + HealthCheckInterval time.Duration + HealthCheckTimeout time.Duration + AttemptThreshold int +} // Sender is a healthcheck sender // It will send healthcheck signal to the receiver // If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work // It will also stop if the context is canceled type Sender struct { - log *log.Entry // HealthCheck is a channel to send health check signal to the peer HealthCheck chan struct{} // Timeout is a channel to the health check signal is not received in a certain time Timeout chan struct{} + log *log.Entry + healthCheckInterval time.Duration + timeout time.Duration + ack chan struct{} alive bool attemptThreshold int } -// NewSender creates a new healthcheck sender -func NewSender(log *log.Entry) *Sender { +func NewSenderWithOpts(log *log.Entry, opts SenderOptions) *Sender { + if opts.HealthCheckInterval <= 0 { + opts.HealthCheckInterval = defaultHealthCheckInterval + } + if opts.HealthCheckTimeout <= 0 { + opts.HealthCheckTimeout = defaultHealthCheckTimeout + } + if opts.AttemptThreshold <= 0 { + opts.AttemptThreshold = defaultAttemptThreshold + } hc := &Sender{ - log: log, - HealthCheck: make(chan struct{}, 1), - Timeout: make(chan struct{}, 1), - ack: make(chan struct{}, 1), - attemptThreshold: getAttemptThresholdFromEnv(), + HealthCheck: make(chan struct{}, 1), + Timeout: make(chan struct{}, 1), + log: log, + healthCheckInterval: opts.HealthCheckInterval, + timeout: opts.HealthCheckInterval + opts.HealthCheckTimeout, + ack: make(chan struct{}, 1), + attemptThreshold: opts.AttemptThreshold, } return hc } +// NewSender creates a new healthcheck sender +func NewSender(log *log.Entry) *Sender { + opts := SenderOptions{ + HealthCheckInterval: defaultHealthCheckInterval, + HealthCheckTimeout: defaultHealthCheckTimeout, + AttemptThreshold: getAttemptThresholdFromEnv(), + } + return NewSenderWithOpts(log, opts) +} + // OnHCResponse sends an acknowledgment signal to the sender func (hc *Sender) OnHCResponse() { select { @@ -57,10 +81,10 @@ func (hc *Sender) OnHCResponse() { } func (hc *Sender) StartHealthCheck(ctx context.Context) { - ticker := time.NewTicker(healthCheckInterval) + ticker := time.NewTicker(hc.healthCheckInterval) defer ticker.Stop() - timeoutTicker := time.NewTicker(hc.getTimeoutTime()) + timeoutTicker := time.NewTicker(hc.timeout) defer timeoutTicker.Stop() defer close(hc.HealthCheck) @@ -92,19 +116,3 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) { } } } - -func (hc *Sender) getTimeoutTime() time.Duration { - return healthCheckInterval + healthCheckTimeout -} - -func getAttemptThresholdFromEnv() int { - if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" { - threshold, err := strconv.ParseInt(attemptThreshold, 10, 64) - if err != nil { - log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold) - return defaultAttemptThreshold - } - return int(threshold) - } - return defaultAttemptThreshold -} diff --git a/shared/relay/healthcheck/sender_test.go b/shared/relay/healthcheck/sender_test.go index 23446366a..122fe0f16 100644 --- a/shared/relay/healthcheck/sender_test.go +++ b/shared/relay/healthcheck/sender_test.go @@ -2,26 +2,23 @@ package healthcheck import ( "context" - "fmt" - "os" "testing" "time" log "github.com/sirupsen/logrus" ) -func TestMain(m *testing.M) { - // override the health check interval to speed up the test - healthCheckInterval = 2 * time.Second - healthCheckTimeout = 100 * time.Millisecond - code := m.Run() - os.Exit(code) -} +var ( + testOpts = SenderOptions{ + HealthCheckInterval: 2 * time.Second, + HealthCheckTimeout: 100 * time.Millisecond, + } +) func TestNewHealthPeriod(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := NewSender(log.WithContext(ctx)) + hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) go hc.StartHealthCheck(ctx) iterations := 0 @@ -32,7 +29,7 @@ func TestNewHealthPeriod(t *testing.T) { hc.OnHCResponse() case <-hc.Timeout: t.Fatalf("health check is timed out") - case <-time.After(healthCheckInterval + 100*time.Millisecond): + case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond): t.Fatalf("health check not received") } } @@ -41,19 +38,19 @@ func TestNewHealthPeriod(t *testing.T) { func TestNewHealthFailed(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := NewSender(log.WithContext(ctx)) + hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) go hc.StartHealthCheck(ctx) select { case <-hc.Timeout: - case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond): + case <-time.After(testOpts.HealthCheckInterval + testOpts.HealthCheckTimeout + 100*time.Millisecond): t.Fatalf("health check is not timed out") } } func TestNewHealthcheckStop(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - hc := NewSender(log.WithContext(ctx)) + hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) go hc.StartHealthCheck(ctx) time.Sleep(100 * time.Millisecond) @@ -78,7 +75,7 @@ func TestNewHealthcheckStop(t *testing.T) { func TestTimeoutReset(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := NewSender(log.WithContext(ctx)) + hc := NewSenderWithOpts(log.WithContext(ctx), testOpts) go hc.StartHealthCheck(ctx) iterations := 0 @@ -89,7 +86,7 @@ func TestTimeoutReset(t *testing.T) { hc.OnHCResponse() case <-hc.Timeout: t.Fatalf("health check is timed out") - case <-time.After(healthCheckInterval + 100*time.Millisecond): + case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond): t.Fatalf("health check not received") } } @@ -118,19 +115,16 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { for _, tc := range testsCases { t.Run(tc.name, func(t *testing.T) { - originalInterval := healthCheckInterval - originalTimeout := healthCheckTimeout - healthCheckInterval = 1 * time.Second - healthCheckTimeout = 500 * time.Millisecond - - //nolint:tenv - os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) - defer os.Unsetenv(defaultAttemptThresholdEnv) + opts := SenderOptions{ + HealthCheckInterval: 1 * time.Second, + HealthCheckTimeout: 500 * time.Millisecond, + AttemptThreshold: tc.threshold, + } ctx, cancel := context.WithCancel(context.Background()) defer cancel() - sender := NewSender(log.WithField("test_name", tc.name)) + sender := NewSenderWithOpts(log.WithField("test_name", tc.name), opts) senderExit := make(chan struct{}) go func() { sender.StartHealthCheck(ctx) @@ -155,7 +149,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { } }() - testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval + testTimeout := (opts.HealthCheckInterval+opts.HealthCheckTimeout)*time.Duration(tc.threshold) + opts.HealthCheckInterval select { case <-sender.Timeout: @@ -175,39 +169,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { case <-time.After(2 * time.Second): t.Fatalf("sender did not exit in time") } - healthCheckInterval = originalInterval - healthCheckTimeout = originalTimeout }) } } - -//nolint:tenv -func TestGetAttemptThresholdFromEnv(t *testing.T) { - tests := []struct { - name string - envValue string - expected int - }{ - {"Default attempt threshold when env is not set", "", defaultAttemptThreshold}, - {"Custom attempt threshold when env is set to a valid integer", "3", 3}, - {"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.envValue == "" { - os.Unsetenv(defaultAttemptThresholdEnv) - } else { - os.Setenv(defaultAttemptThresholdEnv, tt.envValue) - } - - result := getAttemptThresholdFromEnv() - if result != tt.expected { - t.Fatalf("Expected %d, got %d", tt.expected, result) - } - - os.Unsetenv(defaultAttemptThresholdEnv) - }) - } -} From bd23ab925e61c79808ef6b805927c922efef02dd Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 15 Sep 2025 15:08:53 +0200 Subject: [PATCH 08/30] [client] Fix ICE latency handling (#4501) The GetSelectedCandidatePair() does not carry the latency information. --- client/internal/peer/worker_ice.go | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 4e85ba0fa..eb886a4d3 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -218,7 +218,9 @@ func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates [] return nil, err } - if err := agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair); err != nil { + if err := agent.OnSelectedCandidatePairChange(func(c1, c2 ice.Candidate) { + w.onICESelectedCandidatePair(agent, c1, c2) + }); err != nil { return nil, err } @@ -365,26 +367,17 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) { }() } -func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) { +func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) { w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(), w.config.Key) - w.muxAgent.Lock() - - pair, err := w.agent.GetSelectedCandidatePair() - if err != nil { - w.log.Warnf("failed to get selected candidate pair: %s", err) - w.muxAgent.Unlock() + pairStat, ok := agent.GetSelectedCandidatePairStats() + if !ok { + w.log.Warnf("failed to get selected candidate pair stats") return } - if pair == nil { - w.log.Warnf("selected candidate pair is nil, cannot proceed") - w.muxAgent.Unlock() - return - } - w.muxAgent.Unlock() - duration := time.Duration(pair.CurrentRoundTripTime() * float64(time.Second)) + duration := time.Duration(pairStat.CurrentRoundTripTime * float64(time.Second)) if err := w.statusRecorder.UpdateLatency(w.config.Key, duration); err != nil { w.log.Debugf("failed to update latency for peer: %s", err) return From 3130cce72d59165060a6af643bc529be11da2802 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Mon, 15 Sep 2025 21:08:16 +0300 Subject: [PATCH 09/30] [management] Add rule ID validation for policy updates (#4499) --- management/server/policy.go | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/management/server/policy.go b/management/server/policy.go index 312fd53b2..3adee6397 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -167,10 +167,22 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a // validatePolicy validates the policy and its rules. func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { if policy.ID != "" { - _, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) + existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) if err != nil { return err } + + // TODO: Refactor to support multiple rules per policy + existingRuleIDs := make(map[string]bool) + for _, rule := range existingPolicy.Rules { + existingRuleIDs[rule.ID] = true + } + + for _, rule := range policy.Rules { + if rule.ID != "" && !existingRuleIDs[rule.ID] { + return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID) + } + } } else { policy.ID = xid.New().String() policy.AccountID = accountID From ec8d83ade442e0d212db4202fdf60bf826c6da62 Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Thu, 18 Sep 2025 18:13:29 +0700 Subject: [PATCH 10/30] [client] [UI] Down & Up NetBird Async When Settings Updated [client] [UI] Down & Up NetBird Async When Settings Updated --- client/ui/client_ui.go | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 2403b5d05..25d7380a9 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -529,7 +529,7 @@ func (s *serviceClient) getSettingsForm() *widget.Form { var req proto.SetConfigRequest req.ProfileName = activeProf.Name req.Username = currUser.Username - + if iMngURL != "" { req.ManagementUrl = iMngURL } @@ -563,27 +563,28 @@ func (s *serviceClient) getSettingsForm() *widget.Form { return } - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) - if err != nil { - log.Errorf("get service status: %v", err) - dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings) - return - } - if status.Status == string(internal.StatusConnected) { - // run down & up - _, err = conn.Down(s.ctx, &proto.DownRequest{}) + go func() { + status, err := conn.Status(s.ctx, &proto.StatusRequest{}) if err != nil { - log.Errorf("down service: %v", err) - } - - _, err = conn.Up(s.ctx, &proto.UpRequest{}) - if err != nil { - log.Errorf("up service: %v", err) - dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings) + log.Errorf("get service status: %v", err) + dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings) return } - } + if status.Status == string(internal.StatusConnected) { + // run down & up + _, err = conn.Down(s.ctx, &proto.DownRequest{}) + if err != nil { + log.Errorf("down service: %v", err) + } + _, err = conn.Up(s.ctx, &proto.UpRequest{}) + if err != nil { + log.Errorf("up service: %v", err) + dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings) + return + } + } + }() } }, OnCancel: func() { From 2c87fa623654c5eef76bc0226062290201eef13a Mon Sep 17 00:00:00 2001 From: Diego Romar Date: Thu, 18 Sep 2025 10:07:42 -0300 Subject: [PATCH 11/30] [android] Add OnLoginSuccess callback to URLOpener interface (#4492) The callback will be fired once login -> internal.Login completes without errors --- client/android/login.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/client/android/login.go b/client/android/login.go index d8ac645e2..0df78dbc3 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -33,6 +33,7 @@ type ErrListener interface { // the backend want to show an url for the user type URLOpener interface { Open(string) + OnLoginSuccess() } // Auth can register or login new client @@ -181,6 +182,11 @@ func (a *Auth) login(urlOpener URLOpener) error { err = a.withBackOff(a.ctx, func() error { err := internal.Login(a.ctx, a.config, "", jwtToken) + + if err == nil { + go urlOpener.OnLoginSuccess() + } + if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) { return nil } From dc30dcacce4c322502975f1f491e6774efd7e1e9 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Thu, 18 Sep 2025 19:57:07 +0300 Subject: [PATCH 12/30] [management] Filter DNS records to include only peers to connect (#4517) DNS record filtering to only include peers that a peer can connect to, reducing unnecessary DNS data in the peer's network map. - Adds a new `filterZoneRecordsForPeers` function to filter DNS records based on peer connectivity - Modifies `GetPeerNetworkMap` to use filtered DNS records instead of all records in the custom zone - Includes comprehensive test coverage for the new filtering functionality --- management/server/types/account.go | 27 +++++- management/server/types/account_test.go | 109 ++++++++++++++++++++++++ 2 files changed, 135 insertions(+), 1 deletion(-) diff --git a/management/server/types/account.go b/management/server/types/account.go index 9ac2568a0..ca075b9f6 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -302,7 +302,11 @@ func (a *Account) GetPeerNetworkMap( var zones []nbdns.CustomZone if peersCustomZone.Domain != "" { - zones = append(zones, peersCustomZone) + records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect) + zones = append(zones, nbdns.CustomZone{ + Domain: peersCustomZone.Domain, + Records: records, + }) } dnsUpdate.CustomZones = zones dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) @@ -1651,3 +1655,24 @@ func peerSupportsPortRanges(peerVer string) bool { meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer) return err == nil && meetMinVer } + +// filterZoneRecordsForPeers filters DNS records to only include peers to connect. +func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord { + filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records)) + peerIPs := make(map[string]struct{}) + + // Add peer's own IP to include its own DNS records + peerIPs[peer.IP.String()] = struct{}{} + + for _, peerToConnect := range peersToConnect { + peerIPs[peerToConnect.IP.String()] = struct{}{} + } + + for _, record := range customZone.Records { + if _, exists := peerIPs[record.RData]; exists { + filteredRecords = append(filteredRecords, record) + } + } + + return filteredRecords +} diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index f8ab1d627..cd221b590 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -2,14 +2,17 @@ package types import ( "context" + "fmt" "net" "net/netip" "slices" "testing" + "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + nbdns "github.com/netbirdio/netbird/dns" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -835,3 +838,109 @@ func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) { assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") assert.Len(t, sourcePeers, 2, "expected source peers don't match") } + +func Test_FilterZoneRecordsForPeers(t *testing.T) { + tests := []struct { + name string + peer *nbpeer.Peer + customZone nbdns.CustomZone + peersToConnect []*nbpeer.Peer + expectedRecords []nbdns.SimpleRecord + }{ + { + name: "empty peers to connect", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + peersToConnect: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: []nbdns.SimpleRecord{ + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + { + name: "multiple peers multiple records match", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: func() []nbdns.SimpleRecord { + var records []nbdns.SimpleRecord + for i := 1; i <= 100; i++ { + records = append(records, nbdns.SimpleRecord{ + Name: fmt.Sprintf("peer%d.netbird.cloud", i), + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256), + }) + } + return records + }(), + }, + peersToConnect: func() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { + peers = append(peers, &nbpeer.Peer{ + ID: fmt.Sprintf("peer%d", i), + IP: net.ParseIP(fmt.Sprintf("10.0.%d.%d", i/256, i%256)), + }) + } + return peers + }(), + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: func() []nbdns.SimpleRecord { + var records []nbdns.SimpleRecord + for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { + records = append(records, nbdns.SimpleRecord{ + Name: fmt.Sprintf("peer%d.netbird.cloud", i), + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256), + }) + } + return records + }(), + }, + { + name: "peers with multiple DNS labels", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "peer3.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"}, + {Name: "peer3-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + peersToConnect: []*nbpeer.Peer{ + {ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}}, + {ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}}, + }, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect) + assert.Equal(t, len(tt.expectedRecords), len(result)) + assert.ElementsMatch(t, tt.expectedRecords, result) + }) + } +} From 90577682e45b5fbcadc7d7ae814a0dbb46a1621f Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Fri, 19 Sep 2025 14:06:44 +0300 Subject: [PATCH 13/30] Add a new product demo video (#4520) --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ea7655869..2c5ee2ab6 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ +


@@ -52,7 +53,7 @@ ### Open Source Network Security in a Single Platform -centralized-network-management 1 +https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2 ### NetBird on Lawrence Systems (Video) [![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw) From 55126f990cdca39d37e9520488816a0785bf7a0b Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Sat, 20 Sep 2025 09:31:04 +0200 Subject: [PATCH 14/30] [client] Use native windows sock opts to avoid routing loops (#4314) - Move `util/grpc` and `util/net` to `client` so `internal` packages can be accessed - Add methods to return the next best interface after the NetBird interface. - Use `IP_UNICAST_IF` sock opt to force the outgoing interface for the NetBird `net.Dialer` and `net.ListenerConfig` to avoid routing loops. The interface is picked by the new route lookup method. - Some refactoring to avoid import cycles - Old behavior is available through `NB_USE_LEGACY_ROUTING=true` env var --- client/android/client.go | 2 +- client/cmd/down.go | 2 +- client/firewall/iptables/acl_linux.go | 2 +- client/firewall/iptables/router_linux.go | 2 +- client/firewall/iptables/router_linux_test.go | 2 +- client/firewall/nftables/acl_linux.go | 2 +- client/firewall/nftables/router_linux.go | 2 +- {util => client}/grpc/dialer.go | 3 +- client/iface/bind/control.go | 2 +- client/iface/bind/ice_bind.go | 2 +- client/iface/configurer/usp.go | 4 +- client/iface/device/device_kernel_unix.go | 9 +- client/iface/device/device_netstack.go | 2 +- client/iface/udpmux/mux_generic.go | 2 +- client/iface/wgproxy/ebpf/proxy.go | 2 +- client/internal/connect.go | 2 +- client/internal/dns/service_memory.go | 2 +- client/internal/dns/upstream_android.go | 2 +- client/internal/engine.go | 2 + .../internal/netflow/conntrack/conntrack.go | 2 +- client/internal/relay/relay.go | 2 +- client/internal/routemanager/manager.go | 16 +- .../systemops/systemops_android.go | 4 +- .../systemops/systemops_generic.go | 54 +--- .../systemops/systemops_generic_test.go | 16 +- .../routemanager/systemops/systemops_ios.go | 4 +- .../routemanager/systemops/systemops_linux.go | 12 +- .../routemanager/systemops/systemops_unix.go | 4 +- .../systemops/systemops_unix_test.go | 2 +- .../systemops/systemops_windows.go | 148 ++++++++- .../systemops/systemops_windows_test.go | 2 +- client/internal/routemanager/util/ip.go | 14 +- client/internal/stdnet/dialer.go | 2 +- client/internal/stdnet/listener.go | 2 +- client/net/conn.go | 49 +++ client/net/dial.go | 82 +++++ {util => client}/net/dial_ios.go | 0 {util => client}/net/dialer.go | 1 - client/net/dialer_dial.go | 87 ++++++ {util => client}/net/dialer_init_android.go | 0 client/net/dialer_init_generic.go | 7 + {util => client}/net/dialer_init_linux.go | 0 client/net/dialer_init_windows.go | 5 + {util => client}/net/env.go | 1 + client/net/env_android.go | 24 ++ client/net/env_generic.go | 23 ++ {util => client}/net/env_linux.go | 16 +- client/net/env_windows.go | 67 +++++ client/net/hooks/hooks.go | 93 ++++++ client/net/listen.go | 47 +++ {util => client}/net/listen_ios.go | 0 {util => client}/net/listener.go | 6 +- {util => client}/net/listener_init_android.go | 0 client/net/listener_init_generic.go | 7 + {util => client}/net/listener_init_linux.go | 0 client/net/listener_init_windows.go | 8 + client/net/listener_listen.go | 153 ++++++++++ {util => client}/net/listener_listen_ios.go | 0 {util => client}/net/net.go | 14 - {util => client}/net/net_linux.go | 0 {util => client}/net/net_test.go | 0 client/net/net_windows.go | 284 ++++++++++++++++++ {util => client}/net/protectsocket_android.go | 0 flow/client/client.go | 2 +- shared/management/client/grpc.go | 2 +- shared/relay/client/dialer/quic/quic.go | 2 +- shared/relay/client/dialer/ws/ws.go | 2 +- shared/signal/client/grpc.go | 2 +- sharedsock/sock_linux.go | 6 +- util/net/conn.go | 31 -- util/net/dial.go | 58 ---- util/net/dialer_dial.go | 107 ------- util/net/dialer_init_nonlinux.go | 7 - util/net/env_generic.go | 12 - util/net/listen.go | 37 --- util/net/listener_init_nonlinux.go | 7 - util/net/listener_listen.go | 205 ------------- 77 files changed, 1180 insertions(+), 606 deletions(-) rename {util => client}/grpc/dialer.go (98%) create mode 100644 client/net/conn.go create mode 100644 client/net/dial.go rename {util => client}/net/dial_ios.go (100%) rename {util => client}/net/dialer.go (99%) create mode 100644 client/net/dialer_dial.go rename {util => client}/net/dialer_init_android.go (100%) create mode 100644 client/net/dialer_init_generic.go rename {util => client}/net/dialer_init_linux.go (100%) create mode 100644 client/net/dialer_init_windows.go rename {util => client}/net/env.go (94%) create mode 100644 client/net/env_android.go create mode 100644 client/net/env_generic.go rename {util => client}/net/env_linux.go (86%) create mode 100644 client/net/env_windows.go create mode 100644 client/net/hooks/hooks.go create mode 100644 client/net/listen.go rename {util => client}/net/listen_ios.go (100%) rename {util => client}/net/listener.go (81%) rename {util => client}/net/listener_init_android.go (100%) create mode 100644 client/net/listener_init_generic.go rename {util => client}/net/listener_init_linux.go (100%) create mode 100644 client/net/listener_init_windows.go create mode 100644 client/net/listener_listen.go rename {util => client}/net/listener_listen_ios.go (100%) rename {util => client}/net/net.go (81%) rename {util => client}/net/net_linux.go (100%) rename {util => client}/net/net_test.go (100%) create mode 100644 client/net/net_windows.go rename {util => client}/net/protectsocket_android.go (100%) delete mode 100644 util/net/conn.go delete mode 100644 util/net/dial.go delete mode 100644 util/net/dialer_dial.go delete mode 100644 util/net/dialer_init_nonlinux.go delete mode 100644 util/net/env_generic.go delete mode 100644 util/net/listen.go delete mode 100644 util/net/listener_init_nonlinux.go delete mode 100644 util/net/listener_listen.go diff --git a/client/android/client.go b/client/android/client.go index 4b4fcc9be..d2d0c37f6 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -19,7 +19,7 @@ import ( "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/util/net" + "github.com/netbirdio/netbird/client/net" ) // ConnectionListener export internal Listener for mobile diff --git a/client/cmd/down.go b/client/cmd/down.go index 3ce51c678..17c152d22 100644 --- a/client/cmd/down.go +++ b/client/cmd/down.go @@ -27,7 +27,7 @@ var downCmd = &cobra.Command{ return err } - ctx, cancel := context.WithTimeout(context.Background(), time.Second*7) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) defer cancel() conn, err := DialClientGRPCServer(ctx, daemonAddr) diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index 7b90000a8..ed8a7403b 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -12,7 +12,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const ( diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 1e44c7a4d..081991235 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -19,7 +19,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // constants needed to manage and create iptable rules diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index e9eeff863..3490c5dad 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -14,7 +14,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/test" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) func isIptablesSupported() bool { diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 52979d257..9ff5b8c92 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -16,7 +16,7 @@ import ( "golang.org/x/sys/unix" firewall "github.com/netbirdio/netbird/client/firewall/manager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const ( diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index f8fed4d80..e918d0524 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -22,7 +22,7 @@ import ( nbid "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const ( diff --git a/util/grpc/dialer.go b/client/grpc/dialer.go similarity index 98% rename from util/grpc/dialer.go rename to client/grpc/dialer.go index f6d6d2f04..7ac950d85 100644 --- a/util/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -20,8 +20,9 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" + nbnet "github.com/netbirdio/netbird/client/net" + "github.com/netbirdio/netbird/util/embeddedroots" - nbnet "github.com/netbirdio/netbird/util/net" ) func WithCustomDialer() grpc.DialOption { diff --git a/client/iface/bind/control.go b/client/iface/bind/control.go index 89bddf12c..32b07c330 100644 --- a/client/iface/bind/control.go +++ b/client/iface/bind/control.go @@ -3,7 +3,7 @@ package bind import ( wireguard "golang.zx2c4.com/wireguard/conn" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go) diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index b74f90d6c..577c7c0c4 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -17,7 +17,7 @@ import ( "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type RecvMessage struct { diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index 945f1a162..f744e0127 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -17,8 +17,8 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/bind" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/monotime" - nbnet "github.com/netbirdio/netbird/util/net" ) const ( @@ -409,7 +409,7 @@ func toBytes(s string) (int64, error) { } func getFwmark() int { - if nbnet.AdvancedRouting() { + if nbnet.AdvancedRouting() && runtime.GOOS == "linux" { return nbnet.ControlPlaneMark } return 0 diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index 2ef6f6b22..cdac43a53 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -15,8 +15,8 @@ import ( "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/sharedsock" - nbnet "github.com/netbirdio/netbird/util/net" ) type TunKernelDevice struct { @@ -101,13 +101,8 @@ func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { return nil, err } - var udpConn net.PacketConn = rawSock - if !nbnet.AdvancedRouting() { - udpConn = nbnet.WrapPacketConn(rawSock) - } - bindParams := udpmux.UniversalUDPMuxParams{ - UDPConn: udpConn, + UDPConn: nbnet.WrapPacketConn(rawSock), Net: t.transportNet, FilterFn: t.filterFn, WGAddress: t.address, diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index 2fcc74809..a6ef47027 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -12,7 +12,7 @@ import ( nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type TunNetstackDevice struct { diff --git a/client/iface/udpmux/mux_generic.go b/client/iface/udpmux/mux_generic.go index cf3043be0..29fc2d834 100644 --- a/client/iface/udpmux/mux_generic.go +++ b/client/iface/udpmux/mux_generic.go @@ -3,7 +3,7 @@ package udpmux import ( - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) { diff --git a/client/iface/wgproxy/ebpf/proxy.go b/client/iface/wgproxy/ebpf/proxy.go index fcdc0189d..b899f1694 100644 --- a/client/iface/wgproxy/ebpf/proxy.go +++ b/client/iface/wgproxy/ebpf/proxy.go @@ -20,7 +20,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/internal/ebpf" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const ( diff --git a/client/internal/connect.go b/client/internal/connect.go index 33cd4b4a1..c9331baf5 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -34,7 +34,7 @@ import ( relayClient "github.com/netbirdio/netbird/shared/relay/client" signal "github.com/netbirdio/netbird/shared/signal/client" "github.com/netbirdio/netbird/util" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/version" ) diff --git a/client/internal/dns/service_memory.go b/client/internal/dns/service_memory.go index 89d637686..6ef0ab526 100644 --- a/client/internal/dns/service_memory.go +++ b/client/internal/dns/service_memory.go @@ -10,7 +10,7 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type ServiceViaMemory struct { diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index 6b7dcc05e..def281f28 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -10,7 +10,7 @@ import ( "github.com/miekg/dns" "github.com/netbirdio/netbird/client/internal/peer" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type upstreamResolver struct { diff --git a/client/internal/engine.go b/client/internal/engine.go index 9dc744434..d4c465efb 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -446,6 +446,8 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return fmt.Errorf("up wg interface: %w", err) } + + // if inbound conns are blocked there is no need to create the ACL manager if e.firewall != nil && !e.config.BlockInbound { e.acl = acl.NewDefaultManager(e.firewall) diff --git a/client/internal/netflow/conntrack/conntrack.go b/client/internal/netflow/conntrack/conntrack.go index dbb4747a5..a4ffa3a25 100644 --- a/client/internal/netflow/conntrack/conntrack.go +++ b/client/internal/netflow/conntrack/conntrack.go @@ -14,7 +14,7 @@ import ( "github.com/ti-mo/netfilter" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) const defaultChannelSize = 100 diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index 8c3d5a571..fa208716f 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -12,7 +12,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/stdnet" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // ProbeResult holds the info about the result of a relay probe request diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index a6775c45a..04513bbe4 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -36,9 +36,9 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/statemanager" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/route" relayClient "github.com/netbirdio/netbird/shared/relay/client" - nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/version" ) @@ -108,6 +108,10 @@ func NewManager(config ManagerConfig) *DefaultManager { notifier := notifier.NewNotifier() sysOps := systemops.NewSysOps(config.WGInterface, notifier) + if runtime.GOOS == "windows" && config.WGInterface != nil { + nbnet.SetVPNInterfaceName(config.WGInterface.Name()) + } + dm := &DefaultManager{ ctx: mCTX, stop: cancel, @@ -208,7 +212,7 @@ func (m *DefaultManager) Init() error { return nil } - if err := m.sysOps.CleanupRouting(nil); err != nil { + if err := m.sysOps.CleanupRouting(nil, nbnet.AdvancedRouting()); err != nil { log.Warnf("Failed cleaning up routing: %v", err) } @@ -219,7 +223,7 @@ func (m *DefaultManager) Init() error { ips := resolveURLsToIPs(initialAddresses) - if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil { + if err := m.sysOps.SetupRouting(ips, m.stateManager, nbnet.AdvancedRouting()); err != nil { return fmt.Errorf("setup routing: %w", err) } @@ -285,11 +289,15 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { } if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes { - if err := m.sysOps.CleanupRouting(stateManager); err != nil { + if err := m.sysOps.CleanupRouting(stateManager, nbnet.AdvancedRouting()); err != nil { log.Errorf("Error cleaning up routing: %v", err) } else { log.Info("Routing cleanup complete") } + + if runtime.GOOS == "windows" { + nbnet.SetVPNInterfaceName("") + } } m.mux.Lock() diff --git a/client/internal/routemanager/systemops/systemops_android.go b/client/internal/routemanager/systemops/systemops_android.go index a375ce832..7cb8dae93 100644 --- a/client/internal/routemanager/systemops/systemops_android.go +++ b/client/internal/routemanager/systemops/systemops_android.go @@ -12,11 +12,11 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error { return nil } -func (r *SysOps) CleanupRouting(*statemanager.Manager) error { +func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error { return nil } diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 128afa2a5..26a548634 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -3,7 +3,6 @@ package systemops import ( - "context" "errors" "fmt" "net" @@ -22,7 +21,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" + "github.com/netbirdio/netbird/client/net/hooks" ) const localSubnetsCacheTTL = 15 * time.Minute @@ -96,9 +95,9 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error { return nil } - // TODO: Remove hooks selectively - nbnet.RemoveDialerHooks() - nbnet.RemoveListenerHooks() + hooks.RemoveWriteHooks() + hooks.RemoveCloseHooks() + hooks.RemoveAddressRemoveHooks() if err := r.refCounter.Flush(); err != nil { return fmt.Errorf("flush route manager: %w", err) @@ -290,12 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) } func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error { - beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { - prefix, err := util.GetPrefixFromIP(ip) - if err != nil { - return fmt.Errorf("convert ip to prefix: %w", err) - } - + beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error { if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil { return fmt.Errorf("adding route reference: %v", err) } @@ -304,7 +298,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M return nil } - afterHook := func(connID nbnet.ConnectionID) error { + afterHook := func(connID hooks.ConnectionID) error { if err := r.refCounter.DecrementWithID(string(connID)); err != nil { return fmt.Errorf("remove route reference: %w", err) } @@ -317,36 +311,20 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M var merr *multierror.Error for _, ip := range initAddresses { - if err := beforeHook("init", ip); err != nil { - merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err)) + prefix, err := util.GetPrefixFromIP(ip) + if err != nil { + merr = multierror.Append(merr, fmt.Errorf("invalid IP address %s: %w", ip, err)) + continue + } + if err := beforeHook("init", prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", prefix, err)) } } - nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { - if ctx.Err() != nil { - return ctx.Err() - } + hooks.AddWriteHook(beforeHook) + hooks.AddCloseHook(afterHook) - var merr *multierror.Error - for _, ip := range resolvedIPs { - merr = multierror.Append(merr, beforeHook(connID, ip.IP)) - } - return nberrors.FormatErrorOrNil(merr) - }) - - nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { - return afterHook(connID) - }) - - nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { - return beforeHook(connID, ip.IP) - }) - - nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { - return afterHook(connID) - }) - - nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error { + hooks.AddAddressRemoveHook(func(connID hooks.ConnectionID, prefix netip.Prefix) error { if _, err := r.refCounter.Decrement(prefix); err != nil { return fmt.Errorf("remove route reference: %w", err) } diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index c1c1182bc..32ea38a7a 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -22,6 +22,7 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/routemanager/vars" + nbnet "github.com/netbirdio/netbird/client/net" ) type dialer interface { @@ -143,10 +144,11 @@ func TestAddVPNRoute(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) r := NewSysOps(wgInterface, nil) - err := r.SetupRouting(nil, nil) + advancedRouting := nbnet.AdvancedRouting() + err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting(nil)) + assert.NoError(t, r.CleanupRouting(nil, advancedRouting)) }) intf, err := net.InterfaceByName(wgInterface.Name()) @@ -341,10 +343,11 @@ func TestAddRouteToNonVPNIntf(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) r := NewSysOps(wgInterface, nil) - err := r.SetupRouting(nil, nil) + advancedRouting := nbnet.AdvancedRouting() + err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting(nil)) + assert.NoError(t, r.CleanupRouting(nil, advancedRouting)) }) initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) @@ -484,10 +487,11 @@ func setupTestEnv(t *testing.T) { }) r := NewSysOps(wgInterface, nil) - err := r.SetupRouting(nil, nil) + advancedRouting := nbnet.AdvancedRouting() + err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err, "setupRouting should not return err") t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting(nil)) + assert.NoError(t, r.CleanupRouting(nil, advancedRouting)) }) index, err := net.InterfaceByName(wgInterface.Name()) diff --git a/client/internal/routemanager/systemops/systemops_ios.go b/client/internal/routemanager/systemops/systemops_ios.go index 10356eae0..99a363371 100644 --- a/client/internal/routemanager/systemops/systemops_ios.go +++ b/client/internal/routemanager/systemops/systemops_ios.go @@ -12,14 +12,14 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error { r.mu.Lock() defer r.mu.Unlock() r.prefixes = make(map[netip.Prefix]struct{}) return nil } -func (r *SysOps) CleanupRouting(*statemanager.Manager) error { +func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error { r.mu.Lock() defer r.mu.Unlock() diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index c0cef94ba..bd10f131f 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -20,7 +20,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/sysctl" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // IPRule contains IP rule information for debugging @@ -94,15 +94,15 @@ func getSetupRules() []ruleParams { // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // This table is where a default route or other specific routes received from the management server are configured, // enabling VPN connectivity. -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) { - if !nbnet.AdvancedRouting() { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) (err error) { + if !advancedRouting { log.Infof("Using legacy routing setup") return r.setupRefCounter(initAddresses, stateManager) } defer func() { if err != nil { - if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil { + if cleanErr := r.CleanupRouting(stateManager, advancedRouting); cleanErr != nil { log.Errorf("Error cleaning up routing: %v", cleanErr) } } @@ -132,8 +132,8 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager // CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. // It systematically removes the three rules and any associated routing table entries to ensure a clean state. // The function uses error aggregation to report any errors encountered during the cleanup process. -func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { - if !nbnet.AdvancedRouting() { +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { + if !advancedRouting { return r.cleanupRefCounter(stateManager) } diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index f165f7779..d43c2d5bf 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -20,11 +20,11 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { return r.cleanupRefCounter(stateManager) } diff --git a/client/internal/routemanager/systemops/systemops_unix_test.go b/client/internal/routemanager/systemops/systemops_unix_test.go index ad37f611f..959c697e4 100644 --- a/client/internal/routemanager/systemops/systemops_unix_test.go +++ b/client/internal/routemanager/systemops/systemops_unix_test.go @@ -17,7 +17,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type PacketExpectation struct { diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index 4f836897b..95645329e 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -8,6 +8,7 @@ import ( "net/netip" "os" "runtime/debug" + "sort" "strconv" "sync" "syscall" @@ -19,9 +20,16 @@ import ( "golang.org/x/sys/windows" "github.com/netbirdio/netbird/client/internal/statemanager" + nbnet "github.com/netbirdio/netbird/client/net" ) -const InfiniteLifetime = 0xffffffff +func init() { + nbnet.GetBestInterfaceFunc = GetBestInterface +} + +const ( + InfiniteLifetime = 0xffffffff +) type RouteUpdateType int @@ -77,6 +85,14 @@ type MIB_IPFORWARD_TABLE2 struct { Table [1]MIB_IPFORWARD_ROW2 // Flexible array member } +// candidateRoute represents a potential route for selection during route lookup +type candidateRoute struct { + interfaceIndex uint32 + prefixLength uint8 + routeMetric uint32 + interfaceMetric int +} + // IP_ADDRESS_PREFIX is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-ip_address_prefix type IP_ADDRESS_PREFIX struct { Prefix SOCKADDR_INET @@ -177,11 +193,20 @@ const ( RouteDeleted ) -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { + if advancedRouting { + return nil + } + + log.Infof("Using legacy routing setup with ref counters") return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { + if advancedRouting { + return nil + } + return r.cleanupRefCounter(stateManager) } @@ -635,10 +660,7 @@ func getWindowsRoutingTable() (*MIB_IPFORWARD_TABLE2, error) { func freeWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) { if table != nil { - ret, _, _ := procFreeMibTable.Call(uintptr(unsafe.Pointer(table))) - if ret != 0 { - log.Warnf("FreeMibTable failed with return code: %d", ret) - } + _, _, _ = procFreeMibTable.Call(uintptr(unsafe.Pointer(table))) } } @@ -652,8 +674,7 @@ func parseWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) []DetailedRoute { entryPtr := basePtr + uintptr(i)*entrySize entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr)) - detailed := buildWindowsDetailedRoute(entry) - if detailed != nil { + if detailed := buildWindowsDetailedRoute(entry); detailed != nil { detailedRoutes = append(detailedRoutes, *detailed) } } @@ -802,6 +823,46 @@ func addZone(ip netip.Addr, interfaceIndex int) netip.Addr { return ip } +// parseCandidatesFromTable extracts all matching candidate routes from the routing table +func parseCandidatesFromTable(table *MIB_IPFORWARD_TABLE2, dest netip.Addr, skipInterfaceIndex int) []candidateRoute { + var candidates []candidateRoute + entrySize := unsafe.Sizeof(MIB_IPFORWARD_ROW2{}) + basePtr := uintptr(unsafe.Pointer(&table.Table[0])) + + for i := uint32(0); i < table.NumEntries; i++ { + entryPtr := basePtr + uintptr(i)*entrySize + entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr)) + + if candidate := parseCandidateRoute(entry, dest, skipInterfaceIndex); candidate != nil { + candidates = append(candidates, *candidate) + } + } + + return candidates +} + +// parseCandidateRoute extracts candidate route information from a MIB_IPFORWARD_ROW2 entry +// Returns nil if the route doesn't match the destination or should be skipped +func parseCandidateRoute(entry *MIB_IPFORWARD_ROW2, dest netip.Addr, skipInterfaceIndex int) *candidateRoute { + if skipInterfaceIndex > 0 && int(entry.InterfaceIndex) == skipInterfaceIndex { + return nil + } + + destPrefix := parseIPPrefix(entry.DestinationPrefix, int(entry.InterfaceIndex)) + if !destPrefix.IsValid() || !destPrefix.Contains(dest) { + return nil + } + + interfaceMetric := getInterfaceMetric(entry.InterfaceIndex, entry.DestinationPrefix.Prefix.sin6_family) + + return &candidateRoute{ + interfaceIndex: entry.InterfaceIndex, + prefixLength: entry.DestinationPrefix.PrefixLength, + routeMetric: entry.Metric, + interfaceMetric: interfaceMetric, + } +} + // getInterfaceMetric retrieves the interface metric for a given interface and address family func getInterfaceMetric(interfaceIndex uint32, family int16) int { if interfaceIndex == 0 { @@ -821,6 +882,75 @@ func getInterfaceMetric(interfaceIndex uint32, family int16) int { return int(ipInterfaceRow.Metric) } +// sortRouteCandidates sorts route candidates by priority: prefix length -> route metric -> interface metric +func sortRouteCandidates(candidates []candidateRoute) { + sort.Slice(candidates, func(i, j int) bool { + if candidates[i].prefixLength != candidates[j].prefixLength { + return candidates[i].prefixLength > candidates[j].prefixLength + } + if candidates[i].routeMetric != candidates[j].routeMetric { + return candidates[i].routeMetric < candidates[j].routeMetric + } + return candidates[i].interfaceMetric < candidates[j].interfaceMetric + }) +} + +// GetBestInterface finds the best interface for reaching a destination, +// excluding the VPN interface to avoid routing loops. +// +// Route selection priority: +// 1. Longest prefix match (most specific route) +// 2. Lowest route metric +// 3. Lowest interface metric +func GetBestInterface(dest netip.Addr, vpnIntf string) (*net.Interface, error) { + var skipInterfaceIndex int + if vpnIntf != "" { + if iface, err := net.InterfaceByName(vpnIntf); err == nil { + skipInterfaceIndex = iface.Index + } else { + return nil, fmt.Errorf("get VPN interface %s: %w", vpnIntf, err) + } + } + + table, err := getWindowsRoutingTable() + if err != nil { + return nil, fmt.Errorf("get routing table: %w", err) + } + defer freeWindowsRoutingTable(table) + + candidates := parseCandidatesFromTable(table, dest, skipInterfaceIndex) + + if len(candidates) == 0 { + return nil, fmt.Errorf("no route to %s", dest) + } + + // Sort routes: prefix length -> route metric -> interface metric + sortRouteCandidates(candidates) + + for _, candidate := range candidates { + iface, err := net.InterfaceByIndex(int(candidate.interfaceIndex)) + if err != nil { + log.Warnf("failed to get interface by index %d: %v", candidate.interfaceIndex, err) + continue + } + + if iface.Flags&net.FlagLoopback != 0 && !dest.IsLoopback() { + continue + } + + if iface.Flags&net.FlagUp == 0 { + log.Debugf("interface %s is down, trying next route", iface.Name) + continue + } + + log.Debugf("route lookup for %s: selected interface %s (index %d), route metric %d, interface metric %d", + dest, iface.Name, iface.Index, candidate.routeMetric, candidate.interfaceMetric) + return iface, nil + } + + return nil, fmt.Errorf("no usable interface found for %s", dest) +} + // formatRouteAge formats the route age in seconds to a human-readable string func formatRouteAge(ageSeconds uint32) string { if ageSeconds == 0 { diff --git a/client/internal/routemanager/systemops/systemops_windows_test.go b/client/internal/routemanager/systemops/systemops_windows_test.go index 523bd0b0d..3561adec4 100644 --- a/client/internal/routemanager/systemops/systemops_windows_test.go +++ b/client/internal/routemanager/systemops/systemops_windows_test.go @@ -15,7 +15,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) var ( diff --git a/client/internal/routemanager/util/ip.go b/client/internal/routemanager/util/ip.go index ac5a48e37..57ea32f69 100644 --- a/client/internal/routemanager/util/ip.go +++ b/client/internal/routemanager/util/ip.go @@ -12,18 +12,8 @@ func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) { if !ok { return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip) } + addr = addr.Unmap() - - var prefixLength int - switch { - case addr.Is4(): - prefixLength = 32 - case addr.Is6(): - prefixLength = 128 - default: - return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr) - } - - prefix := netip.PrefixFrom(addr, prefixLength) + prefix := netip.PrefixFrom(addr, addr.BitLen()) return prefix, nil } diff --git a/client/internal/stdnet/dialer.go b/client/internal/stdnet/dialer.go index e80adb42b..8961eaa69 100644 --- a/client/internal/stdnet/dialer.go +++ b/client/internal/stdnet/dialer.go @@ -5,7 +5,7 @@ import ( "github.com/pion/transport/v3" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // Dial connects to the address on the named network. diff --git a/client/internal/stdnet/listener.go b/client/internal/stdnet/listener.go index 9ce0a5556..d3be1896f 100644 --- a/client/internal/stdnet/listener.go +++ b/client/internal/stdnet/listener.go @@ -6,7 +6,7 @@ import ( "github.com/pion/transport/v3" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // ListenPacket listens for incoming packets on the given network and address. diff --git a/client/net/conn.go b/client/net/conn.go new file mode 100644 index 000000000..918e7f628 --- /dev/null +++ b/client/net/conn.go @@ -0,0 +1,49 @@ +//go:build !ios + +package net + +import ( + "io" + "net" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/net/hooks" +) + +// Conn wraps a net.Conn to override the Close method +type Conn struct { + net.Conn + ID hooks.ConnectionID +} + +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection +// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection. +func (c *Conn) Close() error { + return closeConn(c.ID, c.Conn) +} + +// TCPConn wraps net.TCPConn to override its Close method to include hook functionality. +type TCPConn struct { + *net.TCPConn + ID hooks.ConnectionID +} + +// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection. +func (c *TCPConn) Close() error { + return closeConn(c.ID, c.TCPConn) +} + +// closeConn is a helper function to close connections and execute close hooks. +func closeConn(id hooks.ConnectionID, conn io.Closer) error { + err := conn.Close() + + closeHooks := hooks.GetCloseHooks() + for _, hook := range closeHooks { + if err := hook(id); err != nil { + log.Errorf("Error executing close hook: %v", err) + } + } + + return err +} diff --git a/client/net/dial.go b/client/net/dial.go new file mode 100644 index 000000000..041a00e5d --- /dev/null +++ b/client/net/dial.go @@ -0,0 +1,82 @@ +//go:build !ios + +package net + +import ( + "fmt" + "net" + "sync" + + "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" +) + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) { + if CustomRoutingDisabled() { + return net.DialUDP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) + } + + switch c := conn.(type) { + case *net.UDPConn: + // Advanced routing: plain connection + return c, nil + case *Conn: + // Legacy routing: wrapped connection preserves close hooks + udpConn, ok := c.Conn.(*net.UDPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDP connection, got %T", c.Conn) + } + return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil + } + + if err := conn.Close(); err != nil { + log.Errorf("failed to close connection: %v", err) + } + return nil, fmt.Errorf("unexpected connection type: %T", conn) +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) { + if CustomRoutingDisabled() { + return net.DialTCP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) + } + + switch c := conn.(type) { + case *net.TCPConn: + // Advanced routing: plain connection + return c, nil + case *Conn: + // Legacy routing: wrapped connection preserves close hooks + tcpConn, ok := c.Conn.(*net.TCPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected TCP connection, got %T", c.Conn) + } + return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil + } + + if err := conn.Close(); err != nil { + log.Errorf("failed to close connection: %v", err) + } + return nil, fmt.Errorf("unexpected connection type: %T", conn) +} diff --git a/util/net/dial_ios.go b/client/net/dial_ios.go similarity index 100% rename from util/net/dial_ios.go rename to client/net/dial_ios.go diff --git a/util/net/dialer.go b/client/net/dialer.go similarity index 99% rename from util/net/dialer.go rename to client/net/dialer.go index 0786c667e..29bec05a7 100644 --- a/util/net/dialer.go +++ b/client/net/dialer.go @@ -16,6 +16,5 @@ func NewDialer() *Dialer { Dialer: &net.Dialer{}, } dialer.init() - return dialer } diff --git a/client/net/dialer_dial.go b/client/net/dialer_dial.go new file mode 100644 index 000000000..2e1eb53d8 --- /dev/null +++ b/client/net/dialer_dial.go @@ -0,0 +1,87 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/routemanager/util" + "github.com/netbirdio/netbird/client/net/hooks" +) + +// DialContext wraps the net.Dialer's DialContext method to use the custom connection +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + log.Debugf("Dialing %s %s", network, address) + + if CustomRoutingDisabled() || AdvancedRouting() { + return d.Dialer.DialContext(ctx, network, address) + } + + connID := hooks.GenerateConnID() + if err := callDialerHooks(ctx, connID, address, d.Resolver); err != nil { + log.Errorf("Failed to call dialer hooks: %v", err) + } + + conn, err := d.Dialer.DialContext(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) + } + + // Wrap the connection in Conn to handle Close with hooks + return &Conn{Conn: conn, ID: connID}, nil +} + +// Dial wraps the net.Dialer's Dial method to use the custom connection +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address string, customResolver *net.Resolver) error { + if ctx.Err() != nil { + return ctx.Err() + } + + writeHooks := hooks.GetWriteHooks() + if len(writeHooks) == 0 { + return nil + } + + host, _, err := net.SplitHostPort(address) + if err != nil { + return fmt.Errorf("split host and port: %w", err) + } + + resolver := customResolver + if resolver == nil { + resolver = net.DefaultResolver + } + + ips, err := resolver.LookupIPAddr(ctx, host) + if err != nil { + return fmt.Errorf("failed to resolve address %s: %w", address, err) + } + + log.Debugf("Dialer resolved IPs for %s: %v", address, ips) + + var merr *multierror.Error + for _, ip := range ips { + prefix, err := util.GetPrefixFromIP(ip.IP) + if err != nil { + merr = multierror.Append(merr, fmt.Errorf("convert IP %s to prefix: %w", ip.IP, err)) + continue + } + for _, hook := range writeHooks { + if err := hook(connID, prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("executing dial hook for IP %s: %w", ip.IP, err)) + } + } + } + + return nberrors.FormatErrorOrNil(merr) +} diff --git a/util/net/dialer_init_android.go b/client/net/dialer_init_android.go similarity index 100% rename from util/net/dialer_init_android.go rename to client/net/dialer_init_android.go diff --git a/client/net/dialer_init_generic.go b/client/net/dialer_init_generic.go new file mode 100644 index 000000000..18ebc6ad1 --- /dev/null +++ b/client/net/dialer_init_generic.go @@ -0,0 +1,7 @@ +//go:build !linux && !windows + +package net + +func (d *Dialer) init() { + // implemented on Linux, Android, and Windows only +} diff --git a/util/net/dialer_init_linux.go b/client/net/dialer_init_linux.go similarity index 100% rename from util/net/dialer_init_linux.go rename to client/net/dialer_init_linux.go diff --git a/client/net/dialer_init_windows.go b/client/net/dialer_init_windows.go new file mode 100644 index 000000000..6eefe5b1e --- /dev/null +++ b/client/net/dialer_init_windows.go @@ -0,0 +1,5 @@ +package net + +func (d *Dialer) init() { + d.Dialer.Control = applyUnicastIFToSocket +} diff --git a/util/net/env.go b/client/net/env.go similarity index 94% rename from util/net/env.go rename to client/net/env.go index 32425665d..8f326ca88 100644 --- a/util/net/env.go +++ b/client/net/env.go @@ -11,6 +11,7 @@ import ( const ( envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" + envUseLegacyRouting = "NB_USE_LEGACY_ROUTING" ) // CustomRoutingDisabled returns true if custom routing is disabled. diff --git a/client/net/env_android.go b/client/net/env_android.go new file mode 100644 index 000000000..9d89951a1 --- /dev/null +++ b/client/net/env_android.go @@ -0,0 +1,24 @@ +//go:build android + +package net + +// Init initializes the network environment for Android +func Init() { + // No initialization needed on Android +} + +// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes. +// Always returns true on Android since we cannot handle routes dynamically. +func AdvancedRouting() bool { + return true +} + +// SetVPNInterfaceName is a no-op on Android +func SetVPNInterfaceName(name string) { + // No-op on Android - not needed for Android VPN service +} + +// GetVPNInterfaceName returns empty string on Android +func GetVPNInterfaceName() string { + return "" +} diff --git a/client/net/env_generic.go b/client/net/env_generic.go new file mode 100644 index 000000000..f467930c3 --- /dev/null +++ b/client/net/env_generic.go @@ -0,0 +1,23 @@ +//go:build !linux && !windows && !android + +package net + +// Init initializes the network environment (no-op on non-Linux/Windows platforms) +func Init() { + // No-op on non-Linux/Windows platforms +} + +// AdvancedRouting returns false on non-Linux/Windows platforms +func AdvancedRouting() bool { + return false +} + +// SetVPNInterfaceName is a no-op on non-Windows platforms +func SetVPNInterfaceName(name string) { + // No-op on non-Windows platforms +} + +// GetVPNInterfaceName returns empty string on non-Windows platforms +func GetVPNInterfaceName() string { + return "" +} diff --git a/util/net/env_linux.go b/client/net/env_linux.go similarity index 86% rename from util/net/env_linux.go rename to client/net/env_linux.go index 3159f6462..82d9a74a8 100644 --- a/util/net/env_linux.go +++ b/client/net/env_linux.go @@ -17,8 +17,7 @@ import ( const ( // these have the same effect, skip socket env supported for backward compatibility - envSkipSocketMark = "NB_SKIP_SOCKET_MARK" - envUseLegacyRouting = "NB_USE_LEGACY_ROUTING" + envSkipSocketMark = "NB_SKIP_SOCKET_MARK" ) var advancedRoutingSupported bool @@ -27,6 +26,7 @@ func Init() { advancedRoutingSupported = checkAdvancedRoutingSupport() } +// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes func AdvancedRouting() bool { return advancedRoutingSupported } @@ -73,7 +73,7 @@ func checkAdvancedRoutingSupport() bool { } func CheckFwmarkSupport() bool { - // temporarily enable advanced routing to check fwmarks are supported + // temporarily enable advanced routing to check if fwmarks are supported old := advancedRoutingSupported advancedRoutingSupported = true defer func() { @@ -129,3 +129,13 @@ func CheckRuleOperationsSupport() bool { } return true } + +// SetVPNInterfaceName is a no-op on Linux +func SetVPNInterfaceName(name string) { + // No-op on Linux - not needed for fwmark-based routing +} + +// GetVPNInterfaceName returns empty string on Linux +func GetVPNInterfaceName() string { + return "" +} diff --git a/client/net/env_windows.go b/client/net/env_windows.go new file mode 100644 index 000000000..7e8868ba5 --- /dev/null +++ b/client/net/env_windows.go @@ -0,0 +1,67 @@ +//go:build windows + +package net + +import ( + "os" + "strconv" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/netstack" +) + +var ( + vpnInterfaceName string + vpnInitMutex sync.RWMutex + + advancedRoutingSupported bool +) + +func Init() { + advancedRoutingSupported = checkAdvancedRoutingSupport() +} + +func checkAdvancedRoutingSupport() bool { + var err error + var legacyRouting bool + if val := os.Getenv(envUseLegacyRouting); val != "" { + legacyRouting, err = strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err) + } + } + + if legacyRouting || netstack.IsEnabled() { + log.Info("advanced routing has been requested to be disabled") + return false + } + + log.Info("system supports advanced routing") + + return true +} + +// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes +func AdvancedRouting() bool { + return advancedRoutingSupported +} + +// GetVPNInterfaceName returns the stored VPN interface name +func GetVPNInterfaceName() string { + vpnInitMutex.RLock() + defer vpnInitMutex.RUnlock() + return vpnInterfaceName +} + +// SetVPNInterfaceName sets the VPN interface name for lazy initialization +func SetVPNInterfaceName(name string) { + vpnInitMutex.Lock() + defer vpnInitMutex.Unlock() + vpnInterfaceName = name + + if name != "" { + log.Infof("VPN interface name set to %s for route exclusion", name) + } +} diff --git a/client/net/hooks/hooks.go b/client/net/hooks/hooks.go new file mode 100644 index 000000000..93d8e18ef --- /dev/null +++ b/client/net/hooks/hooks.go @@ -0,0 +1,93 @@ +package hooks + +import ( + "net/netip" + "slices" + "sync" + + "github.com/google/uuid" +) + +// ConnectionID provides a globally unique identifier for network connections. +// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. +type ConnectionID string + +// GenerateConnID generates a unique identifier for each connection. +func GenerateConnID() ConnectionID { + return ConnectionID(uuid.NewString()) +} + +type WriteHookFunc func(connID ConnectionID, prefix netip.Prefix) error +type CloseHookFunc func(connID ConnectionID) error +type AddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error + +var ( + hooksMutex sync.RWMutex + + writeHooks []WriteHookFunc + closeHooks []CloseHookFunc + addressRemoveHooks []AddressRemoveHookFunc +) + +// AddWriteHook allows adding a new hook to be executed before writing/dialing. +func AddWriteHook(hook WriteHookFunc) { + hooksMutex.Lock() + defer hooksMutex.Unlock() + writeHooks = append(writeHooks, hook) +} + +// AddCloseHook allows adding a new hook to be executed on connection close. +func AddCloseHook(hook CloseHookFunc) { + hooksMutex.Lock() + defer hooksMutex.Unlock() + closeHooks = append(closeHooks, hook) +} + +// RemoveWriteHooks removes all write hooks. +func RemoveWriteHooks() { + hooksMutex.Lock() + defer hooksMutex.Unlock() + writeHooks = nil +} + +// RemoveCloseHooks removes all close hooks. +func RemoveCloseHooks() { + hooksMutex.Lock() + defer hooksMutex.Unlock() + closeHooks = nil +} + +// AddAddressRemoveHook allows adding a new hook to be executed when an address is removed. +func AddAddressRemoveHook(hook AddressRemoveHookFunc) { + hooksMutex.Lock() + defer hooksMutex.Unlock() + addressRemoveHooks = append(addressRemoveHooks, hook) +} + +// RemoveAddressRemoveHooks removes all listener address hooks. +func RemoveAddressRemoveHooks() { + hooksMutex.Lock() + defer hooksMutex.Unlock() + addressRemoveHooks = nil +} + +// GetWriteHooks returns a copy of the current write hooks. +func GetWriteHooks() []WriteHookFunc { + hooksMutex.RLock() + defer hooksMutex.RUnlock() + return slices.Clone(writeHooks) +} + +// GetCloseHooks returns a copy of the current close hooks. +func GetCloseHooks() []CloseHookFunc { + hooksMutex.RLock() + defer hooksMutex.RUnlock() + return slices.Clone(closeHooks) +} + +// GetAddressRemoveHooks returns a copy of the current listener address remove hooks. +func GetAddressRemoveHooks() []AddressRemoveHookFunc { + hooksMutex.RLock() + defer hooksMutex.RUnlock() + return slices.Clone(addressRemoveHooks) +} diff --git a/client/net/listen.go b/client/net/listen.go new file mode 100644 index 000000000..da7262806 --- /dev/null +++ b/client/net/listen.go @@ -0,0 +1,47 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" +) + +// ListenUDP listens on the network address and returns a transport.UDPConn +// which includes support for write and close hooks. +func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { + if CustomRoutingDisabled() { + return net.ListenUDP(network, laddr) + } + + conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) + if err != nil { + return nil, fmt.Errorf("listen UDP: %w", err) + } + + switch c := conn.(type) { + case *net.UDPConn: + // Advanced routing: plain connection + return c, nil + case *PacketConn: + // Legacy routing: wrapped connection for hooks + udpConn, ok := c.PacketConn.(*net.UDPConn) + if !ok { + if err := c.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDPConn, got %T", c.PacketConn) + } + return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil + } + + if err := conn.Close(); err != nil { + log.Errorf("failed to close connection: %v", err) + } + return nil, fmt.Errorf("unexpected connection type: %T", conn) +} diff --git a/util/net/listen_ios.go b/client/net/listen_ios.go similarity index 100% rename from util/net/listen_ios.go rename to client/net/listen_ios.go diff --git a/util/net/listener.go b/client/net/listener.go similarity index 81% rename from util/net/listener.go rename to client/net/listener.go index f4d769f58..4c2f53c05 100644 --- a/util/net/listener.go +++ b/client/net/listener.go @@ -7,14 +7,12 @@ import ( // ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before // responding via the socket and after closing. This can be used to bypass the VPN for listeners. type ListenerConfig struct { - *net.ListenConfig + net.ListenConfig } // NewListener creates a new ListenerConfig instance. func NewListener() *ListenerConfig { - listener := &ListenerConfig{ - ListenConfig: &net.ListenConfig{}, - } + listener := &ListenerConfig{} listener.init() return listener diff --git a/util/net/listener_init_android.go b/client/net/listener_init_android.go similarity index 100% rename from util/net/listener_init_android.go rename to client/net/listener_init_android.go diff --git a/client/net/listener_init_generic.go b/client/net/listener_init_generic.go new file mode 100644 index 000000000..4f8f17ab2 --- /dev/null +++ b/client/net/listener_init_generic.go @@ -0,0 +1,7 @@ +//go:build !linux && !windows + +package net + +func (l *ListenerConfig) init() { + // implemented on Linux, Android, and Windows only +} diff --git a/util/net/listener_init_linux.go b/client/net/listener_init_linux.go similarity index 100% rename from util/net/listener_init_linux.go rename to client/net/listener_init_linux.go diff --git a/client/net/listener_init_windows.go b/client/net/listener_init_windows.go new file mode 100644 index 000000000..a9399b5f1 --- /dev/null +++ b/client/net/listener_init_windows.go @@ -0,0 +1,8 @@ +package net + +func (l *ListenerConfig) init() { + // TODO: this will select a single source interface, but for UDP we can have various source interfaces and IP addresses. + // For now we stick to the one that matches the request IP address, which can be the unspecified IP. In this case + // the interface will be selected that serves the default route. + l.ListenConfig.Control = applyUnicastIFToSocket +} diff --git a/client/net/listener_listen.go b/client/net/listener_listen.go new file mode 100644 index 000000000..0bb5ad67d --- /dev/null +++ b/client/net/listener_listen.go @@ -0,0 +1,153 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/routemanager/util" + "github.com/netbirdio/netbird/client/net/hooks" +) + +// ListenPacket listens on the network address and returns a PacketConn +// which includes support for write hooks. +func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + if CustomRoutingDisabled() || AdvancedRouting() { + return l.ListenConfig.ListenPacket(ctx, network, address) + } + + pc, err := l.ListenConfig.ListenPacket(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("listen packet: %w", err) + } + connID := hooks.GenerateConnID() + + return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil +} + +// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. +type PacketConn struct { + net.PacketConn + ID hooks.ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil { + log.Errorf("Failed to call write hooks: %v", err) + } + return c.PacketConn.WriteTo(b, addr) +} + +// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. +func (c *PacketConn) Close() error { + defer c.seenAddrs.Clear() + return closeConn(c.ID, c.PacketConn) +} + +// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. +type UDPConn struct { + *net.UDPConn + ID hooks.ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil { + log.Errorf("Failed to call write hooks: %v", err) + } + return c.UDPConn.WriteTo(b, addr) +} + +// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. +func (c *UDPConn) Close() error { + defer c.seenAddrs.Clear() + return closeConn(c.ID, c.UDPConn) +} + +// RemoveAddress removes an address from the seen cache and triggers removal hooks. +func (c *PacketConn) RemoveAddress(addr string) { + if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists { + return + } + + ipStr, _, err := net.SplitHostPort(addr) + if err != nil { + log.Errorf("Error splitting IP address and port: %v", err) + return + } + + ipAddr, err := netip.ParseAddr(ipStr) + if err != nil { + log.Errorf("Error parsing IP address %s: %v", ipStr, err) + return + } + + prefix := netip.PrefixFrom(ipAddr.Unmap(), ipAddr.BitLen()) + + addressRemoveHooks := hooks.GetAddressRemoveHooks() + if len(addressRemoveHooks) == 0 { + return + } + + for _, hook := range addressRemoveHooks { + if err := hook(c.ID, prefix); err != nil { + log.Errorf("Error executing listener address remove hook: %v", err) + } + } +} + +// WrapPacketConn wraps an existing net.PacketConn with nbnet hook functionality +func WrapPacketConn(conn net.PacketConn) net.PacketConn { + if AdvancedRouting() { + // hooks not required for advanced routing + return conn + } + return &PacketConn{ + PacketConn: conn, + ID: hooks.GenerateConnID(), + seenAddrs: &sync.Map{}, + } +} + +func callWriteHooks(id hooks.ConnectionID, seenAddrs *sync.Map, addr net.Addr) error { + if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); loaded { + return nil + } + + writeHooks := hooks.GetWriteHooks() + if len(writeHooks) == 0 { + return nil + } + + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + return fmt.Errorf("expected *net.UDPAddr for packet connection, got %T", addr) + } + + prefix, err := util.GetPrefixFromIP(udpAddr.IP) + if err != nil { + return fmt.Errorf("convert UDP IP %s to prefix: %w", udpAddr.IP, err) + } + + log.Debugf("Listener resolved IP for %s: %s", addr, prefix) + + var merr *multierror.Error + for _, hook := range writeHooks { + if err := hook(id, prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("execute write hook: %w", err)) + } + } + + return nberrors.FormatErrorOrNil(merr) +} diff --git a/util/net/listener_listen_ios.go b/client/net/listener_listen_ios.go similarity index 100% rename from util/net/listener_listen_ios.go rename to client/net/listener_listen_ios.go diff --git a/util/net/net.go b/client/net/net.go similarity index 81% rename from util/net/net.go rename to client/net/net.go index fdcf4ee6a..a97de9d59 100644 --- a/util/net/net.go +++ b/client/net/net.go @@ -5,8 +5,6 @@ import ( "math/big" "net" "net/netip" - - "github.com/google/uuid" ) const ( @@ -44,18 +42,6 @@ func IsDataPlaneMark(fwmark uint32) bool { return fwmark >= DataPlaneMarkLower && fwmark <= DataPlaneMarkUpper } -// ConnectionID provides a globally unique identifier for network connections. -// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. -type ConnectionID string - -type AddHookFunc func(connID ConnectionID, IP net.IP) error -type RemoveHookFunc func(connID ConnectionID) error - -// GenerateConnID generates a unique identifier for each connection. -func GenerateConnID() ConnectionID { - return ConnectionID(uuid.NewString()) -} - func GetLastIPFromNetwork(network netip.Prefix, fromEnd int) (netip.Addr, error) { var endIP net.IP addr := network.Addr().AsSlice() diff --git a/util/net/net_linux.go b/client/net/net_linux.go similarity index 100% rename from util/net/net_linux.go rename to client/net/net_linux.go diff --git a/util/net/net_test.go b/client/net/net_test.go similarity index 100% rename from util/net/net_test.go rename to client/net/net_test.go diff --git a/client/net/net_windows.go b/client/net/net_windows.go new file mode 100644 index 000000000..649d83aaf --- /dev/null +++ b/client/net/net_windows.go @@ -0,0 +1,284 @@ +package net + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "strconv" + "strings" + "syscall" + "time" + "unsafe" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +const ( + // https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options + IpUnicastIf = 31 + Ipv6UnicastIf = 31 + + // https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ipv6-socket-options + Ipv6V6only = 27 +) + +// GetBestInterfaceFunc is set at runtime to avoid import cycle +var GetBestInterfaceFunc func(dest netip.Addr, vpnIntf string) (*net.Interface, error) + +// nativeToBigEndian converts a uint32 from native byte order to big-endian +func nativeToBigEndian(v uint32) uint32 { + return (v&0xff)<<24 | (v&0xff00)<<8 | (v&0xff0000)>>8 | (v&0xff000000)>>24 +} + +// parseDestinationAddress parses the destination address from various formats +func parseDestinationAddress(network, address string) (netip.Addr, error) { + if address == "" { + if strings.HasSuffix(network, "6") { + return netip.IPv6Unspecified(), nil + } + return netip.IPv4Unspecified(), nil + } + + if addrPort, err := netip.ParseAddrPort(address); err == nil { + return addrPort.Addr(), nil + } + + if dest, err := netip.ParseAddr(address); err == nil { + return dest, nil + } + + host, _, err := net.SplitHostPort(address) + if err != nil { + // No port, treat whole string as host + host = address + } + + if host == "" { + if strings.HasSuffix(network, "6") { + return netip.IPv6Unspecified(), nil + } + return netip.IPv4Unspecified(), nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil || len(ips) == 0 { + return netip.Addr{}, fmt.Errorf("resolve destination %s: %w", host, err) + } + + dest, ok := netip.AddrFromSlice(ips[0].IP) + if !ok { + return netip.Addr{}, fmt.Errorf("convert IP %v to netip.Addr", ips[0].IP) + } + + if ips[0].Zone != "" { + dest = dest.WithZone(ips[0].Zone) + } + + return dest, nil +} + +func getInterfaceFromZone(zone string) *net.Interface { + if zone == "" { + return nil + } + + idx, err := strconv.Atoi(zone) + if err != nil { + log.Debugf("invalid zone format for Windows (expected numeric): %s", zone) + return nil + } + + iface, err := net.InterfaceByIndex(idx) + if err != nil { + log.Debugf("failed to get interface by index %d from zone: %v", idx, err) + return nil + } + + return iface +} + +type interfaceSelection struct { + iface4 *net.Interface + iface6 *net.Interface +} + +func selectInterfaceForZone(dest netip.Addr, zone string) *interfaceSelection { + iface := getInterfaceFromZone(zone) + if iface == nil { + return nil + } + + if dest.Is6() { + return &interfaceSelection{iface6: iface} + } + return &interfaceSelection{iface4: iface} +} + +func selectInterfaceForUnspecified() (*interfaceSelection, error) { + if GetBestInterfaceFunc == nil { + return nil, errors.New("GetBestInterfaceFunc not initialized") + } + + var result interfaceSelection + vpnIfaceName := GetVPNInterfaceName() + + if iface4, err := GetBestInterfaceFunc(netip.IPv4Unspecified(), vpnIfaceName); err == nil { + result.iface4 = iface4 + } else { + log.Debugf("No IPv4 default route found: %v", err) + } + + if iface6, err := GetBestInterfaceFunc(netip.IPv6Unspecified(), vpnIfaceName); err == nil { + result.iface6 = iface6 + } else { + log.Debugf("No IPv6 default route found: %v", err) + } + + if result.iface4 == nil && result.iface6 == nil { + return nil, errors.New("no default routes found") + } + + return &result, nil +} + +func selectInterface(dest netip.Addr) (*interfaceSelection, error) { + if zone := dest.Zone(); zone != "" { + if selection := selectInterfaceForZone(dest, zone); selection != nil { + return selection, nil + } + } + + if dest.IsUnspecified() { + return selectInterfaceForUnspecified() + } + + if GetBestInterfaceFunc == nil { + return nil, errors.New("GetBestInterfaceFunc not initialized") + } + + iface, err := GetBestInterfaceFunc(dest, GetVPNInterfaceName()) + if err != nil { + return nil, fmt.Errorf("find route for %s: %w", dest, err) + } + + if dest.Is6() { + return &interfaceSelection{iface6: iface}, nil + } + return &interfaceSelection{iface4: iface}, nil +} + +func setIPv4UnicastIF(fd uintptr, iface *net.Interface) error { + ifaceIndexBE := nativeToBigEndian(uint32(iface.Index)) + if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IpUnicastIf, int(ifaceIndexBE)); err != nil { + return fmt.Errorf("set IP_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index) + } + return nil +} + +func setIPv6UnicastIF(fd uintptr, iface *net.Interface) error { + if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6UnicastIf, iface.Index); err != nil { + return fmt.Errorf("set IPV6_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index) + } + return nil +} + +func setUnicastIf(fd uintptr, network string, selection *interfaceSelection, address string) error { + // The Go runtime always passes specific network types to Control (udp4, udp6, tcp4, tcp6, etc.) + // Never generic ones (udp, tcp, ip) + + switch { + case strings.HasSuffix(network, "4"): + // IPv4-only socket (udp4, tcp4, ip4) + return setUnicastIfIPv4(fd, network, selection, address) + + case strings.HasSuffix(network, "6"): + // IPv6 socket (udp6, tcp6, ip6) - could be dual-stack or IPv6-only + return setUnicastIfIPv6(fd, network, selection, address) + } + + // Shouldn't reach here based on Go's documented behavior + return fmt.Errorf("unexpected network type: %s", network) +} + +func setUnicastIfIPv4(fd uintptr, network string, selection *interfaceSelection, address string) error { + if selection.iface4 == nil { + return nil + } + + if err := setIPv4UnicastIF(fd, selection.iface4); err != nil { + return err + } + + log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s", selection.iface4.Index, selection.iface4.Name, network, address) + return nil +} + +func setUnicastIfIPv6(fd uintptr, network string, selection *interfaceSelection, address string) error { + isDualStack := checkDualStack(fd) + + // For dual-stack sockets, also set the IPv4 option + if isDualStack && selection.iface4 != nil { + if err := setIPv4UnicastIF(fd, selection.iface4); err != nil { + return err + } + log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s (dual-stack)", selection.iface4.Index, selection.iface4.Name, network, address) + } + + if selection.iface6 == nil { + return nil + } + + if err := setIPv6UnicastIF(fd, selection.iface6); err != nil { + return err + } + + log.Debugf("Set IPV6_UNICAST_IF=%d on %s for %s to %s", selection.iface6.Index, selection.iface6.Name, network, address) + return nil +} + +func checkDualStack(fd uintptr) bool { + var v6Only int + v6OnlyLen := int32(unsafe.Sizeof(v6Only)) + err := windows.Getsockopt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6V6only, (*byte)(unsafe.Pointer(&v6Only)), &v6OnlyLen) + return err == nil && v6Only == 0 +} + +// applyUnicastIFToSocket applies IpUnicastIf to a socket based on the destination address +func applyUnicastIFToSocket(network string, address string, c syscall.RawConn) error { + if !AdvancedRouting() { + return nil + } + + dest, err := parseDestinationAddress(network, address) + if err != nil { + return err + } + + dest = dest.Unmap() + + if !dest.IsValid() { + return fmt.Errorf("invalid destination address for %s", address) + } + + selection, err := selectInterface(dest) + if err != nil { + return err + } + + var controlErr error + err = c.Control(func(fd uintptr) { + controlErr = setUnicastIf(fd, network, selection, address) + }) + + if err != nil { + return fmt.Errorf("control: %w", err) + } + + return controlErr +} diff --git a/util/net/protectsocket_android.go b/client/net/protectsocket_android.go similarity index 100% rename from util/net/protectsocket_android.go rename to client/net/protectsocket_android.go diff --git a/flow/client/client.go b/flow/client/client.go index 949824065..603fd6882 100644 --- a/flow/client/client.go +++ b/flow/client/client.go @@ -20,9 +20,9 @@ import ( "google.golang.org/grpc/keepalive" "google.golang.org/grpc/status" + nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/flow/proto" "github.com/netbirdio/netbird/util/embeddedroots" - nbgrpc "github.com/netbirdio/netbird/util/grpc" ) type GRPCClient struct { diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index dc26253e9..03cc5aec3 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -17,11 +17,11 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/connectivity" + nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" - nbgrpc "github.com/netbirdio/netbird/util/grpc" ) const ConnectTimeout = 10 * time.Second diff --git a/shared/relay/client/dialer/quic/quic.go b/shared/relay/client/dialer/quic/quic.go index b496f6a9b..967e18d79 100644 --- a/shared/relay/client/dialer/quic/quic.go +++ b/shared/relay/client/dialer/quic/quic.go @@ -12,7 +12,7 @@ import ( log "github.com/sirupsen/logrus" quictls "github.com/netbirdio/netbird/shared/relay/tls" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type Dialer struct { diff --git a/shared/relay/client/dialer/ws/ws.go b/shared/relay/client/dialer/ws/ws.go index 109651f5d..ef6bd6b3c 100644 --- a/shared/relay/client/dialer/ws/ws.go +++ b/shared/relay/client/dialer/ws/ws.go @@ -16,7 +16,7 @@ import ( "github.com/netbirdio/netbird/shared/relay" "github.com/netbirdio/netbird/util/embeddedroots" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) type Dialer struct { diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 82ab678f4..48d1ff04f 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -16,10 +16,10 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/shared/management/client" "github.com/netbirdio/netbird/shared/signal/proto" - nbgrpc "github.com/netbirdio/netbird/util/grpc" ) // ConnStateNotifier is a wrapper interface of the status recorder diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index d4fedc492..bc2d4d1be 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -22,7 +22,7 @@ import ( "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" - nbnet "github.com/netbirdio/netbird/util/net" + nbnet "github.com/netbirdio/netbird/client/net" ) // ErrSharedSockStopped indicates that shared socket has been stopped @@ -93,7 +93,7 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error } if err = nbnet.SetSocketMark(rawSock.conn4); err != nil { - return nil, fmt.Errorf("failed to set SO_MARK on ipv4 socket: %w", err) + return nil, fmt.Errorf("set SO_MARK on ipv4 socket: %w", err) } var sockErr error @@ -102,7 +102,7 @@ func Listen(port int, filter BPFFilter, mtu uint16) (_ net.PacketConn, err error log.Errorf("Failed to create ipv6 raw socket: %v", err) } else { if err = nbnet.SetSocketMark(rawSock.conn6); err != nil { - return nil, fmt.Errorf("failed to set SO_MARK on ipv6 socket: %w", err) + return nil, fmt.Errorf("set SO_MARK on ipv6 socket: %w", err) } } diff --git a/util/net/conn.go b/util/net/conn.go deleted file mode 100644 index 26693f841..000000000 --- a/util/net/conn.go +++ /dev/null @@ -1,31 +0,0 @@ -//go:build !ios - -package net - -import ( - "net" - - log "github.com/sirupsen/logrus" -) - -// Conn wraps a net.Conn to override the Close method -type Conn struct { - net.Conn - ID ConnectionID -} - -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -func (c *Conn) Close() error { - err := c.Conn.Close() - - dialerCloseHooksMutex.RLock() - defer dialerCloseHooksMutex.RUnlock() - - for _, hook := range dialerCloseHooks { - if err := hook(c.ID, &c.Conn); err != nil { - log.Errorf("Error executing dialer close hook: %v", err) - } - } - - return err -} diff --git a/util/net/dial.go b/util/net/dial.go deleted file mode 100644 index 595311492..000000000 --- a/util/net/dial.go +++ /dev/null @@ -1,58 +0,0 @@ -//go:build !ios - -package net - -import ( - "fmt" - "net" - - log "github.com/sirupsen/logrus" -) - -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - if CustomRoutingDisabled() { - return net.DialUDP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) - } - - udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) - } - - return udpConn, nil -} - -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - if CustomRoutingDisabled() { - return net.DialTCP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) - } - - tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) - } - - return tcpConn, nil -} diff --git a/util/net/dialer_dial.go b/util/net/dialer_dial.go deleted file mode 100644 index 1659b6220..000000000 --- a/util/net/dialer_dial.go +++ /dev/null @@ -1,107 +0,0 @@ -//go:build !ios - -package net - -import ( - "context" - "fmt" - "net" - "sync" - - "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" -) - -type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error -type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error - -var ( - dialerDialHooksMutex sync.RWMutex - dialerDialHooks []DialerDialHookFunc - dialerCloseHooksMutex sync.RWMutex - dialerCloseHooks []DialerCloseHookFunc -) - -// AddDialerHook allows adding a new hook to be executed before dialing. -func AddDialerHook(hook DialerDialHookFunc) { - dialerDialHooksMutex.Lock() - defer dialerDialHooksMutex.Unlock() - dialerDialHooks = append(dialerDialHooks, hook) -} - -// AddDialerCloseHook allows adding a new hook to be executed on connection close. -func AddDialerCloseHook(hook DialerCloseHookFunc) { - dialerCloseHooksMutex.Lock() - defer dialerCloseHooksMutex.Unlock() - dialerCloseHooks = append(dialerCloseHooks, hook) -} - -// RemoveDialerHooks removes all dialer hooks. -func RemoveDialerHooks() { - dialerDialHooksMutex.Lock() - defer dialerDialHooksMutex.Unlock() - dialerDialHooks = nil - - dialerCloseHooksMutex.Lock() - defer dialerCloseHooksMutex.Unlock() - dialerCloseHooks = nil -} - -// DialContext wraps the net.Dialer's DialContext method to use the custom connection -func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - log.Debugf("Dialing %s %s", network, address) - - if CustomRoutingDisabled() { - return d.Dialer.DialContext(ctx, network, address) - } - - var resolver *net.Resolver - if d.Resolver != nil { - resolver = d.Resolver - } - - connID := GenerateConnID() - if dialerDialHooks != nil { - if err := callDialerHooks(ctx, connID, address, resolver); err != nil { - log.Errorf("Failed to call dialer hooks: %v", err) - } - } - - conn, err := d.Dialer.DialContext(ctx, network, address) - if err != nil { - return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) - } - - // Wrap the connection in Conn to handle Close with hooks - return &Conn{Conn: conn, ID: connID}, nil -} - -// Dial wraps the net.Dialer's Dial method to use the custom connection -func (d *Dialer) Dial(network, address string) (net.Conn, error) { - return d.DialContext(context.Background(), network, address) -} - -func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { - host, _, err := net.SplitHostPort(address) - if err != nil { - return fmt.Errorf("split host and port: %w", err) - } - ips, err := resolver.LookupIPAddr(ctx, host) - if err != nil { - return fmt.Errorf("failed to resolve address %s: %w", address, err) - } - - log.Debugf("Dialer resolved IPs for %s: %v", address, ips) - - var result *multierror.Error - - dialerDialHooksMutex.RLock() - defer dialerDialHooksMutex.RUnlock() - for _, hook := range dialerDialHooks { - if err := hook(ctx, connID, ips); err != nil { - result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err)) - } - } - - return result.ErrorOrNil() -} diff --git a/util/net/dialer_init_nonlinux.go b/util/net/dialer_init_nonlinux.go deleted file mode 100644 index 8c57ebbaa..000000000 --- a/util/net/dialer_init_nonlinux.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build !linux - -package net - -func (d *Dialer) init() { - // implemented on Linux and Android only -} diff --git a/util/net/env_generic.go b/util/net/env_generic.go deleted file mode 100644 index 6d142a838..000000000 --- a/util/net/env_generic.go +++ /dev/null @@ -1,12 +0,0 @@ -//go:build !linux || android - -package net - -func Init() { - // nothing to do on non-linux -} - -func AdvancedRouting() bool { - // non-linux currently doesn't support advanced routing - return false -} diff --git a/util/net/listen.go b/util/net/listen.go deleted file mode 100644 index 3ae8a9435..000000000 --- a/util/net/listen.go +++ /dev/null @@ -1,37 +0,0 @@ -//go:build !ios - -package net - -import ( - "context" - "fmt" - "net" - "sync" - - "github.com/pion/transport/v3" - log "github.com/sirupsen/logrus" -) - -// ListenUDP listens on the network address and returns a transport.UDPConn -// which includes support for write and close hooks. -func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { - if CustomRoutingDisabled() { - return net.ListenUDP(network, laddr) - } - - conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) - if err != nil { - return nil, fmt.Errorf("listen UDP: %w", err) - } - - packetConn := conn.(*PacketConn) - udpConn, ok := packetConn.PacketConn.(*net.UDPConn) - if !ok { - if err := packetConn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) - } - - return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil -} diff --git a/util/net/listener_init_nonlinux.go b/util/net/listener_init_nonlinux.go deleted file mode 100644 index 80f6f7f1a..000000000 --- a/util/net/listener_init_nonlinux.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build !linux - -package net - -func (l *ListenerConfig) init() { - // implemented on Linux and Android only -} diff --git a/util/net/listener_listen.go b/util/net/listener_listen.go deleted file mode 100644 index 4060ab49a..000000000 --- a/util/net/listener_listen.go +++ /dev/null @@ -1,205 +0,0 @@ -//go:build !ios - -package net - -import ( - "context" - "fmt" - "net" - "net/netip" - "sync" - - log "github.com/sirupsen/logrus" -) - -// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn. -type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error - -// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn. -type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error - -// ListenerAddressRemoveHookFunc defines the function signature for hooks called when addresses are removed. -type ListenerAddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error - -var ( - listenerWriteHooksMutex sync.RWMutex - listenerWriteHooks []ListenerWriteHookFunc - listenerCloseHooksMutex sync.RWMutex - listenerCloseHooks []ListenerCloseHookFunc - listenerAddressRemoveHooksMutex sync.RWMutex - listenerAddressRemoveHooks []ListenerAddressRemoveHookFunc -) - -// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent. -func AddListenerWriteHook(hook ListenerWriteHookFunc) { - listenerWriteHooksMutex.Lock() - defer listenerWriteHooksMutex.Unlock() - listenerWriteHooks = append(listenerWriteHooks, hook) -} - -// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection. -func AddListenerCloseHook(hook ListenerCloseHookFunc) { - listenerCloseHooksMutex.Lock() - defer listenerCloseHooksMutex.Unlock() - listenerCloseHooks = append(listenerCloseHooks, hook) -} - -// AddListenerAddressRemoveHook allows adding a new hook to be executed when an address is removed. -func AddListenerAddressRemoveHook(hook ListenerAddressRemoveHookFunc) { - listenerAddressRemoveHooksMutex.Lock() - defer listenerAddressRemoveHooksMutex.Unlock() - listenerAddressRemoveHooks = append(listenerAddressRemoveHooks, hook) -} - -// RemoveListenerHooks removes all listener hooks. -func RemoveListenerHooks() { - listenerWriteHooksMutex.Lock() - defer listenerWriteHooksMutex.Unlock() - listenerWriteHooks = nil - - listenerCloseHooksMutex.Lock() - defer listenerCloseHooksMutex.Unlock() - listenerCloseHooks = nil - - listenerAddressRemoveHooksMutex.Lock() - defer listenerAddressRemoveHooksMutex.Unlock() - listenerAddressRemoveHooks = nil -} - -// ListenPacket listens on the network address and returns a PacketConn -// which includes support for write hooks. -func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { - if CustomRoutingDisabled() { - return l.ListenConfig.ListenPacket(ctx, network, address) - } - - pc, err := l.ListenConfig.ListenPacket(ctx, network, address) - if err != nil { - return nil, fmt.Errorf("listen packet: %w", err) - } - connID := GenerateConnID() - - return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil -} - -// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. -type PacketConn struct { - net.PacketConn - ID ConnectionID - seenAddrs *sync.Map -} - -// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. -func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - callWriteHooks(c.ID, c.seenAddrs, b, addr) - return c.PacketConn.WriteTo(b, addr) -} - -// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. -func (c *PacketConn) Close() error { - c.seenAddrs = &sync.Map{} - return closeConn(c.ID, c.PacketConn) -} - -// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. -type UDPConn struct { - *net.UDPConn - ID ConnectionID - seenAddrs *sync.Map -} - -// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. -func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - callWriteHooks(c.ID, c.seenAddrs, b, addr) - return c.UDPConn.WriteTo(b, addr) -} - -// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. -func (c *UDPConn) Close() error { - c.seenAddrs = &sync.Map{} - return closeConn(c.ID, c.UDPConn) -} - -// RemoveAddress removes an address from the seen cache and triggers removal hooks. -func (c *PacketConn) RemoveAddress(addr string) { - if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists { - return - } - - ipStr, _, err := net.SplitHostPort(addr) - if err != nil { - log.Errorf("Error splitting IP address and port: %v", err) - return - } - - ipAddr, err := netip.ParseAddr(ipStr) - if err != nil { - log.Errorf("Error parsing IP address %s: %v", ipStr, err) - return - } - - prefix := netip.PrefixFrom(ipAddr, ipAddr.BitLen()) - - listenerAddressRemoveHooksMutex.RLock() - defer listenerAddressRemoveHooksMutex.RUnlock() - - for _, hook := range listenerAddressRemoveHooks { - if err := hook(c.ID, prefix); err != nil { - log.Errorf("Error executing listener address remove hook: %v", err) - } - } -} - - -// WrapPacketConn wraps an existing net.PacketConn with nbnet functionality -func WrapPacketConn(conn net.PacketConn) *PacketConn { - return &PacketConn{ - PacketConn: conn, - ID: GenerateConnID(), - seenAddrs: &sync.Map{}, - } -} - -func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) { - // Lookup the address in the seenAddrs map to avoid calling the hooks for every write - if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded { - ipStr, _, splitErr := net.SplitHostPort(addr.String()) - if splitErr != nil { - log.Errorf("Error splitting IP address and port: %v", splitErr) - return - } - - ip, err := net.ResolveIPAddr("ip", ipStr) - if err != nil { - log.Errorf("Error resolving IP address: %v", err) - return - } - log.Debugf("Listener resolved IP for %s: %s", addr, ip) - - func() { - listenerWriteHooksMutex.RLock() - defer listenerWriteHooksMutex.RUnlock() - - for _, hook := range listenerWriteHooks { - if err := hook(id, ip, b); err != nil { - log.Errorf("Error executing listener write hook: %v", err) - } - } - }() - } -} - -func closeConn(id ConnectionID, conn net.PacketConn) error { - err := conn.Close() - - listenerCloseHooksMutex.RLock() - defer listenerCloseHooksMutex.RUnlock() - - for _, hook := range listenerCloseHooks { - if err := hook(id, conn); err != nil { - log.Errorf("Error executing listener close hook: %v", err) - } - } - - return err -} From ead1c618ba9a4d850fd1bef1ff7c2ddf444dc653 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Sat, 20 Sep 2025 10:00:18 +0200 Subject: [PATCH 15/30] [client] Do not run up cmd if not needed in docker (#4508) optimizes the NetBird client startup process by avoiding unnecessary login commands when the peer is already authenticated. The changes increase the default login timeout and expand the log message patterns used to detect successful authentication. - Increased default login timeout from 1 to 5 seconds for more reliable authentication detection - Enhanced log pattern matching to detect both registration and ready states - Added extended regex support for more flexible pattern matching --- client/Dockerfile | 2 +- client/netbird-entrypoint.sh | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/client/Dockerfile b/client/Dockerfile index e19a09909..b2f627409 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -18,7 +18,7 @@ ENV \ NB_LOG_FILE="console,/var/log/netbird/client.log" \ NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \ NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \ - NB_ENTRYPOINT_LOGIN_TIMEOUT="1" + NB_ENTRYPOINT_LOGIN_TIMEOUT="5" ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] diff --git a/client/netbird-entrypoint.sh b/client/netbird-entrypoint.sh index 2422d2683..7c9fa021a 100755 --- a/client/netbird-entrypoint.sh +++ b/client/netbird-entrypoint.sh @@ -2,7 +2,7 @@ set -eEuo pipefail : ${NB_ENTRYPOINT_SERVICE_TIMEOUT:="5"} -: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="1"} +: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="5"} NETBIRD_BIN="${NETBIRD_BIN:-"netbird"}" export NB_LOG_FILE="${NB_LOG_FILE:-"console,/var/log/netbird/client.log"}" service_pids=() @@ -39,7 +39,7 @@ wait_for_message() { info "not waiting for log line ${message@Q} due to zero timeout." elif test -n "${log_file_path}"; then info "waiting for log line ${message@Q} for ${timeout} seconds..." - grep -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null) + grep -E -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null) else info "log file unsupported, sleeping for ${timeout} seconds..." sleep "${timeout}" @@ -81,7 +81,7 @@ wait_for_daemon_startup() { login_if_needed() { local timeout="${1}" - if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered'; then + if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered|management connection state READY'; then info "already logged in, skipping 'netbird up'..." else info "logging in..." From e254b4cde55e9dfe565c5b1a1dd18f53014a0a0c Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 20 Sep 2025 05:24:04 -0300 Subject: [PATCH 16/30] [misc] Update SIGN_PIPE_VER to version 0.0.23 (#4521) --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7be52259b..e9741f541 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.0.22" + SIGN_PIPE_VER: "v0.0.23" GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "NetBird GmbH" From 998fb30e1eb9651f4ad32e9d50f33c0da902dff7 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Sat, 20 Sep 2025 22:14:01 +0200 Subject: [PATCH 17/30] [client] Check the client status in the earlier phase (#4509) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR improves the NetBird client's status checking mechanism by implementing earlier detection of client state changes and better handling of connection lifecycle management. The key improvements focus on: • Enhanced status detection - Added waitForReady option to StatusRequest for improved client status handling • Better connection management - Improved context handling for signal and management gRPC connections• Reduced connection timeouts - Increased gRPC dial timeout from 3 to 10 seconds for better reliability • Cleaner error handling - Enhanced error propagation and context cancellation in retry loops Key Changes Core Status Improvements: - Added waitForReady optional field to StatusRequest proto (daemon.proto:190) - Enhanced status checking logic to detect client state changes earlier in the connection process - Improved handling of client permanent exit scenarios from retry loops Connection & Context Management: - Fixed context cancellation in management and signal client retry mechanisms - Added proper context propagation for Login operations - Enhanced gRPC connection handling with better timeout management Error Handling & Cleanup: - Moved feedback channels to upper layers for better separation of concerns - Improved error handling patterns throughout the client server implementation - Fixed synchronization issues and removed debug logging --- client/cmd/root.go | 2 +- client/cmd/up.go | 4 +- client/embed/embed.go | 2 +- client/grpc/dialer.go | 4 +- client/proto/daemon.pb.go | 22 ++++-- client/proto/daemon.proto | 2 + client/server/server.go | 126 +++++++++++++++++++++++-------- client/server/server_test.go | 16 ++-- shared/management/client/grpc.go | 2 +- shared/signal/client/grpc.go | 2 +- 10 files changed, 128 insertions(+), 54 deletions(-) diff --git a/client/cmd/root.go b/client/cmd/root.go index 5084bd38a..11e5228f1 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -231,7 +231,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string { // DialClientGRPCServer returns client connection to the daemon server. func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*3) + ctx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() return grpc.DialContext( diff --git a/client/cmd/up.go b/client/cmd/up.go index e686625d6..d047c041e 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -230,7 +230,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager client := proto.NewDaemonServiceClient(conn) - status, err := client.Status(ctx, &proto.StatusRequest{}) + status, err := client.Status(ctx, &proto.StatusRequest{ + WaitForReady: func() *bool { b := true; return &b }(), + }) if err != nil { return fmt.Errorf("unable to get daemon status: %v", err) } diff --git a/client/embed/embed.go b/client/embed/embed.go index de83f9d96..0bfc7a37c 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -135,7 +135,7 @@ func (c *Client) Start(startCtx context.Context) error { // either startup error (permanent backoff err) or nil err (successful engine up) // TODO: make after-startup backoff err available - run := make(chan struct{}, 1) + run := make(chan struct{}) clientErr := make(chan error, 1) go func() { if err := client.Run(run); err != nil { diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 7ac950d85..69e3f088c 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -58,7 +58,7 @@ func Backoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(b, ctx) } -func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) { +func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) { transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) if tlsEnabled { certPool, err := x509.SystemCertPool() @@ -72,7 +72,7 @@ func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) { })) } - connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() conn, err := grpc.DialContext( diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index c633afc83..841e3c0f7 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.6 -// protoc v5.29.3 +// protoc v6.32.1 // source: daemon.proto package proto @@ -794,8 +794,10 @@ type StatusRequest struct { state protoimpl.MessageState `protogen:"open.v1"` GetFullPeerStatus bool `protobuf:"varint,1,opt,name=getFullPeerStatus,proto3" json:"getFullPeerStatus,omitempty"` ShouldRunProbes bool `protobuf:"varint,2,opt,name=shouldRunProbes,proto3" json:"shouldRunProbes,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // the UI do not using this yet, but CLIs could use it to wait until the status is ready + WaitForReady *bool `protobuf:"varint,3,opt,name=waitForReady,proto3,oneof" json:"waitForReady,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *StatusRequest) Reset() { @@ -842,6 +844,13 @@ func (x *StatusRequest) GetShouldRunProbes() bool { return false } +func (x *StatusRequest) GetWaitForReady() bool { + if x != nil && x.WaitForReady != nil { + return *x.WaitForReady + } + return false +} + type StatusResponse struct { state protoimpl.MessageState `protogen:"open.v1"` // status of the server. @@ -4673,10 +4682,12 @@ const file_daemon_proto_rawDesc = "" + "\f_profileNameB\v\n" + "\t_username\"\f\n" + "\n" + - "UpResponse\"g\n" + + "UpResponse\"\xa1\x01\n" + "\rStatusRequest\x12,\n" + "\x11getFullPeerStatus\x18\x01 \x01(\bR\x11getFullPeerStatus\x12(\n" + - "\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\"\x82\x01\n" + + "\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\x12'\n" + + "\fwaitForReady\x18\x03 \x01(\bH\x00R\fwaitForReady\x88\x01\x01B\x0f\n" + + "\r_waitForReady\"\x82\x01\n" + "\x0eStatusResponse\x12\x16\n" + "\x06status\x18\x01 \x01(\tR\x06status\x122\n" + "\n" + @@ -5231,6 +5242,7 @@ func file_daemon_proto_init() { } file_daemon_proto_msgTypes[1].OneofWrappers = []any{} file_daemon_proto_msgTypes[5].OneofWrappers = []any{} + file_daemon_proto_msgTypes[7].OneofWrappers = []any{} file_daemon_proto_msgTypes[26].OneofWrappers = []any{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 0cd3579b9..5b27b4d98 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -186,6 +186,8 @@ message UpResponse {} message StatusRequest{ bool getFullPeerStatus = 1; bool shouldRunProbes = 2; + // the UI do not using this yet, but CLIs could use it to wait until the status is ready + optional bool waitForReady = 3; } message StatusResponse{ diff --git a/client/server/server.go b/client/server/server.go index fae342f78..e6de608c5 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -67,6 +67,7 @@ type Server struct { proto.UnimplementedDaemonServiceServer clientRunning bool // protected by mutex clientRunningChan chan struct{} + clientGiveUpChan chan struct{} connectClient *internal.ConnectClient @@ -106,6 +107,10 @@ func (s *Server) Start() error { s.mutex.Lock() defer s.mutex.Unlock() + if s.clientRunning { + return nil + } + state := internal.CtxGetState(s.rootCtx) if err := handlePanicLog(); err != nil { @@ -175,12 +180,10 @@ func (s *Server) Start() error { return nil } - if s.clientRunning { - return nil - } s.clientRunning = true - s.clientRunningChan = make(chan struct{}, 1) - go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan) + s.clientRunningChan = make(chan struct{}) + s.clientGiveUpChan = make(chan struct{}) + go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) return nil } @@ -211,7 +214,7 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error { // connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional // mechanism to keep the client connected even when the connection is lost. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. -func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) { +func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) { defer func() { s.mutex.Lock() s.clientRunning = false @@ -261,6 +264,10 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil if err := backoff.Retry(runOperation, backOff); err != nil { log.Errorf("operation failed: %v", err) } + + if giveUpChan != nil { + close(giveUpChan) + } } // loginAttempt attempts to login using the provided information. it returns a status in case something fails @@ -379,7 +386,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro if s.actCancel != nil { s.actCancel() } - ctx, cancel := context.WithCancel(s.rootCtx) + ctx, cancel := context.WithCancel(callerCtx) md, ok := metadata.FromIncomingContext(callerCtx) if ok { @@ -389,11 +396,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.actCancel = cancel s.mutex.Unlock() - if err := restoreResidualState(ctx, s.profileManager.GetStatePath()); err != nil { + if err := restoreResidualState(s.rootCtx, s.profileManager.GetStatePath()); err != nil { log.Warnf(errRestoreResidualState, err) } - state := internal.CtxGetState(ctx) + state := internal.CtxGetState(s.rootCtx) defer func() { status, err := state.Status() if err != nil || (status != internal.StatusNeedsLogin && status != internal.StatusLoginFailed) { @@ -606,6 +613,20 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin // Up starts engine work in the daemon. func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) { s.mutex.Lock() + if s.clientRunning { + state := internal.CtxGetState(s.rootCtx) + status, err := state.Status() + if err != nil { + s.mutex.Unlock() + return nil, err + } + if status == internal.StatusNeedsLogin { + s.actCancel() + } + s.mutex.Unlock() + + return s.waitForUp(callerCtx) + } defer s.mutex.Unlock() if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil { @@ -621,16 +642,16 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR if err != nil { return nil, err } + if status != internal.StatusIdle { return nil, fmt.Errorf("up already in progress: current status %s", status) } - // it should be nil here, but . + // it should be nil here, but in case it isn't we cancel it. if s.actCancel != nil { s.actCancel() } ctx, cancel := context.WithCancel(s.rootCtx) - md, ok := metadata.FromIncomingContext(callerCtx) if ok { ctx = metadata.NewOutgoingContext(ctx, md) @@ -673,26 +694,31 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) + s.clientRunning = true + s.clientRunningChan = make(chan struct{}) + s.clientGiveUpChan = make(chan struct{}) + go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) + + return s.waitForUp(callerCtx) +} + +// todo: handle potential race conditions +func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error) { timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second) defer cancel() - if !s.clientRunning { - s.clientRunning = true - s.clientRunningChan = make(chan struct{}, 1) - go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan) - } - for { - select { - case <-s.clientRunningChan: - s.isSessionActive.Store(true) - return &proto.UpResponse{}, nil - case <-callerCtx.Done(): - log.Debug("context done, stopping the wait for engine to become ready") - return nil, callerCtx.Err() - case <-timeoutCtx.Done(): - log.Debug("up is timed out, stopping the wait for engine to become ready") - return nil, timeoutCtx.Err() - } + select { + case <-s.clientGiveUpChan: + return nil, fmt.Errorf("client gave up to connect") + case <-s.clientRunningChan: + s.isSessionActive.Store(true) + return &proto.UpResponse{}, nil + case <-callerCtx.Done(): + log.Debug("context done, stopping the wait for engine to become ready") + return nil, callerCtx.Err() + case <-timeoutCtx.Done(): + log.Debug("up is timed out, stopping the wait for engine to become ready") + return nil, timeoutCtx.Err() } } @@ -966,12 +992,46 @@ func (s *Server) Status( ctx context.Context, msg *proto.StatusRequest, ) (*proto.StatusResponse, error) { - if ctx.Err() != nil { - return nil, ctx.Err() - } - s.mutex.Lock() - defer s.mutex.Unlock() + clientRunning := s.clientRunning + s.mutex.Unlock() + + if msg.WaitForReady != nil && *msg.WaitForReady && clientRunning { + state := internal.CtxGetState(s.rootCtx) + status, err := state.Status() + if err != nil { + return nil, err + } + + if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting { + s.actCancel() + } + + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + loop: + for { + select { + case <-s.clientGiveUpChan: + ticker.Stop() + break loop + case <-s.clientRunningChan: + ticker.Stop() + break loop + case <-ticker.C: + status, err := state.Status() + if err != nil { + continue + } + if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting { + s.actCancel() + } + continue + case <-ctx.Done(): + return nil, ctx.Err() + } + } + } status, err := internal.CtxGetState(s.rootCtx).Status() if err != nil { diff --git a/client/server/server_test.go b/client/server/server_test.go index 45a1aa5c7..755925003 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -105,7 +105,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Setenv(maxRetryTimeVar, "5s") t.Setenv(retryMultiplierVar, "1") - s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil) + s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil) if counter < 3 { t.Fatalf("expected counter > 2, got %d", counter) } @@ -134,8 +134,12 @@ func TestServer_Up(t *testing.T) { profName := "default" + u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345") + require.NoError(t, err) + ic := profilemanager.ConfigInput{ - ConfigPath: filepath.Join(tempDir, profName+".json"), + ConfigPath: filepath.Join(tempDir, profName+".json"), + ManagementURL: u.String(), } _, err = profilemanager.UpdateOrCreateConfig(ic) @@ -153,16 +157,9 @@ func TestServer_Up(t *testing.T) { } s := New(ctx, "console", "", false, false) - err = s.Start() require.NoError(t, err) - u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345") - require.NoError(t, err) - s.config = &profilemanager.Config{ - ManagementURL: u, - } - upCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() @@ -171,6 +168,7 @@ func TestServer_Up(t *testing.T) { Username: &currUser.Username, } _, err = s.Up(upCtx, upReq) + log.Errorf("error from Up: %v", err) assert.Contains(t, err.Error(), "context deadline exceeded") } diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 03cc5aec3..f30e965be 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -52,7 +52,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) + conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled) if err != nil { log.Printf("createConnection error: %v", err) return err diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 48d1ff04f..5ca0c0282 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -57,7 +57,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) + conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled) if err != nil { log.Printf("createConnection error: %v", err) return err From 5853b5553c9b9e80322add7ae5076ba5cea2b74c Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 22 Sep 2025 14:32:00 +0200 Subject: [PATCH 18/30] [client] Skip interface for route lookup if it doesn't exist (#4524) --- client/internal/routemanager/systemops/systemops_windows.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index 95645329e..7bce6af80 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -908,7 +908,8 @@ func GetBestInterface(dest netip.Addr, vpnIntf string) (*net.Interface, error) { if iface, err := net.InterfaceByName(vpnIntf); err == nil { skipInterfaceIndex = iface.Index } else { - return nil, fmt.Errorf("get VPN interface %s: %w", vpnIntf, err) + // not critical, if we cannot get ahold of the interface then we won't need to skip it + log.Warnf("failed to get VPN interface %s: %v", vpnIntf, err) } } From 58faa341d2d6dd9a3821ca9a32d7035b4f933381 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 23 Sep 2025 12:06:10 +0200 Subject: [PATCH 19/30] [management] Add logs for update channel (#4527) --- management/server/grpcserver.go | 1 + management/server/user.go | 1 + 2 files changed, 2 insertions(+) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 27d54e6c2..60a00207e 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -258,6 +258,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil { + log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) return err } diff --git a/management/server/user.go b/management/server/user.go index 3c7c3f433..d40d33c6a 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -965,6 +965,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service + log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID) am.peersUpdateManager.CloseChannels(ctx, peerIDs) am.BufferUpdateAccountPeers(ctx, accountID) } From 644ed4b934452d578ec4854ebda311cd42341f08 Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Thu, 25 Sep 2025 15:36:26 +0700 Subject: [PATCH 20/30] [client] Add WireGuard interface lifecycle monitoring (#4370) * [client] Add WireGuard interface lifecycle monitoring --- client/internal/engine.go | 23 +++++++ client/internal/wg_iface_monitor.go | 98 +++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+) create mode 100644 client/internal/wg_iface_monitor.go diff --git a/client/internal/engine.go b/client/internal/engine.go index d4c465efb..828bc6e94 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -198,6 +198,10 @@ type Engine struct { latestSyncResponse *mgmProto.SyncResponse connSemaphore *semaphoregroup.SemaphoreGroup flowManager nftypes.FlowManager + + // WireGuard interface monitor + wgIfaceMonitor *WGIfaceMonitor + wgIfaceMonitorWg sync.WaitGroup } // Peer is an instance of the Connection Peer @@ -341,6 +345,9 @@ func (e *Engine) Stop() error { log.Errorf("failed to persist state: %v", err) } + // Stop WireGuard interface monitor and wait for it to exit + e.wgIfaceMonitorWg.Wait() + return nil } @@ -479,6 +486,22 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) // starting network monitor at the very last to avoid disruptions e.startNetworkMonitor() + + // monitor WireGuard interface lifecycle and restart engine on changes + e.wgIfaceMonitor = NewWGIfaceMonitor() + e.wgIfaceMonitorWg.Add(1) + + go func() { + defer e.wgIfaceMonitorWg.Done() + + if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart { + log.Infof("WireGuard interface monitor: %s, restarting engine", err) + e.restartEngine() + } else if err != nil { + log.Warnf("WireGuard interface monitor: %s", err) + } + }() + return nil } diff --git a/client/internal/wg_iface_monitor.go b/client/internal/wg_iface_monitor.go new file mode 100644 index 000000000..78d70c15b --- /dev/null +++ b/client/internal/wg_iface_monitor.go @@ -0,0 +1,98 @@ +package internal + +import ( + "context" + "errors" + "fmt" + "net" + "runtime" + "time" + + log "github.com/sirupsen/logrus" +) + +// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine +// if the interface is deleted externally while the engine is running. +type WGIfaceMonitor struct { + done chan struct{} +} + +// NewWGIfaceMonitor creates a new WGIfaceMonitor instance. +func NewWGIfaceMonitor() *WGIfaceMonitor { + return &WGIfaceMonitor{ + done: make(chan struct{}), + } +} + +// Start begins monitoring the WireGuard interface. +// It relies on the provided context cancellation to stop. +func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) { + defer close(m.done) + + // Skip on mobile platforms as they handle interface lifecycle differently + if runtime.GOOS == "android" || runtime.GOOS == "ios" { + log.Debugf("Interface monitor: skipped on %s platform", runtime.GOOS) + return false, errors.New("not supported on mobile platforms") + } + + if ifaceName == "" { + log.Debugf("Interface monitor: empty interface name, skipping monitor") + return false, errors.New("empty interface name") + } + + // Get initial interface index to track the specific interface instance + expectedIndex, err := getInterfaceIndex(ifaceName) + if err != nil { + log.Debugf("Interface monitor: interface %s not found, skipping monitor", ifaceName) + return false, fmt.Errorf("interface %s not found: %w", ifaceName, err) + } + + log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex) + + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Infof("Interface monitor: stopped for %s", ifaceName) + return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err()) + case <-ticker.C: + currentIndex, err := getInterfaceIndex(ifaceName) + if err != nil { + // Interface was deleted + log.Infof("Interface monitor: %s deleted", ifaceName) + return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err) + } + + // Check if interface index changed (interface was recreated) + if currentIndex != expectedIndex { + log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine", + ifaceName, expectedIndex, currentIndex) + return true, nil + } + } + } + +} + +// getInterfaceIndex returns the index of a network interface by name. +// Returns an error if the interface is not found. +func getInterfaceIndex(name string) (int, error) { + if name == "" { + return 0, fmt.Errorf("empty interface name") + } + ifi, err := net.InterfaceByName(name) + if err != nil { + // Check if it's specifically a "not found" error + if errors.Is(err, &net.OpError{}) { + // On some systems, this might be a "not found" error + return 0, fmt.Errorf("interface not found: %w", err) + } + return 0, fmt.Errorf("failed to lookup interface: %w", err) + } + if ifi == nil { + return 0, fmt.Errorf("interface not found") + } + return ifi.Index, nil +} From 25ed58328aa418e88cdc78b6961c9b0b015557e3 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Thu, 25 Sep 2025 16:29:14 +0200 Subject: [PATCH 21/30] [management] fix network map dns filter (#4547) --- management/server/dns.go | 30 ++---------------------------- management/server/dns_test.go | 9 --------- management/server/types/account.go | 1 - 3 files changed, 2 insertions(+), 38 deletions(-) diff --git a/management/server/dns.go b/management/server/dns.go index f6f0201d3..6b73dbd0e 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -20,29 +20,9 @@ import ( // DNSConfigCache is a thread-safe cache for DNS configuration components type DNSConfigCache struct { - CustomZones sync.Map NameServerGroups sync.Map } -// GetCustomZone retrieves a cached custom zone -func (c *DNSConfigCache) GetCustomZone(key string) (*proto.CustomZone, bool) { - if c == nil { - return nil, false - } - if value, ok := c.CustomZones.Load(key); ok { - return value.(*proto.CustomZone), true - } - return nil, false -} - -// SetCustomZone stores a custom zone in the cache -func (c *DNSConfigCache) SetCustomZone(key string, value *proto.CustomZone) { - if c == nil { - return - } - c.CustomZones.Store(key, value) -} - // GetNameServerGroup retrieves a cached name server group func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) { if c == nil { @@ -212,14 +192,8 @@ func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSC } for _, zone := range update.CustomZones { - cacheKey := zone.Domain - if cachedZone, exists := cache.GetCustomZone(cacheKey); exists { - protoUpdate.CustomZones = append(protoUpdate.CustomZones, cachedZone) - } else { - protoZone := convertToProtoCustomZone(zone) - cache.SetCustomZone(cacheKey, protoZone) - protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) - } + protoZone := convertToProtoCustomZone(zone) + protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) } for _, nsGroup := range update.NameServerGroups { diff --git a/management/server/dns_test.go b/management/server/dns_test.go index d58689544..55a1bbe66 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -474,15 +474,6 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { t.Errorf("Results should be different for different inputs") } - // Verify that the cache contains elements from both configs - if _, exists := cache.GetCustomZone("example.com"); !exists { - t.Errorf("Cache should contain custom zone for example.com") - } - - if _, exists := cache.GetCustomZone("example.org"); !exists { - t.Errorf("Cache should contain custom zone for example.org") - } - if _, exists := cache.GetNameServerGroup("group1"); !exists { t.Errorf("Cache should contain name server group 'group1'") } diff --git a/management/server/types/account.go b/management/server/types/account.go index ca075b9f6..a69d3bb08 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -300,7 +300,6 @@ func (a *Account) GetPeerNetworkMap( if dnsManagementStatus { var zones []nbdns.CustomZone - if peersCustomZone.Domain != "" { records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect) zones = append(zones, nbdns.CustomZone{ From 17bab881f78bc49efe68354a1e16e8870c05b530 Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Fri, 26 Sep 2025 16:42:18 +0700 Subject: [PATCH 22/30] [client] Add Windows DNS Policies To GPO Path Always (#4460) [client] Add Windows DNS Policies To GPO Path Always (#4460) --- client/internal/dns/host_windows.go | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index fdc2c3063..0d3f033fb 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -240,15 +240,17 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 for i, domain := range domains { - policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i) - if r.gpo { - policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i) - } singleDomain := []string{domain} - if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil { - return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err) + if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, singleDomain, ip); err != nil { + return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err) + } + + if r.gpo { + if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, singleDomain, ip); err != nil { + return i, fmt.Errorf("configure gpo DNS policy: %w", err) + } } log.Debugf("added NRPT entry for domain: %s", domain) @@ -401,6 +403,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error { if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err)) } + if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err)) } @@ -412,6 +415,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error { if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err)) } + if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err)) } From e8d301fdc9357b9ca9ecf9ce40d4d347255d3bf9 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 30 Sep 2025 15:31:18 +0200 Subject: [PATCH 23/30] [client] Fix/pkg loss (#3338) The Relayed connection setup is optimistic. It does not have any confirmation of an established end-to-end connection. Peers start sending WireGuard handshake packets immediately after the successful offer-answer handshake. Meanwhile, for successful P2P connection negotiation, we change the WireGuard endpoint address, but this change does not trigger new handshake initiation. Because the peer switched from Relayed connection to P2P, the packets from the Relay server are dropped and must wait for the next WireGuard handshake via P2P. To avoid this scenario, the relayed WireGuard proxy no longer drops the packets. Instead, it rewrites the source address to the new P2P endpoint and continues forwarding the packets. We still have one corner case: if the Relayed server negotiation chooses a server that has not been used before. In this case, one side of the peer connection will be slower to reach the Relay server, and the Relay server will drop the handshake packet. If everything goes well we should see exactly 5 seconds improvements between the WireGuard configuration time and the handshake time. --- client/iface/bind/endpoint.go | 14 +- client/iface/bind/ice_bind.go | 15 +- client/iface/iface_new_freebsd.go | 41 +++++ .../{iface_new_unix.go => iface_new_linux.go} | 2 +- client/iface/wgproxy/bind/proxy.go | 107 ++++++++---- client/iface/wgproxy/ebpf/proxy.go | 59 ++----- client/iface/wgproxy/ebpf/wrapper.go | 79 +++++---- client/iface/wgproxy/factory_kernel.go | 1 - .../iface/wgproxy/factory_kernel_freebsd.go | 31 ---- client/iface/wgproxy/proxy.go | 5 + client/iface/wgproxy/proxy_linux_test.go | 104 +++++++----- client/iface/wgproxy/proxy_seed_test.go | 39 +++++ client/iface/wgproxy/proxy_test.go | 152 ++++++++++++++---- client/iface/wgproxy/rawsocket/rawsocket.go | 50 ++++++ client/iface/wgproxy/udp/proxy.go | 94 ++++++++--- client/iface/wgproxy/udp/rawsocket.go | 101 ++++++++++++ client/internal/peer/conn.go | 62 ++++--- client/internal/peer/endpoint.go | 105 ++++++++++++ 18 files changed, 784 insertions(+), 277 deletions(-) create mode 100644 client/iface/iface_new_freebsd.go rename client/iface/{iface_new_unix.go => iface_new_linux.go} (97%) delete mode 100644 client/iface/wgproxy/factory_kernel_freebsd.go create mode 100644 client/iface/wgproxy/proxy_seed_test.go create mode 100644 client/iface/wgproxy/rawsocket/rawsocket.go create mode 100644 client/iface/wgproxy/udp/rawsocket.go create mode 100644 client/internal/peer/endpoint.go diff --git a/client/iface/bind/endpoint.go b/client/iface/bind/endpoint.go index 1926ff88f..caa92f05d 100644 --- a/client/iface/bind/endpoint.go +++ b/client/iface/bind/endpoint.go @@ -1,5 +1,17 @@ package bind -import wgConn "golang.zx2c4.com/wireguard/conn" +import ( + "net" + + wgConn "golang.zx2c4.com/wireguard/conn" +) type Endpoint = wgConn.StdNetEndpoint + +func EndpointToUDPAddr(e Endpoint) *net.UDPAddr { + return &net.UDPAddr{ + IP: e.Addr().AsSlice(), + Port: int(e.Port()), + Zone: e.Addr().Zone(), + } +} diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index 577c7c0c4..ef630b9d0 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -1,6 +1,7 @@ package bind import ( + "context" "encoding/binary" "fmt" "net" @@ -42,7 +43,7 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD // use the port because in the Send function the wgConn.Endpoint the port info is not exported. type ICEBind struct { *wgConn.StdNetBind - RecvChan chan RecvMessage + recvChan chan RecvMessage transportNet transport.Net filterFn udpmux.FilterFn @@ -65,7 +66,7 @@ func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wg b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) ib := &ICEBind{ StdNetBind: b, - RecvChan: make(chan RecvMessage, 1), + recvChan: make(chan RecvMessage, 1), transportNet: transportNet, filterFn: filterFn, endpoints: make(map[netip.Addr]net.Conn), @@ -155,6 +156,14 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { return nil } +func (b *ICEBind) Recv(ctx context.Context, msg RecvMessage) { + select { + case <-ctx.Done(): + return + case b.recvChan <- msg: + } +} + func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc { s.muUDPMux.Lock() defer s.muUDPMux.Unlock() @@ -271,7 +280,7 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo select { case <-c.closedChan: return 0, net.ErrClosed - case msg, ok := <-c.RecvChan: + case msg, ok := <-c.recvChan: if !ok { return 0, net.ErrClosed } diff --git a/client/iface/iface_new_freebsd.go b/client/iface/iface_new_freebsd.go new file mode 100644 index 000000000..86ed14ce1 --- /dev/null +++ b/client/iface/iface_new_freebsd.go @@ -0,0 +1,41 @@ +//go:build freebsd + +package iface + +import ( + "fmt" + + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// NewWGIFace Creates a new WireGuard interface instance +func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { + wgAddress, err := wgaddr.ParseWGAddress(opts.Address) + if err != nil { + return nil, err + } + + wgIFace := &WGIface{} + + if netstack.IsEnabled() { + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) + wgIFace.userspaceBind = true + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + return wgIFace, nil + } + + if device.ModuleTunIsLoaded() { + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) + wgIFace.userspaceBind = true + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + return wgIFace, nil + } + + return nil, fmt.Errorf("couldn't check or load tun module") +} diff --git a/client/iface/iface_new_unix.go b/client/iface/iface_new_linux.go similarity index 97% rename from client/iface/iface_new_unix.go rename to client/iface/iface_new_linux.go index 493144f13..77fd30fae 100644 --- a/client/iface/iface_new_unix.go +++ b/client/iface/iface_new_linux.go @@ -1,4 +1,4 @@ -//go:build (linux && !android) || freebsd +//go:build linux && !android package iface diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index bf6da72c2..dbc694e91 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -16,28 +16,37 @@ import ( "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) +type IceBind interface { + SetEndpoint(fakeIP netip.Addr, conn net.Conn) + RemoveEndpoint(fakeIP netip.Addr) + Recv(ctx context.Context, msg bind.RecvMessage) + MTU() uint16 +} + type ProxyBind struct { - Bind *bind.ICEBind + bind IceBind - fakeNetIP *netip.AddrPort - wgBindEndpoint *bind.Endpoint - remoteConn net.Conn - ctx context.Context - cancel context.CancelFunc - closeMu sync.Mutex - closed bool + // wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address + wgRelayedEndpoint *bind.Endpoint + wgCurrentUsed *bind.Endpoint + remoteConn net.Conn + ctx context.Context + cancel context.CancelFunc + closeMu sync.Mutex + closed bool - pausedMu sync.Mutex - paused bool - isStarted bool + paused bool + pausedCond *sync.Cond + isStarted bool closeListener *listener.CloseListener } -func NewProxyBind(bind *bind.ICEBind) *ProxyBind { +func NewProxyBind(bind IceBind) *ProxyBind { p := &ProxyBind{ - Bind: bind, + bind: bind, closeListener: listener.NewCloseListener(), + pausedCond: sync.NewCond(&sync.Mutex{}), } return p @@ -46,25 +55,25 @@ func NewProxyBind(bind *bind.ICEBind) *ProxyBind { // AddTurnConn adds a new connection to the bind. // endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the // WireGuard configuration. +// +// Parameters: +// - ctx: Context is used for proxyToLocal to avoid unnecessary error messages +// - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address +// - remoteConn: The established TURN connection to the remote peer func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error { fakeNetIP, err := fakeAddress(nbAddr) if err != nil { return err } - - p.fakeNetIP = fakeNetIP - p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP} + p.wgRelayedEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP} p.remoteConn = remoteConn p.ctx, p.cancel = context.WithCancel(ctx) return nil } + func (p *ProxyBind) EndpointAddr() *net.UDPAddr { - return &net.UDPAddr{ - IP: p.fakeNetIP.Addr().AsSlice(), - Port: int(p.fakeNetIP.Port()), - Zone: p.fakeNetIP.Addr().Zone(), - } + return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint) } func (p *ProxyBind) SetDisconnectListener(disconnected func()) { @@ -76,17 +85,21 @@ func (p *ProxyBind) Work() { return } - p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn) + p.bind.SetEndpoint(p.wgRelayedEndpoint.Addr(), p.remoteConn) - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = false - p.pausedMu.Unlock() + + p.wgCurrentUsed = p.wgRelayedEndpoint // Start the proxy only once if !p.isStarted { p.isStarted = true go p.proxyToLocal(p.ctx) } + + p.pausedCond.Signal() + p.pausedCond.L.Unlock() } func (p *ProxyBind) Pause() { @@ -94,9 +107,25 @@ func (p *ProxyBind) Pause() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = true - p.pausedMu.Unlock() + p.pausedCond.L.Unlock() +} + +func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) { + p.pausedCond.L.Lock() + p.paused = false + + p.wgCurrentUsed = addrToEndpoint(endpoint) + + p.pausedCond.Signal() + p.pausedCond.L.Unlock() +} + +func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint { + ip, _ := netip.AddrFromSlice(addr.IP.To4()) + addrPort := netip.AddrPortFrom(ip, uint16(addr.Port)) + return &bind.Endpoint{AddrPort: addrPort} } func (p *ProxyBind) CloseConn() error { @@ -107,6 +136,10 @@ func (p *ProxyBind) CloseConn() error { } func (p *ProxyBind) close() error { + if p.remoteConn == nil { + return nil + } + p.closeMu.Lock() defer p.closeMu.Unlock() @@ -120,7 +153,12 @@ func (p *ProxyBind) close() error { p.cancel() - p.Bind.RemoveEndpoint(p.fakeNetIP.Addr()) + p.pausedCond.L.Lock() + p.paused = false + p.pausedCond.Signal() + p.pausedCond.L.Unlock() + + p.bind.RemoveEndpoint(p.wgRelayedEndpoint.Addr()) if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) { return rErr @@ -136,7 +174,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { }() for { - buf := make([]byte, p.Bind.MTU()+bufsize.WGBufferOverhead) + buf := make([]byte, p.bind.MTU()+bufsize.WGBufferOverhead) n, err := p.remoteConn.Read(buf) if err != nil { if ctx.Err() != nil { @@ -147,18 +185,17 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { return } - p.pausedMu.Lock() - if p.paused { - p.pausedMu.Unlock() - continue + p.pausedCond.L.Lock() + for p.paused { + p.pausedCond.Wait() } msg := bind.RecvMessage{ - Endpoint: p.wgBindEndpoint, + Endpoint: p.wgCurrentUsed, Buffer: buf[:n], } - p.Bind.RecvChan <- msg - p.pausedMu.Unlock() + p.bind.Recv(ctx, msg) + p.pausedCond.L.Unlock() } } diff --git a/client/iface/wgproxy/ebpf/proxy.go b/client/iface/wgproxy/ebpf/proxy.go index b899f1694..858143091 100644 --- a/client/iface/wgproxy/ebpf/proxy.go +++ b/client/iface/wgproxy/ebpf/proxy.go @@ -6,9 +6,7 @@ import ( "context" "fmt" "net" - "os" "sync" - "syscall" "github.com/google/gopacket" "github.com/google/gopacket/layers" @@ -18,6 +16,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/iface/bufsize" + "github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket" "github.com/netbirdio/netbird/client/internal/ebpf" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" nbnet "github.com/netbirdio/netbird/client/net" @@ -27,6 +26,10 @@ const ( loopbackAddr = "127.0.0.1" ) +var ( + localHostNetIP = net.ParseIP("127.0.0.1") +) + // WGEBPFProxy definition for proxy with EBPF support type WGEBPFProxy struct { localWGListenPort int @@ -64,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error { return err } - p.rawConn, err = p.prepareSenderRawSocket() + p.rawConn, err = rawsocket.PrepareSenderRawSocket() if err != nil { return err } @@ -214,57 +217,17 @@ generatePort: return p.lastUsedPort, nil } -func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { - // Create a raw socket. - fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW) - if err != nil { - return nil, fmt.Errorf("creating raw socket failed: %w", err) - } - - // Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet. - err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1) - if err != nil { - return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err) - } - - // Bind the socket to the "lo" interface. - err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo") - if err != nil { - return nil, fmt.Errorf("binding to lo interface failed: %w", err) - } - - // Set the fwmark on the socket. - err = nbnet.SetSocketOpt(fd) - if err != nil { - return nil, fmt.Errorf("setting fwmark failed: %w", err) - } - - // Convert the file descriptor to a PacketConn. - file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) - if file == nil { - return nil, fmt.Errorf("converting fd to file failed") - } - packetConn, err := net.FilePacketConn(file) - if err != nil { - return nil, fmt.Errorf("converting file to packet conn failed: %w", err) - } - - return packetConn, nil -} - -func (p *WGEBPFProxy) sendPkg(data []byte, port int) error { - localhost := net.ParseIP("127.0.0.1") - +func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error { payload := gopacket.Payload(data) ipH := &layers.IPv4{ - DstIP: localhost, - SrcIP: localhost, + DstIP: localHostNetIP, + SrcIP: endpointAddr.IP, Version: 4, TTL: 64, Protocol: layers.IPProtocolUDP, } udpH := &layers.UDP{ - SrcPort: layers.UDPPort(port), + SrcPort: layers.UDPPort(endpointAddr.Port), DstPort: layers.UDPPort(p.localWGListenPort), } @@ -279,7 +242,7 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port int) error { if err != nil { return fmt.Errorf("serialize layers: %w", err) } - if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil { + if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil { return fmt.Errorf("write to raw conn: %w", err) } return nil diff --git a/client/iface/wgproxy/ebpf/wrapper.go b/client/iface/wgproxy/ebpf/wrapper.go index 3d71b01bd..ff44d30c0 100644 --- a/client/iface/wgproxy/ebpf/wrapper.go +++ b/client/iface/wgproxy/ebpf/wrapper.go @@ -18,41 +18,42 @@ import ( // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call type ProxyWrapper struct { - WgeBPFProxy *WGEBPFProxy + wgeBPFProxy *WGEBPFProxy remoteConn net.Conn ctx context.Context cancel context.CancelFunc - wgEndpointAddr *net.UDPAddr + wgRelayedEndpointAddr *net.UDPAddr + wgEndpointCurrentUsedAddr *net.UDPAddr - pausedMu sync.Mutex - paused bool - isStarted bool + paused bool + pausedCond *sync.Cond + isStarted bool closeListener *listener.CloseListener } -func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper { +func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper { return &ProxyWrapper{ - WgeBPFProxy: WgeBPFProxy, + wgeBPFProxy: proxy, + pausedCond: sync.NewCond(&sync.Mutex{}), closeListener: listener.NewCloseListener(), } } - func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { - addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) + addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn) if err != nil { return fmt.Errorf("add turn conn: %w", err) } p.remoteConn = remoteConn p.ctx, p.cancel = context.WithCancel(ctx) - p.wgEndpointAddr = addr + p.wgRelayedEndpointAddr = addr return err } func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { - return p.wgEndpointAddr + return p.wgRelayedEndpointAddr } func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) { @@ -64,14 +65,18 @@ func (p *ProxyWrapper) Work() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = false - p.pausedMu.Unlock() + + p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr if !p.isStarted { p.isStarted = true go p.proxyToLocal(p.ctx) } + + p.pausedCond.Signal() + p.pausedCond.L.Unlock() } func (p *ProxyWrapper) Pause() { @@ -80,45 +85,59 @@ func (p *ProxyWrapper) Pause() { } log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr()) - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = true - p.pausedMu.Unlock() + p.pausedCond.L.Unlock() +} + +func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) { + p.pausedCond.L.Lock() + p.paused = false + + p.wgEndpointCurrentUsedAddr = endpoint + + p.pausedCond.Signal() + p.pausedCond.L.Unlock() } // CloseConn close the remoteConn and automatically remove the conn instance from the map -func (e *ProxyWrapper) CloseConn() error { - if e.cancel == nil { +func (p *ProxyWrapper) CloseConn() error { + if p.cancel == nil { return fmt.Errorf("proxy not started") } - e.cancel() + p.cancel() - e.closeListener.SetCloseListener(nil) + p.closeListener.SetCloseListener(nil) - if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { - return fmt.Errorf("close remote conn: %w", err) + p.pausedCond.L.Lock() + p.paused = false + p.pausedCond.Signal() + p.pausedCond.L.Unlock() + + if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + return fmt.Errorf("failed to close remote conn: %w", err) } return nil } func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { - defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port)) + defer p.wgeBPFProxy.removeTurnConn(uint16(p.wgRelayedEndpointAddr.Port)) - buf := make([]byte, p.WgeBPFProxy.mtu+bufsize.WGBufferOverhead) + buf := make([]byte, p.wgeBPFProxy.mtu+bufsize.WGBufferOverhead) for { n, err := p.readFromRemote(ctx, buf) if err != nil { return } - p.pausedMu.Lock() - if p.paused { - p.pausedMu.Unlock() - continue + p.pausedCond.L.Lock() + for p.paused { + p.pausedCond.Wait() } - err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port) - p.pausedMu.Unlock() + err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr) + p.pausedCond.L.Unlock() if err != nil { if ctx.Err() != nil { @@ -137,7 +156,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err } p.closeListener.Notify() if !errors.Is(err, io.EOF) { - log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) + log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgRelayedEndpointAddr.Port, err) } return 0, err } diff --git a/client/iface/wgproxy/factory_kernel.go b/client/iface/wgproxy/factory_kernel.go index 63bc2ed24..ad2807546 100644 --- a/client/iface/wgproxy/factory_kernel.go +++ b/client/iface/wgproxy/factory_kernel.go @@ -39,7 +39,6 @@ func (w *KernelFactory) GetProxy() Proxy { } return ebpf.NewProxyWrapper(w.ebpfProxy) - } func (w *KernelFactory) Free() error { diff --git a/client/iface/wgproxy/factory_kernel_freebsd.go b/client/iface/wgproxy/factory_kernel_freebsd.go deleted file mode 100644 index 039f1cd3a..000000000 --- a/client/iface/wgproxy/factory_kernel_freebsd.go +++ /dev/null @@ -1,31 +0,0 @@ -package wgproxy - -import ( - log "github.com/sirupsen/logrus" - - udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" -) - -// KernelFactory todo: check eBPF support on FreeBSD -type KernelFactory struct { - wgPort int - mtu uint16 -} - -func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory { - log.Infof("WireGuard Proxy Factory will produce UDP proxy") - f := &KernelFactory{ - wgPort: wgPort, - mtu: mtu, - } - - return f -} - -func (w *KernelFactory) GetProxy() Proxy { - return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu) -} - -func (w *KernelFactory) Free() error { - return nil -} diff --git a/client/iface/wgproxy/proxy.go b/client/iface/wgproxy/proxy.go index c2879877e..3c8dfd30e 100644 --- a/client/iface/wgproxy/proxy.go +++ b/client/iface/wgproxy/proxy.go @@ -11,6 +11,11 @@ type Proxy interface { EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint Work() // Work start or resume the proxy Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. + + //RedirectAs resume the forwarding the packages from relayed connection to WireGuard interface if it was paused + //and rewrite the src address to the endpoint address. + //With this logic can avoid the package loss from relayed connections. + RedirectAs(endpoint *net.UDPAddr) CloseConn() error SetDisconnectListener(disconnected func()) } diff --git a/client/iface/wgproxy/proxy_linux_test.go b/client/iface/wgproxy/proxy_linux_test.go index 5add503e1..9526e91d2 100644 --- a/client/iface/wgproxy/proxy_linux_test.go +++ b/client/iface/wgproxy/proxy_linux_test.go @@ -3,54 +3,82 @@ package wgproxy import ( - "context" - "os" - "testing" + "fmt" + "net" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/wgaddr" + bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind" "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" + "github.com/netbirdio/netbird/client/iface/wgproxy/udp" ) -func TestProxyCloseByRemoteConnEBPF(t *testing.T) { - if os.Getenv("GITHUB_ACTIONS") != "true" { - t.Skip("Skipping test as it requires root privileges") - } - ctx := context.Background() +func seedProxies() ([]proxyInstance, error) { + pl := make([]proxyInstance, 0) ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) if err := ebpfProxy.Listen(); err != nil { - t.Fatalf("failed to initialize ebpf proxy: %s", err) + return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err) } - defer func() { - if err := ebpfProxy.Free(); err != nil { - t.Errorf("failed to free ebpf proxy: %s", err) - } - }() - - tests := []struct { - name string - proxy Proxy - }{ - { - name: "ebpf proxy", - proxy: &ebpf.ProxyWrapper{ - WgeBPFProxy: ebpfProxy, - }, - }, + pEbpf := proxyInstance{ + name: "ebpf kernel proxy", + proxy: ebpf.NewProxyWrapper(ebpfProxy), + wgPort: 51831, + closeFn: ebpfProxy.Free, } + pl = append(pl, pEbpf) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - relayedConn := newMockConn() - err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) - if err != nil { - t.Errorf("error: %v", err) - } - - _ = relayedConn.Close() - if err := tt.proxy.CloseConn(); err != nil { - t.Errorf("error: %v", err) - } - }) + pUDP := proxyInstance{ + name: "udp kernel proxy", + proxy: udp.NewWGUDPProxy(51832, 1280), + wgPort: 51832, + closeFn: func() error { return nil }, } + pl = append(pl, pUDP) + return pl, nil +} + +func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) { + pl := make([]proxyInstance, 0) + + ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) + if err := ebpfProxy.Listen(); err != nil { + return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err) + } + + pEbpf := proxyInstance{ + name: "ebpf kernel proxy", + proxy: ebpf.NewProxyWrapper(ebpfProxy), + wgPort: 51831, + closeFn: ebpfProxy.Free, + } + pl = append(pl, pEbpf) + + pUDP := proxyInstance{ + name: "udp kernel proxy", + proxy: udp.NewWGUDPProxy(51832, 1280), + wgPort: 51832, + closeFn: func() error { return nil }, + } + pl = append(pl, pUDP) + wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32") + if err != nil { + return nil, err + } + iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280) + endpointAddress := &net.UDPAddr{ + IP: net.IPv4(10, 0, 0, 1), + Port: 1234, + } + + pBind := proxyInstance{ + name: "bind proxy", + proxy: bindproxy.NewProxyBind(iceBind), + endpointAddr: endpointAddress, + closeFn: func() error { return nil }, + } + pl = append(pl, pBind) + + return pl, nil } diff --git a/client/iface/wgproxy/proxy_seed_test.go b/client/iface/wgproxy/proxy_seed_test.go new file mode 100644 index 000000000..4d244f18a --- /dev/null +++ b/client/iface/wgproxy/proxy_seed_test.go @@ -0,0 +1,39 @@ +//go:build !linux + +package wgproxy + +import ( + "net" + + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/wgaddr" + bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind" +) + +func seedProxies() ([]proxyInstance, error) { + // todo extend with Bind proxy + pl := make([]proxyInstance, 0) + return pl, nil +} + +func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) { + pl := make([]proxyInstance, 0) + wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32") + if err != nil { + return nil, err + } + iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280) + endpointAddress := &net.UDPAddr{ + IP: net.IPv4(10, 0, 0, 1), + Port: 1234, + } + + pBind := proxyInstance{ + name: "bind proxy", + proxy: bindproxy.NewProxyBind(iceBind), + endpointAddr: endpointAddress, + closeFn: func() error { return nil }, + } + pl = append(pl, pBind) + return pl, nil +} diff --git a/client/iface/wgproxy/proxy_test.go b/client/iface/wgproxy/proxy_test.go index 76e5ed6f7..1aeab66b7 100644 --- a/client/iface/wgproxy/proxy_test.go +++ b/client/iface/wgproxy/proxy_test.go @@ -1,5 +1,3 @@ -//go:build linux - package wgproxy import ( @@ -7,12 +5,9 @@ import ( "io" "net" "os" - "runtime" "testing" "time" - "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" - udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" "github.com/netbirdio/netbird/util" ) @@ -22,6 +17,14 @@ func TestMain(m *testing.M) { os.Exit(code) } +type proxyInstance struct { + name string + proxy Proxy + wgPort int + endpointAddr *net.UDPAddr + closeFn func() error +} + type mocConn struct { closeChan chan struct{} closed bool @@ -78,41 +81,21 @@ func (m *mocConn) SetWriteDeadline(t time.Time) error { func TestProxyCloseByRemoteConn(t *testing.T) { ctx := context.Background() - tests := []struct { - name string - proxy Proxy - }{ - { - name: "userspace proxy", - proxy: udpProxy.NewWGUDPProxy(51830, 1280), - }, + tests, err := seedProxyForProxyCloseByRemoteConn() + if err != nil { + t.Fatalf("error: %v", err) } - if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" { - ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) - if err := ebpfProxy.Listen(); err != nil { - t.Fatalf("failed to initialize ebpf proxy: %s", err) - } - defer func() { - if err := ebpfProxy.Free(); err != nil { - t.Errorf("failed to free ebpf proxy: %s", err) - } - }() - proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy) - - tests = append(tests, struct { - name string - proxy Proxy - }{ - name: "ebpf proxy", - proxy: proxyWrapper, - }) - } + relayedConn, _ := net.Dial("udp", "127.0.0.1:1234") + defer func() { + _ = relayedConn.Close() + }() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + addr, _ := net.ResolveUDPAddr("udp", "100.108.135.221:51892") relayedConn := newMockConn() - err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) + err := tt.proxy.AddTurnConn(ctx, addr, relayedConn) if err != nil { t.Errorf("error: %v", err) } @@ -124,3 +107,104 @@ func TestProxyCloseByRemoteConn(t *testing.T) { }) } } + +// TestProxyRedirect todo extend the proxies with Bind proxy +func TestProxyRedirect(t *testing.T) { + tests, err := seedProxies() + if err != nil { + t.Fatalf("error: %v", err) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + redirectTraffic(t, tt.proxy, tt.wgPort, tt.endpointAddr) + if err := tt.closeFn(); err != nil { + t.Errorf("error: %v", err) + } + }) + } +} + +func redirectTraffic(t *testing.T, proxy Proxy, wgPort int, endPointAddr *net.UDPAddr) { + t.Helper() + + msgHelloFromRelay := []byte("hello from relay") + msgRedirected := [][]byte{ + []byte("hello 1. to p2p"), + []byte("hello 2. to p2p"), + []byte("hello 3. to p2p"), + } + + dummyWgListener, err := net.ListenUDP("udp", &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: wgPort}) + if err != nil { + t.Fatalf("failed to listen on udp port: %s", err) + } + + relayedServer, _ := net.ListenUDP("udp", + &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 1234, + }, + ) + + relayedConn, _ := net.Dial("udp", "127.0.0.1:1234") + + defer func() { + _ = dummyWgListener.Close() + _ = relayedConn.Close() + _ = relayedServer.Close() + }() + + if err := proxy.AddTurnConn(context.Background(), endPointAddr, relayedConn); err != nil { + t.Errorf("error: %v", err) + } + defer func() { + if err := proxy.CloseConn(); err != nil { + t.Errorf("error: %v", err) + } + }() + + proxy.Work() + + if _, err := relayedServer.WriteTo(msgHelloFromRelay, relayedConn.LocalAddr()); err != nil { + t.Errorf("error relayedServer.Write(msgHelloFromRelay): %v", err) + } + + n, err := dummyWgListener.Read(make([]byte, 1024)) + if err != nil { + t.Errorf("error: %v", err) + } + + if n != len(msgHelloFromRelay) { + t.Errorf("expected %d bytes, got %d", len(msgHelloFromRelay), n) + } + + p2pEndpointAddr := &net.UDPAddr{ + IP: net.IPv4(192, 168, 0, 56), + Port: 1234, + } + proxy.RedirectAs(p2pEndpointAddr) + + for _, msg := range msgRedirected { + if _, err := relayedServer.WriteTo(msg, relayedConn.LocalAddr()); err != nil { + t.Errorf("error: %v", err) + } + } + + for i := 0; i < len(msgRedirected); i++ { + buf := make([]byte, 1024) + n, rAddr, err := dummyWgListener.ReadFrom(buf) + if err != nil { + t.Errorf("error: %v", err) + } + + if rAddr.String() != p2pEndpointAddr.String() { + t.Errorf("expected %s, got %s", p2pEndpointAddr.String(), rAddr.String()) + } + if string(buf[:n]) != string(msgRedirected[i]) { + t.Errorf("expected %s, got %s", string(msgRedirected[i]), string(buf[:n])) + } + } +} diff --git a/client/iface/wgproxy/rawsocket/rawsocket.go b/client/iface/wgproxy/rawsocket/rawsocket.go new file mode 100644 index 000000000..a11ac46d5 --- /dev/null +++ b/client/iface/wgproxy/rawsocket/rawsocket.go @@ -0,0 +1,50 @@ +//go:build linux && !android + +package rawsocket + +import ( + "fmt" + "net" + "os" + "syscall" + + nbnet "github.com/netbirdio/netbird/client/net" +) + +func PrepareSenderRawSocket() (net.PacketConn, error) { + // Create a raw socket. + fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW) + if err != nil { + return nil, fmt.Errorf("creating raw socket failed: %w", err) + } + + // Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet. + err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1) + if err != nil { + return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err) + } + + // Bind the socket to the "lo" interface. + err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo") + if err != nil { + return nil, fmt.Errorf("binding to lo interface failed: %w", err) + } + + // Set the fwmark on the socket. + err = nbnet.SetSocketOpt(fd) + if err != nil { + return nil, fmt.Errorf("setting fwmark failed: %w", err) + } + + // Convert the file descriptor to a PacketConn. + file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) + if file == nil { + return nil, fmt.Errorf("converting fd to file failed") + } + packetConn, err := net.FilePacketConn(file) + if err != nil { + return nil, fmt.Errorf("converting file to packet conn failed: %w", err) + } + + return packetConn, nil +} diff --git a/client/iface/wgproxy/udp/proxy.go b/client/iface/wgproxy/udp/proxy.go index be65e2b27..4ef2f19c4 100644 --- a/client/iface/wgproxy/udp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -1,3 +1,5 @@ +//go:build linux && !android + package udp import ( @@ -21,16 +23,18 @@ type WGUDPProxy struct { localWGListenPort int mtu uint16 - remoteConn net.Conn - localConn net.Conn - ctx context.Context - cancel context.CancelFunc - closeMu sync.Mutex - closed bool + remoteConn net.Conn + localConn net.Conn + srcFakerConn *SrcFaker + sendPkg func(data []byte) (int, error) + ctx context.Context + cancel context.CancelFunc + closeMu sync.Mutex + closed bool - pausedMu sync.Mutex - paused bool - isStarted bool + paused bool + pausedCond *sync.Cond + isStarted bool closeListener *listener.CloseListener } @@ -41,6 +45,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy { p := &WGUDPProxy{ localWGListenPort: wgPort, mtu: mtu, + pausedCond: sync.NewCond(&sync.Mutex{}), closeListener: listener.NewCloseListener(), } return p @@ -61,6 +66,7 @@ func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, rem p.ctx, p.cancel = context.WithCancel(ctx) p.localConn = localConn + p.sendPkg = p.localConn.Write p.remoteConn = remoteConn return err @@ -84,15 +90,24 @@ func (p *WGUDPProxy) Work() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = false - p.pausedMu.Unlock() + p.sendPkg = p.localConn.Write + + if p.srcFakerConn != nil { + if err := p.srcFakerConn.Close(); err != nil { + log.Errorf("failed to close src faker conn: %s", err) + } + p.srcFakerConn = nil + } if !p.isStarted { p.isStarted = true go p.proxyToRemote(p.ctx) go p.proxyToLocal(p.ctx) } + p.pausedCond.Signal() + p.pausedCond.L.Unlock() } // Pause pauses the proxy from receiving data from the remote peer @@ -101,9 +116,35 @@ func (p *WGUDPProxy) Pause() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = true - p.pausedMu.Unlock() + p.pausedCond.L.Unlock() +} + +// RedirectAs start to use the fake sourced raw socket as package sender +func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) { + p.pausedCond.L.Lock() + defer func() { + p.pausedCond.Signal() + p.pausedCond.L.Unlock() + }() + + p.paused = false + if p.srcFakerConn != nil { + if err := p.srcFakerConn.Close(); err != nil { + log.Errorf("failed to close src faker conn: %s", err) + } + p.srcFakerConn = nil + } + srcFakerConn, err := NewSrcFaker(p.localWGListenPort, endpoint) + if err != nil { + log.Errorf("failed to create src faker conn: %s", err) + // fallback to continue without redirecting + p.paused = true + return + } + p.srcFakerConn = srcFakerConn + p.sendPkg = p.srcFakerConn.SendPkg } // CloseConn close the localConn @@ -115,6 +156,8 @@ func (p *WGUDPProxy) CloseConn() error { } func (p *WGUDPProxy) close() error { + var result *multierror.Error + p.closeMu.Lock() defer p.closeMu.Unlock() @@ -128,7 +171,11 @@ func (p *WGUDPProxy) close() error { p.cancel() - var result *multierror.Error + p.pausedCond.L.Lock() + p.paused = false + p.pausedCond.Signal() + p.pausedCond.L.Unlock() + if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { result = multierror.Append(result, fmt.Errorf("remote conn: %s", err)) } @@ -136,6 +183,13 @@ func (p *WGUDPProxy) close() error { if err := p.localConn.Close(); err != nil { result = multierror.Append(result, fmt.Errorf("local conn: %s", err)) } + + if p.srcFakerConn != nil { + if err := p.srcFakerConn.Close(); err != nil { + result = multierror.Append(result, fmt.Errorf("src faker raw conn: %s", err)) + } + } + return cerrors.FormatErrorOrNil(result) } @@ -194,14 +248,12 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) { return } - p.pausedMu.Lock() - if p.paused { - p.pausedMu.Unlock() - continue + p.pausedCond.L.Lock() + for p.paused { + p.pausedCond.Wait() } - - _, err = p.localConn.Write(buf[:n]) - p.pausedMu.Unlock() + _, err = p.sendPkg(buf[:n]) + p.pausedCond.L.Unlock() if err != nil { if ctx.Err() != nil { diff --git a/client/iface/wgproxy/udp/rawsocket.go b/client/iface/wgproxy/udp/rawsocket.go new file mode 100644 index 000000000..fdc911463 --- /dev/null +++ b/client/iface/wgproxy/udp/rawsocket.go @@ -0,0 +1,101 @@ +//go:build linux && !android + +package udp + +import ( + "fmt" + "net" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket" +) + +var ( + serializeOpts = gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + localHostNetIPAddr = &net.IPAddr{ + IP: net.ParseIP("127.0.0.1"), + } +) + +type SrcFaker struct { + srcAddr *net.UDPAddr + + rawSocket net.PacketConn + ipH gopacket.SerializableLayer + udpH gopacket.SerializableLayer + layerBuffer gopacket.SerializeBuffer +} + +func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) { + rawSocket, err := rawsocket.PrepareSenderRawSocket() + if err != nil { + return nil, err + } + + ipH, udpH, err := prepareHeaders(dstPort, srcAddr) + if err != nil { + return nil, err + } + + f := &SrcFaker{ + srcAddr: srcAddr, + rawSocket: rawSocket, + ipH: ipH, + udpH: udpH, + layerBuffer: gopacket.NewSerializeBuffer(), + } + + return f, nil +} + +func (f *SrcFaker) Close() error { + return f.rawSocket.Close() +} + +func (f *SrcFaker) SendPkg(data []byte) (int, error) { + defer func() { + if err := f.layerBuffer.Clear(); err != nil { + log.Errorf("failed to clear layer buffer: %s", err) + } + }() + + payload := gopacket.Payload(data) + + err := gopacket.SerializeLayers(f.layerBuffer, serializeOpts, f.ipH, f.udpH, payload) + if err != nil { + return 0, fmt.Errorf("serialize layers: %w", err) + } + n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr) + if err != nil { + return 0, fmt.Errorf("write to raw conn: %w", err) + } + return n, nil +} + +func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) { + ipH := &layers.IPv4{ + DstIP: net.ParseIP("127.0.0.1"), + SrcIP: srcAddr.IP, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + } + udpH := &layers.UDP{ + SrcPort: layers.UDPPort(srcAddr.Port), + DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port + } + + err := udpH.SetNetworkLayerForChecksum(ipH) + if err != nil { + return nil, nil, fmt.Errorf("set network layer for checksum: %w", err) + } + + return ipH, udpH, nil +} diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 86e4596d4..8db9e58f4 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -28,10 +28,6 @@ import ( semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" ) -const ( - defaultWgKeepAlive = 25 * time.Second -) - type ServiceDependencies struct { StatusRecorder *Status Signaler *Signaler @@ -117,6 +113,8 @@ type Conn struct { // debug purpose dumpState *stateDump + + endpointUpdater *EndpointUpdater } // NewConn creates a new not opened Conn to the remote peer. @@ -129,17 +127,18 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) { connLog := log.WithField("peer", config.Key) var conn = &Conn{ - Log: connLog, - config: config, - statusRecorder: services.StatusRecorder, - signaler: services.Signaler, - iFaceDiscover: services.IFaceDiscover, - relayManager: services.RelayManager, - srWatcher: services.SrWatcher, - semaphore: services.Semaphore, - statusRelay: worker.NewAtomicStatus(), - statusICE: worker.NewAtomicStatus(), - dumpState: newStateDump(config.Key, connLog, services.StatusRecorder), + Log: connLog, + config: config, + statusRecorder: services.StatusRecorder, + signaler: services.Signaler, + iFaceDiscover: services.IFaceDiscover, + relayManager: services.RelayManager, + srWatcher: services.SrWatcher, + semaphore: services.Semaphore, + statusRelay: worker.NewAtomicStatus(), + statusICE: worker.NewAtomicStatus(), + dumpState: newStateDump(config.Key, connLog, services.StatusRecorder), + endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)), } return conn, nil @@ -249,7 +248,7 @@ func (conn *Conn) Close(signalToRemote bool) { conn.wgProxyICE = nil } - if err := conn.removeWgPeer(); err != nil { + if err := conn.endpointUpdater.RemoveWgPeer(); err != nil { conn.Log.Errorf("failed to remove wg endpoint: %v", err) } @@ -375,12 +374,19 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn wgProxy.Work() } - if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil { + conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String()) + presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey) + if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil { conn.handleConfigurationFailure(err, wgProxy) return } wgConfigWorkaround() + if conn.wgProxyRelay != nil { + conn.Log.Debugf("redirect packets from relayed conn to WireGuard") + conn.wgProxyRelay.RedirectAs(ep) + } + conn.currentConnPriority = priority conn.statusICE.SetConnected() conn.updateIceState(iceConnInfo) @@ -409,7 +415,8 @@ func (conn *Conn) onICEStateDisconnected() { conn.dumpState.SwitchToRelay() conn.wgProxyRelay.Work() - if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil { + presharedKey := conn.presharedKey(conn.rosenpassRemoteKey) + if err := conn.endpointUpdater.ConfigureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil { conn.Log.Errorf("failed to switch to relay conn: %v", err) } @@ -418,6 +425,7 @@ func (conn *Conn) onICEStateDisconnected() { defer conn.wgWatcherWg.Done() conn.workerRelay.EnableWgWatcher(conn.ctx) }() + conn.wgProxyRelay.Work() conn.currentConnPriority = conntype.Relay } else { conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String()) @@ -477,7 +485,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { } wgProxy.Work() - if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil { + presharedKey := conn.presharedKey(rci.rosenpassPubKey) + if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil { if err := wgProxy.CloseConn(); err != nil { conn.Log.Warnf("Failed to close relay connection: %v", err) } @@ -545,17 +554,6 @@ func (conn *Conn) onGuardEvent() { } } -func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error { - presharedKey := conn.presharedKey(remoteRPKey) - return conn.config.WgConfig.WgInterface.UpdatePeer( - conn.config.WgConfig.RemoteKey, - conn.config.WgConfig.AllowedIps, - defaultWgKeepAlive, - addr, - presharedKey, - ) -} - func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) { peerState := State{ PubKey: conn.config.Key, @@ -698,10 +696,6 @@ func (conn *Conn) isICEActive() bool { return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected } -func (conn *Conn) removeWgPeer() error { - return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) -} - func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { conn.Log.Warnf("Failed to update wg peer configuration: %v", err) if wgProxy != nil { diff --git a/client/internal/peer/endpoint.go b/client/internal/peer/endpoint.go new file mode 100644 index 000000000..39cb95591 --- /dev/null +++ b/client/internal/peer/endpoint.go @@ -0,0 +1,105 @@ +package peer + +import ( + "context" + "net" + "sync" + "time" + + "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +const ( + defaultWgKeepAlive = 25 * time.Second + fallbackDelay = 5 * time.Second +) + +type EndpointUpdater struct { + log *logrus.Entry + wgConfig WgConfig + initiator bool + + // mu protects updateWireGuardPeer and cancelFunc + mu sync.Mutex + cancelFunc func() + updateWg sync.WaitGroup +} + +func NewEndpointUpdater(log *logrus.Entry, wgConfig WgConfig, initiator bool) *EndpointUpdater { + return &EndpointUpdater{ + log: log, + wgConfig: wgConfig, + initiator: initiator, + } +} + +// ConfigureWGEndpoint sets up the WireGuard endpoint configuration. +// The initiator immediately configures the endpoint, while the non-initiator +// waits for a fallback period before configuring to avoid handshake congestion. +func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error { + e.mu.Lock() + defer e.mu.Unlock() + + if e.initiator { + e.log.Debugf("configure up WireGuard as initiatr") + return e.updateWireGuardPeer(addr, presharedKey) + } + + // prevent to run new update while cancel the previous update + e.waitForCloseTheDelayedUpdate() + + var ctx context.Context + ctx, e.cancelFunc = context.WithCancel(context.Background()) + e.updateWg.Add(1) + go e.scheduleDelayedUpdate(ctx, addr, presharedKey) + + e.log.Debugf("configure up WireGuard and wait for handshake") + return e.updateWireGuardPeer(nil, presharedKey) +} + +func (e *EndpointUpdater) RemoveWgPeer() error { + e.mu.Lock() + defer e.mu.Unlock() + + e.waitForCloseTheDelayedUpdate() + return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey) +} + +func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() { + if e.cancelFunc == nil { + return + } + + e.cancelFunc() + e.cancelFunc = nil + e.updateWg.Wait() +} + +// scheduleDelayedUpdate waits for the fallback period before updating the endpoint +func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.UDPAddr, presharedKey *wgtypes.Key) { + defer e.updateWg.Done() + t := time.NewTimer(fallbackDelay) + defer t.Stop() + + select { + case <-ctx.Done(): + return + case <-t.C: + e.mu.Lock() + if err := e.updateWireGuardPeer(addr, presharedKey); err != nil { + e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err) + } + e.mu.Unlock() + } +} + +func (e *EndpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, presharedKey *wgtypes.Key) error { + return e.wgConfig.WgInterface.UpdatePeer( + e.wgConfig.RemoteKey, + e.wgConfig.AllowedIps, + defaultWgKeepAlive, + endpoint, + presharedKey, + ) +} From 5e1a40c33fdd83427d9c5e4389e3338ffdbdf3bd Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 30 Sep 2025 23:40:46 +0200 Subject: [PATCH 24/30] [client] Order the list of candidates for proper comparison (#4561) Order the list of candidates for proper comparison --- client/internal/peer/guard/ice_monitor.go | 25 +++++++++++++++-------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/client/internal/peer/guard/ice_monitor.go b/client/internal/peer/guard/ice_monitor.go index 70850e6eb..09cf9ae63 100644 --- a/client/internal/peer/guard/ice_monitor.go +++ b/client/internal/peer/guard/ice_monitor.go @@ -3,6 +3,8 @@ package guard import ( "context" "fmt" + "slices" + "sort" "sync" "time" @@ -24,8 +26,8 @@ type ICEMonitor struct { iFaceDiscover stdnet.ExternalIFaceDiscover iceConfig icemaker.Config - currentCandidates []ice.Candidate - candidatesMu sync.Mutex + currentCandidatesAddress []string + candidatesMu sync.Mutex } func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor { @@ -115,16 +117,21 @@ func (cm *ICEMonitor) updateCandidates(newCandidates []ice.Candidate) bool { cm.candidatesMu.Lock() defer cm.candidatesMu.Unlock() - if len(cm.currentCandidates) != len(newCandidates) { - cm.currentCandidates = newCandidates + newAddresses := make([]string, len(newCandidates)) + for i, c := range newCandidates { + newAddresses[i] = c.Address() + } + sort.Strings(newAddresses) + + if len(cm.currentCandidatesAddress) != len(newAddresses) { + cm.currentCandidatesAddress = newAddresses return true } - for i, candidate := range cm.currentCandidates { - if candidate.Address() != newCandidates[i].Address() { - cm.currentCandidates = newCandidates - return true - } + // Compare elements + if !slices.Equal(cm.currentCandidatesAddress, newAddresses) { + cm.currentCandidatesAddress = newAddresses + return true } return false From b5daec3b51ee01ea779f727c74a0aa394a7a3d5d Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 1 Oct 2025 20:10:11 +0200 Subject: [PATCH 25/30] [client,signal,management] Add browser client support (#4415) --- .github/workflows/golangci-lint.yml | 2 +- .github/workflows/wasm-build-validation.yml | 67 +++++ .gitmodules | 0 .goreleaser.yaml | 17 ++ client/cmd/debug_js.go | 8 + client/cmd/testutil_test.go | 4 +- client/embed/embed.go | 60 +++- client/grpc/dialer.go | 42 +-- client/grpc/dialer_generic.go | 44 +++ client/grpc/dialer_js.go | 12 + client/iface/bind/error.go | 7 + client/iface/bind/ice_bind.go | 58 ++-- client/iface/bind/recv_msg.go | 6 + client/iface/bind/relay_bind.go | 125 ++++++++ client/iface/configurer/name.go | 2 +- client/iface/configurer/uapi.go | 2 +- client/iface/configurer/uapi_js.go | 23 ++ client/iface/device/device_netstack.go | 28 +- client/iface/device/device_netstack_test.go | 27 ++ client/iface/iface_destroy_js.go | 6 + client/iface/iface_new_android.go | 4 +- client/iface/iface_new_darwin.go | 2 +- client/iface/iface_new_freebsd.go | 4 +- client/iface/iface_new_ios.go | 2 +- client/iface/iface_new_js.go | 27 ++ client/iface/iface_new_linux.go | 4 +- client/iface/iface_new_windows.go | 2 +- client/iface/netstack/env.go | 2 + client/iface/netstack/env_js.go | 12 + client/iface/wgproxy/bind/proxy.go | 23 +- client/iface/wgproxy/factory_usp.go | 11 +- client/iface/wgproxy/proxy_linux_test.go | 2 +- client/iface/wgproxy/proxy_seed_test.go | 2 +- client/internal/dns/server_js.go | 5 + client/internal/dns/unclean_shutdown_js.go | 19 ++ client/internal/engine.go | 20 +- client/internal/engine_generic.go | 19 ++ client/internal/engine_js.go | 18 ++ client/internal/engine_test.go | 8 +- .../networkmonitor/check_change_js.go | 12 + .../routemanager/systemops/systemops_js.go | 48 ++++ .../systemops/systemops_nonlinux.go | 2 +- client/server/server_test.go | 17 +- client/ssh/client.go | 2 + client/ssh/login.go | 2 + client/ssh/server.go | 2 + client/ssh/server_mock.go | 2 + client/ssh/server_test.go | 2 + client/ssh/ssh_js.go | 137 +++++++++ client/ssh/util.go | 2 + client/system/info_js.go | 231 +++++++++++++++ client/wasm/cmd/main.go | 245 ++++++++++++++++ client/wasm/internal/http/http.go | 100 +++++++ client/wasm/internal/rdp/cert_validation.go | 96 +++++++ client/wasm/internal/rdp/rdcleanpath.go | 271 ++++++++++++++++++ .../wasm/internal/rdp/rdcleanpath_handlers.go | 251 ++++++++++++++++ client/wasm/internal/ssh/client.go | 213 ++++++++++++++ client/wasm/internal/ssh/handlers.go | 78 +++++ client/wasm/internal/ssh/key.go | 50 ++++ encryption/route53.go | 2 + flow/client/client.go | 5 +- go.mod | 2 +- go.sum | 4 +- management/internals/server/controllers.go | 8 +- management/internals/server/modules.go | 4 + management/internals/server/server.go | 32 ++- management/server/account.go | 6 + management/server/account/manager.go | 4 +- management/server/account_test.go | 38 +-- management/server/dns_test.go | 4 +- management/server/grpcserver.go | 5 +- .../http/handlers/peers/peers_handler.go | 83 ++++++ management/server/management_proto_test.go | 3 +- management/server/management_test.go | 3 +- management/server/mock_server/account_mock.go | 12 +- management/server/nameserver_test.go | 4 +- .../server/networks/resources/manager.go | 4 +- management/server/peer.go | 58 +++- management/server/peer/peer.go | 12 + management/server/peer_test.go | 74 ++--- .../server/peers/ephemeral/interface.go | 14 + .../ephemeral/manager}/ephemeral.go | 2 +- .../ephemeral/manager}/ephemeral_test.go | 59 +++- management/server/policy.go | 12 + management/server/store/sql_store.go | 19 ++ management/server/store/store.go | 1 + management/server/types/account.go | 32 ++- management/server/types/policy.go | 87 ++++++ management/server/types/resource.go | 13 +- management/server/user_test.go | 4 +- shared/management/client/client_test.go | 4 +- shared/management/http/api/openapi.yml | 81 +++++- shared/management/http/api/types.gen.go | 28 ++ shared/management/proto/management.pb.go | 2 +- shared/relay/client/client.go | 12 +- shared/relay/client/dialer/ws/conn.go | 3 +- .../client/dialer/ws/dialopts_generic.go | 11 + shared/relay/client/dialer/ws/dialopts_js.go | 10 + shared/relay/client/dialer/ws/ws.go | 4 +- shared/relay/client/dialers_generic.go | 19 ++ shared/relay/client/dialers_js.go | 13 + signal/cmd/run.go | 52 +++- util/util_js.go | 8 + util/wsproxy/client/dialer_js.go | 171 +++++++++++ util/wsproxy/constants.go | 13 + util/wsproxy/server/metrics.go | 118 ++++++++ util/wsproxy/server/proxy.go | 227 +++++++++++++++ 107 files changed, 3591 insertions(+), 284 deletions(-) create mode 100644 .github/workflows/wasm-build-validation.yml create mode 100644 .gitmodules create mode 100644 client/cmd/debug_js.go create mode 100644 client/grpc/dialer_generic.go create mode 100644 client/grpc/dialer_js.go create mode 100644 client/iface/bind/error.go create mode 100644 client/iface/bind/recv_msg.go create mode 100644 client/iface/bind/relay_bind.go create mode 100644 client/iface/configurer/uapi_js.go create mode 100644 client/iface/device/device_netstack_test.go create mode 100644 client/iface/iface_destroy_js.go create mode 100644 client/iface/iface_new_js.go create mode 100644 client/iface/netstack/env_js.go create mode 100644 client/internal/dns/server_js.go create mode 100644 client/internal/dns/unclean_shutdown_js.go create mode 100644 client/internal/engine_generic.go create mode 100644 client/internal/engine_js.go create mode 100644 client/internal/networkmonitor/check_change_js.go create mode 100644 client/internal/routemanager/systemops/systemops_js.go create mode 100644 client/ssh/ssh_js.go create mode 100644 client/system/info_js.go create mode 100644 client/wasm/cmd/main.go create mode 100644 client/wasm/internal/http/http.go create mode 100644 client/wasm/internal/rdp/cert_validation.go create mode 100644 client/wasm/internal/rdp/rdcleanpath.go create mode 100644 client/wasm/internal/rdp/rdcleanpath_handlers.go create mode 100644 client/wasm/internal/ssh/client.go create mode 100644 client/wasm/internal/ssh/handlers.go create mode 100644 client/wasm/internal/ssh/key.go create mode 100644 management/server/peers/ephemeral/interface.go rename management/server/{ => peers/ephemeral/manager}/ephemeral.go (99%) rename management/server/{ => peers/ephemeral/manager}/ephemeral_test.go (75%) create mode 100644 shared/relay/client/dialer/ws/dialopts_generic.go create mode 100644 shared/relay/client/dialer/ws/dialopts_js.go create mode 100644 shared/relay/client/dialers_generic.go create mode 100644 shared/relay/client/dialers_js.go create mode 100644 util/util_js.go create mode 100644 util/wsproxy/client/dialer_js.go create mode 100644 util/wsproxy/constants.go create mode 100644 util/wsproxy/server/metrics.go create mode 100644 util/wsproxy/server/proxy.go diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 7e6583cc6..2845b05a5 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -19,7 +19,7 @@ jobs: - name: codespell uses: codespell-project/actions-codespell@v2 with: - ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe + ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros skip: go.mod,go.sum golangci: strategy: diff --git a/.github/workflows/wasm-build-validation.yml b/.github/workflows/wasm-build-validation.yml new file mode 100644 index 000000000..e4ac799bc --- /dev/null +++ b/.github/workflows/wasm-build-validation.yml @@ -0,0 +1,67 @@ +name: Wasm + +on: + push: + branches: + - main + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} + cancel-in-progress: true + +jobs: + js_lint: + name: "JS / Lint" + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + - name: Install dependencies + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev + - name: Install golangci-lint + uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc + with: + version: latest + install-mode: binary + skip-cache: true + skip-pkg-cache: true + skip-build-cache: true + - name: Run golangci-lint for WASM + run: | + GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/... + continue-on-error: true + + js_build: + name: "JS / Build" + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + - name: Build Wasm client + run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd + env: + CGO_ENABLED: 0 + - name: Check Wasm build size + run: | + echo "Wasm build size:" + ls -lh netbird.wasm + + SIZE=$(stat -c%s netbird.wasm) + SIZE_MB=$((SIZE / 1024 / 1024)) + + echo "Size: ${SIZE} bytes (${SIZE_MB} MB)" + + if [ ${SIZE} -gt 52428800 ]; then + echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!" + exit 1 + fi + diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..e69de29bb diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 59a95c89a..952e946dc 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -2,6 +2,18 @@ version: 2 project_name: netbird builds: + - id: netbird-wasm + dir: client/wasm/cmd + binary: netbird + env: [GOOS=js, GOARCH=wasm, CGO_ENABLED=0] + goos: + - js + goarch: + - wasm + ldflags: + - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser + mod_timestamp: "{{ .CommitTimestamp }}" + - id: netbird dir: client binary: netbird @@ -115,6 +127,11 @@ archives: - builds: - netbird - netbird-static + - id: netbird-wasm + builds: + - netbird-wasm + name_template: "{{ .ProjectName }}_{{ .Version }}" + format: binary nfpms: - maintainer: Netbird diff --git a/client/cmd/debug_js.go b/client/cmd/debug_js.go new file mode 100644 index 000000000..d06fb8efc --- /dev/null +++ b/client/cmd/debug_js.go @@ -0,0 +1,8 @@ +package cmd + +import "context" + +// SetupDebugHandler is a no-op for WASM +func SetupDebugHandler(context.Context, interface{}, interface{}, interface{}, string) { + // Debug handler not needed for WASM +} diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 99ccb1539..bd3209605 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -12,6 +12,7 @@ import ( "google.golang.org/grpc" "github.com/netbirdio/management-integrations/integrations" + clientProto "github.com/netbirdio/netbird/client/proto" client "github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/management/internals/server/config" @@ -20,6 +21,7 @@ import ( "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/peers" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -114,7 +116,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp } secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{}) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}) if err != nil { t.Fatal(err) } diff --git a/client/embed/embed.go b/client/embed/embed.go index 0bfc7a37c..e918235ed 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -23,23 +23,29 @@ import ( var ErrClientAlreadyStarted = errors.New("client already started") var ErrClientNotStarted = errors.New("client not started") +var ErrConfigNotInitialized = errors.New("config not initialized") -// Client manages a netbird embedded client instance +// Client manages a netbird embedded client instance. type Client struct { deviceName string config *profilemanager.Config mu sync.Mutex cancel context.CancelFunc setupKey string + jwtToken string connect *internal.ConnectClient } -// Options configures a new Client +// Options configures a new Client. type Options struct { // DeviceName is this peer's name in the network DeviceName string // SetupKey is used for authentication SetupKey string + // JWTToken is used for JWT-based authentication + JWTToken string + // PrivateKey is used for direct private key authentication + PrivateKey string // ManagementURL overrides the default management server URL ManagementURL string // PreSharedKey is the pre-shared key for the WireGuard interface @@ -58,8 +64,35 @@ type Options struct { DisableClientRoutes bool } -// New creates a new netbird embedded client +// validateCredentials checks that exactly one credential type is provided +func (opts *Options) validateCredentials() error { + credentialsProvided := 0 + if opts.SetupKey != "" { + credentialsProvided++ + } + if opts.JWTToken != "" { + credentialsProvided++ + } + if opts.PrivateKey != "" { + credentialsProvided++ + } + + if credentialsProvided == 0 { + return fmt.Errorf("one of SetupKey, JWTToken, or PrivateKey must be provided") + } + if credentialsProvided > 1 { + return fmt.Errorf("only one of SetupKey, JWTToken, or PrivateKey can be specified") + } + + return nil +} + +// New creates a new netbird embedded client. func New(opts Options) (*Client, error) { + if err := opts.validateCredentials(); err != nil { + return nil, err + } + if opts.LogOutput != nil { logrus.SetOutput(opts.LogOutput) } @@ -107,9 +140,14 @@ func New(opts Options) (*Client, error) { return nil, fmt.Errorf("create config: %w", err) } + if opts.PrivateKey != "" { + config.PrivateKey = opts.PrivateKey + } + return &Client{ deviceName: opts.DeviceName, setupKey: opts.SetupKey, + jwtToken: opts.JWTToken, config: config, }, nil } @@ -126,7 +164,7 @@ func (c *Client) Start(startCtx context.Context) error { ctx := internal.CtxInitState(context.Background()) // nolint:staticcheck ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName) - if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil { + if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil { return fmt.Errorf("login: %w", err) } @@ -187,6 +225,16 @@ func (c *Client) Stop(ctx context.Context) error { } } +// GetConfig returns a copy of the internal client config. +func (c *Client) GetConfig() (profilemanager.Config, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.config == nil { + return profilemanager.Config{}, ErrConfigNotInitialized + } + return *c.config, nil +} + // Dial dials a network address in the netbird network. // Not applicable if the userspace networking mode is disabled. func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) { @@ -211,7 +259,7 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e return nsnet.DialContext(ctx, network, address) } -// ListenTCP listens on the given address in the netbird network +// ListenTCP listens on the given address in the netbird network. // Not applicable if the userspace networking mode is disabled. func (c *Client) ListenTCP(address string) (net.Listener, error) { nsnet, addr, err := c.getNet() @@ -232,7 +280,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) { return nsnet.ListenTCP(tcpAddr) } -// ListenUDP listens on the given address in the netbird network +// ListenUDP listens on the given address in the netbird network. // Not applicable if the userspace networking mode is disabled. func (c *Client) ListenUDP(address string) (net.PacketConn, error) { nsnet, addr, err := c.getNet() diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 69e3f088c..7cb38fbff 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -4,15 +4,9 @@ import ( "context" "crypto/tls" "crypto/x509" - "fmt" - "net" - "os/user" "runtime" "time" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "google.golang.org/grpc" @@ -20,37 +14,10 @@ import ( "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" - nbnet "github.com/netbirdio/netbird/client/net" - "github.com/netbirdio/netbird/util/embeddedroots" ) -func WithCustomDialer() grpc.DialOption { - return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - if runtime.GOOS == "linux" { - currentUser, err := user.Current() - if err != nil { - return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err) - } - - // the custom dialer requires root permissions which are not required for use cases run as non-root - if currentUser.Uid != "0" { - log.Debug("Not running as root, using standard dialer") - dialer := &net.Dialer{} - return dialer.DialContext(ctx, "tcp", addr) - } - } - - conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) - if err != nil { - log.Errorf("Failed to dial: %s", err) - return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err) - } - return conn, nil - }) -} - -// grpcDialBackoff is the backoff mechanism for the grpc calls +// Backoff returns a backoff configuration for gRPC calls func Backoff(ctx context.Context) backoff.BackOff { b := backoff.NewExponentialBackOff() b.MaxElapsedTime = 10 * time.Second @@ -58,6 +25,7 @@ func Backoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(b, ctx) } +// CreateConnection creates a gRPC client connection with the appropriate transport options func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) { transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) if tlsEnabled { @@ -68,7 +36,9 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc. } transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ - RootCAs: certPool, + // for js, outer websocket layer takes care of tls verification via WithCustomDialer + InsecureSkipVerify: runtime.GOOS == "js", + RootCAs: certPool, })) } @@ -79,7 +49,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc. connCtx, addr, transportOption, - WithCustomDialer(), + WithCustomDialer(tlsEnabled), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/client/grpc/dialer_generic.go b/client/grpc/dialer_generic.go new file mode 100644 index 000000000..a0d6cee0b --- /dev/null +++ b/client/grpc/dialer_generic.go @@ -0,0 +1,44 @@ +//go:build !js + +package grpc + +import ( + "context" + "fmt" + "net" + "os/user" + "runtime" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + log "github.com/sirupsen/logrus" + "google.golang.org/grpc" + + nbnet "github.com/netbirdio/netbird/client/net" +) + +func WithCustomDialer(tlsEnabled bool) grpc.DialOption { + return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + if runtime.GOOS == "linux" { + currentUser, err := user.Current() + if err != nil { + return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err) + } + + // the custom dialer requires root permissions which are not required for use cases run as non-root + if currentUser.Uid != "0" { + log.Debug("Not running as root, using standard dialer") + dialer := &net.Dialer{} + return dialer.DialContext(ctx, "tcp", addr) + } + } + + conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) + if err != nil { + log.Errorf("Failed to dial: %s", err) + return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err) + } + return conn, nil + }) +} diff --git a/client/grpc/dialer_js.go b/client/grpc/dialer_js.go new file mode 100644 index 000000000..e132c0098 --- /dev/null +++ b/client/grpc/dialer_js.go @@ -0,0 +1,12 @@ +package grpc + +import ( + "google.golang.org/grpc" + + "github.com/netbirdio/netbird/util/wsproxy/client" +) + +// WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments. +func WithCustomDialer(tlsEnabled bool) grpc.DialOption { + return client.WithWebSocketDialer(tlsEnabled) +} diff --git a/client/iface/bind/error.go b/client/iface/bind/error.go new file mode 100644 index 000000000..db7c23144 --- /dev/null +++ b/client/iface/bind/error.go @@ -0,0 +1,7 @@ +package bind + +import "fmt" + +var ( + ErrUDPMUXNotSupported = fmt.Errorf("UDPMUX is not supported in WASM") +) diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index ef630b9d0..dfb22ecde 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -1,3 +1,5 @@ +//go:build !js + package bind import ( @@ -21,11 +23,6 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) -type RecvMessage struct { - Endpoint *Endpoint - Buffer []byte -} - type receiverCreator struct { iceBind *ICEBind } @@ -43,37 +40,38 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD // use the port because in the Send function the wgConn.Endpoint the port info is not exported. type ICEBind struct { *wgConn.StdNetBind - recvChan chan RecvMessage transportNet transport.Net filterFn udpmux.FilterFn - endpoints map[netip.Addr]net.Conn - endpointsMu sync.Mutex + address wgaddr.Address + mtu uint16 + + endpoints map[netip.Addr]net.Conn + endpointsMu sync.Mutex + recvChan chan recvMessage // every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a // new closed channel. With the closedChanMu we can safely close the channel and create a new one - closedChan chan struct{} - closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it. - closed bool - - muUDPMux sync.Mutex - udpMux *udpmux.UniversalUDPMuxDefault - address wgaddr.Address - mtu uint16 + closedChan chan struct{} + closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it. + closed bool activityRecorder *ActivityRecorder + + muUDPMux sync.Mutex + udpMux *udpmux.UniversalUDPMuxDefault } func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) ib := &ICEBind{ StdNetBind: b, - recvChan: make(chan RecvMessage, 1), transportNet: transportNet, filterFn: filterFn, + address: address, + mtu: mtu, endpoints: make(map[netip.Addr]net.Conn), + recvChan: make(chan recvMessage, 1), closedChan: make(chan struct{}), closed: true, - mtu: mtu, - address: address, activityRecorder: NewActivityRecorder(), } @@ -84,10 +82,6 @@ func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wg return ib } -func (s *ICEBind) MTU() uint16 { - return s.mtu -} - func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { s.closed = false s.closedChanMu.Lock() @@ -140,6 +134,16 @@ func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) { delete(b.endpoints, fakeIP) } +func (b *ICEBind) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) { + select { + case <-b.closedChan: + return + case <-ctx.Done(): + return + case b.recvChan <- recvMessage{ep, buf}: + } +} + func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { b.endpointsMu.Lock() conn, ok := b.endpoints[ep.DstIP()] @@ -156,14 +160,6 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { return nil } -func (b *ICEBind) Recv(ctx context.Context, msg RecvMessage) { - select { - case <-ctx.Done(): - return - case b.recvChan <- msg: - } -} - func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc { s.muUDPMux.Lock() defer s.muUDPMux.Unlock() diff --git a/client/iface/bind/recv_msg.go b/client/iface/bind/recv_msg.go new file mode 100644 index 000000000..65baffaac --- /dev/null +++ b/client/iface/bind/recv_msg.go @@ -0,0 +1,6 @@ +package bind + +type recvMessage struct { + Endpoint *Endpoint + Buffer []byte +} diff --git a/client/iface/bind/relay_bind.go b/client/iface/bind/relay_bind.go new file mode 100644 index 000000000..4c179d6a5 --- /dev/null +++ b/client/iface/bind/relay_bind.go @@ -0,0 +1,125 @@ +package bind + +import ( + "context" + "net" + "net/netip" + "sync" + + log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/conn" + + "github.com/netbirdio/netbird/client/iface/udpmux" +) + +// RelayBindJS is a conn.Bind implementation for WebAssembly environments. +// Do not limit to build only js, because we want to be able to run tests +type RelayBindJS struct { + *conn.StdNetBind + + recvChan chan recvMessage + endpoints map[netip.Addr]net.Conn + endpointsMu sync.Mutex + activityRecorder *ActivityRecorder + ctx context.Context + cancel context.CancelFunc +} + +func NewRelayBindJS() *RelayBindJS { + return &RelayBindJS{ + recvChan: make(chan recvMessage, 100), + endpoints: make(map[netip.Addr]net.Conn), + activityRecorder: NewActivityRecorder(), + } +} + +// Open creates a receive function for handling relay packets in WASM. +func (s *RelayBindJS) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { + log.Debugf("Open: creating receive function for port %d", uport) + + s.ctx, s.cancel = context.WithCancel(context.Background()) + + receiveFn := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) { + select { + case <-s.ctx.Done(): + return 0, net.ErrClosed + case msg, ok := <-s.recvChan: + if !ok { + return 0, net.ErrClosed + } + copy(bufs[0], msg.Buffer) + sizes[0] = len(msg.Buffer) + eps[0] = conn.Endpoint(msg.Endpoint) + return 1, nil + } + } + + log.Debugf("Open: receive function created, returning port %d", uport) + return []conn.ReceiveFunc{receiveFn}, uport, nil +} + +func (s *RelayBindJS) Close() error { + if s.cancel == nil { + return nil + } + log.Debugf("close RelayBindJS") + s.cancel() + return nil +} + +func (s *RelayBindJS) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) { + select { + case <-s.ctx.Done(): + return + case <-ctx.Done(): + return + case s.recvChan <- recvMessage{ep, buf}: + } +} + +// Send forwards packets through the relay connection for WASM. +func (s *RelayBindJS) Send(bufs [][]byte, ep conn.Endpoint) error { + if ep == nil { + return nil + } + + fakeIP := ep.DstIP() + + s.endpointsMu.Lock() + relayConn, ok := s.endpoints[fakeIP] + s.endpointsMu.Unlock() + + if !ok { + return nil + } + + for _, buf := range bufs { + if _, err := relayConn.Write(buf); err != nil { + return err + } + } + + return nil +} + +func (b *RelayBindJS) SetEndpoint(fakeIP netip.Addr, conn net.Conn) { + b.endpointsMu.Lock() + b.endpoints[fakeIP] = conn + b.endpointsMu.Unlock() +} + +func (s *RelayBindJS) RemoveEndpoint(fakeIP netip.Addr) { + s.endpointsMu.Lock() + defer s.endpointsMu.Unlock() + + delete(s.endpoints, fakeIP) +} + +// GetICEMux returns the ICE UDPMux that was created and used by ICEBind +func (s *RelayBindJS) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) { + return nil, ErrUDPMUXNotSupported +} + +func (s *RelayBindJS) ActivityRecorder() *ActivityRecorder { + return s.activityRecorder +} diff --git a/client/iface/configurer/name.go b/client/iface/configurer/name.go index 3b9abc0e8..a8469e0b4 100644 --- a/client/iface/configurer/name.go +++ b/client/iface/configurer/name.go @@ -1,4 +1,4 @@ -//go:build linux || windows || freebsd +//go:build linux || windows || freebsd || js || wasip1 package configurer diff --git a/client/iface/configurer/uapi.go b/client/iface/configurer/uapi.go index 4801841de..f85c7852a 100644 --- a/client/iface/configurer/uapi.go +++ b/client/iface/configurer/uapi.go @@ -1,4 +1,4 @@ -//go:build !windows +//go:build !windows && !js package configurer diff --git a/client/iface/configurer/uapi_js.go b/client/iface/configurer/uapi_js.go new file mode 100644 index 000000000..d0188eb35 --- /dev/null +++ b/client/iface/configurer/uapi_js.go @@ -0,0 +1,23 @@ +package configurer + +import ( + "net" +) + +type noopListener struct{} + +func (n *noopListener) Accept() (net.Conn, error) { + return nil, net.ErrClosed +} + +func (n *noopListener) Close() error { + return nil +} + +func (n *noopListener) Addr() net.Addr { + return nil +} + +func openUAPI(deviceName string) (net.Listener, error) { + return &noopListener{}, nil +} diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index a6ef47027..e37321b68 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -1,9 +1,11 @@ package device import ( + "errors" "fmt" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun/netstack" @@ -15,6 +17,12 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) +type Bind interface { + conn.Bind + GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) + ActivityRecorder() *bind.ActivityRecorder +} + type TunNetstackDevice struct { name string address wgaddr.Address @@ -22,7 +30,7 @@ type TunNetstackDevice struct { key string mtu uint16 listenAddress string - iceBind *bind.ICEBind + bind Bind device *device.Device filteredDevice *FilteredDevice @@ -33,7 +41,7 @@ type TunNetstackDevice struct { net *netstack.Net } -func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { +func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, bind Bind, listenAddress string) *TunNetstackDevice { return &TunNetstackDevice{ name: name, address: address, @@ -41,7 +49,7 @@ func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key stri key: key, mtu: mtu, listenAddress: listenAddress, - iceBind: iceBind, + bind: bind, } } @@ -66,11 +74,11 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) { t.device = device.NewDevice( t.filteredDevice, - t.iceBind, + t.bind, device.NewLogger(wgLogLevel(), "[netbird] "), ) - t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder()) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder()) err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { _ = tunIface.Close() @@ -91,11 +99,15 @@ func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { return nil, err } - udpMux, err := t.iceBind.GetICEMux() - if err != nil { + udpMux, err := t.bind.GetICEMux() + if err != nil && !errors.Is(err, bind.ErrUDPMUXNotSupported) { return nil, err } - t.udpMux = udpMux + + if udpMux != nil { + t.udpMux = udpMux + } + log.Debugf("netstack device is ready to use") return udpMux, nil } diff --git a/client/iface/device/device_netstack_test.go b/client/iface/device/device_netstack_test.go new file mode 100644 index 000000000..52059602f --- /dev/null +++ b/client/iface/device/device_netstack_test.go @@ -0,0 +1,27 @@ +package device + +import ( + "testing" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgaddr" +) + +func TestNewNetstackDevice(t *testing.T) { + privateKey, _ := wgtypes.GeneratePrivateKey() + wgAddress, _ := wgaddr.ParseWGAddress("1.2.3.4/24") + + relayBind := bind.NewRelayBindJS() + nsTun := NewNetstackDevice("wtx", wgAddress, 1234, privateKey.String(), 1500, relayBind, netstack.ListenAddr()) + + cfgr, err := nsTun.Create() + if err != nil { + t.Fatalf("failed to create netstack device: %v", err) + } + if cfgr == nil { + t.Fatal("expected non-nil configurer") + } +} diff --git a/client/iface/iface_destroy_js.go b/client/iface/iface_destroy_js.go new file mode 100644 index 000000000..b443273c3 --- /dev/null +++ b/client/iface/iface_destroy_js.go @@ -0,0 +1,6 @@ +package iface + +// Destroy is a no-op on WASM +func (w *WGIface) Destroy() error { + return nil +} diff --git a/client/iface/iface_new_android.go b/client/iface/iface_new_android.go index 26952f48d..3b68f63f2 100644 --- a/client/iface/iface_new_android.go +++ b/client/iface/iface_new_android.go @@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ userspaceBind: true, tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()), - wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil } @@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ userspaceBind: true, tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS), - wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil } diff --git a/client/iface/iface_new_darwin.go b/client/iface/iface_new_darwin.go index 7dd74d571..9f21ec950 100644 --- a/client/iface/iface_new_darwin.go +++ b/client/iface/iface_new_darwin.go @@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ userspaceBind: true, tun: tun, - wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil } diff --git a/client/iface/iface_new_freebsd.go b/client/iface/iface_new_freebsd.go index 86ed14ce1..a342bd579 100644 --- a/client/iface/iface_new_freebsd.go +++ b/client/iface/iface_new_freebsd.go @@ -25,7 +25,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) return wgIFace, nil } @@ -33,7 +33,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) return wgIFace, nil } diff --git a/client/iface/iface_new_ios.go b/client/iface/iface_new_ios.go index 06ccf0be1..5d6a32e39 100644 --- a/client/iface/iface_new_ios.go +++ b/client/iface/iface_new_ios.go @@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd), userspaceBind: true, - wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil } diff --git a/client/iface/iface_new_js.go b/client/iface/iface_new_js.go new file mode 100644 index 000000000..ad913ab04 --- /dev/null +++ b/client/iface/iface_new_js.go @@ -0,0 +1,27 @@ +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode) +func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { + wgAddress, err := wgaddr.ParseWGAddress(opts.Address) + if err != nil { + return nil, err + } + + relayBind := bind.NewRelayBindJS() + + wgIface := &WGIface{ + tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()), + userspaceBind: true, + wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU), + } + + return wgIface, nil +} diff --git a/client/iface/iface_new_linux.go b/client/iface/iface_new_linux.go index 77fd30fae..d84035403 100644 --- a/client/iface/iface_new_linux.go +++ b/client/iface/iface_new_linux.go @@ -25,7 +25,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) return wgIFace, nil } @@ -38,7 +38,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) wgIFace.userspaceBind = true - wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU) return wgIFace, nil } diff --git a/client/iface/iface_new_windows.go b/client/iface/iface_new_windows.go index 349c5b33b..dfd9028e7 100644 --- a/client/iface/iface_new_windows.go +++ b/client/iface/iface_new_windows.go @@ -26,7 +26,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{ userspaceBind: true, tun: tun, - wgProxyFactory: wgproxy.NewUSPFactory(iceBind), + wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU), } return wgIFace, nil diff --git a/client/iface/netstack/env.go b/client/iface/netstack/env.go index cdbf975b1..dd8cf29a3 100644 --- a/client/iface/netstack/env.go +++ b/client/iface/netstack/env.go @@ -1,3 +1,5 @@ +//go:build !js + package netstack import ( diff --git a/client/iface/netstack/env_js.go b/client/iface/netstack/env_js.go new file mode 100644 index 000000000..05c20f036 --- /dev/null +++ b/client/iface/netstack/env_js.go @@ -0,0 +1,12 @@ +package netstack + +const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE" + +// IsEnabled always returns true for js since it's the only mode available +func IsEnabled() bool { + return true +} + +func ListenAddr() string { + return "" +} diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index dbc694e91..eb585d8a2 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -16,15 +16,14 @@ import ( "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) -type IceBind interface { - SetEndpoint(fakeIP netip.Addr, conn net.Conn) - RemoveEndpoint(fakeIP netip.Addr) - Recv(ctx context.Context, msg bind.RecvMessage) - MTU() uint16 +type Bind interface { + SetEndpoint(addr netip.Addr, conn net.Conn) + RemoveEndpoint(addr netip.Addr) + ReceiveFromEndpoint(ctx context.Context, ep *bind.Endpoint, buf []byte) } type ProxyBind struct { - bind IceBind + bind Bind // wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address wgRelayedEndpoint *bind.Endpoint @@ -40,13 +39,15 @@ type ProxyBind struct { isStarted bool closeListener *listener.CloseListener + mtu uint16 } -func NewProxyBind(bind IceBind) *ProxyBind { +func NewProxyBind(bind Bind, mtu uint16) *ProxyBind { p := &ProxyBind{ bind: bind, closeListener: listener.NewCloseListener(), pausedCond: sync.NewCond(&sync.Mutex{}), + mtu: mtu + bufsize.WGBufferOverhead, } return p @@ -174,7 +175,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { }() for { - buf := make([]byte, p.bind.MTU()+bufsize.WGBufferOverhead) + buf := make([]byte, p.mtu) n, err := p.remoteConn.Read(buf) if err != nil { if ctx.Err() != nil { @@ -190,11 +191,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { p.pausedCond.Wait() } - msg := bind.RecvMessage{ - Endpoint: p.wgCurrentUsed, - Buffer: buf[:n], - } - p.bind.Recv(ctx, msg) + p.bind.ReceiveFromEndpoint(ctx, p.wgCurrentUsed, buf[:n]) p.pausedCond.L.Unlock() } } diff --git a/client/iface/wgproxy/factory_usp.go b/client/iface/wgproxy/factory_usp.go index 141b4c1f9..a1b1c34d7 100644 --- a/client/iface/wgproxy/factory_usp.go +++ b/client/iface/wgproxy/factory_usp.go @@ -3,24 +3,25 @@ package wgproxy import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/iface/bind" proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind" ) type USPFactory struct { - bind *bind.ICEBind + bind proxyBind.Bind + mtu uint16 } -func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory { +func NewUSPFactory(bind proxyBind.Bind, mtu uint16) *USPFactory { log.Infof("WireGuard Proxy Factory will produce bind proxy") f := &USPFactory{ - bind: iceBind, + bind: bind, + mtu: mtu, } return f } func (w *USPFactory) GetProxy() Proxy { - return proxyBind.NewProxyBind(w.bind) + return proxyBind.NewProxyBind(w.bind, w.mtu) } func (w *USPFactory) Free() error { diff --git a/client/iface/wgproxy/proxy_linux_test.go b/client/iface/wgproxy/proxy_linux_test.go index 9526e91d2..dd24d1cdc 100644 --- a/client/iface/wgproxy/proxy_linux_test.go +++ b/client/iface/wgproxy/proxy_linux_test.go @@ -74,7 +74,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) { pBind := proxyInstance{ name: "bind proxy", - proxy: bindproxy.NewProxyBind(iceBind), + proxy: bindproxy.NewProxyBind(iceBind, 0), endpointAddr: endpointAddress, closeFn: func() error { return nil }, } diff --git a/client/iface/wgproxy/proxy_seed_test.go b/client/iface/wgproxy/proxy_seed_test.go index 4d244f18a..ad375ccde 100644 --- a/client/iface/wgproxy/proxy_seed_test.go +++ b/client/iface/wgproxy/proxy_seed_test.go @@ -30,7 +30,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) { pBind := proxyInstance{ name: "bind proxy", - proxy: bindproxy.NewProxyBind(iceBind), + proxy: bindproxy.NewProxyBind(iceBind, 0), endpointAddr: endpointAddress, closeFn: func() error { return nil }, } diff --git a/client/internal/dns/server_js.go b/client/internal/dns/server_js.go new file mode 100644 index 000000000..a8bc35d09 --- /dev/null +++ b/client/internal/dns/server_js.go @@ -0,0 +1,5 @@ +package dns + +func (s *DefaultServer) initialize() (hostManager, error) { + return &noopHostConfigurator{}, nil +} diff --git a/client/internal/dns/unclean_shutdown_js.go b/client/internal/dns/unclean_shutdown_js.go new file mode 100644 index 000000000..378ffc164 --- /dev/null +++ b/client/internal/dns/unclean_shutdown_js.go @@ -0,0 +1,19 @@ +package dns + +import ( + "context" +) + +type ShutdownState struct{} + +func (s *ShutdownState) Name() string { + return "dns_state" +} + +func (s *ShutdownState) Cleanup() error { + return nil +} + +func (s *ShutdownState) RestoreUncleanShutdownConfigs(context.Context) error { + return nil +} diff --git a/client/internal/engine.go b/client/internal/engine.go index 828bc6e94..3fa0b58a8 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -453,8 +453,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return fmt.Errorf("up wg interface: %w", err) } - - // if inbound conns are blocked there is no need to create the ACL manager if e.firewall != nil && !e.config.BlockInbound { e.acl = acl.NewDefaultManager(e.firewall) @@ -466,14 +464,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return fmt.Errorf("initialize dns server: %w", err) } - iceCfg := icemaker.Config{ - StunTurn: &e.stunTurn, - InterfaceBlackList: e.config.IFaceBlackList, - DisableIPv6Discovery: e.config.DisableIPv6Discovery, - UDPMux: e.udpMux.SingleSocketUDPMux, - UDPMuxSrflx: e.udpMux, - NATExternalIPs: e.parseNATExternalIPMappings(), - } + iceCfg := e.createICEConfig() e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface) e.connMgr.Start(e.ctx) @@ -1347,14 +1338,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV Addr: e.getRosenpassAddr(), PermissiveMode: e.config.RosenpassPermissive, }, - ICEConfig: icemaker.Config{ - StunTurn: &e.stunTurn, - InterfaceBlackList: e.config.IFaceBlackList, - DisableIPv6Discovery: e.config.DisableIPv6Discovery, - UDPMux: e.udpMux.SingleSocketUDPMux, - UDPMuxSrflx: e.udpMux, - NATExternalIPs: e.parseNATExternalIPMappings(), - }, + ICEConfig: e.createICEConfig(), } serviceDependencies := peer.ServiceDependencies{ diff --git a/client/internal/engine_generic.go b/client/internal/engine_generic.go new file mode 100644 index 000000000..34a75e45b --- /dev/null +++ b/client/internal/engine_generic.go @@ -0,0 +1,19 @@ +//go:build !js + +package internal + +import ( + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" +) + +// createICEConfig creates ICE configuration for non-WASM environments +func (e *Engine) createICEConfig() icemaker.Config { + return icemaker.Config{ + StunTurn: &e.stunTurn, + InterfaceBlackList: e.config.IFaceBlackList, + DisableIPv6Discovery: e.config.DisableIPv6Discovery, + UDPMux: e.udpMux.SingleSocketUDPMux, + UDPMuxSrflx: e.udpMux, + NATExternalIPs: e.parseNATExternalIPMappings(), + } +} diff --git a/client/internal/engine_js.go b/client/internal/engine_js.go new file mode 100644 index 000000000..dce3c57fb --- /dev/null +++ b/client/internal/engine_js.go @@ -0,0 +1,18 @@ +//go:build js + +package internal + +import ( + icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" +) + +// createICEConfig creates ICE configuration for WASM environment. +func (e *Engine) createICEConfig() icemaker.Config { + cfg := icemaker.Config{ + StunTurn: &e.stunTurn, + InterfaceBlackList: e.config.IFaceBlackList, + DisableIPv6Discovery: e.config.DisableIPv6Discovery, + NATExternalIPs: e.parseNATExternalIPMappings(), + } + return cfg +} diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 4d2e81f43..344104405 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -27,6 +27,10 @@ import ( "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" @@ -42,10 +46,8 @@ import ( "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" @@ -1584,7 +1586,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{}) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}) if err != nil { return nil, "", err } diff --git a/client/internal/networkmonitor/check_change_js.go b/client/internal/networkmonitor/check_change_js.go new file mode 100644 index 000000000..640cf7184 --- /dev/null +++ b/client/internal/networkmonitor/check_change_js.go @@ -0,0 +1,12 @@ +package networkmonitor + +import ( + "context" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { + // No-op for WASM - network changes don't apply + return nil +} diff --git a/client/internal/routemanager/systemops/systemops_js.go b/client/internal/routemanager/systemops/systemops_js.go new file mode 100644 index 000000000..808507fc9 --- /dev/null +++ b/client/internal/routemanager/systemops/systemops_js.go @@ -0,0 +1,48 @@ +package systemops + +import ( + "errors" + "net" + "net/netip" + + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +var ErrRouteNotSupported = errors.New("route operations not supported on js") + +func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { + return ErrRouteNotSupported +} + +func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error { + return ErrRouteNotSupported +} + +func GetRoutesFromTable() ([]netip.Prefix, error) { + return []netip.Prefix{}, nil +} + +func hasSeparateRouting() ([]netip.Prefix, error) { + return []netip.Prefix{}, nil +} + +// GetDetailedRoutesFromTable returns empty routes for WASM. +func GetDetailedRoutesFromTable() ([]DetailedRoute, error) { + return []DetailedRoute{}, nil +} + +func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error { + return ErrRouteNotSupported +} + +func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error { + return ErrRouteNotSupported +} + +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, _ bool) error { + return nil +} + +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, _ bool) error { + return nil +} diff --git a/client/internal/routemanager/systemops/systemops_nonlinux.go b/client/internal/routemanager/systemops/systemops_nonlinux.go index 83b64e82b..905a7bc12 100644 --- a/client/internal/routemanager/systemops/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops/systemops_nonlinux.go @@ -1,4 +1,4 @@ -//go:build !linux && !ios +//go:build !linux && !ios && !js package systemops diff --git a/client/server/server_test.go b/client/server/server_test.go index 755925003..e0a4805f6 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -10,23 +10,26 @@ import ( "time" "github.com/golang/mock/gomock" - log "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" - "google.golang.org/grpc" - "google.golang.org/grpc/keepalive" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "google.golang.org/grpc/keepalive" + "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" daemonProto "github.com/netbirdio/netbird/client/proto" - "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" @@ -314,7 +317,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{}) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}) if err != nil { return nil, "", err } diff --git a/client/ssh/client.go b/client/ssh/client.go index 2dc70e8fc..afba347f8 100644 --- a/client/ssh/client.go +++ b/client/ssh/client.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import ( diff --git a/client/ssh/login.go b/client/ssh/login.go index d1d56ceb0..cb2615e55 100644 --- a/client/ssh/login.go +++ b/client/ssh/login.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import ( diff --git a/client/ssh/server.go b/client/ssh/server.go index 1f2001d0f..8c5db2547 100644 --- a/client/ssh/server.go +++ b/client/ssh/server.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import ( diff --git a/client/ssh/server_mock.go b/client/ssh/server_mock.go index cc080ffdb..76f43fd4e 100644 --- a/client/ssh/server_mock.go +++ b/client/ssh/server_mock.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import "context" diff --git a/client/ssh/server_test.go b/client/ssh/server_test.go index 5caca1834..1f310c2bb 100644 --- a/client/ssh/server_test.go +++ b/client/ssh/server_test.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import ( diff --git a/client/ssh/ssh_js.go b/client/ssh/ssh_js.go new file mode 100644 index 000000000..8cea88702 --- /dev/null +++ b/client/ssh/ssh_js.go @@ -0,0 +1,137 @@ +package ssh + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "crypto/x509" + "encoding/pem" + "errors" + "strings" + + "golang.org/x/crypto/ssh" +) + +var ErrSSHNotSupported = errors.New("SSH is not supported in WASM environment") + +// Server is a dummy SSH server interface for WASM. +type Server interface { + Start() error + Stop() error + EnableSSH(enabled bool) + AddAuthorizedKey(peer string, key string) error + RemoveAuthorizedKey(key string) +} + +type dummyServer struct{} + +func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) { + return &dummyServer{}, nil +} + +func NewServer(addr string) Server { + return &dummyServer{} +} + +func (s *dummyServer) Start() error { + return ErrSSHNotSupported +} + +func (s *dummyServer) Stop() error { + return nil +} + +func (s *dummyServer) EnableSSH(enabled bool) { +} + +func (s *dummyServer) AddAuthorizedKey(peer string, key string) error { + return nil +} + +func (s *dummyServer) RemoveAuthorizedKey(key string) { +} + +type Client struct{} + +func NewClient(ctx context.Context, addr string, config interface{}, recorder *SessionRecorder) (*Client, error) { + return nil, ErrSSHNotSupported +} + +func (c *Client) Close() error { + return nil +} + +func (c *Client) Run(command []string) error { + return ErrSSHNotSupported +} + +type SessionRecorder struct{} + +func NewSessionRecorder() *SessionRecorder { + return &SessionRecorder{} +} + +func (r *SessionRecorder) Record(session string, data []byte) { +} + +func GetUserShell() string { + return "/bin/sh" +} + +func LookupUserInfo(username string) (string, string, error) { + return "", "", ErrSSHNotSupported +} + +const DefaultSSHPort = 44338 + +const ED25519 = "ed25519" + +func isRoot() bool { + return false +} + +func GeneratePrivateKey(keyType string) ([]byte, error) { + if keyType != ED25519 { + return nil, errors.New("only ED25519 keys are supported in WASM") + } + + _, privateKey, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, err + } + + pkcs8Bytes, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + return nil, err + } + + pemBlock := &pem.Block{ + Type: "PRIVATE KEY", + Bytes: pkcs8Bytes, + } + + pemBytes := pem.EncodeToMemory(pemBlock) + return pemBytes, nil +} + +func GeneratePublicKey(privateKey []byte) ([]byte, error) { + signer, err := ssh.ParsePrivateKey(privateKey) + if err != nil { + block, _ := pem.Decode(privateKey) + if block != nil { + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + signer, err = ssh.NewSignerFromKey(key) + if err != nil { + return nil, err + } + } else { + return nil, err + } + } + + pubKeyBytes := ssh.MarshalAuthorizedKey(signer.PublicKey()) + return []byte(strings.TrimSpace(string(pubKeyBytes))), nil +} diff --git a/client/ssh/util.go b/client/ssh/util.go index cf5f1396e..a54a609bc 100644 --- a/client/ssh/util.go +++ b/client/ssh/util.go @@ -1,3 +1,5 @@ +//go:build !js + package ssh import ( diff --git a/client/system/info_js.go b/client/system/info_js.go new file mode 100644 index 000000000..994d439a7 --- /dev/null +++ b/client/system/info_js.go @@ -0,0 +1,231 @@ +package system + +import ( + "context" + "runtime" + "strings" + "syscall/js" + + "github.com/netbirdio/netbird/version" +) + +// UpdateStaticInfoAsync is a no-op on JS as there is no static info to update +func UpdateStaticInfoAsync() { + // do nothing +} + +// GetInfo retrieves system information for WASM environment +func GetInfo(_ context.Context) *Info { + info := &Info{ + GoOS: runtime.GOOS, + Kernel: runtime.GOARCH, + KernelVersion: runtime.GOARCH, + Platform: runtime.GOARCH, + OS: runtime.GOARCH, + Hostname: "wasm-client", + CPUs: runtime.NumCPU(), + NetbirdVersion: version.NetbirdVersion(), + } + + collectBrowserInfo(info) + collectLocationInfo(info) + collectSystemInfo(info) + return info +} + +func collectBrowserInfo(info *Info) { + navigator := js.Global().Get("navigator") + if navigator.IsUndefined() { + return + } + + collectUserAgent(info, navigator) + collectPlatform(info, navigator) + collectCPUInfo(info, navigator) +} + +func collectUserAgent(info *Info, navigator js.Value) { + ua := navigator.Get("userAgent") + if ua.IsUndefined() { + return + } + + userAgent := ua.String() + os, osVersion := parseOSFromUserAgent(userAgent) + if os != "" { + info.OS = os + } + if osVersion != "" { + info.OSVersion = osVersion + } +} + +func collectPlatform(info *Info, navigator js.Value) { + // Try regular platform property + if plat := navigator.Get("platform"); !plat.IsUndefined() { + if platStr := plat.String(); platStr != "" { + info.Platform = platStr + } + } + + // Try newer userAgentData API for more accurate platform + userAgentData := navigator.Get("userAgentData") + if userAgentData.IsUndefined() { + return + } + + platformInfo := userAgentData.Get("platform") + if !platformInfo.IsUndefined() { + if platStr := platformInfo.String(); platStr != "" { + info.Platform = platStr + } + } +} + +func collectCPUInfo(info *Info, navigator js.Value) { + hardwareConcurrency := navigator.Get("hardwareConcurrency") + if !hardwareConcurrency.IsUndefined() { + info.CPUs = hardwareConcurrency.Int() + } +} + +func collectLocationInfo(info *Info) { + location := js.Global().Get("location") + if location.IsUndefined() { + return + } + + if host := location.Get("hostname"); !host.IsUndefined() { + hostnameStr := host.String() + if hostnameStr != "" && hostnameStr != "localhost" { + info.Hostname = hostnameStr + } + } +} + +func checkFileAndProcess(_ []string) ([]File, error) { + return []File{}, nil +} + +func collectSystemInfo(info *Info) { + navigator := js.Global().Get("navigator") + if navigator.IsUndefined() { + return + } + + if vendor := navigator.Get("vendor"); !vendor.IsUndefined() { + info.SystemManufacturer = vendor.String() + } + + if product := navigator.Get("product"); !product.IsUndefined() { + info.SystemProductName = product.String() + } + + if userAgent := navigator.Get("userAgent"); !userAgent.IsUndefined() { + ua := userAgent.String() + info.Environment = detectEnvironmentFromUA(ua) + } +} + +func parseOSFromUserAgent(userAgent string) (string, string) { + if userAgent == "" { + return "", "" + } + + switch { + case strings.Contains(userAgent, "Windows NT"): + return parseWindowsVersion(userAgent) + case strings.Contains(userAgent, "Mac OS X"): + return parseMacOSVersion(userAgent) + case strings.Contains(userAgent, "FreeBSD"): + return "FreeBSD", "" + case strings.Contains(userAgent, "OpenBSD"): + return "OpenBSD", "" + case strings.Contains(userAgent, "NetBSD"): + return "NetBSD", "" + case strings.Contains(userAgent, "Linux"): + return parseLinuxVersion(userAgent) + case strings.Contains(userAgent, "iPhone") || strings.Contains(userAgent, "iPad"): + return parseiOSVersion(userAgent) + case strings.Contains(userAgent, "CrOS"): + return "ChromeOS", "" + default: + return "", "" + } +} + +func parseWindowsVersion(userAgent string) (string, string) { + switch { + case strings.Contains(userAgent, "Windows NT 10.0; Win64; x64"): + return "Windows", "10/11" + case strings.Contains(userAgent, "Windows NT 10.0"): + return "Windows", "10" + case strings.Contains(userAgent, "Windows NT 6.3"): + return "Windows", "8.1" + case strings.Contains(userAgent, "Windows NT 6.2"): + return "Windows", "8" + case strings.Contains(userAgent, "Windows NT 6.1"): + return "Windows", "7" + default: + return "Windows", "Unknown" + } +} + +func parseMacOSVersion(userAgent string) (string, string) { + idx := strings.Index(userAgent, "Mac OS X ") + if idx == -1 { + return "macOS", "Unknown" + } + + versionStart := idx + len("Mac OS X ") + versionEnd := strings.Index(userAgent[versionStart:], ")") + if versionEnd <= 0 { + return "macOS", "Unknown" + } + + ver := userAgent[versionStart : versionStart+versionEnd] + ver = strings.ReplaceAll(ver, "_", ".") + return "macOS", ver +} + +func parseLinuxVersion(userAgent string) (string, string) { + if strings.Contains(userAgent, "Android") { + return "Android", extractAndroidVersion(userAgent) + } + if strings.Contains(userAgent, "Ubuntu") { + return "Ubuntu", "" + } + return "Linux", "" +} + +func parseiOSVersion(userAgent string) (string, string) { + idx := strings.Index(userAgent, "OS ") + if idx == -1 { + return "iOS", "Unknown" + } + + versionStart := idx + 3 + versionEnd := strings.Index(userAgent[versionStart:], " ") + if versionEnd <= 0 { + return "iOS", "Unknown" + } + + ver := userAgent[versionStart : versionStart+versionEnd] + ver = strings.ReplaceAll(ver, "_", ".") + return "iOS", ver +} + +func extractAndroidVersion(userAgent string) string { + if idx := strings.Index(userAgent, "Android "); idx != -1 { + versionStart := idx + len("Android ") + versionEnd := strings.IndexAny(userAgent[versionStart:], ";)") + if versionEnd > 0 { + return userAgent[versionStart : versionStart+versionEnd] + } + } + return "Unknown" +} + +func detectEnvironmentFromUA(_ string) Environment { + return Environment{} +} diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go new file mode 100644 index 000000000..d542e2739 --- /dev/null +++ b/client/wasm/cmd/main.go @@ -0,0 +1,245 @@ +//go:build js + +package main + +import ( + "context" + "fmt" + "syscall/js" + "time" + + log "github.com/sirupsen/logrus" + + netbird "github.com/netbirdio/netbird/client/embed" + "github.com/netbirdio/netbird/client/wasm/internal/http" + "github.com/netbirdio/netbird/client/wasm/internal/rdp" + "github.com/netbirdio/netbird/client/wasm/internal/ssh" + "github.com/netbirdio/netbird/util" +) + +const ( + clientStartTimeout = 30 * time.Second + clientStopTimeout = 10 * time.Second + defaultLogLevel = "warn" +) + +func main() { + js.Global().Set("NetBirdClient", js.FuncOf(netBirdClientConstructor)) + + select {} +} + +func startClient(ctx context.Context, nbClient *netbird.Client) error { + log.Info("Starting NetBird client...") + if err := nbClient.Start(ctx); err != nil { + return err + } + log.Info("NetBird client started successfully") + return nil +} + +// parseClientOptions extracts NetBird options from JavaScript object +func parseClientOptions(jsOptions js.Value) (netbird.Options, error) { + options := netbird.Options{ + DeviceName: "dashboard-client", + LogLevel: defaultLogLevel, + } + + if jwtToken := jsOptions.Get("jwtToken"); !jwtToken.IsNull() && !jwtToken.IsUndefined() { + options.JWTToken = jwtToken.String() + } + + if setupKey := jsOptions.Get("setupKey"); !setupKey.IsNull() && !setupKey.IsUndefined() { + options.SetupKey = setupKey.String() + } + + if privateKey := jsOptions.Get("privateKey"); !privateKey.IsNull() && !privateKey.IsUndefined() { + options.PrivateKey = privateKey.String() + } + + if mgmtURL := jsOptions.Get("managementURL"); !mgmtURL.IsNull() && !mgmtURL.IsUndefined() { + mgmtURLStr := mgmtURL.String() + if mgmtURLStr != "" { + options.ManagementURL = mgmtURLStr + } + } + + if logLevel := jsOptions.Get("logLevel"); !logLevel.IsNull() && !logLevel.IsUndefined() { + options.LogLevel = logLevel.String() + } + + if deviceName := jsOptions.Get("deviceName"); !deviceName.IsNull() && !deviceName.IsUndefined() { + options.DeviceName = deviceName.String() + } + + return options, nil +} + +// createStartMethod creates the start method for the client +func createStartMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(this js.Value, args []js.Value) any { + return createPromise(func(resolve, reject js.Value) { + ctx, cancel := context.WithTimeout(context.Background(), clientStartTimeout) + defer cancel() + + if err := startClient(ctx, client); err != nil { + reject.Invoke(js.ValueOf(err.Error())) + return + } + + resolve.Invoke(js.ValueOf(true)) + }) + }) +} + +// createStopMethod creates the stop method for the client +func createStopMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(this js.Value, args []js.Value) any { + return createPromise(func(resolve, reject js.Value) { + ctx, cancel := context.WithTimeout(context.Background(), clientStopTimeout) + defer cancel() + + if err := client.Stop(ctx); err != nil { + log.Errorf("Error stopping client: %v", err) + reject.Invoke(js.ValueOf(err.Error())) + return + } + + log.Info("NetBird client stopped") + resolve.Invoke(js.ValueOf(true)) + }) + }) +} + +// createSSHMethod creates the SSH connection method +func createSSHMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 2 { + return js.ValueOf("error: requires host and port") + } + + host := args[0].String() + port := args[1].Int() + username := "root" + if len(args) > 2 && args[2].String() != "" { + username = args[2].String() + } + + return createPromise(func(resolve, reject js.Value) { + sshClient := ssh.NewClient(client) + + if err := sshClient.Connect(host, port, username); err != nil { + reject.Invoke(err.Error()) + return + } + + if err := sshClient.StartSession(80, 24); err != nil { + if closeErr := sshClient.Close(); closeErr != nil { + log.Errorf("Error closing SSH client: %v", closeErr) + } + reject.Invoke(err.Error()) + return + } + + jsInterface := ssh.CreateJSInterface(sshClient) + resolve.Invoke(jsInterface) + }) + }) +} + +// createProxyRequestMethod creates the proxyRequest method +func createProxyRequestMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 1 { + return js.ValueOf("error: request details required") + } + + request := args[0] + + return createPromise(func(resolve, reject js.Value) { + response, err := http.ProxyRequest(client, request) + if err != nil { + reject.Invoke(err.Error()) + return + } + resolve.Invoke(response) + }) + }) +} + +// createRDPProxyMethod creates the RDP proxy method +func createRDPProxyMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(_ js.Value, args []js.Value) any { + if len(args) < 2 { + return js.ValueOf("error: hostname and port required") + } + + proxy := rdp.NewRDCleanPathProxy(client) + return proxy.CreateProxy(args[0].String(), args[1].String()) + }) +} + +// createPromise is a helper to create JavaScript promises +func createPromise(handler func(resolve, reject js.Value)) js.Value { + return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { + resolve := promiseArgs[0] + reject := promiseArgs[1] + + go handler(resolve, reject) + + return nil + })) +} + +// createClientObject wraps the NetBird client in a JavaScript object +func createClientObject(client *netbird.Client) js.Value { + obj := make(map[string]interface{}) + + obj["start"] = createStartMethod(client) + obj["stop"] = createStopMethod(client) + obj["createSSHConnection"] = createSSHMethod(client) + obj["proxyRequest"] = createProxyRequestMethod(client) + obj["createRDPProxy"] = createRDPProxyMethod(client) + + return js.ValueOf(obj) +} + +// netBirdClientConstructor acts as a JavaScript constructor function +func netBirdClientConstructor(this js.Value, args []js.Value) any { + return js.Global().Get("Promise").New(js.FuncOf(func(this js.Value, promiseArgs []js.Value) any { + resolve := promiseArgs[0] + reject := promiseArgs[1] + + if len(args) < 1 { + reject.Invoke(js.ValueOf("Options object required")) + return nil + } + + go func() { + options, err := parseClientOptions(args[0]) + if err != nil { + reject.Invoke(js.ValueOf(err.Error())) + return + } + + if err := util.InitLog(options.LogLevel, util.LogConsole); err != nil { + log.Warnf("Failed to initialize logging: %v", err) + } + + log.Infof("Creating NetBird client with options: deviceName=%s, hasJWT=%v, hasSetupKey=%v, mgmtURL=%s", + options.DeviceName, options.JWTToken != "", options.SetupKey != "", options.ManagementURL) + + client, err := netbird.New(options) + if err != nil { + reject.Invoke(js.ValueOf(fmt.Sprintf("create client: %v", err))) + return + } + + clientObj := createClientObject(client) + log.Info("NetBird client created successfully") + resolve.Invoke(clientObj) + }() + + return nil + })) +} diff --git a/client/wasm/internal/http/http.go b/client/wasm/internal/http/http.go new file mode 100644 index 000000000..cddc9e681 --- /dev/null +++ b/client/wasm/internal/http/http.go @@ -0,0 +1,100 @@ +//go:build js + +package http + +import ( + "fmt" + "io" + log "github.com/sirupsen/logrus" + "net/http" + "strings" + "syscall/js" + "time" + + netbird "github.com/netbirdio/netbird/client/embed" +) + +const ( + httpTimeout = 30 * time.Second + maxResponseSize = 1024 * 1024 // 1MB +) + +// performRequest executes an HTTP request through NetBird and returns the response and body +func performRequest(nbClient *netbird.Client, method, url string, headers map[string]string, body []byte) (*http.Response, []byte, error) { + httpClient := nbClient.NewHTTPClient() + httpClient.Timeout = httpTimeout + + req, err := http.NewRequest(method, url, strings.NewReader(string(body))) + if err != nil { + return nil, nil, fmt.Errorf("create request: %w", err) + } + + for key, value := range headers { + req.Header.Set(key, value) + } + + resp, err := httpClient.Do(req) + if err != nil { + return nil, nil, fmt.Errorf("request failed: %w", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + log.Errorf("failed to close response body: %v", err) + } + }() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize)) + if err != nil { + return nil, nil, fmt.Errorf("read response: %w", err) + } + + return resp, respBody, nil +} + +// ProxyRequest performs a proxied HTTP request through NetBird and returns a JavaScript object +func ProxyRequest(nbClient *netbird.Client, request js.Value) (js.Value, error) { + url := request.Get("url").String() + if url == "" { + return js.Undefined(), fmt.Errorf("URL is required") + } + + method := "GET" + if methodVal := request.Get("method"); !methodVal.IsNull() && !methodVal.IsUndefined() { + method = strings.ToUpper(methodVal.String()) + } + + var requestBody []byte + if bodyVal := request.Get("body"); !bodyVal.IsNull() && !bodyVal.IsUndefined() { + requestBody = []byte(bodyVal.String()) + } + + requestHeaders := make(map[string]string) + if headersVal := request.Get("headers"); !headersVal.IsNull() && !headersVal.IsUndefined() && headersVal.Type() == js.TypeObject { + headerKeys := js.Global().Get("Object").Call("keys", headersVal) + for i := 0; i < headerKeys.Length(); i++ { + key := headerKeys.Index(i).String() + value := headersVal.Get(key).String() + requestHeaders[key] = value + } + } + + resp, body, err := performRequest(nbClient, method, url, requestHeaders, requestBody) + if err != nil { + return js.Undefined(), err + } + + result := js.Global().Get("Object").New() + result.Set("status", resp.StatusCode) + result.Set("statusText", resp.Status) + result.Set("body", string(body)) + + headers := js.Global().Get("Object").New() + for key, values := range resp.Header { + if len(values) > 0 { + headers.Set(strings.ToLower(key), values[0]) + } + } + result.Set("headers", headers) + + return result, nil +} diff --git a/client/wasm/internal/rdp/cert_validation.go b/client/wasm/internal/rdp/cert_validation.go new file mode 100644 index 000000000..4a23a4bc8 --- /dev/null +++ b/client/wasm/internal/rdp/cert_validation.go @@ -0,0 +1,96 @@ +//go:build js + +package rdp + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "syscall/js" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + certValidationTimeout = 60 * time.Second +) + +func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) { + if !conn.wsHandlers.Get("onCertificateRequest").Truthy() { + return false, fmt.Errorf("certificate validation handler not configured") + } + + certInfo := js.Global().Get("Object").New() + certInfo.Set("ServerAddr", conn.destination) + + certArray := js.Global().Get("Array").New() + for i, certBytes := range certChain { + uint8Array := js.Global().Get("Uint8Array").New(len(certBytes)) + js.CopyBytesToJS(uint8Array, certBytes) + certArray.SetIndex(i, uint8Array) + } + certInfo.Set("ServerCertChain", certArray) + if len(certChain) > 0 { + cert, err := x509.ParseCertificate(certChain[0]) + if err == nil { + info := js.Global().Get("Object").New() + info.Set("subject", cert.Subject.String()) + info.Set("issuer", cert.Issuer.String()) + info.Set("validFrom", cert.NotBefore.Format(time.RFC3339)) + info.Set("validTo", cert.NotAfter.Format(time.RFC3339)) + info.Set("serialNumber", cert.SerialNumber.String()) + certInfo.Set("CertificateInfo", info) + } + } + + promise := conn.wsHandlers.Call("onCertificateRequest", certInfo) + + resultChan := make(chan bool) + errorChan := make(chan error) + + promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} { + result := args[0].Bool() + resultChan <- result + return nil + })).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} { + errorChan <- fmt.Errorf("certificate validation failed") + return nil + })) + + select { + case result := <-resultChan: + if result { + log.Info("Certificate accepted by user") + } else { + log.Info("Certificate rejected by user") + } + return result, nil + case err := <-errorChan: + return false, err + case <-time.After(certValidationTimeout): + return false, fmt.Errorf("certificate validation timeout") + } +} + +func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config { + return &tls.Config{ + InsecureSkipVerify: true, // We'll validate manually after handshake + VerifyConnection: func(cs tls.ConnectionState) error { + var certChain [][]byte + for _, cert := range cs.PeerCertificates { + certChain = append(certChain, cert.Raw) + } + + accepted, err := p.validateCertificateWithJS(conn, certChain) + if err != nil { + return err + } + if !accepted { + return fmt.Errorf("certificate rejected by user") + } + + return nil + }, + } +} diff --git a/client/wasm/internal/rdp/rdcleanpath.go b/client/wasm/internal/rdp/rdcleanpath.go new file mode 100644 index 000000000..8062a05cc --- /dev/null +++ b/client/wasm/internal/rdp/rdcleanpath.go @@ -0,0 +1,271 @@ +//go:build js + +package rdp + +import ( + "context" + "crypto/tls" + "encoding/asn1" + "fmt" + "io" + "net" + "sync" + "syscall/js" + + log "github.com/sirupsen/logrus" +) + +const ( + RDCleanPathVersion = 3390 + RDCleanPathProxyHost = "rdcleanpath.proxy.local" + RDCleanPathProxyScheme = "ws" +) + +type RDCleanPathPDU struct { + Version int64 `asn1:"tag:0,explicit"` + Error []byte `asn1:"tag:1,explicit,optional"` + Destination string `asn1:"utf8,tag:2,explicit,optional"` + ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"` + ServerAuth string `asn1:"utf8,tag:4,explicit,optional"` + PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"` + X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"` + ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"` + ServerAddr string `asn1:"utf8,tag:9,explicit,optional"` +} + +type RDCleanPathProxy struct { + nbClient interface { + Dial(ctx context.Context, network, address string) (net.Conn, error) + } + activeConnections map[string]*proxyConnection + destinations map[string]string + mu sync.Mutex +} + +type proxyConnection struct { + id string + destination string + rdpConn net.Conn + tlsConn *tls.Conn + wsHandlers js.Value + ctx context.Context + cancel context.CancelFunc +} + +// NewRDCleanPathProxy creates a new RDCleanPath proxy +func NewRDCleanPathProxy(client interface { + Dial(ctx context.Context, network, address string) (net.Conn, error) +}) *RDCleanPathProxy { + return &RDCleanPathProxy{ + nbClient: client, + activeConnections: make(map[string]*proxyConnection), + } +} + +// CreateProxy creates a new proxy endpoint for the given destination +func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value { + destination := fmt.Sprintf("%s:%s", hostname, port) + + return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, args []js.Value) any { + resolve := args[0] + + go func() { + proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections)) + + p.mu.Lock() + if p.destinations == nil { + p.destinations = make(map[string]string) + } + p.destinations[proxyID] = destination + p.mu.Unlock() + + proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID) + + // Register the WebSocket handler for this specific proxy + js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any { + if len(args) < 1 { + return js.ValueOf("error: requires WebSocket argument") + } + + ws := args[0] + p.HandleWebSocketConnection(ws, proxyID) + return nil + })) + + log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination) + resolve.Invoke(proxyURL) + }() + + return nil + })) +} + +// HandleWebSocketConnection handles incoming WebSocket connections from IronRDP +func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string) { + p.mu.Lock() + destination := p.destinations[proxyID] + p.mu.Unlock() + + if destination == "" { + log.Errorf("No destination found for proxy ID: %s", proxyID) + return + } + + ctx, cancel := context.WithCancel(context.Background()) + // Don't defer cancel here - it will be called by cleanupConnection + + conn := &proxyConnection{ + id: proxyID, + destination: destination, + wsHandlers: ws, + ctx: ctx, + cancel: cancel, + } + + p.mu.Lock() + p.activeConnections[proxyID] = conn + p.mu.Unlock() + + p.setupWebSocketHandlers(ws, conn) + + log.Infof("RDCleanPath proxy WebSocket connection established for %s", proxyID) +} + +func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) { + ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 1 { + return nil + } + + data := args[0] + go p.handleWebSocketMessage(conn, data) + return nil + })) + + ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any { + log.Debug("WebSocket closed by JavaScript") + conn.cancel() + return nil + })) +} + +func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) { + if !data.InstanceOf(js.Global().Get("Uint8Array")) { + return + } + + length := data.Get("length").Int() + bytes := make([]byte, length) + js.CopyBytesToGo(bytes, data) + + if conn.rdpConn != nil || conn.tlsConn != nil { + p.forwardToRDP(conn, bytes) + return + } + + var pdu RDCleanPathPDU + _, err := asn1.Unmarshal(bytes, &pdu) + if err != nil { + log.Warnf("Failed to parse RDCleanPath PDU: %v", err) + n := len(bytes) + if n > 20 { + n = 20 + } + log.Warnf("First %d bytes: %x", n, bytes[:n]) + + if len(bytes) > 0 && bytes[0] == 0x03 { + log.Debug("Received raw RDP packet instead of RDCleanPath PDU") + go p.handleDirectRDP(conn, bytes) + return + } + return + } + + go p.processRDCleanPathPDU(conn, pdu) +} + +func (p *RDCleanPathProxy) forwardToRDP(conn *proxyConnection, bytes []byte) { + var writer io.Writer + var connType string + + if conn.tlsConn != nil { + writer = conn.tlsConn + connType = "TLS" + } else if conn.rdpConn != nil { + writer = conn.rdpConn + connType = "TCP" + } else { + log.Error("No RDP connection available") + return + } + + if _, err := writer.Write(bytes); err != nil { + log.Errorf("Failed to write to %s: %v", connType, err) + } +} + +func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []byte) { + defer p.cleanupConnection(conn) + + destination := conn.destination + log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination) + + rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) + if err != nil { + log.Errorf("Failed to connect to %s: %v", destination, err) + return + } + conn.rdpConn = rdpConn + + _, err = rdpConn.Write(firstPacket) + if err != nil { + log.Errorf("Failed to write first packet: %v", err) + return + } + + response := make([]byte, 1024) + n, err := rdpConn.Read(response) + if err != nil { + log.Errorf("Failed to read X.224 response: %v", err) + return + } + + p.sendToWebSocket(conn, response[:n]) + + go p.forwardWSToConn(conn, conn.rdpConn, "TCP") + go p.forwardConnToWS(conn, conn.rdpConn, "TCP") +} + +func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) { + log.Debugf("Cleaning up connection %s", conn.id) + conn.cancel() + if conn.tlsConn != nil { + log.Debug("Closing TLS connection") + if err := conn.tlsConn.Close(); err != nil { + log.Debugf("Error closing TLS connection: %v", err) + } + conn.tlsConn = nil + } + if conn.rdpConn != nil { + log.Debug("Closing TCP connection") + if err := conn.rdpConn.Close(); err != nil { + log.Debugf("Error closing TCP connection: %v", err) + } + conn.rdpConn = nil + } + p.mu.Lock() + delete(p.activeConnections, conn.id) + p.mu.Unlock() +} + +func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) { + if conn.wsHandlers.Get("receiveFromGo").Truthy() { + uint8Array := js.Global().Get("Uint8Array").New(len(data)) + js.CopyBytesToJS(uint8Array, data) + conn.wsHandlers.Call("receiveFromGo", uint8Array.Get("buffer")) + } else if conn.wsHandlers.Get("send").Truthy() { + uint8Array := js.Global().Get("Uint8Array").New(len(data)) + js.CopyBytesToJS(uint8Array, data) + conn.wsHandlers.Call("send", uint8Array.Get("buffer")) + } +} diff --git a/client/wasm/internal/rdp/rdcleanpath_handlers.go b/client/wasm/internal/rdp/rdcleanpath_handlers.go new file mode 100644 index 000000000..010efa5ea --- /dev/null +++ b/client/wasm/internal/rdp/rdcleanpath_handlers.go @@ -0,0 +1,251 @@ +//go:build js + +package rdp + +import ( + "crypto/tls" + "encoding/asn1" + "io" + "syscall/js" + + log "github.com/sirupsen/logrus" +) + +func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { + log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination) + + if pdu.Version != RDCleanPathVersion { + p.sendRDCleanPathError(conn, "Unsupported version") + return + } + + destination := conn.destination + if pdu.Destination != "" { + destination = pdu.Destination + } + + rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination) + if err != nil { + log.Errorf("Failed to connect to %s: %v", destination, err) + p.sendRDCleanPathError(conn, "Connection failed") + p.cleanupConnection(conn) + return + } + conn.rdpConn = rdpConn + + // RDP always starts with X.224 negotiation, then determines if TLS is needed + // Modern RDP (since Windows Vista/2008) typically requires TLS + // The X.224 Connection Confirm response will indicate if TLS is required + // For now, we'll attempt TLS for all connections as it's the modern default + p.setupTLSConnection(conn, pdu) +} + +func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) { + var x224Response []byte + if len(pdu.X224ConnectionPDU) > 0 { + log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU)) + _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU) + if err != nil { + log.Errorf("Failed to write X.224 PDU: %v", err) + p.sendRDCleanPathError(conn, "Failed to forward X.224") + return + } + + response := make([]byte, 1024) + n, err := conn.rdpConn.Read(response) + if err != nil { + log.Errorf("Failed to read X.224 response: %v", err) + p.sendRDCleanPathError(conn, "Failed to read X.224 response") + return + } + x224Response = response[:n] + log.Debugf("Received X.224 Connection Confirm (%d bytes)", n) + } + + tlsConfig := p.getTLSConfigWithValidation(conn) + + tlsConn := tls.Client(conn.rdpConn, tlsConfig) + conn.tlsConn = tlsConn + + if err := tlsConn.Handshake(); err != nil { + log.Errorf("TLS handshake failed: %v", err) + p.sendRDCleanPathError(conn, "TLS handshake failed") + return + } + + log.Info("TLS handshake successful") + + // Certificate validation happens during handshake via VerifyConnection callback + var certChain [][]byte + connState := tlsConn.ConnectionState() + if len(connState.PeerCertificates) > 0 { + for _, cert := range connState.PeerCertificates { + certChain = append(certChain, cert.Raw) + } + log.Debugf("Extracted %d certificates from TLS connection", len(certChain)) + } + + responsePDU := RDCleanPathPDU{ + Version: RDCleanPathVersion, + ServerAddr: conn.destination, + ServerCertChain: certChain, + } + + if len(x224Response) > 0 { + responsePDU.X224ConnectionPDU = x224Response + } + + p.sendRDCleanPathPDU(conn, responsePDU) + + log.Debug("Starting TLS forwarding") + go p.forwardConnToWS(conn, conn.tlsConn, "TLS") + go p.forwardWSToConn(conn, conn.tlsConn, "TLS") + + <-conn.ctx.Done() + log.Debug("TLS connection context done, cleaning up") + p.cleanupConnection(conn) +} + +func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) { + if len(pdu.X224ConnectionPDU) > 0 { + log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU)) + _, err := conn.rdpConn.Write(pdu.X224ConnectionPDU) + if err != nil { + log.Errorf("Failed to write X.224 PDU: %v", err) + p.sendRDCleanPathError(conn, "Failed to forward X.224") + return + } + + response := make([]byte, 1024) + n, err := conn.rdpConn.Read(response) + if err != nil { + log.Errorf("Failed to read X.224 response: %v", err) + p.sendRDCleanPathError(conn, "Failed to read X.224 response") + return + } + + responsePDU := RDCleanPathPDU{ + Version: RDCleanPathVersion, + X224ConnectionPDU: response[:n], + ServerAddr: conn.destination, + } + + p.sendRDCleanPathPDU(conn, responsePDU) + } else { + responsePDU := RDCleanPathPDU{ + Version: RDCleanPathVersion, + ServerAddr: conn.destination, + } + p.sendRDCleanPathPDU(conn, responsePDU) + } + + go p.forwardConnToWS(conn, conn.rdpConn, "TCP") + go p.forwardWSToConn(conn, conn.rdpConn, "TCP") + + <-conn.ctx.Done() + log.Debug("TCP connection context done, cleaning up") + p.cleanupConnection(conn) +} + +func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) { + data, err := asn1.Marshal(pdu) + if err != nil { + log.Errorf("Failed to marshal RDCleanPath PDU: %v", err) + return + } + + log.Debugf("Sending RDCleanPath PDU response (%d bytes)", len(data)) + p.sendToWebSocket(conn, data) +} + +func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) { + pdu := RDCleanPathPDU{ + Version: RDCleanPathVersion, + Error: []byte(errorMsg), + } + + data, err := asn1.Marshal(pdu) + if err != nil { + log.Errorf("Failed to marshal error PDU: %v", err) + return + } + + p.sendToWebSocket(conn, data) +} + +func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) { + msgChan := make(chan []byte) + errChan := make(chan error) + + handler := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + if len(args) < 1 { + errChan <- io.EOF + return nil + } + + data := args[0] + if data.InstanceOf(js.Global().Get("Uint8Array")) { + length := data.Get("length").Int() + bytes := make([]byte, length) + js.CopyBytesToGo(bytes, data) + msgChan <- bytes + } + return nil + }) + defer handler.Release() + + conn.wsHandlers.Set("onceGoMessage", handler) + + select { + case msg := <-msgChan: + return msg, nil + case err := <-errChan: + return nil, err + case <-conn.ctx.Done(): + return nil, conn.ctx.Err() + } +} + +func (p *RDCleanPathProxy) forwardWSToConn(conn *proxyConnection, dst io.Writer, connType string) { + for { + if conn.ctx.Err() != nil { + return + } + + msg, err := p.readWebSocketMessage(conn) + if err != nil { + if err != io.EOF { + log.Errorf("Failed to read from WebSocket: %v", err) + } + return + } + + _, err = dst.Write(msg) + if err != nil { + log.Errorf("Failed to write to %s: %v", connType, err) + return + } + } +} + +func (p *RDCleanPathProxy) forwardConnToWS(conn *proxyConnection, src io.Reader, connType string) { + buffer := make([]byte, 32*1024) + + for { + if conn.ctx.Err() != nil { + return + } + + n, err := src.Read(buffer) + if err != nil { + if err != io.EOF { + log.Errorf("Failed to read from %s: %v", connType, err) + } + return + } + + if n > 0 { + p.sendToWebSocket(conn, buffer[:n]) + } + } +} diff --git a/client/wasm/internal/ssh/client.go b/client/wasm/internal/ssh/client.go new file mode 100644 index 000000000..ca35525eb --- /dev/null +++ b/client/wasm/internal/ssh/client.go @@ -0,0 +1,213 @@ +//go:build js + +package ssh + +import ( + "context" + "fmt" + "io" + "sync" + "time" + + "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" + + netbird "github.com/netbirdio/netbird/client/embed" +) + +const ( + sshDialTimeout = 30 * time.Second +) + +func closeWithLog(c io.Closer, resource string) { + if c != nil { + if err := c.Close(); err != nil { + logrus.Debugf("Failed to close %s: %v", resource, err) + } + } +} + +type Client struct { + nbClient *netbird.Client + sshClient *ssh.Client + session *ssh.Session + stdin io.WriteCloser + stdout io.Reader + stderr io.Reader + mu sync.RWMutex +} + +// NewClient creates a new SSH client +func NewClient(nbClient *netbird.Client) *Client { + return &Client{ + nbClient: nbClient, + } +} + +// Connect establishes an SSH connection through NetBird network +func (c *Client) Connect(host string, port int, username string) error { + addr := fmt.Sprintf("%s:%d", host, port) + logrus.Infof("SSH: Connecting to %s as %s", addr, username) + + var authMethods []ssh.AuthMethod + + nbConfig, err := c.nbClient.GetConfig() + if err != nil { + return fmt.Errorf("get NetBird config: %w", err) + } + if nbConfig.SSHKey == "" { + return fmt.Errorf("no NetBird SSH key available - key should be generated during client initialization") + } + + signer, err := parseSSHPrivateKey([]byte(nbConfig.SSHKey)) + if err != nil { + return fmt.Errorf("parse NetBird SSH private key: %w", err) + } + + pubKey := signer.PublicKey() + logrus.Infof("SSH: Using NetBird key authentication with public key type: %s", pubKey.Type()) + + authMethods = append(authMethods, ssh.PublicKeys(signer)) + + config := &ssh.ClientConfig{ + User: username, + Auth: authMethods, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: sshDialTimeout, + } + + ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout) + defer cancel() + + conn, err := c.nbClient.Dial(ctx, "tcp", addr) + if err != nil { + return fmt.Errorf("dial %s: %w", addr, err) + } + + sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config) + if err != nil { + closeWithLog(conn, "connection after handshake error") + return fmt.Errorf("SSH handshake: %w", err) + } + + c.sshClient = ssh.NewClient(sshConn, chans, reqs) + logrus.Infof("SSH: Connected to %s", addr) + + return nil +} + +// StartSession starts an SSH session with PTY +func (c *Client) StartSession(cols, rows int) error { + if c.sshClient == nil { + return fmt.Errorf("SSH client not connected") + } + + session, err := c.sshClient.NewSession() + if err != nil { + return fmt.Errorf("create session: %w", err) + } + + c.mu.Lock() + defer c.mu.Unlock() + c.session = session + + modes := ssh.TerminalModes{ + ssh.ECHO: 1, + ssh.TTY_OP_ISPEED: 14400, + ssh.TTY_OP_OSPEED: 14400, + ssh.VINTR: 3, + ssh.VQUIT: 28, + ssh.VERASE: 127, + } + + if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil { + closeWithLog(session, "session after PTY error") + return fmt.Errorf("PTY request: %w", err) + } + + c.stdin, err = session.StdinPipe() + if err != nil { + closeWithLog(session, "session after stdin error") + return fmt.Errorf("get stdin: %w", err) + } + + c.stdout, err = session.StdoutPipe() + if err != nil { + closeWithLog(session, "session after stdout error") + return fmt.Errorf("get stdout: %w", err) + } + + c.stderr, err = session.StderrPipe() + if err != nil { + closeWithLog(session, "session after stderr error") + return fmt.Errorf("get stderr: %w", err) + } + + if err := session.Shell(); err != nil { + closeWithLog(session, "session after shell error") + return fmt.Errorf("start shell: %w", err) + } + + logrus.Info("SSH: Session started with PTY") + return nil +} + +// Write sends data to the SSH session +func (c *Client) Write(data []byte) (int, error) { + c.mu.RLock() + stdin := c.stdin + c.mu.RUnlock() + + if stdin == nil { + return 0, fmt.Errorf("SSH session not started") + } + return stdin.Write(data) +} + +// Read reads data from the SSH session +func (c *Client) Read(buffer []byte) (int, error) { + c.mu.RLock() + stdout := c.stdout + c.mu.RUnlock() + + if stdout == nil { + return 0, fmt.Errorf("SSH session not started") + } + return stdout.Read(buffer) +} + +// Resize updates the terminal size +func (c *Client) Resize(cols, rows int) error { + c.mu.RLock() + session := c.session + c.mu.RUnlock() + + if session == nil { + return fmt.Errorf("SSH session not started") + } + return session.WindowChange(rows, cols) +} + +// Close closes the SSH connection +func (c *Client) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.session != nil { + closeWithLog(c.session, "SSH session") + c.session = nil + } + if c.stdin != nil { + closeWithLog(c.stdin, "stdin") + c.stdin = nil + } + c.stdout = nil + c.stderr = nil + + if c.sshClient != nil { + err := c.sshClient.Close() + c.sshClient = nil + return err + } + return nil +} diff --git a/client/wasm/internal/ssh/handlers.go b/client/wasm/internal/ssh/handlers.go new file mode 100644 index 000000000..ea64eb0aa --- /dev/null +++ b/client/wasm/internal/ssh/handlers.go @@ -0,0 +1,78 @@ +//go:build js + +package ssh + +import ( + "io" + "syscall/js" + + "github.com/sirupsen/logrus" +) + +// CreateJSInterface creates a JavaScript interface for the SSH client +func CreateJSInterface(client *Client) js.Value { + jsInterface := js.Global().Get("Object").Call("create", js.Null()) + + jsInterface.Set("write", js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 1 { + return js.ValueOf(false) + } + + data := args[0] + var bytes []byte + + if data.Type() == js.TypeString { + bytes = []byte(data.String()) + } else { + uint8Array := js.Global().Get("Uint8Array").New(data) + length := uint8Array.Get("length").Int() + bytes = make([]byte, length) + js.CopyBytesToGo(bytes, uint8Array) + } + + _, err := client.Write(bytes) + return js.ValueOf(err == nil) + })) + + jsInterface.Set("resize", js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 2 { + return js.ValueOf(false) + } + cols := args[0].Int() + rows := args[1].Int() + err := client.Resize(cols, rows) + return js.ValueOf(err == nil) + })) + + jsInterface.Set("close", js.FuncOf(func(this js.Value, args []js.Value) any { + client.Close() + return js.Undefined() + })) + + go readLoop(client, jsInterface) + + return jsInterface +} + +func readLoop(client *Client, jsInterface js.Value) { + buffer := make([]byte, 4096) + for { + n, err := client.Read(buffer) + if err != nil { + if err != io.EOF { + logrus.Debugf("SSH read error: %v", err) + } + if onclose := jsInterface.Get("onclose"); !onclose.IsUndefined() { + onclose.Invoke() + } + client.Close() + return + } + + if ondata := jsInterface.Get("ondata"); !ondata.IsUndefined() { + uint8Array := js.Global().Get("Uint8Array").New(n) + js.CopyBytesToJS(uint8Array, buffer[:n]) + ondata.Invoke(uint8Array) + } + } +} diff --git a/client/wasm/internal/ssh/key.go b/client/wasm/internal/ssh/key.go new file mode 100644 index 000000000..4868ba30a --- /dev/null +++ b/client/wasm/internal/ssh/key.go @@ -0,0 +1,50 @@ +//go:build js + +package ssh + +import ( + "crypto/x509" + "encoding/pem" + "fmt" + "strings" + + "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" +) + +// parseSSHPrivateKey parses a private key in either SSH or PKCS8 format +func parseSSHPrivateKey(keyPEM []byte) (ssh.Signer, error) { + keyStr := string(keyPEM) + if !strings.Contains(keyStr, "-----BEGIN") { + keyPEM = []byte("-----BEGIN PRIVATE KEY-----\n" + keyStr + "\n-----END PRIVATE KEY-----") + } + + signer, err := ssh.ParsePrivateKey(keyPEM) + if err == nil { + return signer, nil + } + logrus.Debugf("SSH: Failed to parse as SSH format: %v", err) + + block, _ := pem.Decode(keyPEM) + if block == nil { + keyPreview := string(keyPEM) + if len(keyPreview) > 100 { + keyPreview = keyPreview[:100] + } + return nil, fmt.Errorf("decode PEM block from key: %s", keyPreview) + } + + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + logrus.Debugf("SSH: Failed to parse as PKCS8: %v", err) + if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { + return ssh.NewSignerFromKey(rsaKey) + } + if ecKey, err := x509.ParseECPrivateKey(block.Bytes); err == nil { + return ssh.NewSignerFromKey(ecKey) + } + return nil, fmt.Errorf("parse private key: %w", err) + } + + return ssh.NewSignerFromKey(key) +} diff --git a/encryption/route53.go b/encryption/route53.go index 3c81ab103..48c7a3a1b 100644 --- a/encryption/route53.go +++ b/encryption/route53.go @@ -1,3 +1,5 @@ +//go:build !js + package encryption import ( diff --git a/flow/client/client.go b/flow/client/client.go index 603fd6882..03a4accaf 100644 --- a/flow/client/client.go +++ b/flow/client/client.go @@ -38,7 +38,8 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl return nil, fmt.Errorf("parsing url: %w", err) } var opts []grpc.DialOption - if parsedURL.Scheme == "https" { + tlsEnabled := parsedURL.Scheme == "https" + if tlsEnabled { certPool, err := x509.SystemCertPool() if err != nil || certPool == nil { log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err) @@ -53,7 +54,7 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl } opts = append(opts, - nbgrpc.WithCustomDialer(), + nbgrpc.WithCustomDialer(tlsEnabled), grpc.WithIdleTimeout(interval*2), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/go.mod b/go.mod index 23aa45277..c4b629993 100644 --- a/go.mod +++ b/go.mod @@ -37,7 +37,7 @@ require ( github.com/c-robinson/iplib v1.0.3 github.com/caddyserver/certmagic v0.21.3 github.com/cilium/ebpf v0.15.0 - github.com/coder/websocket v1.8.12 + github.com/coder/websocket v1.8.13 github.com/coreos/go-iptables v0.7.0 github.com/creack/pty v1.1.18 github.com/eko/gocache/lib/v4 v4.2.0 diff --git a/go.sum b/go.sum index 7096be3fe..13838b82d 100644 --- a/go.sum +++ b/go.sum @@ -140,8 +140,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= -github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= -github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= +github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= +github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFquP7hsFII= github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index 984a56a39..ddd81daa2 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -10,6 +10,8 @@ import ( "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/peers/ephemeral" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" ) func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager { @@ -56,8 +58,8 @@ func (s *BaseServer) AuthManager() auth.Manager { }) } -func (s *BaseServer) EphemeralManager() *server.EphemeralManager { - return Create(s, func() *server.EphemeralManager { - return server.NewEphemeralManager(s.Store(), s.AccountManager()) +func (s *BaseServer) EphemeralManager() ephemeral.Manager { + return Create(s, func() ephemeral.Manager { + return manager.NewEphemeralManager(s.Store(), s.AccountManager()) }) } diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 70f0f93a9..daec4ef6f 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -65,6 +65,10 @@ func (s *BaseServer) AccountManager() account.Manager { if err != nil { log.Fatalf("failed to create account manager: %v", err) } + + s.AfterInit(func(s *BaseServer) { + accountManager.SetEphemeralManager(s.EphemeralManager()) + }) return accountManager }) } diff --git a/management/internals/server/server.go b/management/internals/server/server.go index e868c2529..ae9ac4a60 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -6,12 +6,14 @@ import ( "fmt" "net" "net/http" + "net/netip" "strings" "sync" "time" "github.com/google/uuid" log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/metric" "golang.org/x/crypto/acme/autocert" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" @@ -22,6 +24,8 @@ import ( "github.com/netbirdio/netbird/management/server/metrics" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/util/wsproxy" + wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server" "github.com/netbirdio/netbird/version" ) @@ -92,12 +96,6 @@ func (s *BaseServer) Start(ctx context.Context) error { s.PeersManager() s.GeoLocationManager() - for _, fn := range s.afterInit { - if fn != nil { - fn(s) - } - } - err := s.Metrics().Expose(srvCtx, s.mgmtMetricsPort, "/metrics") if err != nil { return fmt.Errorf("failed to expose metrics: %v", err) @@ -147,7 +145,7 @@ func (s *BaseServer) Start(ctx context.Context) error { log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String()) } - rootHandler := handlerFunc(s.GRPCServer(), s.APIHandler()) + rootHandler := s.handlerFunc(s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter()) switch { case s.certManager != nil: // a call to certManager.Listener() always creates a new listener so we do it once @@ -176,6 +174,12 @@ func (s *BaseServer) Start(ctx context.Context) error { } } + for _, fn := range s.afterInit { + if fn != nil { + fn(s) + } + } + log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion()) log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String()) s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled) @@ -247,13 +251,17 @@ func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config) return util.DirectWriteJson(ctx, path, config) } -func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handler { +func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler { + wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), ManagementLegacyPort), wsproxyserver.WithOTelMeter(meter)) + return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - grpcHeader := strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") || - strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto") - if request.ProtoMajor == 2 && grpcHeader { + switch { + case request.ProtoMajor == 2 && (strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") || + strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto")): gRPCHandler.ServeHTTP(writer, request) - } else { + case request.URL.Path == wsproxy.ProxyPath: + wsProxy.Handler().ServeHTTP(writer, request) + default: httpHandler.ServeHTTP(writer, request) } }) diff --git a/management/server/account.go b/management/server/account.go index ee9f294a4..dca105ddf 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -35,6 +35,7 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -74,6 +75,7 @@ type DefaultAccountManager struct { ctx context.Context eventStore activity.Store geo geolocation.Geolocation + ephemeralManager ephemeral.Manager requestBuffer *AccountRequestBuffer @@ -261,6 +263,10 @@ func BuildManager( return am, nil } +func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) { + am.ephemeralManager = em +} + func (am *DefaultAccountManager) startWarmup(ctx context.Context) { var initialInterval int64 intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS") diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 30fbbbc3e..a1ed9498b 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -12,6 +12,7 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -56,7 +57,7 @@ type Manager interface { UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error) - AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) @@ -125,5 +126,6 @@ type Manager interface { UpdateToPrimaryAccount(ctx context.Context, accountId string) error GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) + SetEphemeralManager(em ephemeral.Manager) AllowSync(string, uint64) bool } diff --git a/management/server/account_test.go b/management/server/account_test.go index 81a921bf9..07d2f2383 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -66,7 +66,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account setupKey = key.Key } - _, _, _, err := manager.AddPeer(context.Background(), setupKey, userID, peer) + _, _, _, err := manager.AddPeer(context.Background(), "", setupKey, userID, peer, false) if err != nil { t.Error("expected to add new peer successfully after creating new account, but failed", err) } @@ -1048,10 +1048,10 @@ func TestAccountManager_AddPeer(t *testing.T) { } expectedPeerKey := key.PublicKey().String() - peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -1112,10 +1112,10 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { expectedPeerKey := key.PublicKey().String() expectedUserID := userID - peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v, account users: %v", err, account.CreatedBy) return @@ -1429,10 +1429,10 @@ func TestAccountManager_DeletePeer(t *testing.T) { peerKey := key.PublicKey().String() - peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey, Meta: nbpeer.PeerSystemMeta{Hostname: peerKey}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -1805,11 +1805,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { key, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: key.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, LoginExpirationEnabled: true, - }) + }, false) require.NoError(t, err, "unable to add peer") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -1861,11 +1861,11 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. key, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - _, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: key.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, LoginExpirationEnabled: true, - }) + }, false) require.NoError(t, err, "unable to add peer") _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour, @@ -1904,11 +1904,11 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test key, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - _, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: key.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, LoginExpirationEnabled: true, - }) + }, false) require.NoError(t, err, "unable to add peer") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -2952,14 +2952,14 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, } expectedPeerKey := key.PublicKey().String() - peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, Status: &nbpeer.PeerStatus{ Connected: true, LastSeen: time.Now().UTC(), }, - }) + }, false) if err != nil { t.Fatalf("expecting peer to be added, got failure %v", err) } @@ -3552,16 +3552,16 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) { key2, err := wgtypes.GenerateKey() require.NoError(t, err, "unable to generate WireGuard key") - peer1, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: key1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, - }) + }, false) require.NoError(t, err, "unable to add peer1") - peer2, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{ + peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ Key: key2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) require.NoError(t, err, "unable to add peer2") t.Run("update peer IP successfully", func(t *testing.T) { diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 55a1bbe66..a2a2ce529 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -281,11 +281,11 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account return nil, err } - savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1) + savedPeer1, _, _, err := am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer1, false) if err != nil { return nil, err } - _, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2) + _, _, _, err = am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer2, false) if err != nil { return nil, err } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 60a00207e..1177eefff 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -22,6 +22,7 @@ import ( integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/store" @@ -55,7 +56,7 @@ type GRPCServer struct { config *nbconfig.Config secretsManager SecretsManager appMetrics telemetry.AppMetrics - ephemeralManager *EphemeralManager + ephemeralManager ephemeral.Manager peerLocks sync.Map authManager auth.Manager @@ -73,7 +74,7 @@ func NewServer( peersUpdateManager *PeersUpdateManager, secretsManager SecretsManager, appMetrics telemetry.AppMetrics, - ephemeralManager *EphemeralManager, + ephemeralManager ephemeral.Manager, authManager auth.Manager, integratedPeerValidator integrated_validator.IntegratedValidator, ) (*GRPCServer, error) { diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index af501e151..4b33495de 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -32,6 +32,7 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) { router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). Methods("GET", "PUT", "DELETE", "OPTIONS") router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS") + router.HandleFunc("/peers/{peerId}/temporary-access", peersHandler.CreateTemporaryAccess).Methods("POST", "OPTIONS") } // NewHandler creates a new peers Handler @@ -318,6 +319,88 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) } +func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + peerID := vars["peerId"] + if len(peerID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w) + return + } + + var req api.PeerTemporaryAccessRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + newPeer := &nbpeer.Peer{} + newPeer.FromAPITemporaryAccessRequest(&req) + + targetPeer, err := h.accountManager.GetPeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + peer, _, _, err := h.accountManager.AddPeer(r.Context(), userAuth.AccountId, "", userAuth.UserId, newPeer, true) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + for _, rule := range req.Rules { + protocol, portRange, err := types.ParseRuleString(rule) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + policy := &types.Policy{ + AccountID: userAuth.AccountId, + Description: "Temporary access policy for peer " + peer.Name, + Name: "Temporary access policy for peer " + peer.Name, + Enabled: true, + Rules: []*types.PolicyRule{{ + Name: "Temporary access rule", + Description: "Temporary access rule", + Enabled: true, + Action: types.PolicyTrafficActionAccept, + SourceResource: types.Resource{ + Type: types.ResourceTypePeer, + ID: peer.ID, + }, + DestinationResource: types.Resource{ + Type: types.ResourceTypePeer, + ID: targetPeer.ID, + }, + Bidirectional: false, + Protocol: protocol, + PortRanges: []types.RulePortRange{portRange}, + }}, + } + + _, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + } + + resp := &api.PeerTemporaryAccessResponse{ + Id: peer.ID, + Name: peer.Name, + Rules: req.Rules, + } + + util.WriteJSONObject(r.Context(), w, resp) +} + func toAccessiblePeers(netMap *types.NetworkMap, dnsDomain string) []api.AccessiblePeer { accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers)) for _, p := range netMap.Peers { diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index ba4997d22..a34d2086b 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -26,6 +26,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -460,7 +461,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - ephemeralMgr := NewEphemeralManager(store, accountManager) + ephemeralMgr := manager.NewEphemeralManager(store, accountManager) mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}) if err != nil { return nil, nil, "", cleanup, err diff --git a/management/server/management_test.go b/management/server/management_test.go index 61dc46d87..1a5e47354 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -25,6 +25,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -228,7 +229,7 @@ func startServer( peersUpdateManager, secretsManager, nil, - nil, + &manager.EphemeralManager{}, nil, server.MockIntegratedValidator{}, ) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 003385eb5..d160e7269 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -15,6 +15,7 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -41,7 +42,7 @@ type MockAccountManager struct { DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) - AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error) GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error) @@ -351,12 +352,14 @@ func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string // AddPeer mock implementation of AddPeer from server.AccountManager interface func (am *MockAccountManager) AddPeer( ctx context.Context, + accountID string, setupKey string, userId string, peer *nbpeer.Peer, + temporary bool, ) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if am.AddPeerFunc != nil { - return am.AddPeerFunc(ctx, setupKey, userId, peer) + return am.AddPeerFunc(ctx, accountID, setupKey, userId, peer, temporary) } return nil, nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented") } @@ -972,6 +975,11 @@ func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth n return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented") } +// SetEphemeralManager mocks SetEphemeralManager of the AccountManager interface +func (am *MockAccountManager) SetEphemeralManager(em ephemeral.Manager) { + // Mock implementation - does nothing +} + func (am *MockAccountManager) AllowSync(key string, hash uint64) bool { if am.AllowSyncFunc != nil { return am.AllowSyncFunc(key, hash) diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 959e7856a..6c985410c 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -876,11 +876,11 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, return nil, err } - _, _, _, err = am.AddPeer(context.Background(), "", userID, peer1) + _, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer1, false) if err != nil { return nil, err } - _, _, _, err = am.AddPeer(context.Background(), "", userID, peer2) + _, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer2, false) if err != nil { return nil, err } diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 294f51676..66484d120 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -132,7 +132,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc res := nbtypes.Resource{ ID: resource.ID, - Type: resource.Type.String(), + Type: nbtypes.ResourceType(resource.Type.String()), } for _, groupID := range resource.GroupIDs { event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res) @@ -265,7 +265,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc func (m *managerImpl) updateResourceGroups(ctx context.Context, transaction store.Store, userID string, newResource, oldResource *types.NetworkResource) ([]func(), error) { res := nbtypes.Resource{ ID: newResource.ID, - Type: newResource.Type.String(), + Type: nbtypes.ResourceType(newResource.Type.String()), } oldResourceGroups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthUpdate, oldResource.AccountID, oldResource.ID) diff --git a/management/server/peer.go b/management/server/peer.go index 81f037499..ea4617af0 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -450,7 +450,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID stri // to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied // Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused). // The peer property is just a placeholder for the Peer properties to pass further -func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if setupKey == "" && userID == "" { // no auth method provided => reject access return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login") @@ -482,8 +482,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s var ephemeral bool var groupsToAdd []string var allowExtraDNSLabels bool - var accountID string - var isEphemeral bool if addedByUser { user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { @@ -492,10 +490,21 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s if user.PendingApproval { return nil, nil, nil, status.Errorf(status.PermissionDenied, "user pending approval cannot add peers") } - groupsToAdd = user.AutoGroups + if temporary { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Create) + if err != nil { + return nil, nil, nil, status.NewPermissionValidationError(err) + } + + if !allowed { + return nil, nil, nil, status.NewPermissionDeniedError() + } + } else { + accountID = user.AccountID + groupsToAdd = user.AutoGroups + } opEvent.InitiatorID = userID opEvent.Activity = activity.PeerAddedByUser - accountID = user.AccountID } else { // Validate the setup key sk, err := am.Store.GetSetupKeyBySecret(ctx, store.LockingStrengthNone, encodedHashedKey) @@ -516,13 +525,16 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s setupKeyName = sk.Name allowExtraDNSLabels = sk.AllowExtraDNSLabels accountID = sk.AccountID - isEphemeral = sk.Ephemeral if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 { return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels") } } opEvent.AccountID = accountID + if temporary { + ephemeral = true + } + if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" { if am.idpManager != nil { userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID}) @@ -549,10 +561,10 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s SSHKey: peer.SSHKey, LastLogin: ®istrationTime, CreatedAt: registrationTime, - LoginExpirationEnabled: addedByUser, + LoginExpirationEnabled: addedByUser && !temporary, Ephemeral: ephemeral, Location: peer.Location, - InactivityExpirationEnabled: addedByUser, + InactivityExpirationEnabled: addedByUser && !temporary, ExtraDNSLabels: peer.ExtraDNSLabels, AllowExtraDNSLabels: allowExtraDNSLabels, } @@ -588,7 +600,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } var freeLabel string - if isEphemeral || attempt > 1 { + if ephemeral || attempt > 1 { freeLabel, err = getPeerIPDNSLabel(freeIP, peer.Meta.Hostname) if err != nil { return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err) @@ -622,6 +634,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return fmt.Errorf("failed adding peer to All group: %w", err) } + if temporary { + // we are running the on disconnect handler so that it is considered not connected as we are adding the peer manually + am.ephemeralManager.OnPeerDisconnected(ctx, newPeer) + } + if addedByUser { err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin()) if err != nil { @@ -790,7 +807,7 @@ func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, lo ExtraDNSLabels: login.ExtraDNSLabels, } - return am.AddPeer(ctx, login.SetupKey, login.UserID, newPeer) + return am.AddPeer(ctx, "", login.SetupKey, login.UserID, newPeer, false) } log.WithContext(ctx).Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err) @@ -877,6 +894,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer if peer.SSHKey != login.SSHKey { peer.SSHKey = login.SSHKey shouldStorePeer = true + updateRemotePeers = true } if !peer.AllowExtraDNSLabels && len(login.ExtraDNSLabels) > 0 { @@ -1540,6 +1558,26 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto return nil, err } + peerPolicyRules, err := transaction.GetPolicyRulesByResourceID(ctx, store.LockingStrengthNone, accountID, peer.ID) + if err != nil { + return nil, err + } + for _, rule := range peerPolicyRules { + policy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, rule.PolicyID) + if err != nil { + return nil, err + } + + err = transaction.DeletePolicy(ctx, accountID, rule.PolicyID) + if err != nil { + return nil, err + } + + peerDeletedEvents = append(peerDeletedEvents, func() { + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) + }) + } + if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil { return nil, err } diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 6a6d1c91d..f89f10dac 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -8,6 +8,7 @@ import ( "time" "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/shared/management/http/api" ) // Peer represents a machine connected to the network. @@ -334,6 +335,17 @@ func (p *Peer) UpdateLastLogin() *Peer { return p } +func (p *Peer) FromAPITemporaryAccessRequest(a *api.PeerTemporaryAccessRequest) { + p.Ephemeral = true + p.Name = a.Name + p.Key = a.WgPubKey + p.Meta = PeerSystemMeta{ + Hostname: a.Name, + GoOS: "js", + OS: "js", + } +} + func (f Flags) isEqual(other Flags) bool { return f.RosenpassEnabled == other.RosenpassEnabled && f.RosenpassPermissive == other.RosenpassPermissive && diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 31c309430..734536d7b 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -193,10 +193,10 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { return } - peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -207,10 +207,10 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { t.Fatal(err) return } - _, _, _, err = manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) @@ -266,10 +266,10 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } - peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -280,10 +280,10 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { t.Fatal(err) return } - peer2, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer2, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -442,10 +442,10 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { return } - peer1, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -456,10 +456,10 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { t.Fatal(err) return } - _, _, _, err = manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) @@ -514,10 +514,10 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { return } - peer1, _, _, err := manager.AddPeer(context.Background(), "", someUser, &nbpeer.Peer{ + peer1, _, _, err := manager.AddPeer(context.Background(), "", "", someUser, &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -530,10 +530,10 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { } // the second peer added with a setup key - peer2, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + peer2, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Fatal(err) return @@ -702,19 +702,19 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { return } - _, _, _, err = manager.AddPeer(context.Background(), "", someUser, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", "", someUser, &nbpeer.Peer{ Key: peerKey1.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return } - _, _, _, err = manager.AddPeer(context.Background(), "", adminUser, &nbpeer.Peer{ + _, _, _, err = manager.AddPeer(context.Background(), "", "", adminUser, &nbpeer.Peer{ Key: peerKey2.PublicKey().String(), Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, - }) + }, false) if err != nil { t.Errorf("expecting peer to be added, got failure %v", err) return @@ -1300,7 +1300,7 @@ func Test_RegisterPeerByUser(t *testing.T) { }, } - addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer) + addedPeer, _, _, err := am.AddPeer(context.Background(), "", "", existingUserID, newPeer, false) require.NoError(t, err) assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels) @@ -1422,7 +1422,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { ExtraDNSLabels: newPeerTemplate.ExtraDNSLabels, } - addedPeer, _, _, err := am.AddPeer(context.Background(), tc.existingSetupKeyID, "", currentPeer) + addedPeer, _, _, err := am.AddPeer(context.Background(), "", tc.existingSetupKeyID, "", currentPeer, false) if tc.expectAddPeerError { require.Error(t, err, "Expected an error when adding peer with setup key: %s", tc.existingSetupKeyID) @@ -1523,7 +1523,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { SSHEnabled: false, } - _, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer) + _, _, _, err = am.AddPeer(context.Background(), "", faultyKey, "", newPeer, false) require.Error(t, err) _, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, newPeer.Key) @@ -1658,7 +1658,7 @@ func Test_LoginPeer(t *testing.T) { if sk.AllowExtraDNSLabels { currentPeer.ExtraDNSLabels = newPeerTemplate.ExtraDNSLabels } - _, _, _, err = am.AddPeer(context.Background(), tc.setupKey, "", currentPeer) + _, _, _, err = am.AddPeer(context.Background(), "", tc.setupKey, "", currentPeer, false) require.NoError(t, err, "Expected no error when adding peer with setup key: %s", tc.setupKey) loginInput := types.PeerLogin{ @@ -1797,10 +1797,10 @@ func TestPeerAccountPeersUpdate(t *testing.T) { require.NoError(t, err) expectedPeerKey := key.PublicKey().String() - peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{ + peer4, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser1", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) require.NoError(t, err) select { @@ -1918,11 +1918,11 @@ func TestPeerAccountPeersUpdate(t *testing.T) { require.NoError(t, err) expectedPeerKey := key.PublicKey().String() - peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{ + peer4, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser1", &nbpeer.Peer{ Key: expectedPeerKey, LoginExpirationEnabled: true, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) require.NoError(t, err) select { @@ -1982,11 +1982,11 @@ func TestPeerAccountPeersUpdate(t *testing.T) { require.NoError(t, err) expectedPeerKey := key.PublicKey().String() - peer5, _, _, err = manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{ + peer5, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser2", &nbpeer.Peer{ Key: expectedPeerKey, LoginExpirationEnabled: true, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) require.NoError(t, err) select { @@ -2037,11 +2037,11 @@ func TestPeerAccountPeersUpdate(t *testing.T) { require.NoError(t, err) expectedPeerKey := key.PublicKey().String() - peer6, _, _, err = manager.AddPeer(context.Background(), "", "regularUser3", &nbpeer.Peer{ + peer6, _, _, err = manager.AddPeer(context.Background(), "", "", "regularUser3", &nbpeer.Peer{ Key: expectedPeerKey, LoginExpirationEnabled: true, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) require.NoError(t, err) select { @@ -2208,7 +2208,7 @@ func Test_AddPeer(t *testing.T) { <-start - _, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", newPeer) + _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", newPeer, false) if err != nil { errs <- fmt.Errorf("AddPeer failed for peer %d: %w", i, err) return @@ -2416,7 +2416,7 @@ func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) { }, } - _, _, _, err = manager.AddPeer(context.Background(), "", pendingUser.Id, peer) + _, _, _, err = manager.AddPeer(context.Background(), "", "", pendingUser.Id, peer, false) require.Error(t, err) assert.Contains(t, err.Error(), "user pending approval cannot add peers") } @@ -2451,7 +2451,7 @@ func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) { }, } - _, _, _, err = manager.AddPeer(context.Background(), "", regularUser.Id, peer) + _, _, _, err = manager.AddPeer(context.Background(), "", "", regularUser.Id, peer, false) require.NoError(t, err, "Regular user should be able to add peers") } @@ -2494,7 +2494,7 @@ func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) { WtVersion: "0.28.0", }, } - existingPeer, _, _, err := manager.AddPeer(context.Background(), "", pendingUser.Id, newPeer) + existingPeer, _, _, err := manager.AddPeer(context.Background(), "", "", pendingUser.Id, newPeer, false) require.NoError(t, err) // Now set the user back to pending approval after peer was created @@ -2550,7 +2550,7 @@ func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) { WtVersion: "0.28.0", }, } - existingPeer, _, _, err := manager.AddPeer(context.Background(), "", regularUser.Id, newPeer) + existingPeer, _, _, err := manager.AddPeer(context.Background(), "", "", regularUser.Id, newPeer, false) require.NoError(t, err) // Try to login with regular user diff --git a/management/server/peers/ephemeral/interface.go b/management/server/peers/ephemeral/interface.go new file mode 100644 index 000000000..a1605b3b9 --- /dev/null +++ b/management/server/peers/ephemeral/interface.go @@ -0,0 +1,14 @@ +package ephemeral + +import ( + "context" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +type Manager interface { + LoadInitialPeers(ctx context.Context) + Stop() + OnPeerConnected(ctx context.Context, peer *nbpeer.Peer) + OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) +} diff --git a/management/server/ephemeral.go b/management/server/peers/ephemeral/manager/ephemeral.go similarity index 99% rename from management/server/ephemeral.go rename to management/server/peers/ephemeral/manager/ephemeral.go index e3cb5459a..062ba69d2 100644 --- a/management/server/ephemeral.go +++ b/management/server/peers/ephemeral/manager/ephemeral.go @@ -1,4 +1,4 @@ -package server +package manager import ( "context" diff --git a/management/server/ephemeral_test.go b/management/server/peers/ephemeral/manager/ephemeral_test.go similarity index 75% rename from management/server/ephemeral_test.go rename to management/server/peers/ephemeral/manager/ephemeral_test.go index d07b9a422..fc7525c29 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/peers/ephemeral/manager/ephemeral_test.go @@ -1,4 +1,4 @@ -package server +package manager import ( "context" @@ -7,12 +7,15 @@ import ( "testing" "time" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + nbdns "github.com/netbirdio/netbird/dns" nbAccount "github.com/netbirdio/netbird/management/server/account" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" ) type MockStore struct { @@ -223,3 +226,57 @@ func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) store.account.Peers[p.ID] = p } } + +// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id +func newAccountWithId(ctx context.Context, accountID, userID, domain string, disableDefaultPolicy bool) *types.Account { + log.WithContext(ctx).Debugf("creating new account") + + network := types.NewNetwork() + peers := make(map[string]*nbpeer.Peer) + users := make(map[string]*types.User) + routes := make(map[route.ID]*route.Route) + setupKeys := map[string]*types.SetupKey{} + nameServersGroups := make(map[string]*nbdns.NameServerGroup) + + owner := types.NewOwnerUser(userID) + owner.AccountID = accountID + users[userID] = owner + + dnsSettings := types.DNSSettings{ + DisabledManagementGroups: make([]string, 0), + } + log.WithContext(ctx).Debugf("created new account %s", accountID) + + acc := &types.Account{ + Id: accountID, + CreatedAt: time.Now().UTC(), + SetupKeys: setupKeys, + Network: network, + Peers: peers, + Users: users, + CreatedBy: userID, + Domain: domain, + Routes: routes, + NameServerGroups: nameServersGroups, + DNSSettings: dnsSettings, + Settings: &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + GroupsPropagationEnabled: true, + RegularUsersViewBlocked: true, + + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, + RoutingPeerDNSResolutionEnabled: true, + }, + Onboarding: types.AccountOnboarding{ + OnboardingFlowPending: true, + SignupFormPending: true, + }, + } + + if err := acc.AddAllGroup(disableDefaultPolicy); err != nil { + log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err) + } + return acc +} diff --git a/management/server/policy.go b/management/server/policy.go index 3adee6397..9e4b3f73a 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -151,6 +151,12 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a return false, nil } + for _, rule := range existingPolicy.Rules { + if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { + return true, nil + } + } + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups()) if err != nil { return false, err @@ -161,6 +167,12 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a } } + for _, rule := range policy.Rules { + if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { + return true, nil + } + } + return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 027938320..382d026c8 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2037,6 +2037,25 @@ func (s *SqlStore) DeletePolicy(ctx context.Context, accountID, policyID string) }) } +func (s *SqlStore) GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, resourceID string) ([]*types.PolicyRule, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var policyRules []*types.PolicyRule + resourceIDPattern := `%"ID":"` + resourceID + `"%` + result := tx.Where("source_resource LIKE ? OR destination_resource LIKE ?", resourceIDPattern, resourceIDPattern). + Find(&policyRules) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get policy rules for resource id from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get policy rules for resource id from store") + } + + return policyRules, nil +} + // GetAccountPostureChecks retrieves posture checks for an account. func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { tx := s.db diff --git a/management/server/store/store.go b/management/server/store/store.go index 3c9d896b0..21b660d96 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -202,6 +202,7 @@ type Store interface { IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) MarkAccountPrimary(ctx context.Context, accountID string) error UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error + GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) ([]*types.PolicyRule, error) } const ( diff --git a/management/server/types/account.go b/management/server/types/account.go index a69d3bb08..f830023c7 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -1001,8 +1001,20 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P continue } - sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peer.ID, policy.SourcePostureChecks, validatedPeersMap) - destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peer.ID, nil, validatedPeersMap) + var sourcePeers, destinationPeers []*nbpeer.Peer + var peerInSources, peerInDestinations bool + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + sourcePeers, peerInSources = a.getPeerFromResource(rule.SourceResource, peer.ID) + } else { + sourcePeers, peerInSources = a.getAllPeersFromGroups(ctx, rule.Sources, peer.ID, policy.SourcePostureChecks, validatedPeersMap) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + destinationPeers, peerInDestinations = a.getPeerFromResource(rule.DestinationResource, peer.ID) + } else { + destinationPeers, peerInDestinations = a.getAllPeersFromGroups(ctx, rule.Destinations, peer.ID, nil, validatedPeersMap) + } if rule.Bidirectional { if peerInSources { @@ -1124,6 +1136,15 @@ func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, pe return filteredPeers, peerInGroups } +func (a *Account) getPeerFromResource(resource Resource, peerID string) ([]*nbpeer.Peer, bool) { + peer := a.GetPeer(resource.ID) + if peer == nil { + return []*nbpeer.Peer{}, false + } + + return []*nbpeer.Peer{peer}, resource.ID == peerID +} + // validatePostureChecksOnPeer validates the posture checks on a peer func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool { peer, ok := a.Peers[peerID] @@ -1379,7 +1400,12 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st addedResourceRoute := false for _, policy := range resourcePolicies[resource.ID] { - peers := a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) + var peers []string + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + peers = []string{policy.Rules[0].SourceResource.ID} + } else { + peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) + } if addSourcePeers { for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) { allSourcePeers[pID] = struct{}{} diff --git a/management/server/types/policy.go b/management/server/types/policy.go index 17964ed1f..5e86a87c6 100644 --- a/management/server/types/policy.go +++ b/management/server/types/policy.go @@ -1,5 +1,12 @@ package types +import ( + "errors" + "fmt" + "strconv" + "strings" +) + const ( // PolicyTrafficActionAccept indicates that the traffic is accepted PolicyTrafficActionAccept = PolicyTrafficActionType("accept") @@ -134,3 +141,83 @@ func (p *Policy) SourceGroups() []string { return groupIDs } + +func ParseRuleString(rule string) (PolicyRuleProtocolType, RulePortRange, error) { + rule = strings.TrimSpace(strings.ToLower(rule)) + if rule == "all" { + return PolicyRuleProtocolALL, RulePortRange{}, nil + } + if rule == "icmp" { + return PolicyRuleProtocolICMP, RulePortRange{}, nil + } + + split := strings.Split(rule, "/") + if len(split) != 2 { + return "", RulePortRange{}, errors.New("invalid rule format: expected protocol/port or protocol/port-range") + } + + protoStr := strings.TrimSpace(split[0]) + portStr := strings.TrimSpace(split[1]) + + var protocol PolicyRuleProtocolType + switch protoStr { + case "tcp": + protocol = PolicyRuleProtocolTCP + case "udp": + protocol = PolicyRuleProtocolUDP + case "icmp": + return "", RulePortRange{}, errors.New("icmp does not accept ports; use 'icmp' without '/…'") + default: + return "", RulePortRange{}, fmt.Errorf("invalid protocol: %q", protoStr) + } + + portRange, err := parsePortRange(portStr) + if err != nil { + return "", RulePortRange{}, err + } + + return protocol, portRange, nil +} + +func parsePortRange(portStr string) (RulePortRange, error) { + if strings.Contains(portStr, "-") { + rangeParts := strings.Split(portStr, "-") + if len(rangeParts) != 2 { + return RulePortRange{}, fmt.Errorf("invalid port range %q", portStr) + } + start, err := parsePort(strings.TrimSpace(rangeParts[0])) + if err != nil { + return RulePortRange{}, err + } + end, err := parsePort(strings.TrimSpace(rangeParts[1])) + if err != nil { + return RulePortRange{}, err + } + if start > end { + return RulePortRange{}, fmt.Errorf("invalid port range: start %d > end %d", start, end) + } + return RulePortRange{Start: uint16(start), End: uint16(end)}, nil + } + + p, err := parsePort(portStr) + if err != nil { + return RulePortRange{}, err + } + + return RulePortRange{Start: uint16(p), End: uint16(p)}, nil +} + +func parsePort(portStr string) (int, error) { + + if portStr == "" { + return 0, errors.New("empty port") + } + p, err := strconv.Atoi(portStr) + if err != nil { + return 0, fmt.Errorf("invalid port %q: %w", portStr, err) + } + if p < 1 || p > 65535 { + return 0, fmt.Errorf("port out of range (1–65535): %d", p) + } + return p, nil +} diff --git a/management/server/types/resource.go b/management/server/types/resource.go index 84d8e4b88..8347d8c03 100644 --- a/management/server/types/resource.go +++ b/management/server/types/resource.go @@ -4,9 +4,18 @@ import ( "github.com/netbirdio/netbird/shared/management/http/api" ) +type ResourceType string + +const ( + ResourceTypePeer ResourceType = "peer" + ResourceTypeDomain ResourceType = "domain" + ResourceTypeHost ResourceType = "host" + ResourceTypeSubnet ResourceType = "subnet" +) + type Resource struct { ID string - Type string + Type ResourceType } func (r *Resource) ToAPIResponse() *api.Resource { @@ -26,5 +35,5 @@ func (r *Resource) FromAPIRequest(req *api.Resource) { } r.ID = req.Id - r.Type = string(req.Type) + r.Type = ResourceType(req.Type) } diff --git a/management/server/user_test.go b/management/server/user_test.go index 9638559f9..5920a2a33 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1439,10 +1439,10 @@ func TestUserAccountPeersUpdate(t *testing.T) { require.NoError(t, err) expectedPeerKey := key.PublicKey().String() - peer4, _, _, err := manager.AddPeer(context.Background(), "", "regularUser2", &nbpeer.Peer{ + peer4, _, _, err := manager.AddPeer(context.Background(), "", "", "regularUser2", &nbpeer.Peer{ Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) + }, false) require.NoError(t, err) // updating user with linked peers should update account peers and send peer update diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index becc10ded..d4a9f1823 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc/status" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/internals/server/config" @@ -27,6 +28,7 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/peers" + "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -117,7 +119,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { groupsManager := groups.NewManagerMock() secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{}) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, mgmt.MockIntegratedValidator{}) if err != nil { t.Fatal(err) } diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 9a531b2ff..93578b1ae 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -507,6 +507,48 @@ components: - serial_number - extra_dns_labels - ephemeral + PeerTemporaryAccessRequest: + type: object + properties: + name: + description: Peer's hostname + type: string + example: temp-host-1 + wg_pub_key: + description: Peer's WireGuard public key + type: string + example: "n0r3pL4c3h0ld3rK3y==" + rules: + description: List of temporary access rules + type: array + items: + type: string + example: "tcp/80" + required: + - name + - wg_pub_key + - rules + PeerTemporaryAccessResponse: + type: object + properties: + name: + description: Peer's hostname + type: string + example: temp-host-1 + id: + description: Peer ID + type: string + example: chacbco6lnnbn6cg5s90 + rules: + description: List of temporary access rules + type: array + items: + type: string + example: "tcp/80" + required: + - name + - id + - rules AccessiblePeer: allOf: - $ref: '#/components/schemas/PeerMinimum' @@ -1404,7 +1446,8 @@ components: allOf: - $ref: '#/components/schemas/NetworkResourceType' - type: string - example: host + enum: ["peer"] + example: peer NetworkRequest: type: object properties: @@ -2793,6 +2836,42 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/peers/{peerId}/temporary-access: + post: + summary: Create a Temporary Access Peer + description: Creates a temporary access peer that can be used to access this peer and this peer only. The temporary access peer and its access policies will be automatically deleted after it disconnects. + tags: [ Peers ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: peerId + required: true + schema: + type: string + description: The unique identifier of a peer + requestBody: + description: Temporary Access Peer create request + content: + 'application/json': + schema: + $ref: '#/components/schemas/PeerTemporaryAccessRequest' + responses: + '200': + description: Temporary Access Peer response + content: + application/json: + schema: + $ref: '#/components/schemas/PeerTemporaryAccessResponse' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/peers/{peerId}/ingress/ports: get: x-cloud-only: true diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 28b89633c..3dbb32ef6 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -168,6 +168,7 @@ const ( const ( ResourceTypeDomain ResourceType = "domain" ResourceTypeHost ResourceType = "host" + ResourceTypePeer ResourceType = "peer" ResourceTypeSubnet ResourceType = "subnet" ) @@ -1221,6 +1222,30 @@ type PeerRequest struct { SshEnabled bool `json:"ssh_enabled"` } +// PeerTemporaryAccessRequest defines model for PeerTemporaryAccessRequest. +type PeerTemporaryAccessRequest struct { + // Name Peer's hostname + Name string `json:"name"` + + // Rules List of temporary access rules + Rules []string `json:"rules"` + + // WgPubKey Peer's WireGuard public key + WgPubKey string `json:"wg_pub_key"` +} + +// PeerTemporaryAccessResponse defines model for PeerTemporaryAccessResponse. +type PeerTemporaryAccessResponse struct { + // Id Peer ID + Id string `json:"id"` + + // Name Peer's hostname + Name string `json:"name"` + + // Rules List of temporary access rules + Rules []string `json:"rules"` +} + // PersonalAccessToken defines model for PersonalAccessToken. type PersonalAccessToken struct { // CreatedAt Date the token was created @@ -1949,6 +1974,9 @@ type PostApiPeersPeerIdIngressPortsJSONRequestBody = IngressPortAllocationReques // PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody defines body for PutApiPeersPeerIdIngressPortsAllocationId for application/json ContentType. type PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody = IngressPortAllocationRequest +// PostApiPeersPeerIdTemporaryAccessJSONRequestBody defines body for PostApiPeersPeerIdTemporaryAccess for application/json ContentType. +type PostApiPeersPeerIdTemporaryAccessJSONRequestBody = PeerTemporaryAccessRequest + // PostApiPoliciesJSONRequestBody defines body for PostApiPolicies for application/json ContentType. type PostApiPoliciesJSONRequestBody = PolicyUpdate diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index bf614e8aa..8381d6682 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v4.24.3 +// protoc v6.32.0 // source: management.proto package proto diff --git a/shared/relay/client/client.go b/shared/relay/client/client.go index 5dabc5742..57a98614d 100644 --- a/shared/relay/client/client.go +++ b/shared/relay/client/client.go @@ -9,11 +9,8 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/iface" auth "github.com/netbirdio/netbird/shared/relay/auth/hmac" "github.com/netbirdio/netbird/shared/relay/client/dialer" - "github.com/netbirdio/netbird/shared/relay/client/dialer/quic" - "github.com/netbirdio/netbird/shared/relay/client/dialer/ws" "github.com/netbirdio/netbird/shared/relay/healthcheck" "github.com/netbirdio/netbird/shared/relay/messages" ) @@ -296,14 +293,7 @@ func (c *Client) Close() error { } func (c *Client) connect(ctx context.Context) (*RelayAddr, error) { - // Force WebSocket for MTUs larger than default to avoid QUIC DATAGRAM frame size issues - var dialers []dialer.DialeFn - if c.mtu > 0 && c.mtu > iface.DefaultMTU { - c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU) - dialers = []dialer.DialeFn{ws.Dialer{}} - } else { - dialers = []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}} - } + dialers := c.getDialers() rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...) conn, err := rd.Dial() diff --git a/shared/relay/client/dialer/ws/conn.go b/shared/relay/client/dialer/ws/conn.go index 0086b702b..d5b719f51 100644 --- a/shared/relay/client/dialer/ws/conn.go +++ b/shared/relay/client/dialer/ws/conn.go @@ -38,8 +38,7 @@ func (c *Conn) Read(b []byte) (n int, err error) { } func (c *Conn) Write(b []byte) (n int, err error) { - err = c.Conn.Write(c.ctx, websocket.MessageBinary, b) - return 0, err + return 0, c.Conn.Write(c.ctx, websocket.MessageBinary, b) } func (c *Conn) RemoteAddr() net.Addr { diff --git a/shared/relay/client/dialer/ws/dialopts_generic.go b/shared/relay/client/dialer/ws/dialopts_generic.go new file mode 100644 index 000000000..9dfe698d0 --- /dev/null +++ b/shared/relay/client/dialer/ws/dialopts_generic.go @@ -0,0 +1,11 @@ +//go:build !js + +package ws + +import "github.com/coder/websocket" + +func createDialOptions() *websocket.DialOptions { + return &websocket.DialOptions{ + HTTPClient: httpClientNbDialer(), + } +} diff --git a/shared/relay/client/dialer/ws/dialopts_js.go b/shared/relay/client/dialer/ws/dialopts_js.go new file mode 100644 index 000000000..7eac27531 --- /dev/null +++ b/shared/relay/client/dialer/ws/dialopts_js.go @@ -0,0 +1,10 @@ +//go:build js + +package ws + +import "github.com/coder/websocket" + +func createDialOptions() *websocket.DialOptions { + // WASM version doesn't support HTTPClient + return &websocket.DialOptions{} +} diff --git a/shared/relay/client/dialer/ws/ws.go b/shared/relay/client/dialer/ws/ws.go index ef6bd6b3c..66fff3447 100644 --- a/shared/relay/client/dialer/ws/ws.go +++ b/shared/relay/client/dialer/ws/ws.go @@ -32,9 +32,7 @@ func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) { return nil, err } - opts := &websocket.DialOptions{ - HTTPClient: httpClientNbDialer(), - } + opts := createDialOptions() parsedURL, err := url.Parse(wsURL) if err != nil { diff --git a/shared/relay/client/dialers_generic.go b/shared/relay/client/dialers_generic.go new file mode 100644 index 000000000..a8ed79961 --- /dev/null +++ b/shared/relay/client/dialers_generic.go @@ -0,0 +1,19 @@ +//go:build !js + +package client + +import ( + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/shared/relay/client/dialer" + "github.com/netbirdio/netbird/shared/relay/client/dialer/quic" + "github.com/netbirdio/netbird/shared/relay/client/dialer/ws" +) + +// getDialers returns the list of dialers to use for connecting to the relay server. +func (c *Client) getDialers() []dialer.DialeFn { + if c.mtu > 0 && c.mtu > iface.DefaultMTU { + c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU) + return []dialer.DialeFn{ws.Dialer{}} + } + return []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}} +} diff --git a/shared/relay/client/dialers_js.go b/shared/relay/client/dialers_js.go new file mode 100644 index 000000000..6bd0e6696 --- /dev/null +++ b/shared/relay/client/dialers_js.go @@ -0,0 +1,13 @@ +//go:build js + +package client + +import ( + "github.com/netbirdio/netbird/shared/relay/client/dialer" + "github.com/netbirdio/netbird/shared/relay/client/dialer/ws" +) + +func (c *Client) getDialers() []dialer.DialeFn { + // JS/WASM build only uses WebSocket transport + return []dialer.DialeFn{ws.Dialer{}} +} diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 1d76fa4e4..e2a69a75b 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -8,14 +8,16 @@ import ( "fmt" "net" "net/http" - // nolint:gosec _ "net/http/pprof" - "strings" + "net/netip" "time" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" + "go.opentelemetry.io/otel/metric" "golang.org/x/crypto/acme/autocert" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" "github.com/netbirdio/netbird/signal/metrics" @@ -23,6 +25,8 @@ import ( "github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/signal/server" "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/util/wsproxy" + wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server" "github.com/netbirdio/netbird/version" log "github.com/sirupsen/logrus" @@ -32,6 +36,8 @@ import ( "google.golang.org/grpc/keepalive" ) +const legacyGRPCPort = 10000 + var ( signalPort int metricsPort int @@ -113,7 +119,7 @@ var ( } proto.RegisterSignalExchangeServer(grpcServer, srv) - grpcRootHandler := grpcHandlerFunc(grpcServer) + grpcRootHandler := grpcHandlerFunc(grpcServer, metricsServer.Meter) if certManager != nil { startServerWithCertManager(certManager, grpcRootHandler) @@ -123,19 +129,30 @@ var ( var grpcListener net.Listener var httpListener net.Listener - // If certManager is configured and signalPort == 443, then the gRPC server has already been started - if certManager == nil || signalPort != 443 { - grpcListener, err = serveGRPC(grpcServer, signalPort) + // Start the main server - always serve HTTP with WebSocket proxy support + // If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager + if certManager == nil { + // Without TLS, serve plain HTTP + httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort)) if err != nil { return err } - log.Infof("running gRPC server: %s", grpcListener.Addr().String()) + log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String()) + serveHTTP(httpListener, grpcRootHandler) + } else if signalPort != 443 { + // With TLS but not on port 443, serve HTTPS + httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), certManager.TLSConfig()) + if err != nil { + return err + } + log.Infof("running HTTPS server with WebSocket proxy: %s", httpListener.Addr().String()) + serveHTTP(httpListener, grpcRootHandler) } - if signalPort != 10000 { + if signalPort != legacyGRPCPort { // The Signal gRPC server was running on port 10000 previously. Old agents that are already connected to Signal // are using port 10000. For compatibility purposes we keep running a 2nd gRPC server on port 10000. - compatListener, err = serveGRPC(grpcServer, 10000) + compatListener, err = serveGRPC(grpcServer, legacyGRPCPort) if err != nil { return err } @@ -236,11 +253,14 @@ func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler h } } -func grpcHandlerFunc(grpcServer *grpc.Server) http.Handler { +func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler { + wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), legacyGRPCPort), wsproxyserver.WithOTelMeter(meter)) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - grpcHeader := strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") || - strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc+proto") - if r.ProtoMajor == 2 && grpcHeader { + switch { + case r.URL.Path == wsproxy.ProxyPath: + wsProxy.Handler().ServeHTTP(w, r) + default: grpcServer.ServeHTTP(w, r) } }) @@ -257,7 +277,11 @@ func notifyStop(msg string) { func serveHTTP(httpListener net.Listener, handler http.Handler) { go func() { - err := http.Serve(httpListener, handler) + // Use h2c to support HTTP/2 without TLS (needed for gRPC) + h1s := &http.Server{ + Handler: h2c.NewHandler(handler, &http2.Server{}), + } + err := h1s.Serve(httpListener) if err != nil { notifyStop(fmt.Sprintf("failed running HTTP server %v", err)) } diff --git a/util/util_js.go b/util/util_js.go new file mode 100644 index 000000000..8c243cab3 --- /dev/null +++ b/util/util_js.go @@ -0,0 +1,8 @@ +//go:build js + +package util + +// IsAdmin returns false for WASM as there's no admin concept in browser +func IsAdmin() bool { + return false +} diff --git a/util/wsproxy/client/dialer_js.go b/util/wsproxy/client/dialer_js.go new file mode 100644 index 000000000..2caeed025 --- /dev/null +++ b/util/wsproxy/client/dialer_js.go @@ -0,0 +1,171 @@ +package client + +import ( + "context" + "fmt" + "net" + "sync" + "syscall/js" + "time" + + log "github.com/sirupsen/logrus" + "google.golang.org/grpc" + + "github.com/netbirdio/netbird/util/wsproxy" +) + +const dialTimeout = 30 * time.Second + +// websocketConn wraps a JavaScript WebSocket to implement net.Conn +type websocketConn struct { + ws js.Value + remoteAddr string + messages chan []byte + readBuf []byte + ctx context.Context + cancel context.CancelFunc + mu sync.Mutex +} + +func (c *websocketConn) Read(b []byte) (int, error) { + c.mu.Lock() + if len(c.readBuf) > 0 { + n := copy(b, c.readBuf) + c.readBuf = c.readBuf[n:] + c.mu.Unlock() + return n, nil + } + c.mu.Unlock() + + select { + case data := <-c.messages: + n := copy(b, data) + if n < len(data) { + c.mu.Lock() + c.readBuf = data[n:] + c.mu.Unlock() + } + return n, nil + case <-c.ctx.Done(): + return 0, c.ctx.Err() + } +} + +func (c *websocketConn) Write(b []byte) (int, error) { + select { + case <-c.ctx.Done(): + return 0, c.ctx.Err() + default: + } + + uint8Array := js.Global().Get("Uint8Array").New(len(b)) + js.CopyBytesToJS(uint8Array, b) + c.ws.Call("send", uint8Array) + return len(b), nil +} + +func (c *websocketConn) Close() error { + c.cancel() + c.ws.Call("close") + return nil +} + +func (c *websocketConn) LocalAddr() net.Addr { + return nil +} + +func (c *websocketConn) RemoteAddr() net.Addr { + return stringAddr(c.remoteAddr) +} +func (c *websocketConn) SetDeadline(t time.Time) error { + return nil +} + +func (c *websocketConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *websocketConn) SetWriteDeadline(t time.Time) error { + return nil +} + +// stringAddr is a simple net.Addr that returns a string +type stringAddr string + +func (s stringAddr) Network() string { return "tcp" } +func (s stringAddr) String() string { return string(s) } + +// WithWebSocketDialer returns a gRPC dial option that uses WebSocket transport for JS/WASM environments. +func WithWebSocketDialer(tlsEnabled bool) grpc.DialOption { + return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + scheme := "wss" + if !tlsEnabled { + scheme = "ws" + } + wsURL := fmt.Sprintf("%s://%s%s", scheme, addr, wsproxy.ProxyPath) + + ws := js.Global().Get("WebSocket").New(wsURL) + + connCtx, connCancel := context.WithCancel(context.Background()) + conn := &websocketConn{ + ws: ws, + remoteAddr: addr, + messages: make(chan []byte, 100), + ctx: connCtx, + cancel: connCancel, + } + + ws.Set("binaryType", "arraybuffer") + + openCh := make(chan struct{}) + errorCh := make(chan error, 1) + + ws.Set("onopen", js.FuncOf(func(this js.Value, args []js.Value) any { + close(openCh) + return nil + })) + + ws.Set("onerror", js.FuncOf(func(this js.Value, args []js.Value) any { + select { + case errorCh <- wsproxy.ErrConnectionFailed: + default: + } + return nil + })) + + ws.Set("onmessage", js.FuncOf(func(this js.Value, args []js.Value) any { + event := args[0] + data := event.Get("data") + + uint8Array := js.Global().Get("Uint8Array").New(data) + length := uint8Array.Get("length").Int() + bytes := make([]byte, length) + js.CopyBytesToGo(bytes, uint8Array) + + select { + case conn.messages <- bytes: + default: + log.Warnf("gRPC WebSocket message dropped for %s - buffer full", addr) + } + return nil + })) + + ws.Set("onclose", js.FuncOf(func(this js.Value, args []js.Value) any { + conn.cancel() + return nil + })) + + select { + case <-openCh: + return conn, nil + case err := <-errorCh: + return nil, err + case <-ctx.Done(): + ws.Call("close") + return nil, ctx.Err() + case <-time.After(dialTimeout): + ws.Call("close") + return nil, wsproxy.ErrConnectionTimeout + } + }) +} diff --git a/util/wsproxy/constants.go b/util/wsproxy/constants.go new file mode 100644 index 000000000..8d117c7d9 --- /dev/null +++ b/util/wsproxy/constants.go @@ -0,0 +1,13 @@ +package wsproxy + +import "errors" + +// ProxyPath is the standard path where the WebSocket proxy is mounted on servers. +const ProxyPath = "/ws-proxy" + +// Common errors +var ( + ErrConnectionTimeout = errors.New("WebSocket connection timeout") + ErrConnectionFailed = errors.New("WebSocket connection failed") + ErrBackendUnavailable = errors.New("backend unavailable") +) diff --git a/util/wsproxy/server/metrics.go b/util/wsproxy/server/metrics.go new file mode 100644 index 000000000..dd3b96dad --- /dev/null +++ b/util/wsproxy/server/metrics.go @@ -0,0 +1,118 @@ +package server + +import ( + "context" + + log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" +) + +// MetricsRecorder defines the interface for recording proxy metrics +type MetricsRecorder interface { + // RecordConnection records a new connection + RecordConnection(ctx context.Context) + // RecordDisconnection records a connection closing + RecordDisconnection(ctx context.Context) + // RecordBytesTransferred records bytes transferred in a direction + RecordBytesTransferred(ctx context.Context, direction string, bytes int64) + // RecordError records an error + RecordError(ctx context.Context, errorType string) +} + +// NoOpMetricsRecorder is a no-op implementation that does nothing +type NoOpMetricsRecorder struct{} + +func (n NoOpMetricsRecorder) RecordConnection(ctx context.Context) { + // no-op +} +func (n NoOpMetricsRecorder) RecordDisconnection(ctx context.Context) { + // no-op +} +func (n NoOpMetricsRecorder) RecordBytesTransferred(ctx context.Context, direction string, bytes int64) { + // no-op +} +func (n NoOpMetricsRecorder) RecordError(ctx context.Context, errorType string) { + // no-op +} + +// Recorder implements MetricsRecorder using OpenTelemetry +type Recorder struct { + activeConnections metric.Int64UpDownCounter + bytesTransferred metric.Int64Counter + errors metric.Int64Counter +} + +// NewMetricsRecorder creates a new OpenTelemetry-based metrics recorder +func NewMetricsRecorder(meter metric.Meter) (*Recorder, error) { + activeConnections, err := meter.Int64UpDownCounter( + "wsproxy_active_connections", + metric.WithDescription("Number of active WebSocket proxy connections"), + ) + if err != nil { + return nil, err + } + + bytesTransferred, err := meter.Int64Counter( + "wsproxy_bytes_transferred_total", + metric.WithDescription("Total bytes transferred through the proxy"), + ) + if err != nil { + return nil, err + } + + errors, err := meter.Int64Counter( + "wsproxy_errors_total", + metric.WithDescription("Total number of proxy errors"), + ) + if err != nil { + return nil, err + } + + return &Recorder{ + activeConnections: activeConnections, + bytesTransferred: bytesTransferred, + errors: errors, + }, nil +} + +func (o *Recorder) RecordConnection(ctx context.Context) { + o.activeConnections.Add(ctx, 1) +} + +func (o *Recorder) RecordDisconnection(ctx context.Context) { + o.activeConnections.Add(ctx, -1) +} + +func (o *Recorder) RecordBytesTransferred(ctx context.Context, direction string, bytes int64) { + o.bytesTransferred.Add(ctx, bytes, metric.WithAttributes( + attribute.String("direction", direction), + )) +} + +func (o *Recorder) RecordError(ctx context.Context, errorType string) { + o.errors.Add(ctx, 1, metric.WithAttributes( + attribute.String("error_type", errorType), + )) +} + +// Option defines functional options for the Proxy +type Option func(*Config) + +// WithMetrics sets a custom metrics recorder +func WithMetrics(recorder MetricsRecorder) Option { + return func(c *Config) { + c.MetricsRecorder = recorder + } +} + +// WithOTelMeter creates and sets an OpenTelemetry metrics recorder +func WithOTelMeter(meter metric.Meter) Option { + return func(c *Config) { + if recorder, err := NewMetricsRecorder(meter); err == nil { + c.MetricsRecorder = recorder + } else { + log.Warnf("Failed to create OTel metrics recorder: %v", err) + } + } +} diff --git a/util/wsproxy/server/proxy.go b/util/wsproxy/server/proxy.go new file mode 100644 index 000000000..977440a60 --- /dev/null +++ b/util/wsproxy/server/proxy.go @@ -0,0 +1,227 @@ +package server + +import ( + "context" + "errors" + "io" + "net" + "net/http" + "net/netip" + "sync" + "time" + + "github.com/coder/websocket" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/util/wsproxy" +) + +const ( + dialTimeout = 10 * time.Second + bufferSize = 32 * 1024 +) + +// Config contains the configuration for the WebSocket proxy. +type Config struct { + LocalGRPCAddr netip.AddrPort + Path string + MetricsRecorder MetricsRecorder +} + +// Proxy handles WebSocket to TCP proxying for gRPC connections. +type Proxy struct { + config Config + metrics MetricsRecorder +} + +// New creates a new WebSocket proxy instance with optional configuration +func New(localGRPCAddr netip.AddrPort, opts ...Option) *Proxy { + config := Config{ + LocalGRPCAddr: localGRPCAddr, + Path: wsproxy.ProxyPath, + MetricsRecorder: NoOpMetricsRecorder{}, // Default to no-op + } + + for _, opt := range opts { + opt(&config) + } + + return &Proxy{ + config: config, + metrics: config.MetricsRecorder, + } +} + +// Handler returns an http.Handler that proxies WebSocket connections to the local gRPC server. +func (p *Proxy) Handler() http.Handler { + return http.HandlerFunc(p.handleWebSocket) +} + +func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + p.metrics.RecordConnection(ctx) + defer p.metrics.RecordDisconnection(ctx) + + log.Debugf("WebSocket proxy handling connection from %s, forwarding to %s", r.RemoteAddr, p.config.LocalGRPCAddr) + acceptOptions := &websocket.AcceptOptions{ + OriginPatterns: []string{"*"}, + } + + wsConn, err := websocket.Accept(w, r, acceptOptions) + if err != nil { + p.metrics.RecordError(ctx, "websocket_accept_failed") + log.Errorf("WebSocket upgrade failed from %s: %v", r.RemoteAddr, err) + return + } + defer func() { + if err := wsConn.Close(websocket.StatusNormalClosure, ""); err != nil { + log.Debugf("Failed to close WebSocket: %v", err) + } + }() + + log.Debugf("WebSocket proxy attempting to connect to local gRPC at %s", p.config.LocalGRPCAddr) + tcpConn, err := net.DialTimeout("tcp", p.config.LocalGRPCAddr.String(), dialTimeout) + if err != nil { + p.metrics.RecordError(ctx, "tcp_dial_failed") + log.Warnf("Failed to connect to local gRPC server at %s: %v", p.config.LocalGRPCAddr, err) + if err := wsConn.Close(websocket.StatusInternalError, "Backend unavailable"); err != nil { + log.Debugf("Failed to close WebSocket after connection failure: %v", err) + } + return + } + defer func() { + if err := tcpConn.Close(); err != nil { + log.Debugf("Failed to close TCP connection: %v", err) + } + }() + + log.Debugf("WebSocket proxy established: client %s -> local gRPC %s", r.RemoteAddr, p.config.LocalGRPCAddr) + + p.proxyData(ctx, wsConn, tcpConn) +} + +func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, tcpConn net.Conn) { + proxyCtx, cancel := context.WithCancel(ctx) + defer cancel() + + var wg sync.WaitGroup + wg.Add(2) + + go p.wsToTCP(proxyCtx, cancel, &wg, wsConn, tcpConn) + go p.tcpToWS(proxyCtx, cancel, &wg, wsConn, tcpConn) + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + log.Tracef("Proxy data transfer completed, both goroutines terminated") + case <-proxyCtx.Done(): + log.Tracef("Proxy data transfer cancelled, forcing connection closure") + + if err := wsConn.Close(websocket.StatusGoingAway, "proxy cancelled"); err != nil { + log.Tracef("Error closing WebSocket during cancellation: %v", err) + } + if err := tcpConn.Close(); err != nil { + log.Tracef("Error closing TCP connection during cancellation: %v", err) + } + + select { + case <-done: + log.Tracef("Goroutines terminated after forced connection closure") + case <-time.After(2 * time.Second): + log.Tracef("Goroutines did not terminate within timeout after connection closure") + } + } +} + +func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) { + defer wg.Done() + defer cancel() + + for { + msgType, data, err := wsConn.Read(ctx) + if err != nil { + switch { + case ctx.Err() != nil: + log.Debugf("wsToTCP goroutine terminating due to context cancellation") + case websocket.CloseStatus(err) == websocket.StatusNormalClosure: + log.Debugf("WebSocket closed normally") + default: + p.metrics.RecordError(ctx, "websocket_read_error") + log.Errorf("WebSocket read error: %v", err) + } + return + } + + if msgType != websocket.MessageBinary { + log.Warnf("Unexpected WebSocket message type: %v", msgType) + continue + } + + if ctx.Err() != nil { + log.Tracef("wsToTCP goroutine terminating due to context cancellation before TCP write") + return + } + + if err := tcpConn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil { + log.Debugf("Failed to set TCP write deadline: %v", err) + } + + n, err := tcpConn.Write(data) + if err != nil { + p.metrics.RecordError(ctx, "tcp_write_error") + log.Errorf("TCP write error: %v", err) + return + } + + p.metrics.RecordBytesTransferred(ctx, "ws_to_tcp", int64(n)) + } +} + +func (p *Proxy) tcpToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) { + defer wg.Done() + defer cancel() + + buf := make([]byte, bufferSize) + for { + if err := tcpConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { + log.Debugf("Failed to set TCP read deadline: %v", err) + } + n, err := tcpConn.Read(buf) + + if err != nil { + if ctx.Err() != nil { + log.Tracef("tcpToWS goroutine terminating due to context cancellation") + return + } + + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + continue + } + + if err != io.EOF { + log.Errorf("TCP read error: %v", err) + } + return + } + + if ctx.Err() != nil { + log.Tracef("tcpToWS goroutine terminating due to context cancellation before WebSocket write") + return + } + + if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil { + p.metrics.RecordError(ctx, "websocket_write_error") + log.Errorf("WebSocket write error: %v", err) + return + } + + p.metrics.RecordBytesTransferred(ctx, "tcp_to_ws", int64(n)) + } +} From 4d7e59f199ce72e8b5d84f0953cf752181e3c923 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 2 Oct 2025 00:10:47 +0200 Subject: [PATCH 26/30] [client,signal,management] Adjust browser client ws proxy paths (#4565) --- client/grpc/dialer.go | 7 ++++--- client/grpc/dialer_generic.go | 2 +- client/grpc/dialer_js.go | 5 +++-- flow/client/client.go | 3 ++- management/internals/server/server.go | 2 +- shared/management/client/grpc.go | 3 ++- shared/signal/client/grpc.go | 3 ++- signal/cmd/run.go | 2 +- util/wsproxy/client/dialer_js.go | 5 +++-- util/wsproxy/constants.go | 9 ++++++++- 10 files changed, 27 insertions(+), 14 deletions(-) diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 7cb38fbff..54fbb002c 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -25,8 +25,9 @@ func Backoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(b, ctx) } -// CreateConnection creates a gRPC client connection with the appropriate transport options -func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) { +// CreateConnection creates a gRPC client connection with the appropriate transport options. +// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). +func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) if tlsEnabled { certPool, err := x509.SystemCertPool() @@ -49,7 +50,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc. connCtx, addr, transportOption, - WithCustomDialer(tlsEnabled), + WithCustomDialer(tlsEnabled, component), grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/client/grpc/dialer_generic.go b/client/grpc/dialer_generic.go index a0d6cee0b..96f347c64 100644 --- a/client/grpc/dialer_generic.go +++ b/client/grpc/dialer_generic.go @@ -18,7 +18,7 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) -func WithCustomDialer(tlsEnabled bool) grpc.DialOption { +func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { if runtime.GOOS == "linux" { currentUser, err := user.Current() diff --git a/client/grpc/dialer_js.go b/client/grpc/dialer_js.go index e132c0098..b89ec3c21 100644 --- a/client/grpc/dialer_js.go +++ b/client/grpc/dialer_js.go @@ -7,6 +7,7 @@ import ( ) // WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments. -func WithCustomDialer(tlsEnabled bool) grpc.DialOption { - return client.WithWebSocketDialer(tlsEnabled) +// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). +func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { + return client.WithWebSocketDialer(tlsEnabled, component) } diff --git a/flow/client/client.go b/flow/client/client.go index 03a4accaf..318fcfe1e 100644 --- a/flow/client/client.go +++ b/flow/client/client.go @@ -23,6 +23,7 @@ import ( nbgrpc "github.com/netbirdio/netbird/client/grpc" "github.com/netbirdio/netbird/flow/proto" "github.com/netbirdio/netbird/util/embeddedroots" + "github.com/netbirdio/netbird/util/wsproxy" ) type GRPCClient struct { @@ -54,7 +55,7 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl } opts = append(opts, - nbgrpc.WithCustomDialer(tlsEnabled), + nbgrpc.WithCustomDialer(tlsEnabled, wsproxy.FlowComponent), grpc.WithIdleTimeout(interval*2), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, diff --git a/management/internals/server/server.go b/management/internals/server/server.go index ae9ac4a60..94c633fc6 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -259,7 +259,7 @@ func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Hand case request.ProtoMajor == 2 && (strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") || strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto")): gRPCHandler.ServeHTTP(writer, request) - case request.URL.Path == wsproxy.ProxyPath: + case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent: wsProxy.Handler().ServeHTTP(writer, request) default: httpHandler.ServeHTTP(writer, request) diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index f30e965be..076f2532b 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -22,6 +22,7 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/util/wsproxy" ) const ConnectTimeout = 10 * time.Second @@ -52,7 +53,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled) + conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent) if err != nil { log.Printf("createConnection error: %v", err) return err diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 5ca0c0282..31f3372c0 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/shared/management/client" "github.com/netbirdio/netbird/shared/signal/proto" + "github.com/netbirdio/netbird/util/wsproxy" ) // ConnStateNotifier is a wrapper interface of the status recorder @@ -57,7 +58,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled) + conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent) if err != nil { log.Printf("createConnection error: %v", err) return err diff --git a/signal/cmd/run.go b/signal/cmd/run.go index e2a69a75b..696c44723 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -258,7 +258,7 @@ func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { - case r.URL.Path == wsproxy.ProxyPath: + case r.URL.Path == wsproxy.ProxyPath+wsproxy.SignalComponent: wsProxy.Handler().ServeHTTP(w, r) default: grpcServer.ServeHTTP(w, r) diff --git a/util/wsproxy/client/dialer_js.go b/util/wsproxy/client/dialer_js.go index 2caeed025..bd50f51b5 100644 --- a/util/wsproxy/client/dialer_js.go +++ b/util/wsproxy/client/dialer_js.go @@ -96,13 +96,14 @@ func (s stringAddr) Network() string { return "tcp" } func (s stringAddr) String() string { return string(s) } // WithWebSocketDialer returns a gRPC dial option that uses WebSocket transport for JS/WASM environments. -func WithWebSocketDialer(tlsEnabled bool) grpc.DialOption { +// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). +func WithWebSocketDialer(tlsEnabled bool, component string) grpc.DialOption { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { scheme := "wss" if !tlsEnabled { scheme = "ws" } - wsURL := fmt.Sprintf("%s://%s%s", scheme, addr, wsproxy.ProxyPath) + wsURL := fmt.Sprintf("%s://%s%s%s", scheme, addr, wsproxy.ProxyPath, component) ws := js.Global().Get("WebSocket").New(wsURL) diff --git a/util/wsproxy/constants.go b/util/wsproxy/constants.go index 8d117c7d9..a31c0fbc8 100644 --- a/util/wsproxy/constants.go +++ b/util/wsproxy/constants.go @@ -2,9 +2,16 @@ package wsproxy import "errors" -// ProxyPath is the standard path where the WebSocket proxy is mounted on servers. +// ProxyPath is the base path where the WebSocket proxy is mounted on servers. const ProxyPath = "/ws-proxy" +// Component paths that are appended to ProxyPath +const ( + ManagementComponent = "/management" + SignalComponent = "/signal" + FlowComponent = "/flow" +) + // Common errors var ( ErrConnectionTimeout = errors.New("WebSocket connection timeout") From b85045e723b251e30d3f22e14da2bce817b7767e Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 1 Oct 2025 19:52:54 -0300 Subject: [PATCH 27/30] [misc] Update infra scripts with ws proxy for browser client (#4566) * Update infra scripts with ws proxy for browser client * add ws proxy to nginx tmpl --- infrastructure_files/docker-compose.yml.tmpl.traefik | 7 ++++++- infrastructure_files/getting-started-with-zitadel.sh | 2 ++ infrastructure_files/nginx.tmpl.conf | 8 ++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik index 08749a4f7..fb01e6867 100644 --- a/infrastructure_files/docker-compose.yml.tmpl.traefik +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -45,6 +45,9 @@ services: - $SIGNAL_VOLUMENAME:/var/lib/netbird labels: - traefik.enable=true + - traefik.http.routers.netbird-wsproxy-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/signal`) + - traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal + - traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=10000 - traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`) - traefik.http.services.netbird-signal.loadbalancer.server.port=10000 - traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c @@ -87,7 +90,9 @@ services: - traefik.http.routers.netbird-api.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/api`) - traefik.http.routers.netbird-api.service=netbird-api - traefik.http.services.netbird-api.loadbalancer.server.port=33073 - + - traefik.http.routers.netbird-wsproxy-mgmt.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/management`) + - traefik.http.routers.netbird-wsproxy-mgmt.service=netbird-wsproxy-mgmt + - traefik.http.services.netbird-wsproxy-mgmt.loadbalancer.server.port=33073 - traefik.http.routers.netbird-management.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/management.ManagementService/`) - traefik.http.routers.netbird-management.service=netbird-management - traefik.http.services.netbird-management.loadbalancer.server.port=33073 diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index cfec1000e..be9662345 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -621,9 +621,11 @@ renderCaddyfile() { # relay reverse_proxy /relay* relay:80 # Signal + reverse_proxy /ws-proxy/signal* signal:10000 reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000 # Management reverse_proxy /api/* management:80 + reverse_proxy /ws-proxy/management* management:80 reverse_proxy /management.ManagementService/* h2c://management:80 # Zitadel reverse_proxy /zitadel.admin.v1.AdminService/* h2c://zitadel:8080 diff --git a/infrastructure_files/nginx.tmpl.conf b/infrastructure_files/nginx.tmpl.conf index f7fa4a9d0..fbd892c29 100644 --- a/infrastructure_files/nginx.tmpl.conf +++ b/infrastructure_files/nginx.tmpl.conf @@ -52,6 +52,10 @@ server { location / { proxy_pass http://dashboard; } + # Proxy Signal wsproxy endpoint + location /ws-proxy/signal { + proxy_pass http://signal; + } # Proxy Signal location /signalexchange.SignalExchange/ { grpc_pass grpc://signal; @@ -64,6 +68,10 @@ server { location /api { proxy_pass http://management; } + # Proxy Management wsproxy endpoint + location /ws-proxy/management { + proxy_pass http://management; + } # Proxy Management grpc endpoint location /management.ManagementService/ { grpc_pass grpc://management; From 9bcd3ebed4c6e9144b939a3ed6555e9d6a0f0be4 Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Thu, 2 Oct 2025 06:02:10 +0700 Subject: [PATCH 28/30] [management,client] Make DNS ForwarderPort Configurable & Change Well Known Port (#4479) makes the DNS forwarder port configurable in the management and client components, while changing the well-known port from 5454 to 22054. The change includes version-aware port assignment to ensure backward compatibility. - Adds a configurable `ForwarderPort` field to the DNS configuration protocol - Implements version-based port computation that returns the new port (22054) only when all peers support version 0.59.0 or newer - Updates the client to dynamically restart the DNS forwarder when the port changes --- client/internal/dnsfwd/manager.go | 33 +- client/internal/engine.go | 35 +- client/internal/netflow/logger/logger.go | 2 +- .../routemanager/dnsinterceptor/handler.go | 4 +- go.mod | 2 +- management/server/dns.go | 44 ++- management/server/dns_test.go | 132 +++++++- management/server/grpcserver.go | 23 +- management/server/peer.go | 21 +- management/server/peer/peer.go | 11 +- management/server/peer_test.go | 3 +- shared/management/proto/management.pb.go | 304 +++++++++--------- shared/management/proto/management.proto | 1 + 13 files changed, 416 insertions(+), 199 deletions(-) diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index bf2ee839b..5c7a3fbdd 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "sync" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" @@ -11,14 +12,18 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" +) + +var ( + // ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also + listenPort uint16 = 5353 + listenPortMu sync.RWMutex ) const ( - // ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also - ListenPort = 5353 - dnsTTL = 60 //seconds + dnsTTL = 60 //seconds ) // ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list. @@ -35,12 +40,20 @@ type Manager struct { fwRules []firewall.Rule tcpRules []firewall.Rule dnsForwarder *DNSForwarder + port uint16 } -func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager { +func ListenPort() uint16 { + listenPortMu.RLock() + defer listenPortMu.RUnlock() + return listenPort +} + +func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager { return &Manager{ firewall: fw, statusRecorder: statusRecorder, + port: port, } } @@ -54,7 +67,13 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.firewall, m.statusRecorder) + if m.port > 0 { + listenPortMu.Lock() + listenPort = m.port + listenPortMu.Unlock() + } + + m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder) go func() { if err := m.dnsForwarder.Listen(fwdEntries); err != nil { // todo handle close error if it is exists @@ -94,7 +113,7 @@ func (m *Manager) Stop(ctx context.Context) error { func (m *Manager) allowDNSFirewall() error { dport := &firewall.Port{ IsRange: false, - Values: []uint16{ListenPort}, + Values: []uint16{ListenPort()}, } if m.firewall == nil { diff --git a/client/internal/engine.go b/client/internal/engine.go index 3fa0b58a8..646e059d4 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -202,6 +202,9 @@ type Engine struct { // WireGuard interface monitor wgIfaceMonitor *WGIfaceMonitor wgIfaceMonitorWg sync.WaitGroup + + // dns forwarder port + dnsFwdPort uint16 } // Peer is an instance of the Connection Peer @@ -244,6 +247,7 @@ func NewEngine( statusRecorder: statusRecorder, checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), + dnsFwdPort: dnsfwd.ListenPort(), } sm := profilemanager.NewServiceManager("") @@ -1080,7 +1084,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes) - e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries) + e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort)) // Ingress forward rules forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules()) @@ -1839,6 +1843,7 @@ func (e *Engine) GetWgAddr() netip.Addr { func (e *Engine) updateDNSForwarder( enabled bool, fwdEntries []*dnsfwd.ForwarderEntry, + forwarderPort uint16, ) { if e.config.DisableServerRoutes { return @@ -1855,16 +1860,20 @@ func (e *Engine) updateDNSForwarder( } if len(fwdEntries) > 0 { - if e.dnsForwardMgr == nil { - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) - + switch { + case e.dnsForwardMgr == nil: + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort) if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil } - log.Infof("started domain router service with %d entries", len(fwdEntries)) - } else { + case e.dnsFwdPort != forwarderPort: + log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) + e.restartDnsFwd(fwdEntries, forwarderPort) + e.dnsFwdPort = forwarderPort + + default: e.dnsForwardMgr.UpdateDomains(fwdEntries) } } else if e.dnsForwardMgr != nil { @@ -1874,6 +1883,20 @@ func (e *Engine) updateDNSForwarder( } e.dnsForwardMgr = nil } + +} + +func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) { + log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) + // stop and start the forwarder to apply the new port + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort) + if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { + log.Errorf("failed to start DNS forward: %v", err) + e.dnsForwardMgr = nil + } } func (e *Engine) GetNet() (*netstack.Net, error) { diff --git a/client/internal/netflow/logger/logger.go b/client/internal/netflow/logger/logger.go index e28fdf2f4..899faf108 100644 --- a/client/internal/netflow/logger/logger.go +++ b/client/internal/netflow/logger/logger.go @@ -138,7 +138,7 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) { func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool { // check dns collection - if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == dnsfwd.ListenPort) { + if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) { return false } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 9069cdcc5..47c2ffcda 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -24,8 +24,8 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/fakeip" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) const dnsTimeout = 8 * time.Second @@ -257,7 +257,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { r.MsgHdr.AuthenticatedData = true } - upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) + upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort()) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) defer cancel() diff --git a/go.mod b/go.mod index c4b629993..a1560b409 100644 --- a/go.mod +++ b/go.mod @@ -102,6 +102,7 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a + golang.org/x/mod v0.25.0 golang.org/x/net v0.42.0 golang.org/x/oauth2 v0.28.0 golang.org/x/sync v0.16.0 @@ -243,7 +244,6 @@ require ( go.uber.org/mock v0.4.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect - golang.org/x/mod v0.25.0 // indirect golang.org/x/text v0.27.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.34.0 // indirect diff --git a/management/server/dns.go b/management/server/dns.go index 6b73dbd0e..534f43ec6 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -6,9 +6,11 @@ import ( "sync" log "github.com/sirupsen/logrus" + "golang.org/x/mod/semver" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" + nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" @@ -18,6 +20,13 @@ import ( "github.com/netbirdio/netbird/shared/management/status" ) +const ( + dnsForwarderPort = 22054 + oldForwarderPort = 5353 +) + +const dnsForwarderPortMinVersion = "v0.59.0" + // DNSConfigCache is a thread-safe cache for DNS configuration components type DNSConfigCache struct { NameServerGroups sync.Map @@ -183,12 +192,45 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID return validateGroups(settings.DisabledManagementGroups, groups) } +// computeForwarderPort checks if all peers in the account have updated to a specific version or newer. +// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0. +func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { + if len(peers) == 0 { + return oldForwarderPort + } + + reqVer := semver.Canonical(requiredVersion) + + // Check if all peers have the required version or newer + for _, peer := range peers { + + // Development version is always supported + if peer.Meta.WtVersion == "development" { + continue + } + peerVersion := semver.Canonical("v" + peer.Meta.WtVersion) + if peerVersion == "" { + // If any peer doesn't have version info, return 0 + return oldForwarderPort + } + + // Compare versions + if semver.Compare(peerVersion, reqVer) < 0 { + return oldForwarderPort + } + } + + // All peers have the required version or newer + return dnsForwarderPort +} + // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache -func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig { +func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache, forwardPort int64) *proto.DNSConfig { protoUpdate := &proto.DNSConfig{ ServiceEnable: update.ServiceEnable, CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)), NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)), + ForwarderPort: forwardPort, } for _, zone := range update.CustomZones { diff --git a/management/server/dns_test.go b/management/server/dns_test.go index a2a2ce529..83caf74ef 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -21,7 +21,6 @@ import ( "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/shared/management/status" @@ -324,13 +323,13 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account return nil, err } - account.NameServerGroups[dnsNSGroup1] = &dns.NameServerGroup{ + account.NameServerGroups[dnsNSGroup1] = &nbdns.NameServerGroup{ ID: dnsNSGroup1, Name: "ns-group-1", - NameServers: []dns.NameServer{{ + NameServers: []nbdns.NameServer{{ IP: netip.MustParseAddr(savedPeer1.IP.String()), - NSType: dns.UDPNameServerType, - Port: dns.DefaultDNSPort, + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, }}, Primary: true, Enabled: true, @@ -395,7 +394,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - toProtocolDNSConfig(testData, cache) + toProtocolDNSConfig(testData, cache, dnsForwarderPort) } }) @@ -403,7 +402,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { cache := &DNSConfigCache{} - toProtocolDNSConfig(testData, cache) + toProtocolDNSConfig(testData, cache, dnsForwarderPort) } }) } @@ -456,13 +455,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { } // First run with config1 - result1 := toProtocolDNSConfig(config1, &cache) + result1 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) // Second run with config2 - result2 := toProtocolDNSConfig(config2, &cache) + result2 := toProtocolDNSConfig(config2, &cache, dnsForwarderPort) // Third run with config1 again - result3 := toProtocolDNSConfig(config1, &cache) + result3 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) // Verify that result1 and result3 are identical if !reflect.DeepEqual(result1, result3) { @@ -483,6 +482,107 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { } } +func TestComputeForwarderPort(t *testing.T) { + // Test with empty peers list + peers := []*nbpeer.Peer{} + result := computeForwarderPort(peers, "v0.59.0") + if result != oldForwarderPort { + t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result) + } + + // Test with peers that have old versions + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.57.0", + }, + }, + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.26.0", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != oldForwarderPort { + t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result) + } + + // Test with peers that have new versions + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + }, + }, + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != dnsForwarderPort { + t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result) + } + + // Test with peers that have mixed versions + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + }, + }, + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.57.0", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != oldForwarderPort { + t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result) + } + + // Test with peers that have empty version + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != oldForwarderPort { + t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result) + } + + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "development", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result == oldForwarderPort { + t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result) + } + + // Test with peers that have unknown version string + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "unknown", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != oldForwarderPort { + t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result) + } +} + func TestDNSAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) @@ -534,10 +634,10 @@ func TestDNSAccountPeersUpdate(t *testing.T) { }() _, err = manager.CreateNameServerGroup( - context.Background(), account.Id, "ns-group", "ns-group", []dns.NameServer{{ + context.Background(), account.Id, "ns-group", "ns-group", []nbdns.NameServer{{ IP: netip.MustParseAddr(peer1.IP.String()), - NSType: dns.UDPNameServerType, - Port: dns.DefaultDNSPort, + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, }}, []string{"groupB"}, true, []string{}, true, userID, false, @@ -567,10 +667,10 @@ func TestDNSAccountPeersUpdate(t *testing.T) { }() _, err = manager.CreateNameServerGroup( - context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{ + context.Background(), account.Id, "ns-group-1", "ns-group-1", []nbdns.NameServer{{ IP: netip.MustParseAddr(peer1.IP.String()), - NSType: dns.UDPNameServerType, - Port: dns.DefaultDNSPort, + NSType: nbdns.UDPNameServerType, + Port: nbdns.DefaultDNSPort, }}, []string{"groupA"}, true, []string{}, true, userID, false, diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 1177eefff..12b59b691 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -715,13 +715,13 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set } } -func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string) *proto.SyncResponse { +func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { response := &proto.SyncResponse{ PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), NetworkMap: &proto.NetworkMap{ Serial: networkMap.Network.CurrentSerial(), Routes: toProtocolRoutes(networkMap.Routes), - DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache), + DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort), }, Checks: toProtocolChecks(ctx, checks), } @@ -732,11 +732,11 @@ func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.P response.NetworkMap.PeerConfig = response.PeerConfig - allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) - allPeers = appendRemotePeerConfig(allPeers, networkMap.Peers, dnsName) - response.RemotePeers = allPeers - response.NetworkMap.RemotePeers = allPeers - response.RemotePeersIsEmpty = len(allPeers) == 0 + remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) + remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName) + response.RemotePeers = remotePeers + response.NetworkMap.RemotePeers = remotePeers + response.RemotePeersIsEmpty = len(remotePeers) == 0 response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName) @@ -808,7 +808,14 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p return status.Errorf(codes.Internal, "failed to get peer groups %s", err) } - plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups) + // Get all peers in the account for forwarder port computation + allPeers, err := s.accountManager.GetStore().GetAccountPeers(ctx, store.LockingStrengthNone, peer.AccountID, "", "") + if err != nil { + return fmt.Errorf("get account peers: %w", err) + } + dnsFwdPort := computeForwarderPort(allPeers, dnsForwarderPortMinVersion) + + plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) if err != nil { diff --git a/management/server/peer.go b/management/server/peer.go index ea4617af0..4cf5d1e46 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -729,7 +729,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy var peer *nbpeer.Peer var peerNotValid bool var isStatusChanged bool - var updated bool + var updated, versionChanged bool var err error var postureChecks []*posture.Checks @@ -769,7 +769,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy return err } - updated = peer.UpdateMetaIfNew(sync.Meta) + updated, versionChanged = peer.UpdateMetaIfNew(sync.Meta) if updated { am.metrics.AccountManagerMetrics().CountPeerMetUpdate() log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID) @@ -788,7 +788,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy return nil, nil, nil, err } - if isStatusChanged || sync.UpdateAccountPeers || (updated && len(postureChecks) > 0) { + if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) { am.BufferUpdateAccountPeers(ctx, accountID) } @@ -880,7 +880,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return err } - isPeerUpdated = peer.UpdateMetaIfNew(login.Meta) + isPeerUpdated, _ = peer.UpdateMetaIfNew(login.Meta) if isPeerUpdated { am.metrics.AccountManagerMetrics().CountPeerMetUpdate() shouldStorePeer = true @@ -1229,6 +1229,8 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account return } + dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion) + for _, peer := range account.Peers { if !am.peersUpdateManager.HasChannel(peer.ID) { log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) @@ -1265,7 +1267,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account peerGroups := account.GetPeerGroups(p.ID) start = time.Now() - update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups)) + update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start)) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) @@ -1376,7 +1378,9 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI } peerGroups := account.GetPeerGroups(peerId) - update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups)) + dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion) + + update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) } @@ -1549,6 +1553,8 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto return nil, err } + dnsFwdPort := computeForwarderPort(peers, dnsForwarderPortMinVersion) + for _, peer := range peers { if err := transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil { return nil, fmt.Errorf("failed to remove peer %s from groups", peer.ID) @@ -1592,6 +1598,9 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto RemotePeersIsEmpty: true, FirewallRules: []*proto.FirewallRule{}, FirewallRulesIsEmpty: true, + DNSConfig: &proto.DNSConfig{ + ForwarderPort: dnsFwdPort, + }, }, }, NetworkMap: &types.NetworkMap{}, diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index f89f10dac..a898fd782 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -233,21 +233,24 @@ func (p *Peer) Copy() *Peer { // UpdateMetaIfNew updates peer's system metadata if new information is provided // returns true if meta was updated, false otherwise -func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) bool { +func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) (updated, versionChanged bool) { if meta.isEmpty() { - return false + return updated, versionChanged } + versionChanged = p.Meta.WtVersion != meta.WtVersion + // Avoid overwriting UIVersion if the update was triggered sole by the CLI client if meta.UIVersion == "" { meta.UIVersion = p.Meta.UIVersion } if p.Meta.isEqual(meta) { - return false + return updated, versionChanged } p.Meta = meta - return true + updated = true + return updated, versionChanged } // GetLastLogin returns the last login time of the peer. diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 734536d7b..42b3244ae 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1161,7 +1161,7 @@ func TestToSyncResponse(t *testing.T) { } dnsCache := &DNSConfigCache{} accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true} - response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}) + response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, dnsForwarderPort) assert.NotNil(t, response) // assert peer config @@ -1212,6 +1212,7 @@ func TestToSyncResponse(t *testing.T) { assert.Equal(t, "route1", response.NetworkMap.Routes[0].NetID) // assert network map DNSConfig assert.Equal(t, true, response.NetworkMap.DNSConfig.ServiceEnable) + assert.Equal(t, int64(dnsForwarderPort), response.NetworkMap.DNSConfig.ForwarderPort) assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones)) assert.Equal(t, 2, len(response.NetworkMap.DNSConfig.NameServerGroups)) // assert network map DNSConfig.CustomZones diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 8381d6682..0de00ec0c 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -7,12 +7,13 @@ package proto import ( + reflect "reflect" + sync "sync" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" durationpb "google.golang.org/protobuf/types/known/durationpb" timestamppb "google.golang.org/protobuf/types/known/timestamppb" - reflect "reflect" - sync "sync" ) const ( @@ -2491,6 +2492,7 @@ type DNSConfig struct { ServiceEnable bool `protobuf:"varint,1,opt,name=ServiceEnable,proto3" json:"ServiceEnable,omitempty"` NameServerGroups []*NameServerGroup `protobuf:"bytes,2,rep,name=NameServerGroups,proto3" json:"NameServerGroups,omitempty"` CustomZones []*CustomZone `protobuf:"bytes,3,rep,name=CustomZones,proto3" json:"CustomZones,omitempty"` + ForwarderPort int64 `protobuf:"varint,4,opt,name=ForwarderPort,proto3" json:"ForwarderPort,omitempty"` } func (x *DNSConfig) Reset() { @@ -2546,6 +2548,13 @@ func (x *DNSConfig) GetCustomZones() []*CustomZone { return nil } +func (x *DNSConfig) GetForwarderPort() int64 { + if x != nil { + return x.ForwarderPort + } + return 0 +} + // CustomZone represents a dns.CustomZone type CustomZone struct { state protoimpl.MessageState @@ -3721,7 +3730,7 @@ var file_management_proto_rawDesc = []byte{ 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x73, 0x6b, 0x69, 0x70, 0x41, 0x75, - 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xb4, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, + 0x74, 0x6f, 0x41, 0x70, 0x70, 0x6c, 0x79, 0x22, 0xda, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, @@ -3732,157 +3741,160 @@ var file_management_proto_rawDesc = []byte{ 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, - 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, - 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, - 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, - 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, - 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, - 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, - 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, - 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, - 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, - 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, - 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, - 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, - 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, - 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, - 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, - 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, - 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, - 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, - 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, - 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, - 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, - 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, - 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, - 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, - 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, - 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, - 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, - 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, - 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, - 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, - 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, - 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, - 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, - 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, - 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, - 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, - 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, - 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, - 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, - 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, - 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, - 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, - 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, - 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, - 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, - 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, - 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, - 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, - 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, - 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, - 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, - 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, - 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, - 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, - 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, - 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, - 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, - 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, - 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, - 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, - 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, - 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, - 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, - 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, - 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, - 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, - 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, - 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, - 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, - 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, - 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, - 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, - 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, - 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, - 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, - 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, - 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x32, 0xcd, - 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, - 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, - 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, - 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, - 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x12, 0x24, + 0x0a, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0d, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x65, 0x72, + 0x50, 0x6f, 0x72, 0x74, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, + 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, + 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, + 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, + 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, + 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, + 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, + 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, + 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, + 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, + 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, + 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, + 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, + 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, + 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, + 0x50, 0x6f, 0x72, 0x74, 0x22, 0xa7, 0x02, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, + 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, + 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, + 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, + 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, + 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, + 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, + 0x12, 0x30, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x06, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, + 0x66, 0x6f, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x18, 0x07, + 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, 0x44, 0x22, 0x38, + 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, + 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, + 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, + 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, + 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, + 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, + 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, + 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, + 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, + 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, + 0x6e, 0x22, 0x87, 0x03, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, + 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, + 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, + 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, + 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, + 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, + 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, + 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, + 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, + 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, + 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, + 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, + 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, + 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x07, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x26, 0x0a, + 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, + 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, + 0x44, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x49, + 0x44, 0x12, 0x18, 0x0a, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x18, 0x0a, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x07, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x22, 0xf2, 0x01, 0x0a, 0x0e, + 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x34, + 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, + 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x3e, 0x0a, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, + 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, + 0x6e, 0x66, 0x6f, 0x52, 0x0f, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x50, 0x6f, 0x72, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, + 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, + 0x11, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x41, 0x64, 0x64, 0x72, 0x65, + 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, + 0x50, 0x6f, 0x72, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, + 0x52, 0x0e, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x6c, 0x61, 0x74, 0x65, 0x64, 0x50, 0x6f, 0x72, 0x74, + 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, + 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, + 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, + 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, + 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, + 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, + 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, + 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, + 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, + 0x4f, 0x50, 0x10, 0x01, 0x32, 0xcd, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, - 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, - 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, - 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, + 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, + 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, + 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, + 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, + 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, + 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, + 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, - 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, + 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, - 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, - 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, - 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x3b, 0x0a, 0x06, 0x4c, 0x6f, 0x67, 0x6f, 0x75, + 0x74, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, + 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, + 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index dcdd387b4..ad82d37d9 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -410,6 +410,7 @@ message DNSConfig { bool ServiceEnable = 1; repeated NameServerGroup NameServerGroups = 2; repeated CustomZone CustomZones = 3; + int64 ForwarderPort = 4; } // CustomZone represents a dns.CustomZone From 95794f53ce9454f18f73f50c67223ece2b50e68d Mon Sep 17 00:00:00 2001 From: hakansa <43675540+hakansa@users.noreply.github.com> Date: Thu, 2 Oct 2025 17:42:25 +0700 Subject: [PATCH 29/30] [client] fix Windows NRPT Policy Path (#4572) [client] fix Windows NRPT Policy Path --- client/internal/dns/host_windows.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index 0d3f033fb..a14a01f40 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -240,15 +240,17 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 for i, domain := range domains { + localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i) + gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i) singleDomain := []string{domain} - if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, singleDomain, ip); err != nil { + if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil { return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err) } if r.gpo { - if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, singleDomain, ip); err != nil { + if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil { return i, fmt.Errorf("configure gpo DNS policy: %w", err) } } From e7b5537dcc280384470668f461bbb1f7d2f41218 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 2 Oct 2025 13:51:39 +0200 Subject: [PATCH 30/30] Add websocket paths including relay to nginx template (#4573) --- infrastructure_files/nginx.tmpl.conf | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/infrastructure_files/nginx.tmpl.conf b/infrastructure_files/nginx.tmpl.conf index fbd892c29..46cb195e7 100644 --- a/infrastructure_files/nginx.tmpl.conf +++ b/infrastructure_files/nginx.tmpl.conf @@ -20,6 +20,10 @@ upstream management { # insert the grpc+http port of your management container here server 127.0.0.1:8012; } +upstream relay { + # insert the port of your relay container here + server 127.0.0.1:33080; +} server { # HTTP server config @@ -55,6 +59,10 @@ server { # Proxy Signal wsproxy endpoint location /ws-proxy/signal { proxy_pass http://signal; + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "Upgrade"; + proxy_set_header Host $host; } # Proxy Signal location /signalexchange.SignalExchange/ { @@ -71,6 +79,10 @@ server { # Proxy Management wsproxy endpoint location /ws-proxy/management { proxy_pass http://management; + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "Upgrade"; + proxy_set_header Host $host; } # Proxy Management grpc endpoint location /management.ManagementService/ { @@ -80,6 +92,14 @@ server { grpc_send_timeout 1d; grpc_socket_keepalive on; } + # Proxy Relay + location /relay { + proxy_pass http://relay; + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "Upgrade"; + proxy_set_header Host $host; + } ssl_certificate /etc/ssl/certs/ssl-cert-snakeoil.pem; ssl_certificate_key /etc/ssl/certs/ssl-cert-snakeoil.pem;