From fbb1b55beb95d3106e7a279931a69bac34f0af1f Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 4 Jul 2025 19:52:27 +0200 Subject: [PATCH 01/12] [client] refactor lazy detection (#4050) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR introduces a new inactivity package responsible for monitoring peer activity and notifying when peers become inactive. Introduces a new Signal message type to close the peer connection after the idle timeout is reached. Periodically checks the last activity of registered peers via a Bind interface. Notifies via a channel when peers exceed a configurable inactivity threshold. Default settings DefaultInactivityThreshold is set to 15 minutes, with a minimum allowed threshold of 1 minute. Limitations This inactivity check does not support kernel WireGuard integration. In kernel–user space communication, the user space side will always be responsible for closing the connection. --- client/iface/bind/activity.go | 94 +++++ client/iface/bind/activity_test.go | 27 ++ client/iface/bind/ice_bind.go | 57 ++- client/iface/configurer/kernel_unix.go | 4 + client/iface/configurer/usp.go | 36 +- client/iface/device/device_android.go | 2 +- client/iface/device/device_darwin.go | 2 +- client/iface/device/device_ios.go | 2 +- client/iface/device/device_netstack.go | 2 +- client/iface/device/device_usp_unix.go | 2 +- client/iface/device/device_windows.go | 2 +- client/iface/device/interface.go | 1 + client/iface/iface.go | 8 + client/internal/conn_mgr.go | 66 ++-- client/internal/engine.go | 32 +- client/internal/engine_test.go | 15 +- client/internal/iface_common.go | 1 + client/internal/lazyconn/activity/listener.go | 4 +- client/internal/lazyconn/activity/manager.go | 13 +- .../lazyconn/inactivity/inactivity.go | 75 ---- .../lazyconn/inactivity/inactivity_test.go | 156 --------- .../internal/lazyconn/inactivity/manager.go | 152 ++++++++ .../lazyconn/inactivity/manager_test.go | 113 ++++++ client/internal/lazyconn/inactivity/ticker.go | 24 ++ client/internal/lazyconn/manager/manager.go | 326 ++++++++---------- client/internal/lazyconn/wgiface.go | 2 + client/internal/peer/conn.go | 49 ++- client/internal/peer/signaler.go | 10 + client/internal/peerstore/store.go | 13 +- monotime/time.go | 29 ++ monotime/time_test.go | 20 ++ signal/proto/signalexchange.pb.go | 58 ++-- signal/proto/signalexchange.proto | 3 +- 33 files changed, 857 insertions(+), 543 deletions(-) create mode 100644 client/iface/bind/activity.go create mode 100644 client/iface/bind/activity_test.go delete mode 100644 client/internal/lazyconn/inactivity/inactivity.go delete mode 100644 client/internal/lazyconn/inactivity/inactivity_test.go create mode 100644 client/internal/lazyconn/inactivity/manager.go create mode 100644 client/internal/lazyconn/inactivity/manager_test.go create mode 100644 client/internal/lazyconn/inactivity/ticker.go create mode 100644 monotime/time.go create mode 100644 monotime/time_test.go diff --git a/client/iface/bind/activity.go b/client/iface/bind/activity.go new file mode 100644 index 000000000..d3b406bcd --- /dev/null +++ b/client/iface/bind/activity.go @@ -0,0 +1,94 @@ +package bind + +import ( + "net/netip" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/monotime" +) + +const ( + saveFrequency = int64(5 * time.Second) +) + +type PeerRecord struct { + Address netip.AddrPort + LastActivity atomic.Int64 // UnixNano timestamp +} + +type ActivityRecorder struct { + mu sync.RWMutex + peers map[string]*PeerRecord // publicKey to PeerRecord map + addrToPeer map[netip.AddrPort]*PeerRecord // address to PeerRecord map +} + +func NewActivityRecorder() *ActivityRecorder { + return &ActivityRecorder{ + peers: make(map[string]*PeerRecord), + addrToPeer: make(map[netip.AddrPort]*PeerRecord), + } +} + +// GetLastActivities returns a snapshot of peer last activity +func (r *ActivityRecorder) GetLastActivities() map[string]time.Time { + r.mu.RLock() + defer r.mu.RUnlock() + + activities := make(map[string]time.Time, len(r.peers)) + for key, record := range r.peers { + unixNano := record.LastActivity.Load() + activities[key] = time.Unix(0, unixNano) + } + return activities +} + +// UpsertAddress adds or updates the address for a publicKey +func (r *ActivityRecorder) UpsertAddress(publicKey string, address netip.AddrPort) { + r.mu.Lock() + defer r.mu.Unlock() + + if pr, exists := r.peers[publicKey]; exists { + delete(r.addrToPeer, pr.Address) + pr.Address = address + } else { + record := &PeerRecord{ + Address: address, + } + record.LastActivity.Store(monotime.Now()) + r.peers[publicKey] = record + } + + r.addrToPeer[address] = r.peers[publicKey] +} + +func (r *ActivityRecorder) Remove(publicKey string) { + r.mu.Lock() + defer r.mu.Unlock() + if record, exists := r.peers[publicKey]; exists { + delete(r.addrToPeer, record.Address) + delete(r.peers, publicKey) + } +} + +// record updates LastActivity for the given address using atomic store +func (r *ActivityRecorder) record(address netip.AddrPort) { + r.mu.RLock() + record, ok := r.addrToPeer[address] + r.mu.RUnlock() + if !ok { + log.Warnf("could not find record for address %s", address) + return + } + + now := monotime.Now() + last := record.LastActivity.Load() + if now-last < saveFrequency { + return + } + + _ = record.LastActivity.CompareAndSwap(last, now) +} diff --git a/client/iface/bind/activity_test.go b/client/iface/bind/activity_test.go new file mode 100644 index 000000000..598607b95 --- /dev/null +++ b/client/iface/bind/activity_test.go @@ -0,0 +1,27 @@ +package bind + +import ( + "net/netip" + "testing" + "time" +) + +func TestActivityRecorder_GetLastActivities(t *testing.T) { + peer := "peer1" + ar := NewActivityRecorder() + ar.UpsertAddress("peer1", netip.MustParseAddrPort("192.168.0.5:51820")) + activities := ar.GetLastActivities() + + p, ok := activities[peer] + if !ok { + t.Fatalf("Expected activity for peer %s, but got none", peer) + } + + if p.IsZero() { + t.Fatalf("Expected activity for peer %s, but got zero", peer) + } + + if p.Before(time.Now().Add(-2 * time.Minute)) { + t.Fatalf("Expected activity for peer %s to be recent, but got %v", peer, p) + } +} diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index 66ec6a00d..bb7a27279 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -1,6 +1,7 @@ package bind import ( + "encoding/binary" "fmt" "net" "net/netip" @@ -51,22 +52,24 @@ type ICEBind struct { closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it. closed bool - muUDPMux sync.Mutex - udpMux *UniversalUDPMuxDefault - address wgaddr.Address + muUDPMux sync.Mutex + udpMux *UniversalUDPMuxDefault + address wgaddr.Address + activityRecorder *ActivityRecorder } func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind { b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) ib := &ICEBind{ - StdNetBind: b, - RecvChan: make(chan RecvMessage, 1), - transportNet: transportNet, - filterFn: filterFn, - endpoints: make(map[netip.Addr]net.Conn), - closedChan: make(chan struct{}), - closed: true, - address: address, + StdNetBind: b, + RecvChan: make(chan RecvMessage, 1), + transportNet: transportNet, + filterFn: filterFn, + endpoints: make(map[netip.Addr]net.Conn), + closedChan: make(chan struct{}), + closed: true, + address: address, + activityRecorder: NewActivityRecorder(), } rc := receiverCreator{ @@ -100,6 +103,10 @@ func (s *ICEBind) Close() error { return s.StdNetBind.Close() } +func (s *ICEBind) ActivityRecorder() *ActivityRecorder { + return s.activityRecorder +} + // GetICEMux returns the ICE UDPMux that was created and used by ICEBind func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { s.muUDPMux.Lock() @@ -199,6 +206,11 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r continue } addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + + if isTransportPkg(msg.Buffers, msg.N) { + s.activityRecorder.record(addrPort) + } + ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep) eps[i] = ep @@ -257,6 +269,13 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo copy(buffs[0], msg.Buffer) sizes[0] = len(msg.Buffer) eps[0] = wgConn.Endpoint(msg.Endpoint) + + if isTransportPkg(buffs, sizes[0]) { + if ep, ok := eps[0].(*Endpoint); ok { + c.activityRecorder.record(ep.AddrPort) + } + } + return 1, nil } } @@ -272,3 +291,19 @@ func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) { } msgsPool.Put(msgs) } + +func isTransportPkg(buffers [][]byte, n int) bool { + // The first buffer should contain at least 4 bytes for type + if len(buffers[0]) < 4 { + return true + } + + // WireGuard packet type is a little-endian uint32 at start + packetType := binary.LittleEndian.Uint32(buffers[0][:4]) + + // Check if packetType matches known WireGuard message types + if packetType == 4 && n > 32 { + return true + } + return false +} diff --git a/client/iface/configurer/kernel_unix.go b/client/iface/configurer/kernel_unix.go index 4922a54fc..e2ea19144 100644 --- a/client/iface/configurer/kernel_unix.go +++ b/client/iface/configurer/kernel_unix.go @@ -276,3 +276,7 @@ func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) { } return stats, nil } + +func (c *KernelConfigurer) LastActivities() map[string]time.Time { + return nil +} diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index 79ce91eea..6ead716f1 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -16,6 +16,7 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/bind" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -36,16 +37,18 @@ const ( var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found") type WGUSPConfigurer struct { - device *device.Device - deviceName string + device *device.Device + deviceName string + activityRecorder *bind.ActivityRecorder uapiListener net.Listener } -func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer { +func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer { wgCfg := &WGUSPConfigurer{ - device: device, - deviceName: deviceName, + device: device, + deviceName: deviceName, + activityRecorder: activityRecorder, } wgCfg.startUAPI() return wgCfg @@ -87,7 +90,19 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, Peers: []wgtypes.PeerConfig{peer}, } - return c.device.IpcSet(toWgUserspaceString(config)) + if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil { + return ipcErr + } + + if endpoint != nil { + addr, err := netip.ParseAddr(endpoint.IP.String()) + if err != nil { + return fmt.Errorf("failed to parse endpoint address: %w", err) + } + addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port)) + c.activityRecorder.UpsertAddress(peerKey, addrPort) + } + return nil } func (c *WGUSPConfigurer) RemovePeer(peerKey string) error { @@ -104,7 +119,10 @@ func (c *WGUSPConfigurer) RemovePeer(peerKey string) error { config := wgtypes.Config{ Peers: []wgtypes.PeerConfig{peer}, } - return c.device.IpcSet(toWgUserspaceString(config)) + ipcErr := c.device.IpcSet(toWgUserspaceString(config)) + + c.activityRecorder.Remove(peerKey) + return ipcErr } func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error { @@ -205,6 +223,10 @@ func (c *WGUSPConfigurer) FullStats() (*Stats, error) { return parseStatus(c.deviceName, ipcStr) } +func (c *WGUSPConfigurer) LastActivities() map[string]time.Time { + return c.activityRecorder.GetLastActivities() +} + // startUAPI starts the UAPI listener for managing the WireGuard interface via external tool func (t *WGUSPConfigurer) startUAPI() { var err error diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index ae9e29bd1..4fe6e466b 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -79,7 +79,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string // this helps with support for the older NetBird clients that had a hardcoded direct mode // t.device.DisableSomeRoamingForBrokenMobileSemantics() - t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder()) err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index 01bfbf381..81de0e360 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -61,7 +61,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return nil, fmt.Errorf("error assigning ip: %s", err) } - t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder()) err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index 56d44d68e..4613762c3 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -71,7 +71,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { // this helps with support for the older NetBird clients that had a hardcoded direct mode // t.device.DisableSomeRoamingForBrokenMobileSemantics() - t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder()) err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index d2f2c87a1..fc3cb0215 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -72,7 +72,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) { device.NewLogger(wgLogLevel(), "[netbird] "), ) - t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder()) err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { _ = tunIface.Close() diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index c45ae9676..e781f6004 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -64,7 +64,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) { return nil, fmt.Errorf("error assigning ip: %s", err) } - t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder()) err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index 41e615bc2..0316c4b8d 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -94,7 +94,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) { return nil, fmt.Errorf("error assigning ip: %s", err) } - t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder()) err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() diff --git a/client/iface/device/interface.go b/client/iface/device/interface.go index 296eb7dda..d68e6bf90 100644 --- a/client/iface/device/interface.go +++ b/client/iface/device/interface.go @@ -19,4 +19,5 @@ type WGConfigurer interface { Close() GetStats() (map[string]configurer.WGStats, error) FullStats() (*configurer.Stats, error) + LastActivities() map[string]time.Time } diff --git a/client/iface/iface.go b/client/iface/iface.go index 006dfe4e7..1b9055e6c 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -217,6 +217,14 @@ func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) { return w.configurer.GetStats() } +func (w *WGIface) LastActivities() map[string]time.Time { + w.mu.Lock() + defer w.mu.Unlock() + + return w.configurer.LastActivities() + +} + func (w *WGIface) FullStats() (*configurer.Stats, error) { return w.configurer.FullStats() } diff --git a/client/internal/conn_mgr.go b/client/internal/conn_mgr.go index c630d3052..c76b0a99f 100644 --- a/client/internal/conn_mgr.go +++ b/client/internal/conn_mgr.go @@ -12,7 +12,6 @@ import ( "github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/lazyconn/manager" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/internal/peer/dispatcher" "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/route" ) @@ -26,11 +25,11 @@ import ( // // The implementation is not thread-safe; it is protected by engine.syncMsgMux. type ConnMgr struct { - peerStore *peerstore.Store - statusRecorder *peer.Status - iface lazyconn.WGIface - dispatcher *dispatcher.ConnectionDispatcher - enabledLocally bool + peerStore *peerstore.Store + statusRecorder *peer.Status + iface lazyconn.WGIface + enabledLocally bool + rosenpassEnabled bool lazyConnMgr *manager.Manager @@ -39,12 +38,12 @@ type ConnMgr struct { lazyCtxCancel context.CancelFunc } -func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr { +func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface) *ConnMgr { e := &ConnMgr{ - peerStore: peerStore, - statusRecorder: statusRecorder, - iface: iface, - dispatcher: dispatcher, + peerStore: peerStore, + statusRecorder: statusRecorder, + iface: iface, + rosenpassEnabled: engineConfig.RosenpassEnabled, } if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() { e.enabledLocally = true @@ -64,6 +63,11 @@ func (e *ConnMgr) Start(ctx context.Context) { return } + if e.rosenpassEnabled { + log.Warnf("rosenpass connection manager is enabled, lazy connection manager will not be started") + return + } + e.initLazyManager(ctx) e.statusRecorder.UpdateLazyConnection(true) } @@ -83,7 +87,12 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er return nil } - log.Infof("lazy connection manager is enabled by management feature flag") + if e.rosenpassEnabled { + log.Infof("rosenpass connection manager is enabled, lazy connection manager will not be started") + return nil + } + + log.Warnf("lazy connection manager is enabled by management feature flag") e.initLazyManager(ctx) e.statusRecorder.UpdateLazyConnection(true) return e.addPeersToLazyConnManager() @@ -133,7 +142,7 @@ func (e *ConnMgr) SetExcludeList(ctx context.Context, peerIDs map[string]bool) { excludedPeers = append(excludedPeers, lazyPeerCfg) } - added := e.lazyConnMgr.ExcludePeer(e.lazyCtx, excludedPeers) + added := e.lazyConnMgr.ExcludePeer(excludedPeers) for _, peerID := range added { var peerConn *peer.Conn var exists bool @@ -175,7 +184,7 @@ func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Co PeerConnID: conn.ConnID(), Log: conn.Log, } - excluded, err := e.lazyConnMgr.AddPeer(e.lazyCtx, lazyPeerCfg) + excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg) if err != nil { conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err) if err := conn.Open(ctx); err != nil { @@ -201,7 +210,7 @@ func (e *ConnMgr) RemovePeerConn(peerKey string) { if !ok { return } - defer conn.Close() + defer conn.Close(false) if !e.isStartedWithLazyMgr() { return @@ -211,23 +220,28 @@ func (e *ConnMgr) RemovePeerConn(peerKey string) { conn.Log.Infof("removed peer from lazy conn manager") } -func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn, bool) { - conn, ok := e.peerStore.PeerConn(peerKey) - if !ok { - return nil, false - } - +func (e *ConnMgr) ActivatePeer(ctx context.Context, conn *peer.Conn) { if !e.isStartedWithLazyMgr() { - return conn, true + return } - if found := e.lazyConnMgr.ActivatePeer(e.lazyCtx, peerKey); found { + if found := e.lazyConnMgr.ActivatePeer(conn.GetKey()); found { conn.Log.Infof("activated peer from inactive state") if err := conn.Open(ctx); err != nil { conn.Log.Errorf("failed to open connection: %v", err) } } - return conn, true +} + +// DeactivatePeer deactivates a peer connection in the lazy connection manager. +// If locally the lazy connection is disabled, we force the peer connection open. +func (e *ConnMgr) DeactivatePeer(conn *peer.Conn) { + if !e.isStartedWithLazyMgr() { + return + } + + conn.Log.Infof("closing peer connection: remote peer initiated inactive, idle lazy state and sent GOAWAY") + e.lazyConnMgr.DeactivatePeer(conn.ConnID()) } func (e *ConnMgr) Close() { @@ -244,7 +258,7 @@ func (e *ConnMgr) initLazyManager(engineCtx context.Context) { cfg := manager.Config{ InactivityThreshold: inactivityThresholdEnv(), } - e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface, e.dispatcher) + e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface) e.lazyCtx, e.lazyCtxCancel = context.WithCancel(engineCtx) @@ -275,7 +289,7 @@ func (e *ConnMgr) addPeersToLazyConnManager() error { lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg) } - return e.lazyConnMgr.AddActivePeers(e.lazyCtx, lazyPeerCfgs) + return e.lazyConnMgr.AddActivePeers(lazyPeerCfgs) } func (e *ConnMgr) closeManager(ctx context.Context) { diff --git a/client/internal/engine.go b/client/internal/engine.go index 74d84569a..e9772b359 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -38,7 +38,6 @@ import ( nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/internal/peer/dispatcher" "github.com/netbirdio/netbird/client/internal/peer/guard" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/peerstore" @@ -175,8 +174,7 @@ type Engine struct { sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error) sshServer nbssh.Server - statusRecorder *peer.Status - peerConnDispatcher *dispatcher.ConnectionDispatcher + statusRecorder *peer.Status firewall firewallManager.Manager routeManager routemanager.Manager @@ -458,9 +456,7 @@ func (e *Engine) Start() error { NATExternalIPs: e.parseNATExternalIPMappings(), } - e.peerConnDispatcher = dispatcher.NewConnectionDispatcher() - - e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface, e.peerConnDispatcher) + e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface) e.connMgr.Start(e.ctx) e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg) @@ -1261,7 +1257,7 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { } if exists := e.connMgr.AddPeerConn(e.ctx, peerKey, conn); exists { - conn.Close() + conn.Close(false) return fmt.Errorf("peer already exists: %s", peerKey) } @@ -1308,13 +1304,12 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV } serviceDependencies := peer.ServiceDependencies{ - StatusRecorder: e.statusRecorder, - Signaler: e.signaler, - IFaceDiscover: e.mobileDep.IFaceDiscover, - RelayManager: e.relayManager, - SrWatcher: e.srWatcher, - Semaphore: e.connSemaphore, - PeerConnDispatcher: e.peerConnDispatcher, + StatusRecorder: e.statusRecorder, + Signaler: e.signaler, + IFaceDiscover: e.mobileDep.IFaceDiscover, + RelayManager: e.relayManager, + SrWatcher: e.srWatcher, + Semaphore: e.connSemaphore, } peerConn, err := peer.NewConn(config, serviceDependencies) if err != nil { @@ -1337,11 +1332,16 @@ func (e *Engine) receiveSignalEvents() { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() - conn, ok := e.connMgr.OnSignalMsg(e.ctx, msg.Key) + conn, ok := e.peerStore.PeerConn(msg.Key) if !ok { return fmt.Errorf("wrongly addressed message %s", msg.Key) } + msgType := msg.GetBody().GetType() + if msgType != sProto.Body_GO_IDLE { + e.connMgr.ActivatePeer(e.ctx, conn) + } + switch msg.GetBody().Type { case sProto.Body_OFFER: remoteCred, err := signal.UnMarshalCredential(msg) @@ -1398,6 +1398,8 @@ func (e *Engine) receiveSignalEvents() { go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes()) case sProto.Body_MODE: + case sProto.Body_GO_IDLE: + e.connMgr.DeactivatePeer(conn) } return nil diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index d9c9881da..f4ed8f1c0 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -36,7 +36,6 @@ import ( "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/client/internal/peer/dispatcher" "github.com/netbirdio/netbird/client/internal/peer/guard" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/routemanager" @@ -97,6 +96,7 @@ type MockWGIface struct { GetInterfaceGUIDStringFunc func() (string, error) GetProxyFunc func() wgproxy.Proxy GetNetFunc func() *netstack.Net + LastActivitiesFunc func() map[string]time.Time } func (m *MockWGIface) FullStats() (*configurer.Stats, error) { @@ -187,6 +187,13 @@ func (m *MockWGIface) GetNet() *netstack.Net { return m.GetNetFunc() } +func (m *MockWGIface) LastActivities() map[string]time.Time { + if m.LastActivitiesFunc != nil { + return m.LastActivitiesFunc() + } + return nil +} + func TestMain(m *testing.M) { _ = util.InitLog("debug", "console") code := m.Run() @@ -404,7 +411,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn}) engine.ctx = ctx engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{}) - engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface, dispatcher.NewConnectionDispatcher()) + engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface) engine.connMgr.Start(ctx) type testCase struct { @@ -793,7 +800,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { engine.routeManager = mockRouteManager engine.dnsServer = &dns.MockServer{} - engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher()) + engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface) engine.connMgr.Start(ctx) defer func() { @@ -991,7 +998,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { } engine.dnsServer = mockDNSServer - engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher()) + engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface) engine.connMgr.Start(ctx) defer func() { diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index 999472411..38fb3561e 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -38,4 +38,5 @@ type wgIfaceBase interface { GetStats() (map[string]configurer.WGStats, error) GetNet() *netstack.Net FullStats() (*configurer.Stats, error) + LastActivities() map[string]time.Time } diff --git a/client/internal/lazyconn/activity/listener.go b/client/internal/lazyconn/activity/listener.go index 1ef48416a..81b5da17b 100644 --- a/client/internal/lazyconn/activity/listener.go +++ b/client/internal/lazyconn/activity/listener.go @@ -13,7 +13,7 @@ import ( // Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking type Listener struct { - wgIface lazyconn.WGIface + wgIface WgInterface peerCfg lazyconn.PeerConfig conn *net.UDPConn endpoint *net.UDPAddr @@ -22,7 +22,7 @@ type Listener struct { isClosed atomic.Bool // use to avoid error log when closing the listener } -func NewListener(wgIface lazyconn.WGIface, cfg lazyconn.PeerConfig) (*Listener, error) { +func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) { d := &Listener{ wgIface: wgIface, peerCfg: cfg, diff --git a/client/internal/lazyconn/activity/manager.go b/client/internal/lazyconn/activity/manager.go index e18b96465..915fb9cb8 100644 --- a/client/internal/lazyconn/activity/manager.go +++ b/client/internal/lazyconn/activity/manager.go @@ -1,18 +1,27 @@ package activity import ( + "net" + "net/netip" "sync" + "time" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/internal/lazyconn" peerid "github.com/netbirdio/netbird/client/internal/peer/id" ) +type WgInterface interface { + RemovePeer(peerKey string) error + UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error +} + type Manager struct { OnActivityChan chan peerid.ConnID - wgIface lazyconn.WGIface + wgIface WgInterface peers map[peerid.ConnID]*Listener done chan struct{} @@ -20,7 +29,7 @@ type Manager struct { mu sync.Mutex } -func NewManager(wgIface lazyconn.WGIface) *Manager { +func NewManager(wgIface WgInterface) *Manager { m := &Manager{ OnActivityChan: make(chan peerid.ConnID, 1), wgIface: wgIface, diff --git a/client/internal/lazyconn/inactivity/inactivity.go b/client/internal/lazyconn/inactivity/inactivity.go deleted file mode 100644 index 9b7c8511b..000000000 --- a/client/internal/lazyconn/inactivity/inactivity.go +++ /dev/null @@ -1,75 +0,0 @@ -package inactivity - -import ( - "context" - "time" - - peer "github.com/netbirdio/netbird/client/internal/peer/id" -) - -const ( - DefaultInactivityThreshold = 60 * time.Minute // idle after 1 hour inactivity - MinimumInactivityThreshold = 3 * time.Minute -) - -type Monitor struct { - id peer.ConnID - timer *time.Timer - cancel context.CancelFunc - inactivityThreshold time.Duration -} - -func NewInactivityMonitor(peerID peer.ConnID, threshold time.Duration) *Monitor { - i := &Monitor{ - id: peerID, - timer: time.NewTimer(0), - inactivityThreshold: threshold, - } - i.timer.Stop() - return i -} - -func (i *Monitor) Start(ctx context.Context, timeoutChan chan peer.ConnID) { - i.timer.Reset(i.inactivityThreshold) - defer i.timer.Stop() - - ctx, i.cancel = context.WithCancel(ctx) - defer func() { - defer i.cancel() - select { - case <-i.timer.C: - default: - } - }() - - select { - case <-i.timer.C: - select { - case timeoutChan <- i.id: - case <-ctx.Done(): - return - } - case <-ctx.Done(): - return - } -} - -func (i *Monitor) Stop() { - if i.cancel == nil { - return - } - i.cancel() -} - -func (i *Monitor) PauseTimer() { - i.timer.Stop() -} - -func (i *Monitor) ResetTimer() { - i.timer.Reset(i.inactivityThreshold) -} - -func (i *Monitor) ResetMonitor(ctx context.Context, timeoutChan chan peer.ConnID) { - i.Stop() - go i.Start(ctx, timeoutChan) -} diff --git a/client/internal/lazyconn/inactivity/inactivity_test.go b/client/internal/lazyconn/inactivity/inactivity_test.go deleted file mode 100644 index 944512985..000000000 --- a/client/internal/lazyconn/inactivity/inactivity_test.go +++ /dev/null @@ -1,156 +0,0 @@ -package inactivity - -import ( - "context" - "testing" - "time" - - peerid "github.com/netbirdio/netbird/client/internal/peer/id" -) - -type MocPeer struct { -} - -func (m *MocPeer) ConnID() peerid.ConnID { - return peerid.ConnID(m) -} - -func TestInactivityMonitor(t *testing.T) { - tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5) - defer testTimeoutCancel() - - p := &MocPeer{} - im := NewInactivityMonitor(p.ConnID(), time.Second*2) - - timeoutChan := make(chan peerid.ConnID) - - exitChan := make(chan struct{}) - - go func() { - defer close(exitChan) - im.Start(tCtx, timeoutChan) - }() - - select { - case <-timeoutChan: - case <-tCtx.Done(): - t.Fatal("timeout") - } - - select { - case <-exitChan: - case <-tCtx.Done(): - t.Fatal("timeout") - } -} - -func TestReuseInactivityMonitor(t *testing.T) { - p := &MocPeer{} - im := NewInactivityMonitor(p.ConnID(), time.Second*2) - - timeoutChan := make(chan peerid.ConnID) - - for i := 2; i > 0; i-- { - exitChan := make(chan struct{}) - - testTimeoutCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5) - - go func() { - defer close(exitChan) - im.Start(testTimeoutCtx, timeoutChan) - }() - - select { - case <-timeoutChan: - case <-testTimeoutCtx.Done(): - t.Fatal("timeout") - } - - select { - case <-exitChan: - case <-testTimeoutCtx.Done(): - t.Fatal("timeout") - } - testTimeoutCancel() - } -} - -func TestStopInactivityMonitor(t *testing.T) { - tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5) - defer testTimeoutCancel() - - p := &MocPeer{} - im := NewInactivityMonitor(p.ConnID(), DefaultInactivityThreshold) - - timeoutChan := make(chan peerid.ConnID) - - exitChan := make(chan struct{}) - - go func() { - defer close(exitChan) - im.Start(tCtx, timeoutChan) - }() - - go func() { - time.Sleep(3 * time.Second) - im.Stop() - }() - - select { - case <-timeoutChan: - t.Fatal("unexpected timeout") - case <-exitChan: - case <-tCtx.Done(): - t.Fatal("timeout") - } -} - -func TestPauseInactivityMonitor(t *testing.T) { - tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*10) - defer testTimeoutCancel() - - p := &MocPeer{} - trashHold := time.Second * 3 - im := NewInactivityMonitor(p.ConnID(), trashHold) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - timeoutChan := make(chan peerid.ConnID) - - exitChan := make(chan struct{}) - - go func() { - defer close(exitChan) - im.Start(ctx, timeoutChan) - }() - - time.Sleep(1 * time.Second) // grant time to start the monitor - im.PauseTimer() - - // check to do not receive timeout - thresholdCtx, thresholdCancel := context.WithTimeout(context.Background(), trashHold+time.Second) - defer thresholdCancel() - select { - case <-exitChan: - t.Fatal("unexpected exit") - case <-timeoutChan: - t.Fatal("unexpected timeout") - case <-thresholdCtx.Done(): - // test ok - case <-tCtx.Done(): - t.Fatal("test timed out") - } - - // test reset timer - im.ResetTimer() - - select { - case <-tCtx.Done(): - t.Fatal("test timed out") - case <-exitChan: - t.Fatal("unexpected exit") - case <-timeoutChan: - // expected timeout - } -} diff --git a/client/internal/lazyconn/inactivity/manager.go b/client/internal/lazyconn/inactivity/manager.go new file mode 100644 index 000000000..854951729 --- /dev/null +++ b/client/internal/lazyconn/inactivity/manager.go @@ -0,0 +1,152 @@ +package inactivity + +import ( + "context" + "fmt" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/lazyconn" +) + +const ( + checkInterval = 1 * time.Minute + + DefaultInactivityThreshold = 15 * time.Minute + MinimumInactivityThreshold = 1 * time.Minute +) + +type WgInterface interface { + LastActivities() map[string]time.Time +} + +type Manager struct { + inactivePeersChan chan map[string]struct{} + + iface WgInterface + interestedPeers map[string]*lazyconn.PeerConfig + inactivityThreshold time.Duration +} + +func NewManager(iface WgInterface, configuredThreshold *time.Duration) *Manager { + inactivityThreshold, err := validateInactivityThreshold(configuredThreshold) + if err != nil { + inactivityThreshold = DefaultInactivityThreshold + log.Warnf("invalid inactivity threshold configured: %v, using default: %v", err, DefaultInactivityThreshold) + } + + log.Infof("inactivity threshold configured: %v", inactivityThreshold) + return &Manager{ + inactivePeersChan: make(chan map[string]struct{}, 1), + iface: iface, + interestedPeers: make(map[string]*lazyconn.PeerConfig), + inactivityThreshold: inactivityThreshold, + } +} + +func (m *Manager) InactivePeersChan() chan map[string]struct{} { + if m == nil { + // return a nil channel that blocks forever + return nil + } + + return m.inactivePeersChan +} + +func (m *Manager) AddPeer(peerCfg *lazyconn.PeerConfig) { + if m == nil { + return + } + + if _, exists := m.interestedPeers[peerCfg.PublicKey]; exists { + return + } + + peerCfg.Log.Infof("adding peer to inactivity manager") + m.interestedPeers[peerCfg.PublicKey] = peerCfg +} + +func (m *Manager) RemovePeer(peer string) { + if m == nil { + return + } + + pi, ok := m.interestedPeers[peer] + if !ok { + return + } + + pi.Log.Debugf("remove peer from inactivity manager") + delete(m.interestedPeers, peer) +} + +func (m *Manager) Start(ctx context.Context) { + if m == nil { + return + } + + ticker := newTicker(checkInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C(): + idlePeers, err := m.checkStats() + if err != nil { + log.Errorf("error checking stats: %v", err) + return + } + + if len(idlePeers) == 0 { + continue + } + + m.notifyInactivePeers(ctx, idlePeers) + } + } +} + +func (m *Manager) notifyInactivePeers(ctx context.Context, inactivePeers map[string]struct{}) { + select { + case m.inactivePeersChan <- inactivePeers: + case <-ctx.Done(): + return + default: + return + } +} + +func (m *Manager) checkStats() (map[string]struct{}, error) { + lastActivities := m.iface.LastActivities() + + idlePeers := make(map[string]struct{}) + + for peerID, peerCfg := range m.interestedPeers { + lastActive, ok := lastActivities[peerID] + if !ok { + // when peer is in connecting state + peerCfg.Log.Warnf("peer not found in wg stats") + continue + } + + if time.Since(lastActive) > m.inactivityThreshold { + peerCfg.Log.Infof("peer is inactive since: %v", lastActive) + idlePeers[peerID] = struct{}{} + } + } + + return idlePeers, nil +} + +func validateInactivityThreshold(configuredThreshold *time.Duration) (time.Duration, error) { + if configuredThreshold == nil { + return DefaultInactivityThreshold, nil + } + if *configuredThreshold < MinimumInactivityThreshold { + return 0, fmt.Errorf("configured inactivity threshold %v is too low, using %v", *configuredThreshold, MinimumInactivityThreshold) + } + return *configuredThreshold, nil +} diff --git a/client/internal/lazyconn/inactivity/manager_test.go b/client/internal/lazyconn/inactivity/manager_test.go new file mode 100644 index 000000000..d012b41a2 --- /dev/null +++ b/client/internal/lazyconn/inactivity/manager_test.go @@ -0,0 +1,113 @@ +package inactivity + +import ( + "context" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/client/internal/lazyconn" +) + +type mockWgInterface struct { + lastActivities map[string]time.Time +} + +func (m *mockWgInterface) LastActivities() map[string]time.Time { + return m.lastActivities +} + +func TestPeerTriggersInactivity(t *testing.T) { + peerID := "peer1" + + wgMock := &mockWgInterface{ + lastActivities: map[string]time.Time{ + peerID: time.Now().Add(-20 * time.Minute), + }, + } + + fakeTick := make(chan time.Time, 1) + newTicker = func(d time.Duration) Ticker { + return &fakeTickerMock{CChan: fakeTick} + } + + peerLog := log.WithField("peer", peerID) + peerCfg := &lazyconn.PeerConfig{ + PublicKey: peerID, + Log: peerLog, + } + + manager := NewManager(wgMock, nil) + manager.AddPeer(peerCfg) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Start the manager in a goroutine + go manager.Start(ctx) + + // Send a tick to simulate time passage + fakeTick <- time.Now() + + // Check if peer appears on inactivePeersChan + select { + case inactivePeers := <-manager.inactivePeersChan: + assert.Contains(t, inactivePeers, peerID, "expected peer to be marked inactive") + case <-time.After(1 * time.Second): + t.Fatal("expected inactivity event, but none received") + } +} + +func TestPeerTriggersActivity(t *testing.T) { + peerID := "peer1" + + wgMock := &mockWgInterface{ + lastActivities: map[string]time.Time{ + peerID: time.Now().Add(-5 * time.Minute), + }, + } + + fakeTick := make(chan time.Time, 1) + newTicker = func(d time.Duration) Ticker { + return &fakeTickerMock{CChan: fakeTick} + } + + peerLog := log.WithField("peer", peerID) + peerCfg := &lazyconn.PeerConfig{ + PublicKey: peerID, + Log: peerLog, + } + + manager := NewManager(wgMock, nil) + manager.AddPeer(peerCfg) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Start the manager in a goroutine + go manager.Start(ctx) + + // Send a tick to simulate time passage + fakeTick <- time.Now() + + // Check if peer appears on inactivePeersChan + select { + case <-manager.inactivePeersChan: + t.Fatal("expected inactive peer to be marked inactive") + case <-time.After(1 * time.Second): + // No inactivity event should be received + } +} + +// fakeTickerMock implements Ticker interface for testing +type fakeTickerMock struct { + CChan chan time.Time +} + +func (f *fakeTickerMock) C() <-chan time.Time { + return f.CChan +} + +func (f *fakeTickerMock) Stop() {} diff --git a/client/internal/lazyconn/inactivity/ticker.go b/client/internal/lazyconn/inactivity/ticker.go new file mode 100644 index 000000000..12b64bd5f --- /dev/null +++ b/client/internal/lazyconn/inactivity/ticker.go @@ -0,0 +1,24 @@ +package inactivity + +import "time" + +var newTicker = func(d time.Duration) Ticker { + return &realTicker{t: time.NewTicker(d)} +} + +type Ticker interface { + C() <-chan time.Time + Stop() +} + +type realTicker struct { + t *time.Ticker +} + +func (r *realTicker) C() <-chan time.Time { + return r.t.C +} + +func (r *realTicker) Stop() { + r.t.Stop() +} diff --git a/client/internal/lazyconn/manager/manager.go b/client/internal/lazyconn/manager/manager.go index 74ede50a7..b45b39221 100644 --- a/client/internal/lazyconn/manager/manager.go +++ b/client/internal/lazyconn/manager/manager.go @@ -52,51 +52,39 @@ type Manager struct { excludes map[string]lazyconn.PeerConfig managedPeersMu sync.Mutex - activityManager *activity.Manager - inactivityMonitors map[peerid.ConnID]*inactivity.Monitor + activityManager *activity.Manager + inactivityManager *inactivity.Manager // Route HA group management + // If any peer in the same HA group is active, all peers in that group should prevent going idle peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group routesMu sync.RWMutex - - onInactive chan peerid.ConnID } // NewManager creates a new lazy connection manager // engineCtx is the context for creating peer Connection -func NewManager(config Config, engineCtx context.Context, peerStore *peerstore.Store, wgIface lazyconn.WGIface, connStateDispatcher *dispatcher.ConnectionDispatcher) *Manager { +func NewManager(config Config, engineCtx context.Context, peerStore *peerstore.Store, wgIface lazyconn.WGIface) *Manager { log.Infof("setup lazy connection service") + m := &Manager{ engineCtx: engineCtx, peerStore: peerStore, - connStateDispatcher: connStateDispatcher, inactivityThreshold: inactivity.DefaultInactivityThreshold, managedPeers: make(map[string]*lazyconn.PeerConfig), managedPeersByConnID: make(map[peerid.ConnID]*managedPeer), excludes: make(map[string]lazyconn.PeerConfig), activityManager: activity.NewManager(wgIface), - inactivityMonitors: make(map[peerid.ConnID]*inactivity.Monitor), peerToHAGroups: make(map[string][]route.HAUniqueID), haGroupToPeers: make(map[route.HAUniqueID][]string), - onInactive: make(chan peerid.ConnID), } - if config.InactivityThreshold != nil { - if *config.InactivityThreshold >= inactivity.MinimumInactivityThreshold { - m.inactivityThreshold = *config.InactivityThreshold - } else { - log.Warnf("inactivity threshold is too low, using %v", m.inactivityThreshold) - } + if wgIface.IsUserspaceBind() { + m.inactivityManager = inactivity.NewManager(wgIface, config.InactivityThreshold) + } else { + log.Warnf("inactivity manager not supported for kernel mode, wait for remote peer to close the connection") } - m.connStateListener = &dispatcher.ConnectionListener{ - OnConnected: m.onPeerConnected, - OnDisconnected: m.onPeerDisconnected, - } - - connStateDispatcher.AddListener(m.connStateListener) - return m } @@ -131,24 +119,28 @@ func (m *Manager) UpdateRouteHAMap(haMap route.HAMap) { } } - log.Debugf("updated route HA mappings: %d HA groups, %d peers with routes", - len(m.haGroupToPeers), len(m.peerToHAGroups)) + log.Debugf("updated route HA mappings: %d HA groups, %d peers with routes", len(m.haGroupToPeers), len(m.peerToHAGroups)) } // Start starts the manager and listens for peer activity and inactivity events func (m *Manager) Start(ctx context.Context) { defer m.close() + if m.inactivityManager != nil { + go m.inactivityManager.Start(ctx) + } + for { select { case <-ctx.Done(): return case peerConnID := <-m.activityManager.OnActivityChan: - m.onPeerActivity(ctx, peerConnID) - case peerConnID := <-m.onInactive: - m.onPeerInactivityTimedOut(ctx, peerConnID) + m.onPeerActivity(peerConnID) + case peerIDs := <-m.inactivityManager.InactivePeersChan(): + m.onPeerInactivityTimedOut(peerIDs) } } + } // ExcludePeer marks peers for a permanent connection @@ -156,7 +148,7 @@ func (m *Manager) Start(ctx context.Context) { // Adds them back to the managed list and start the inactivity listener if they are removed from the exclude list. In // this case, we suppose that the connection status is connected or connecting. // If the peer is not exists yet in the managed list then the responsibility is the upper layer to call the AddPeer function -func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerConfig) []string { +func (m *Manager) ExcludePeer(peerConfigs []lazyconn.PeerConfig) []string { m.managedPeersMu.Lock() defer m.managedPeersMu.Unlock() @@ -187,7 +179,7 @@ func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerCo peerCfg.Log.Infof("peer removed from lazy connection exclude list") - if err := m.addActivePeer(ctx, peerCfg); err != nil { + if err := m.addActivePeer(&peerCfg); err != nil { log.Errorf("failed to add peer to lazy connection manager: %s", err) continue } @@ -197,7 +189,7 @@ func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerCo return added } -func (m *Manager) AddPeer(ctx context.Context, peerCfg lazyconn.PeerConfig) (bool, error) { +func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) { m.managedPeersMu.Lock() defer m.managedPeersMu.Unlock() @@ -217,9 +209,6 @@ func (m *Manager) AddPeer(ctx context.Context, peerCfg lazyconn.PeerConfig) (boo return false, err } - im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold) - m.inactivityMonitors[peerCfg.PeerConnID] = im - m.managedPeers[peerCfg.PublicKey] = &peerCfg m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{ peerCfg: &peerCfg, @@ -229,7 +218,7 @@ func (m *Manager) AddPeer(ctx context.Context, peerCfg lazyconn.PeerConfig) (boo // Check if this peer should be activated because its HA group peers are active if group, ok := m.shouldActivateNewPeer(peerCfg.PublicKey); ok { peerCfg.Log.Debugf("peer belongs to active HA group %s, will activate immediately", group) - m.activateNewPeerInActiveGroup(ctx, peerCfg) + m.activateNewPeerInActiveGroup(peerCfg) } return false, nil @@ -237,7 +226,7 @@ func (m *Manager) AddPeer(ctx context.Context, peerCfg lazyconn.PeerConfig) (boo // AddActivePeers adds a list of peers to the lazy connection manager // suppose these peers was in connected or in connecting states -func (m *Manager) AddActivePeers(ctx context.Context, peerCfg []lazyconn.PeerConfig) error { +func (m *Manager) AddActivePeers(peerCfg []lazyconn.PeerConfig) error { m.managedPeersMu.Lock() defer m.managedPeersMu.Unlock() @@ -247,7 +236,7 @@ func (m *Manager) AddActivePeers(ctx context.Context, peerCfg []lazyconn.PeerCon continue } - if err := m.addActivePeer(ctx, cfg); err != nil { + if err := m.addActivePeer(&cfg); err != nil { cfg.Log.Errorf("failed to add peer to lazy connection manager: %v", err) return err } @@ -264,7 +253,7 @@ func (m *Manager) RemovePeer(peerID string) { // ActivatePeer activates a peer connection when a signal message is received // Also activates all peers in the same HA groups as this peer -func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool) { +func (m *Manager) ActivatePeer(peerID string) (found bool) { m.managedPeersMu.Lock() defer m.managedPeersMu.Unlock() cfg, mp := m.getPeerForActivation(peerID) @@ -272,15 +261,42 @@ func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool) return false } - if !m.activateSinglePeer(ctx, cfg, mp) { + if !m.activateSinglePeer(cfg, mp) { return false } - m.activateHAGroupPeers(ctx, peerID) + m.activateHAGroupPeers(cfg) return true } +func (m *Manager) DeactivatePeer(peerID peerid.ConnID) { + m.managedPeersMu.Lock() + defer m.managedPeersMu.Unlock() + + mp, ok := m.managedPeersByConnID[peerID] + if !ok { + return + } + + if mp.expectedWatcher != watcherInactivity { + return + } + + m.peerStore.PeerConnClose(mp.peerCfg.PublicKey) + + mp.peerCfg.Log.Infof("start activity monitor") + + mp.expectedWatcher = watcherActivity + + m.inactivityManager.RemovePeer(mp.peerCfg.PublicKey) + + if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil { + mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err) + return + } +} + // getPeerForActivation checks if a peer can be activated and returns the necessary structs // Returns nil values if the peer should be skipped func (m *Manager) getPeerForActivation(peerID string) (*lazyconn.PeerConfig, *managedPeer) { @@ -302,41 +318,36 @@ func (m *Manager) getPeerForActivation(peerID string) (*lazyconn.PeerConfig, *ma return cfg, mp } -// activateSinglePeer activates a single peer (internal method) -func (m *Manager) activateSinglePeer(ctx context.Context, cfg *lazyconn.PeerConfig, mp *managedPeer) bool { - mp.expectedWatcher = watcherInactivity - - m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID) - - im, ok := m.inactivityMonitors[cfg.PeerConnID] - if !ok { - cfg.Log.Errorf("inactivity monitor not found for peer") +// activateSinglePeer activates a single peer +// return true if the peer was activated, false if it was already active +func (m *Manager) activateSinglePeer(cfg *lazyconn.PeerConfig, mp *managedPeer) bool { + if mp.expectedWatcher == watcherInactivity { return false } - cfg.Log.Infof("starting inactivity monitor") - go im.Start(ctx, m.onInactive) - + mp.expectedWatcher = watcherInactivity + m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID) + m.inactivityManager.AddPeer(cfg) return true } // activateHAGroupPeers activates all peers in HA groups that the given peer belongs to -func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string) { +func (m *Manager) activateHAGroupPeers(triggeredPeerCfg *lazyconn.PeerConfig) { var peersToActivate []string m.routesMu.RLock() - haGroups := m.peerToHAGroups[triggerPeerID] + haGroups := m.peerToHAGroups[triggeredPeerCfg.PublicKey] if len(haGroups) == 0 { m.routesMu.RUnlock() - log.Debugf("peer %s is not part of any HA groups", triggerPeerID) + triggeredPeerCfg.Log.Debugf("peer is not part of any HA groups") return } for _, haGroup := range haGroups { peers := m.haGroupToPeers[haGroup] for _, peerID := range peers { - if peerID != triggerPeerID { + if peerID != triggeredPeerCfg.PublicKey { peersToActivate = append(peersToActivate, peerID) } } @@ -350,16 +361,16 @@ func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string continue } - if m.activateSinglePeer(ctx, cfg, mp) { + if m.activateSinglePeer(cfg, mp) { activatedCount++ - cfg.Log.Infof("activated peer as part of HA group (triggered by %s)", triggerPeerID) + cfg.Log.Infof("activated peer as part of HA group (triggered by %s)", triggeredPeerCfg.PublicKey) m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey) } } if activatedCount > 0 { log.Infof("activated %d additional peers in HA groups for peer %s (groups: %v)", - activatedCount, triggerPeerID, haGroups) + activatedCount, triggeredPeerCfg.PublicKey, haGroups) } } @@ -394,13 +405,13 @@ func (m *Manager) shouldActivateNewPeer(peerID string) (route.HAUniqueID, bool) } // activateNewPeerInActiveGroup activates a newly added peer that should be active due to HA group -func (m *Manager) activateNewPeerInActiveGroup(ctx context.Context, peerCfg lazyconn.PeerConfig) { +func (m *Manager) activateNewPeerInActiveGroup(peerCfg lazyconn.PeerConfig) { mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID] if !ok { return } - if !m.activateSinglePeer(ctx, &peerCfg, mp) { + if !m.activateSinglePeer(&peerCfg, mp) { return } @@ -408,23 +419,19 @@ func (m *Manager) activateNewPeerInActiveGroup(ctx context.Context, peerCfg lazy m.peerStore.PeerConnOpen(m.engineCtx, peerCfg.PublicKey) } -func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error { +func (m *Manager) addActivePeer(peerCfg *lazyconn.PeerConfig) error { if _, ok := m.managedPeers[peerCfg.PublicKey]; ok { peerCfg.Log.Warnf("peer already managed") return nil } - im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold) - m.inactivityMonitors[peerCfg.PeerConnID] = im - - m.managedPeers[peerCfg.PublicKey] = &peerCfg + m.managedPeers[peerCfg.PublicKey] = peerCfg m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{ - peerCfg: &peerCfg, + peerCfg: peerCfg, expectedWatcher: watcherInactivity, } - peerCfg.Log.Infof("starting inactivity monitor on peer that has been removed from exclude list") - go im.Start(ctx, m.onInactive) + m.inactivityManager.AddPeer(peerCfg) return nil } @@ -436,12 +443,7 @@ func (m *Manager) removePeer(peerID string) { cfg.Log.Infof("removing lazy peer") - if im, ok := m.inactivityMonitors[cfg.PeerConnID]; ok { - im.Stop() - delete(m.inactivityMonitors, cfg.PeerConnID) - cfg.Log.Debugf("inactivity monitor stopped") - } - + m.inactivityManager.RemovePeer(cfg.PublicKey) m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID) delete(m.managedPeers, peerID) delete(m.managedPeersByConnID, cfg.PeerConnID) @@ -453,10 +455,7 @@ func (m *Manager) close() { m.connStateDispatcher.RemoveListener(m.connStateListener) m.activityManager.Close() - for _, iw := range m.inactivityMonitors { - iw.Stop() - } - m.inactivityMonitors = make(map[peerid.ConnID]*inactivity.Monitor) + m.managedPeers = make(map[string]*lazyconn.PeerConfig) m.managedPeersByConnID = make(map[peerid.ConnID]*managedPeer) @@ -470,7 +469,7 @@ func (m *Manager) close() { } // shouldDeferIdleForHA checks if peer should stay connected due to HA group requirements -func (m *Manager) shouldDeferIdleForHA(peerID string) bool { +func (m *Manager) shouldDeferIdleForHA(inactivePeers map[string]struct{}, peerID string) bool { m.routesMu.RLock() defer m.routesMu.RUnlock() @@ -480,38 +479,45 @@ func (m *Manager) shouldDeferIdleForHA(peerID string) bool { } for _, haGroup := range haGroups { - groupPeers := m.haGroupToPeers[haGroup] - - for _, groupPeerID := range groupPeers { - if groupPeerID == peerID { - continue - } - - cfg, ok := m.managedPeers[groupPeerID] - if !ok { - continue - } - - groupMp, ok := m.managedPeersByConnID[cfg.PeerConnID] - if !ok { - continue - } - - if groupMp.expectedWatcher != watcherInactivity { - continue - } - - // Other member is still connected, defer idle - if peer, ok := m.peerStore.PeerConn(groupPeerID); ok && peer.IsConnected() { - return true - } + if active := m.checkHaGroupActivity(haGroup, peerID, inactivePeers); active { + return true } } return false } -func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) { +func (m *Manager) checkHaGroupActivity(haGroup route.HAUniqueID, peerID string, inactivePeers map[string]struct{}) bool { + groupPeers := m.haGroupToPeers[haGroup] + for _, groupPeerID := range groupPeers { + + if groupPeerID == peerID { + continue + } + + cfg, ok := m.managedPeers[groupPeerID] + if !ok { + continue + } + + groupMp, ok := m.managedPeersByConnID[cfg.PeerConnID] + if !ok { + continue + } + + if groupMp.expectedWatcher != watcherInactivity { + continue + } + + // If any peer in the group is active, do defer idle + if _, isInactive := inactivePeers[groupPeerID]; !isInactive { + return true + } + } + return false +} + +func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) { m.managedPeersMu.Lock() defer m.managedPeersMu.Unlock() @@ -528,100 +534,56 @@ func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) mp.peerCfg.Log.Infof("detected peer activity") - if !m.activateSinglePeer(ctx, mp.peerCfg, mp) { + if !m.activateSinglePeer(mp.peerCfg, mp) { return } - m.activateHAGroupPeers(ctx, mp.peerCfg.PublicKey) + m.activateHAGroupPeers(mp.peerCfg) m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey) } -func (m *Manager) onPeerInactivityTimedOut(ctx context.Context, peerConnID peerid.ConnID) { +func (m *Manager) onPeerInactivityTimedOut(peerIDs map[string]struct{}) { m.managedPeersMu.Lock() defer m.managedPeersMu.Unlock() - mp, ok := m.managedPeersByConnID[peerConnID] - if !ok { - log.Errorf("peer not found by id: %v", peerConnID) - return - } - - if mp.expectedWatcher != watcherInactivity { - mp.peerCfg.Log.Warnf("ignore inactivity event") - return - } - - if m.shouldDeferIdleForHA(mp.peerCfg.PublicKey) { - iw, ok := m.inactivityMonitors[peerConnID] - if ok { - mp.peerCfg.Log.Debugf("resetting inactivity timer due to HA group requirements") - iw.ResetMonitor(ctx, m.onInactive) - } else { - mp.peerCfg.Log.Errorf("inactivity monitor not found for HA defer reset") + for peerID := range peerIDs { + peerCfg, ok := m.managedPeers[peerID] + if !ok { + log.Errorf("peer not found by peerId: %v", peerID) + continue } - return - } - mp.peerCfg.Log.Infof("connection timed out") + mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID] + if !ok { + log.Errorf("peer not found by conn id: %v", peerCfg.PeerConnID) + continue + } - // this is blocking operation, potentially can be optimized - m.peerStore.PeerConnClose(mp.peerCfg.PublicKey) + if mp.expectedWatcher != watcherInactivity { + mp.peerCfg.Log.Warnf("ignore inactivity event") + continue + } - mp.peerCfg.Log.Infof("start activity monitor") + if m.shouldDeferIdleForHA(peerIDs, mp.peerCfg.PublicKey) { + mp.peerCfg.Log.Infof("defer inactivity due to active HA group peers") + continue + } - mp.expectedWatcher = watcherActivity + mp.peerCfg.Log.Infof("connection timed out") - // just in case free up - m.inactivityMonitors[peerConnID].PauseTimer() + // this is blocking operation, potentially can be optimized + m.peerStore.PeerConnIdle(mp.peerCfg.PublicKey) - if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil { - mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err) - return + mp.peerCfg.Log.Infof("start activity monitor") + + mp.expectedWatcher = watcherActivity + + m.inactivityManager.RemovePeer(mp.peerCfg.PublicKey) + + if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil { + mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err) + continue + } } } - -func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) { - m.managedPeersMu.Lock() - defer m.managedPeersMu.Unlock() - - mp, ok := m.managedPeersByConnID[peerConnID] - if !ok { - return - } - - if mp.expectedWatcher != watcherInactivity { - return - } - - iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID] - if !ok { - mp.peerCfg.Log.Warnf("inactivity monitor not found for peer") - return - } - - mp.peerCfg.Log.Infof("peer connected, pausing inactivity monitor while connection is not disconnected") - iw.PauseTimer() -} - -func (m *Manager) onPeerDisconnected(peerConnID peerid.ConnID) { - m.managedPeersMu.Lock() - defer m.managedPeersMu.Unlock() - - mp, ok := m.managedPeersByConnID[peerConnID] - if !ok { - return - } - - if mp.expectedWatcher != watcherInactivity { - return - } - - iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID] - if !ok { - return - } - - mp.peerCfg.Log.Infof("reset inactivity monitor timer") - iw.ResetTimer() -} diff --git a/client/internal/lazyconn/wgiface.go b/client/internal/lazyconn/wgiface.go index 090a9319c..d55ff9670 100644 --- a/client/internal/lazyconn/wgiface.go +++ b/client/internal/lazyconn/wgiface.go @@ -11,4 +11,6 @@ import ( type WGIface interface { RemovePeer(peerKey string) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + IsUserspaceBind() bool + LastActivities() map[string]time.Time } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index c3f44cc7f..1f0ec164e 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -117,10 +117,9 @@ type Conn struct { wgProxyRelay wgproxy.Proxy handshaker *Handshaker - guard *guard.Guard - semaphore *semaphoregroup.SemaphoreGroup - peerConnDispatcher *dispatcher.ConnectionDispatcher - wg sync.WaitGroup + guard *guard.Guard + semaphore *semaphoregroup.SemaphoreGroup + wg sync.WaitGroup // debug purpose dumpState *stateDump @@ -136,18 +135,17 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) { connLog := log.WithField("peer", config.Key) var conn = &Conn{ - Log: connLog, - config: config, - statusRecorder: services.StatusRecorder, - signaler: services.Signaler, - iFaceDiscover: services.IFaceDiscover, - relayManager: services.RelayManager, - srWatcher: services.SrWatcher, - semaphore: services.Semaphore, - peerConnDispatcher: services.PeerConnDispatcher, - statusRelay: worker.NewAtomicStatus(), - statusICE: worker.NewAtomicStatus(), - dumpState: newStateDump(config.Key, connLog, services.StatusRecorder), + Log: connLog, + config: config, + statusRecorder: services.StatusRecorder, + signaler: services.Signaler, + iFaceDiscover: services.IFaceDiscover, + relayManager: services.RelayManager, + srWatcher: services.SrWatcher, + semaphore: services.Semaphore, + statusRelay: worker.NewAtomicStatus(), + statusICE: worker.NewAtomicStatus(), + dumpState: newStateDump(config.Key, connLog, services.StatusRecorder), } return conn, nil @@ -226,7 +224,7 @@ func (conn *Conn) Open(engineCtx context.Context) error { } // Close closes this peer Conn issuing a close event to the Conn closeCh -func (conn *Conn) Close() { +func (conn *Conn) Close(signalToRemote bool) { conn.mu.Lock() defer conn.wgWatcherWg.Wait() defer conn.mu.Unlock() @@ -236,6 +234,12 @@ func (conn *Conn) Close() { return } + if signalToRemote { + if err := conn.signaler.SignalIdle(conn.config.Key); err != nil { + conn.Log.Errorf("failed to signal idle state to peer: %v", err) + } + } + conn.Log.Infof("close peer connection") conn.ctxCancel() @@ -404,15 +408,10 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn } wgConfigWorkaround() - oldState := conn.currentConnPriority conn.currentConnPriority = priority conn.statusICE.SetConnected() conn.updateIceState(iceConnInfo) conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) - - if oldState == conntype.None { - conn.peerConnDispatcher.NotifyConnected(conn.ConnID()) - } } func (conn *Conn) onICEStateDisconnected() { @@ -450,7 +449,6 @@ func (conn *Conn) onICEStateDisconnected() { } else { conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String()) conn.currentConnPriority = conntype.None - conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID()) } changed := conn.statusICE.Get() != worker.StatusDisconnected @@ -530,7 +528,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.Log.Infof("start to communicate with peer via relay") conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) - conn.peerConnDispatcher.NotifyConnected(conn.ConnID()) } func (conn *Conn) onRelayDisconnected() { @@ -545,11 +542,7 @@ func (conn *Conn) onRelayDisconnected() { if conn.currentConnPriority == conntype.Relay { conn.Log.Debugf("clean up WireGuard config") - if err := conn.removeWgPeer(); err != nil { - conn.Log.Errorf("failed to remove wg endpoint: %v", err) - } conn.currentConnPriority = conntype.None - conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID()) } if conn.wgProxyRelay != nil { diff --git a/client/internal/peer/signaler.go b/client/internal/peer/signaler.go index 713123e5d..9022e0299 100644 --- a/client/internal/peer/signaler.go +++ b/client/internal/peer/signaler.go @@ -68,3 +68,13 @@ func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, return nil } + +func (s *Signaler) SignalIdle(remoteKey string) error { + return s.signal.Send(&sProto.Message{ + Key: s.wgPrivateKey.PublicKey().String(), + RemoteKey: remoteKey, + Body: &sProto.Body{ + Type: sProto.Body_GO_IDLE, + }, + }) +} diff --git a/client/internal/peerstore/store.go b/client/internal/peerstore/store.go index 81ac7a5b6..099fe4528 100644 --- a/client/internal/peerstore/store.go +++ b/client/internal/peerstore/store.go @@ -95,6 +95,17 @@ func (s *Store) PeerConnOpen(ctx context.Context, pubKey string) { } +func (s *Store) PeerConnIdle(pubKey string) { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return + } + p.Close(true) +} + func (s *Store) PeerConnClose(pubKey string) { s.peerConnsMu.RLock() defer s.peerConnsMu.RUnlock() @@ -103,7 +114,7 @@ func (s *Store) PeerConnClose(pubKey string) { if !ok { return } - p.Close() + p.Close(false) } func (s *Store) PeersPubKey() []string { diff --git a/monotime/time.go b/monotime/time.go new file mode 100644 index 000000000..6032fb60b --- /dev/null +++ b/monotime/time.go @@ -0,0 +1,29 @@ +package monotime + +import ( + "time" +) + +var ( + baseWallTime time.Time + baseWallNano int64 +) + +func init() { + baseWallTime = time.Now() + baseWallNano = baseWallTime.UnixNano() +} + +// Now returns the current time as Unix nanoseconds (int64). +// It uses monotonic time measurement from the base time to ensure +// the returned value increases monotonically and is not affected +// by system clock adjustments. +// +// Performance optimization: By capturing the base wall time once at startup +// and using time.Since() for elapsed calculation, this avoids repeated +// time.Now() calls and leverages Go's internal monotonic clock for +// efficient duration measurement. +func Now() int64 { + elapsed := time.Since(baseWallTime) + return baseWallNano + int64(elapsed) +} diff --git a/monotime/time_test.go b/monotime/time_test.go new file mode 100644 index 000000000..ac837b226 --- /dev/null +++ b/monotime/time_test.go @@ -0,0 +1,20 @@ +package monotime + +import ( + "testing" + "time" +) + +func BenchmarkMonotimeNow(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = Now() + } +} + +func BenchmarkTimeNow(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = time.Now() + } +} diff --git a/signal/proto/signalexchange.pb.go b/signal/proto/signalexchange.pb.go index 30f704c6f..3d45dea69 100644 --- a/signal/proto/signalexchange.pb.go +++ b/signal/proto/signalexchange.pb.go @@ -29,6 +29,7 @@ const ( Body_ANSWER Body_Type = 1 Body_CANDIDATE Body_Type = 2 Body_MODE Body_Type = 4 + Body_GO_IDLE Body_Type = 5 ) // Enum value maps for Body_Type. @@ -38,12 +39,14 @@ var ( 1: "ANSWER", 2: "CANDIDATE", 4: "MODE", + 5: "GO_IDLE", } Body_Type_value = map[string]int32{ "OFFER": 0, "ANSWER": 1, "CANDIDATE": 2, "MODE": 4, + "GO_IDLE": 5, } ) @@ -225,7 +228,7 @@ type Body struct { FeaturesSupported []uint32 `protobuf:"varint,6,rep,packed,name=featuresSupported,proto3" json:"featuresSupported,omitempty"` // RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"` - // relayServerAddress is an IP:port of the relay server + // relayServerAddress is url of the relay server RelayServerAddress string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3" json:"relayServerAddress,omitempty"` } @@ -440,7 +443,7 @@ var file_signalexchange_proto_rawDesc = []byte{ 0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52, - 0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xa6, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d, + 0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xb3, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a, @@ -463,33 +466,34 @@ var file_signalexchange_proto_rawDesc = []byte{ 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, - 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0x36, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09, + 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0x43, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09, 0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53, 0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41, - 0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x22, 0x2e, - 0x0a, 0x04, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, - 0x88, 0x01, 0x01, 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d, - 0x0a, 0x0f, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x12, 0x28, 0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, - 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, - 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, - 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, - 0x64, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, - 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, - 0x0a, 0x0e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, - 0x12, 0x4c, 0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, - 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, - 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, - 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, - 0x0a, 0x0d, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, - 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, - 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, - 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, - 0x61, 0x67, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x12, 0x0b, + 0x0a, 0x07, 0x47, 0x4f, 0x5f, 0x49, 0x44, 0x4c, 0x45, 0x10, 0x05, 0x22, 0x2e, 0x0a, 0x04, 0x4d, + 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01, 0x01, + 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d, 0x0a, 0x0f, 0x52, + 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x28, + 0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, + 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, + 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, + 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53, + 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a, + 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, + 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, + 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43, + 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73, + 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, + 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/signal/proto/signalexchange.proto b/signal/proto/signalexchange.proto index 4431edd7c..b04d6ef28 100644 --- a/signal/proto/signalexchange.proto +++ b/signal/proto/signalexchange.proto @@ -47,6 +47,7 @@ message Body { ANSWER = 1; CANDIDATE = 2; MODE = 4; + GO_IDLE = 5; } Type type = 1; string payload = 2; @@ -74,4 +75,4 @@ message RosenpassConfig { bytes rosenpassPubKey = 1; // rosenpassServerAddr is an IP:port of the rosenpass service string rosenpassServerAddr = 2; -} \ No newline at end of file +} From 8942c40fde62421d510361a2558afafc1bcd3fd9 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Sun, 6 Jul 2025 15:13:14 +0200 Subject: [PATCH 02/12] [client] Fix nil pointer exception in lazy connection (#4109) Remove unused variable --- client/internal/lazyconn/manager/manager.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/client/internal/lazyconn/manager/manager.go b/client/internal/lazyconn/manager/manager.go index b45b39221..416e4e7e7 100644 --- a/client/internal/lazyconn/manager/manager.go +++ b/client/internal/lazyconn/manager/manager.go @@ -11,7 +11,6 @@ import ( "github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/lazyconn/activity" "github.com/netbirdio/netbird/client/internal/lazyconn/inactivity" - "github.com/netbirdio/netbird/client/internal/peer/dispatcher" peerid "github.com/netbirdio/netbird/client/internal/peer/id" "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/route" @@ -43,10 +42,8 @@ type Config struct { type Manager struct { engineCtx context.Context peerStore *peerstore.Store - connStateDispatcher *dispatcher.ConnectionDispatcher inactivityThreshold time.Duration - connStateListener *dispatcher.ConnectionListener managedPeers map[string]*lazyconn.PeerConfig managedPeersByConnID map[peerid.ConnID]*managedPeer excludes map[string]lazyconn.PeerConfig @@ -453,7 +450,6 @@ func (m *Manager) close() { m.managedPeersMu.Lock() defer m.managedPeersMu.Unlock() - m.connStateDispatcher.RemoveListener(m.connStateListener) m.activityManager.Close() m.managedPeers = make(map[string]*lazyconn.PeerConfig) From 768ba24fda97db833b0d44eca6a269ed12cb245f Mon Sep 17 00:00:00 2001 From: "M. Essam" Date: Tue, 8 Jul 2025 19:08:13 +0300 Subject: [PATCH 03/12] [management,rest] Add name/ip filters to peer management rest client (#4112) --- management/client/rest/accounts.go | 6 ++--- management/client/rest/client.go | 10 +++++++- management/client/rest/dns.go | 14 +++++------ management/client/rest/events.go | 2 +- management/client/rest/geo.go | 4 ++-- management/client/rest/groups.go | 10 ++++---- management/client/rest/networks.go | 30 +++++++++++------------ management/client/rest/peers.go | 32 ++++++++++++++++++++----- management/client/rest/peers_test.go | 4 ++++ management/client/rest/policies.go | 10 ++++---- management/client/rest/posturechecks.go | 10 ++++---- management/client/rest/routes.go | 10 ++++---- management/client/rest/setupkeys.go | 10 ++++---- management/client/rest/tokens.go | 8 +++---- management/client/rest/users.go | 12 +++++----- 15 files changed, 102 insertions(+), 70 deletions(-) diff --git a/management/client/rest/accounts.go b/management/client/rest/accounts.go index 2530e4f72..fbe3010e1 100644 --- a/management/client/rest/accounts.go +++ b/management/client/rest/accounts.go @@ -16,7 +16,7 @@ type AccountsAPI struct { // List list all accounts, only returns one account always // See more: https://docs.netbird.io/api/resources/accounts#list-all-accounts func (a *AccountsAPI) List(ctx context.Context) ([]api.Account, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/accounts", nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/accounts", nil, nil) if err != nil { return nil, err } @@ -34,7 +34,7 @@ func (a *AccountsAPI) Update(ctx context.Context, accountID string, request api. if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "PUT", "/api/accounts/"+accountID, bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "PUT", "/api/accounts/"+accountID, bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -48,7 +48,7 @@ func (a *AccountsAPI) Update(ctx context.Context, accountID string, request api. // Delete delete account // See more: https://docs.netbird.io/api/resources/accounts#delete-an-account func (a *AccountsAPI) Delete(ctx context.Context, accountID string) error { - resp, err := a.c.NewRequest(ctx, "DELETE", "/api/accounts/"+accountID, nil) + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/accounts/"+accountID, nil, nil) if err != nil { return err } diff --git a/management/client/rest/client.go b/management/client/rest/client.go index 8bf11caae..b5945985f 100644 --- a/management/client/rest/client.go +++ b/management/client/rest/client.go @@ -117,7 +117,7 @@ func (c *Client) initialize() { } // NewRequest creates and executes new management API request -func (c *Client) NewRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { +func (c *Client) NewRequest(ctx context.Context, method, path string, body io.Reader, query map[string]string) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, method, c.managementURL+path, body) if err != nil { return nil, err @@ -129,6 +129,14 @@ func (c *Client) NewRequest(ctx context.Context, method, path string, body io.Re req.Header.Add("Content-Type", "application/json") } + if len(query) != 0 { + q := req.URL.Query() + for k, v := range query { + q.Add(k, v) + } + req.URL.RawQuery = q.Encode() + } + resp, err := c.httpClient.Do(req) if err != nil { return nil, err diff --git a/management/client/rest/dns.go b/management/client/rest/dns.go index 1e35c0226..3fb74d5f5 100644 --- a/management/client/rest/dns.go +++ b/management/client/rest/dns.go @@ -16,7 +16,7 @@ type DNSAPI struct { // ListNameserverGroups list all nameserver groups // See more: https://docs.netbird.io/api/resources/dns#list-all-nameserver-groups func (a *DNSAPI) ListNameserverGroups(ctx context.Context) ([]api.NameserverGroup, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers", nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers", nil, nil) if err != nil { return nil, err } @@ -30,7 +30,7 @@ func (a *DNSAPI) ListNameserverGroups(ctx context.Context) ([]api.NameserverGrou // GetNameserverGroup get nameserver group info // See more: https://docs.netbird.io/api/resources/dns#retrieve-a-nameserver-group func (a *DNSAPI) GetNameserverGroup(ctx context.Context, nameserverGroupID string) (*api.NameserverGroup, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers/"+nameserverGroupID, nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers/"+nameserverGroupID, nil, nil) if err != nil { return nil, err } @@ -48,7 +48,7 @@ func (a *DNSAPI) CreateNameserverGroup(ctx context.Context, request api.PostApiD if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "POST", "/api/dns/nameservers", bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "POST", "/api/dns/nameservers", bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -66,7 +66,7 @@ func (a *DNSAPI) UpdateNameserverGroup(ctx context.Context, nameserverGroupID st if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/nameservers/"+nameserverGroupID, bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/nameservers/"+nameserverGroupID, bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -80,7 +80,7 @@ func (a *DNSAPI) UpdateNameserverGroup(ctx context.Context, nameserverGroupID st // DeleteNameserverGroup delete nameserver group // See more: https://docs.netbird.io/api/resources/dns#delete-a-nameserver-group func (a *DNSAPI) DeleteNameserverGroup(ctx context.Context, nameserverGroupID string) error { - resp, err := a.c.NewRequest(ctx, "DELETE", "/api/dns/nameservers/"+nameserverGroupID, nil) + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/dns/nameservers/"+nameserverGroupID, nil, nil) if err != nil { return err } @@ -94,7 +94,7 @@ func (a *DNSAPI) DeleteNameserverGroup(ctx context.Context, nameserverGroupID st // GetSettings get DNS settings // See more: https://docs.netbird.io/api/resources/dns#retrieve-dns-settings func (a *DNSAPI) GetSettings(ctx context.Context) (*api.DNSSettings, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/settings", nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/settings", nil, nil) if err != nil { return nil, err } @@ -112,7 +112,7 @@ func (a *DNSAPI) UpdateSettings(ctx context.Context, request api.PutApiDnsSettin if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/settings", bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/settings", bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } diff --git a/management/client/rest/events.go b/management/client/rest/events.go index cae813e86..775d3ba2e 100644 --- a/management/client/rest/events.go +++ b/management/client/rest/events.go @@ -14,7 +14,7 @@ type EventsAPI struct { // List list all events // See more: https://docs.netbird.io/api/resources/events#list-all-events func (a *EventsAPI) List(ctx context.Context) ([]api.Event, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/events", nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/events", nil, nil) if err != nil { return nil, err } diff --git a/management/client/rest/geo.go b/management/client/rest/geo.go index d06d65d80..dfecee09e 100644 --- a/management/client/rest/geo.go +++ b/management/client/rest/geo.go @@ -14,7 +14,7 @@ type GeoLocationAPI struct { // ListCountries list all country codes // See more: https://docs.netbird.io/api/resources/geo-locations#list-all-country-codes func (a *GeoLocationAPI) ListCountries(ctx context.Context) ([]api.Country, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries", nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries", nil, nil) if err != nil { return nil, err } @@ -28,7 +28,7 @@ func (a *GeoLocationAPI) ListCountries(ctx context.Context) ([]api.Country, erro // ListCountryCities Get a list of all English city names for a given country code // See more: https://docs.netbird.io/api/resources/geo-locations#list-all-city-names-by-country func (a *GeoLocationAPI) ListCountryCities(ctx context.Context, countryCode string) ([]api.City, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries/"+countryCode+"/cities", nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries/"+countryCode+"/cities", nil, nil) if err != nil { return nil, err } diff --git a/management/client/rest/groups.go b/management/client/rest/groups.go index 7612b7188..7d4bac62c 100644 --- a/management/client/rest/groups.go +++ b/management/client/rest/groups.go @@ -16,7 +16,7 @@ type GroupsAPI struct { // List list all groups // See more: https://docs.netbird.io/api/resources/groups#list-all-groups func (a *GroupsAPI) List(ctx context.Context) ([]api.Group, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/groups", nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/groups", nil, nil) if err != nil { return nil, err } @@ -30,7 +30,7 @@ func (a *GroupsAPI) List(ctx context.Context) ([]api.Group, error) { // Get get group info // See more: https://docs.netbird.io/api/resources/groups#retrieve-a-group func (a *GroupsAPI) Get(ctx context.Context, groupID string) (*api.Group, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/groups/"+groupID, nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/groups/"+groupID, nil, nil) if err != nil { return nil, err } @@ -48,7 +48,7 @@ func (a *GroupsAPI) Create(ctx context.Context, request api.PostApiGroupsJSONReq if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "POST", "/api/groups", bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "POST", "/api/groups", bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -66,7 +66,7 @@ func (a *GroupsAPI) Update(ctx context.Context, groupID string, request api.PutA if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "PUT", "/api/groups/"+groupID, bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "PUT", "/api/groups/"+groupID, bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -80,7 +80,7 @@ func (a *GroupsAPI) Update(ctx context.Context, groupID string, request api.PutA // Delete delete group // See more: https://docs.netbird.io/api/resources/groups#delete-a-group func (a *GroupsAPI) Delete(ctx context.Context, groupID string) error { - resp, err := a.c.NewRequest(ctx, "DELETE", "/api/groups/"+groupID, nil) + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/groups/"+groupID, nil, nil) if err != nil { return err } diff --git a/management/client/rest/networks.go b/management/client/rest/networks.go index b744e3fe7..9441780f3 100644 --- a/management/client/rest/networks.go +++ b/management/client/rest/networks.go @@ -16,7 +16,7 @@ type NetworksAPI struct { // List list all networks // See more: https://docs.netbird.io/api/resources/networks#list-all-networks func (a *NetworksAPI) List(ctx context.Context) ([]api.Network, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/networks", nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/networks", nil, nil) if err != nil { return nil, err } @@ -30,7 +30,7 @@ func (a *NetworksAPI) List(ctx context.Context) ([]api.Network, error) { // Get get network info // See more: https://docs.netbird.io/api/resources/networks#retrieve-a-network func (a *NetworksAPI) Get(ctx context.Context, networkID string) (*api.Network, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+networkID, nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+networkID, nil, nil) if err != nil { return nil, err } @@ -48,7 +48,7 @@ func (a *NetworksAPI) Create(ctx context.Context, request api.PostApiNetworksJSO if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "POST", "/api/networks", bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "POST", "/api/networks", bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -66,7 +66,7 @@ func (a *NetworksAPI) Update(ctx context.Context, networkID string, request api. if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+networkID, bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+networkID, bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -80,7 +80,7 @@ func (a *NetworksAPI) Update(ctx context.Context, networkID string, request api. // Delete delete network // See more: https://docs.netbird.io/api/resources/networks#delete-a-network func (a *NetworksAPI) Delete(ctx context.Context, networkID string) error { - resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+networkID, nil) + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+networkID, nil, nil) if err != nil { return err } @@ -108,7 +108,7 @@ func (a *NetworksAPI) Resources(networkID string) *NetworkResourcesAPI { // List list all resources in networks // See more: https://docs.netbird.io/api/resources/networks#list-all-network-resources func (a *NetworkResourcesAPI) List(ctx context.Context) ([]api.NetworkResource, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources", nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources", nil, nil) if err != nil { return nil, err } @@ -122,7 +122,7 @@ func (a *NetworkResourcesAPI) List(ctx context.Context) ([]api.NetworkResource, // Get get network resource info // See more: https://docs.netbird.io/api/resources/networks#retrieve-a-network-resource func (a *NetworkResourcesAPI) Get(ctx context.Context, networkResourceID string) (*api.NetworkResource, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil, nil) if err != nil { return nil, err } @@ -140,7 +140,7 @@ func (a *NetworkResourcesAPI) Create(ctx context.Context, request api.PostApiNet if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "POST", "/api/networks/"+a.networkID+"/resources", bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "POST", "/api/networks/"+a.networkID+"/resources", bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -158,7 +158,7 @@ func (a *NetworkResourcesAPI) Update(ctx context.Context, networkResourceID stri if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -172,7 +172,7 @@ func (a *NetworkResourcesAPI) Update(ctx context.Context, networkResourceID stri // Delete delete network resource // See more: https://docs.netbird.io/api/resources/networks#delete-a-network-resource func (a *NetworkResourcesAPI) Delete(ctx context.Context, networkResourceID string) error { - resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil) + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil, nil) if err != nil { return err } @@ -200,7 +200,7 @@ func (a *NetworksAPI) Routers(networkID string) *NetworkRoutersAPI { // List list all routers in networks // See more: https://docs.netbird.io/api/routers/networks#list-all-network-routers func (a *NetworkRoutersAPI) List(ctx context.Context) ([]api.NetworkRouter, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers", nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers", nil, nil) if err != nil { return nil, err } @@ -214,7 +214,7 @@ func (a *NetworkRoutersAPI) List(ctx context.Context) ([]api.NetworkRouter, erro // Get get network router info // See more: https://docs.netbird.io/api/routers/networks#retrieve-a-network-router func (a *NetworkRoutersAPI) Get(ctx context.Context, networkRouterID string) (*api.NetworkRouter, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil, nil) if err != nil { return nil, err } @@ -232,7 +232,7 @@ func (a *NetworkRoutersAPI) Create(ctx context.Context, request api.PostApiNetwo if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "POST", "/api/networks/"+a.networkID+"/routers", bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "POST", "/api/networks/"+a.networkID+"/routers", bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -250,7 +250,7 @@ func (a *NetworkRoutersAPI) Update(ctx context.Context, networkRouterID string, if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -264,7 +264,7 @@ func (a *NetworkRoutersAPI) Update(ctx context.Context, networkRouterID string, // Delete delete network router // See more: https://docs.netbird.io/api/routers/networks#delete-a-network-router func (a *NetworkRoutersAPI) Delete(ctx context.Context, networkRouterID string) error { - resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil) + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil, nil) if err != nil { return err } diff --git a/management/client/rest/peers.go b/management/client/rest/peers.go index 37679fdb9..f4364bb62 100644 --- a/management/client/rest/peers.go +++ b/management/client/rest/peers.go @@ -13,10 +13,30 @@ type PeersAPI struct { c *Client } +// PeersListOption options for Peers List API +type PeersListOption func() (string, string) + +func PeerNameFilter(name string) PeersListOption { + return func() (string, string) { + return "name", name + } +} + +func PeerIPFilter(ip string) PeersListOption { + return func() (string, string) { + return "ip", ip + } +} + // List list all peers // See more: https://docs.netbird.io/api/resources/peers#list-all-peers -func (a *PeersAPI) List(ctx context.Context) ([]api.Peer, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/peers", nil) +func (a *PeersAPI) List(ctx context.Context, opts ...PeersListOption) ([]api.Peer, error) { + query := make(map[string]string) + for _, o := range opts { + k, v := o() + query[k] = v + } + resp, err := a.c.NewRequest(ctx, "GET", "/api/peers", nil, query) if err != nil { return nil, err } @@ -30,7 +50,7 @@ func (a *PeersAPI) List(ctx context.Context) ([]api.Peer, error) { // Get retrieve a peer // See more: https://docs.netbird.io/api/resources/peers#retrieve-a-peer func (a *PeersAPI) Get(ctx context.Context, peerID string) (*api.Peer, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+peerID, nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+peerID, nil, nil) if err != nil { return nil, err } @@ -48,7 +68,7 @@ func (a *PeersAPI) Update(ctx context.Context, peerID string, request api.PutApi if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "PUT", "/api/peers/"+peerID, bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "PUT", "/api/peers/"+peerID, bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -62,7 +82,7 @@ func (a *PeersAPI) Update(ctx context.Context, peerID string, request api.PutApi // Delete delete a peer // See more: https://docs.netbird.io/api/resources/peers#delete-a-peer func (a *PeersAPI) Delete(ctx context.Context, peerID string) error { - resp, err := a.c.NewRequest(ctx, "DELETE", "/api/peers/"+peerID, nil) + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/peers/"+peerID, nil, nil) if err != nil { return err } @@ -76,7 +96,7 @@ func (a *PeersAPI) Delete(ctx context.Context, peerID string) error { // ListAccessiblePeers list all peers that the specified peer can connect to within the network // See more: https://docs.netbird.io/api/resources/peers#list-accessible-peers func (a *PeersAPI) ListAccessiblePeers(ctx context.Context, peerID string) ([]api.Peer, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+peerID+"/accessible-peers", nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+peerID+"/accessible-peers", nil, nil) if err != nil { return nil, err } diff --git a/management/client/rest/peers_test.go b/management/client/rest/peers_test.go index 4c5cd1e60..f31e44e10 100644 --- a/management/client/rest/peers_test.go +++ b/management/client/rest/peers_test.go @@ -184,6 +184,10 @@ func TestPeers_Integration(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, peers) + filteredPeers, err := c.Peers.List(context.Background(), rest.PeerIPFilter("192.168.10.0")) + require.NoError(t, err) + require.Empty(t, filteredPeers) + peer, err := c.Peers.Get(context.Background(), peers[0].Id) require.NoError(t, err) assert.Equal(t, peers[0].Id, peer.Id) diff --git a/management/client/rest/policies.go b/management/client/rest/policies.go index 2f2df4a78..a6e0e38d3 100644 --- a/management/client/rest/policies.go +++ b/management/client/rest/policies.go @@ -18,7 +18,7 @@ type PoliciesAPI struct { func (a *PoliciesAPI) List(ctx context.Context) ([]api.Policy, error) { path := "/api/policies" - resp, err := a.c.NewRequest(ctx, "GET", path, nil) + resp, err := a.c.NewRequest(ctx, "GET", path, nil, nil) if err != nil { return nil, err } @@ -32,7 +32,7 @@ func (a *PoliciesAPI) List(ctx context.Context) ([]api.Policy, error) { // Get get policy info // See more: https://docs.netbird.io/api/resources/policies#retrieve-a-policy func (a *PoliciesAPI) Get(ctx context.Context, policyID string) (*api.Policy, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/policies/"+policyID, nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/policies/"+policyID, nil, nil) if err != nil { return nil, err } @@ -50,7 +50,7 @@ func (a *PoliciesAPI) Create(ctx context.Context, request api.PostApiPoliciesJSO if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "POST", "/api/policies", bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "POST", "/api/policies", bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -70,7 +70,7 @@ func (a *PoliciesAPI) Update(ctx context.Context, policyID string, request api.P if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "PUT", path, bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "PUT", path, bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -84,7 +84,7 @@ func (a *PoliciesAPI) Update(ctx context.Context, policyID string, request api.P // Delete delete policy // See more: https://docs.netbird.io/api/resources/policies#delete-a-policy func (a *PoliciesAPI) Delete(ctx context.Context, policyID string) error { - resp, err := a.c.NewRequest(ctx, "DELETE", "/api/policies/"+policyID, nil) + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/policies/"+policyID, nil, nil) if err != nil { return err } diff --git a/management/client/rest/posturechecks.go b/management/client/rest/posturechecks.go index 622eeeb64..2ab8f4549 100644 --- a/management/client/rest/posturechecks.go +++ b/management/client/rest/posturechecks.go @@ -16,7 +16,7 @@ type PostureChecksAPI struct { // List list all posture checks // See more: https://docs.netbird.io/api/resources/posture-checks#list-all-posture-checks func (a *PostureChecksAPI) List(ctx context.Context) ([]api.PostureCheck, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/posture-checks", nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/posture-checks", nil, nil) if err != nil { return nil, err } @@ -30,7 +30,7 @@ func (a *PostureChecksAPI) List(ctx context.Context) ([]api.PostureCheck, error) // Get get posture check info // See more: https://docs.netbird.io/api/resources/posture-checks#retrieve-a-posture-check func (a *PostureChecksAPI) Get(ctx context.Context, postureCheckID string) (*api.PostureCheck, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/posture-checks/"+postureCheckID, nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/posture-checks/"+postureCheckID, nil, nil) if err != nil { return nil, err } @@ -48,7 +48,7 @@ func (a *PostureChecksAPI) Create(ctx context.Context, request api.PostApiPostur if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "POST", "/api/posture-checks", bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "POST", "/api/posture-checks", bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -66,7 +66,7 @@ func (a *PostureChecksAPI) Update(ctx context.Context, postureCheckID string, re if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "PUT", "/api/posture-checks/"+postureCheckID, bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "PUT", "/api/posture-checks/"+postureCheckID, bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -80,7 +80,7 @@ func (a *PostureChecksAPI) Update(ctx context.Context, postureCheckID string, re // Delete delete posture check // See more: https://docs.netbird.io/api/resources/posture-checks#delete-a-posture-check func (a *PostureChecksAPI) Delete(ctx context.Context, postureCheckID string) error { - resp, err := a.c.NewRequest(ctx, "DELETE", "/api/posture-checks/"+postureCheckID, nil) + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/posture-checks/"+postureCheckID, nil, nil) if err != nil { return err } diff --git a/management/client/rest/routes.go b/management/client/rest/routes.go index 671c3bfc9..183c363cf 100644 --- a/management/client/rest/routes.go +++ b/management/client/rest/routes.go @@ -16,7 +16,7 @@ type RoutesAPI struct { // List list all routes // See more: https://docs.netbird.io/api/resources/routes#list-all-routes func (a *RoutesAPI) List(ctx context.Context) ([]api.Route, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/routes", nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/routes", nil, nil) if err != nil { return nil, err } @@ -30,7 +30,7 @@ func (a *RoutesAPI) List(ctx context.Context) ([]api.Route, error) { // Get get route info // See more: https://docs.netbird.io/api/resources/routes#retrieve-a-route func (a *RoutesAPI) Get(ctx context.Context, routeID string) (*api.Route, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/routes/"+routeID, nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/routes/"+routeID, nil, nil) if err != nil { return nil, err } @@ -48,7 +48,7 @@ func (a *RoutesAPI) Create(ctx context.Context, request api.PostApiRoutesJSONReq if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "POST", "/api/routes", bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "POST", "/api/routes", bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -66,7 +66,7 @@ func (a *RoutesAPI) Update(ctx context.Context, routeID string, request api.PutA if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "PUT", "/api/routes/"+routeID, bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "PUT", "/api/routes/"+routeID, bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -80,7 +80,7 @@ func (a *RoutesAPI) Update(ctx context.Context, routeID string, request api.PutA // Delete delete route // See more: https://docs.netbird.io/api/resources/routes#delete-a-route func (a *RoutesAPI) Delete(ctx context.Context, routeID string) error { - resp, err := a.c.NewRequest(ctx, "DELETE", "/api/routes/"+routeID, nil) + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/routes/"+routeID, nil, nil) if err != nil { return err } diff --git a/management/client/rest/setupkeys.go b/management/client/rest/setupkeys.go index 5625b6acc..6f26cd0b7 100644 --- a/management/client/rest/setupkeys.go +++ b/management/client/rest/setupkeys.go @@ -16,7 +16,7 @@ type SetupKeysAPI struct { // List list all setup keys // See more: https://docs.netbird.io/api/resources/setup-keys#list-all-setup-keys func (a *SetupKeysAPI) List(ctx context.Context) ([]api.SetupKey, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/setup-keys", nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/setup-keys", nil, nil) if err != nil { return nil, err } @@ -30,7 +30,7 @@ func (a *SetupKeysAPI) List(ctx context.Context) ([]api.SetupKey, error) { // Get get setup key info // See more: https://docs.netbird.io/api/resources/setup-keys#retrieve-a-setup-key func (a *SetupKeysAPI) Get(ctx context.Context, setupKeyID string) (*api.SetupKey, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/setup-keys/"+setupKeyID, nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/setup-keys/"+setupKeyID, nil, nil) if err != nil { return nil, err } @@ -50,7 +50,7 @@ func (a *SetupKeysAPI) Create(ctx context.Context, request api.PostApiSetupKeysJ if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "POST", path, bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "POST", path, bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -68,7 +68,7 @@ func (a *SetupKeysAPI) Update(ctx context.Context, setupKeyID string, request ap if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "PUT", "/api/setup-keys/"+setupKeyID, bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "PUT", "/api/setup-keys/"+setupKeyID, bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -82,7 +82,7 @@ func (a *SetupKeysAPI) Update(ctx context.Context, setupKeyID string, request ap // Delete delete setup key // See more: https://docs.netbird.io/api/resources/setup-keys#delete-a-setup-key func (a *SetupKeysAPI) Delete(ctx context.Context, setupKeyID string) error { - resp, err := a.c.NewRequest(ctx, "DELETE", "/api/setup-keys/"+setupKeyID, nil) + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/setup-keys/"+setupKeyID, nil, nil) if err != nil { return err } diff --git a/management/client/rest/tokens.go b/management/client/rest/tokens.go index 278a0d159..7a63d0c9d 100644 --- a/management/client/rest/tokens.go +++ b/management/client/rest/tokens.go @@ -16,7 +16,7 @@ type TokensAPI struct { // List list user tokens // See more: https://docs.netbird.io/api/resources/tokens#list-all-tokens func (a *TokensAPI) List(ctx context.Context, userID string) ([]api.PersonalAccessToken, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/users/"+userID+"/tokens", nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/users/"+userID+"/tokens", nil, nil) if err != nil { return nil, err } @@ -30,7 +30,7 @@ func (a *TokensAPI) List(ctx context.Context, userID string) ([]api.PersonalAcce // Get get user token info // See more: https://docs.netbird.io/api/resources/tokens#retrieve-a-token func (a *TokensAPI) Get(ctx context.Context, userID, tokenID string) (*api.PersonalAccessToken, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/users/"+userID+"/tokens/"+tokenID, nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/users/"+userID+"/tokens/"+tokenID, nil, nil) if err != nil { return nil, err } @@ -48,7 +48,7 @@ func (a *TokensAPI) Create(ctx context.Context, userID string, request api.PostA if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "POST", "/api/users/"+userID+"/tokens", bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "POST", "/api/users/"+userID+"/tokens", bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -62,7 +62,7 @@ func (a *TokensAPI) Create(ctx context.Context, userID string, request api.PostA // Delete delete user token // See more: https://docs.netbird.io/api/resources/tokens#delete-a-token func (a *TokensAPI) Delete(ctx context.Context, userID, tokenID string) error { - resp, err := a.c.NewRequest(ctx, "DELETE", "/api/users/"+userID+"/tokens/"+tokenID, nil) + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/users/"+userID+"/tokens/"+tokenID, nil, nil) if err != nil { return err } diff --git a/management/client/rest/users.go b/management/client/rest/users.go index 107b0581e..f0ef54be2 100644 --- a/management/client/rest/users.go +++ b/management/client/rest/users.go @@ -16,7 +16,7 @@ type UsersAPI struct { // List list all users, only returns one user always // See more: https://docs.netbird.io/api/resources/users#list-all-users func (a *UsersAPI) List(ctx context.Context) ([]api.User, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/users", nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/users", nil, nil) if err != nil { return nil, err } @@ -34,7 +34,7 @@ func (a *UsersAPI) Create(ctx context.Context, request api.PostApiUsersJSONReque if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "POST", "/api/users", bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "POST", "/api/users", bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -52,7 +52,7 @@ func (a *UsersAPI) Update(ctx context.Context, userID string, request api.PutApi if err != nil { return nil, err } - resp, err := a.c.NewRequest(ctx, "PUT", "/api/users/"+userID, bytes.NewReader(requestBytes)) + resp, err := a.c.NewRequest(ctx, "PUT", "/api/users/"+userID, bytes.NewReader(requestBytes), nil) if err != nil { return nil, err } @@ -66,7 +66,7 @@ func (a *UsersAPI) Update(ctx context.Context, userID string, request api.PutApi // Delete delete user // See more: https://docs.netbird.io/api/resources/users#delete-a-user func (a *UsersAPI) Delete(ctx context.Context, userID string) error { - resp, err := a.c.NewRequest(ctx, "DELETE", "/api/users/"+userID, nil) + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/users/"+userID, nil, nil) if err != nil { return err } @@ -80,7 +80,7 @@ func (a *UsersAPI) Delete(ctx context.Context, userID string) error { // ResendInvitation resend user invitation // See more: https://docs.netbird.io/api/resources/users#resend-user-invitation func (a *UsersAPI) ResendInvitation(ctx context.Context, userID string) error { - resp, err := a.c.NewRequest(ctx, "POST", "/api/users/"+userID+"/invite", nil) + resp, err := a.c.NewRequest(ctx, "POST", "/api/users/"+userID+"/invite", nil, nil) if err != nil { return err } @@ -94,7 +94,7 @@ func (a *UsersAPI) ResendInvitation(ctx context.Context, userID string) error { // Current gets the current user info // See more: https://docs.netbird.io/api/resources/users#retrieve-current-user func (a *UsersAPI) Current(ctx context.Context) (*api.User, error) { - resp, err := a.c.NewRequest(ctx, "GET", "/api/users/current", nil) + resp, err := a.c.NewRequest(ctx, "GET", "/api/users/current", nil, nil) if err != nil { return nil, err } From 969f1ed59a9645eaee4fd76b3529b4b829accc35 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 9 Jul 2025 10:14:10 +0300 Subject: [PATCH 04/12] [management] Remove deleted user peers from groups on user deletion (#4121) Refactors peer deletion to centralize group cleanup logic, ensuring deleted peers are consistently removed from all groups in one place. - Removed redundant group removal code from DefaultAccountManager.DeletePeer - Added group removal logic inside deletePeers to handle both single and multiple peer deletions --- management/server/peer.go | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index 1dd390dd9..44156e534 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -364,19 +364,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthUpdate, accountID, peerID) - if err != nil { - return fmt.Errorf("failed to get peer groups: %w", err) - } - - for _, group := range groups { - group.RemovePeer(peerID) - err = transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) - if err != nil { - return fmt.Errorf("failed to save group: %w", err) - } - } - eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}) return err }) @@ -1517,13 +1504,26 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto } dnsDomain := am.GetDNSDomain(settings) + network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + for _, peer := range peers { - if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID); err != nil { - return nil, err + groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthUpdate, accountID, peer.ID) + if err != nil { + return nil, fmt.Errorf("failed to get peer groups: %w", err) } - network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID) - if err != nil { + for _, group := range groups { + group.RemovePeer(peer.ID) + err = transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) + if err != nil { + return nil, fmt.Errorf("failed to save group: %w", err) + } + } + + if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID); err != nil { return nil, err } From f17dd3619cf6c121fa4ea76d3f1e8192b15497d8 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Wed, 9 Jul 2025 15:49:09 +0200 Subject: [PATCH 05/12] [misc] update image in README.md (#4122) --- README.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index c3b365694..d5469c28b 100644 --- a/README.md +++ b/README.md @@ -50,10 +50,9 @@ **Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure. -### Open-Source Network Security in a Single Platform +### Open Source Network Security in a Single Platform - -![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab) +centralized-network-management 1 ### NetBird on Lawrence Systems (Video) [![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw) From 408f423adcf7a51f1ae56cdb6016c0b0158643fb Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 9 Jul 2025 22:16:08 +0200 Subject: [PATCH 06/12] [client] Disable pidfd check on Android 11 and below (#4127) Disable pidfd check on Android 11 and below On Android 11 (SDK <= 30) and earlier, pidfd-related system calls are blocked by seccomp policies, causing SIGSYS crashes. This change overrides `checkPidfdOnce` to return an error on affected versions, preventing the use of unsupported pidfd features. --- .github/workflows/mobile-build-validation.yml | 2 +- client/android/client.go | 4 ++- client/android/exec.go | 26 +++++++++++++++++++ 3 files changed, 30 insertions(+), 2 deletions(-) create mode 100644 client/android/exec.go diff --git a/.github/workflows/mobile-build-validation.yml b/.github/workflows/mobile-build-validation.yml index 569956a54..c7d43695b 100644 --- a/.github/workflows/mobile-build-validation.yml +++ b/.github/workflows/mobile-build-validation.yml @@ -43,7 +43,7 @@ jobs: - name: gomobile init run: gomobile init - name: build android netbird lib - run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android + run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-checklinkname=0 -X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android env: CGO_ENABLED: 0 ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620 diff --git a/client/android/client.go b/client/android/client.go index a17439696..0d0c76549 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -64,7 +64,9 @@ type Client struct { } // NewClient instantiate a new Client -func NewClient(cfgFile, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { +func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { + execWorkaround(androidSDKVersion) + net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket) return &Client{ cfgFile: cfgFile, diff --git a/client/android/exec.go b/client/android/exec.go new file mode 100644 index 000000000..805d3129b --- /dev/null +++ b/client/android/exec.go @@ -0,0 +1,26 @@ +//go:build android + +package android + +import ( + "fmt" + _ "unsafe" +) + +// https://github.com/golang/go/pull/69543/commits/aad6b3b32c81795f86bc4a9e81aad94899daf520 +// In Android version 11 and earlier, pidfd-related system calls +// are not allowed by the seccomp policy, which causes crashes due +// to SIGSYS signals. + +//go:linkname checkPidfdOnce os.checkPidfdOnce +var checkPidfdOnce func() error + +func execWorkaround(androidSDKVersion int) { + if androidSDKVersion > 30 { // above Android 11 + return + } + + checkPidfdOnce = func() error { + return fmt.Errorf("unsupported Android version") + } +} From e59d75d56ab047c5374177f16976988aac486405 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 10 Jul 2025 14:24:20 +0200 Subject: [PATCH 07/12] Nil check in iface configurer (#4132) --- client/iface/iface.go | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/client/iface/iface.go b/client/iface/iface.go index 1b9055e6c..e90c3536b 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -29,6 +29,11 @@ const ( WgInterfaceDefault = configurer.WgInterfaceDefault ) +var ( + // ErrIfaceNotFound is returned when the WireGuard interface is not found + ErrIfaceNotFound = fmt.Errorf("wireguard interface not found") +) + type wgProxyFactory interface { GetProxy() wgproxy.Proxy Free() error @@ -117,6 +122,9 @@ func (w *WGIface) UpdateAddr(newAddr string) error { func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { w.mu.Lock() defer w.mu.Unlock() + if w.configurer == nil { + return ErrIfaceNotFound + } log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps) return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) @@ -126,6 +134,9 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv func (w *WGIface) RemovePeer(peerKey string) error { w.mu.Lock() defer w.mu.Unlock() + if w.configurer == nil { + return ErrIfaceNotFound + } log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName()) return w.configurer.RemovePeer(peerKey) @@ -135,6 +146,9 @@ func (w *WGIface) RemovePeer(peerKey string) error { func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error { w.mu.Lock() defer w.mu.Unlock() + if w.configurer == nil { + return ErrIfaceNotFound + } log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) return w.configurer.AddAllowedIP(peerKey, allowedIP) @@ -144,6 +158,9 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error { func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error { w.mu.Lock() defer w.mu.Unlock() + if w.configurer == nil { + return ErrIfaceNotFound + } log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) return w.configurer.RemoveAllowedIP(peerKey, allowedIP) @@ -214,6 +231,9 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device { // GetStats returns the last handshake time, rx and tx bytes func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) { + if w.configurer == nil { + return nil, ErrIfaceNotFound + } return w.configurer.GetStats() } @@ -221,11 +241,19 @@ func (w *WGIface) LastActivities() map[string]time.Time { w.mu.Lock() defer w.mu.Unlock() + if w.configurer == nil { + return nil + } + return w.configurer.LastActivities() } func (w *WGIface) FullStats() (*configurer.Stats, error) { + if w.configurer == nil { + return nil, ErrIfaceNotFound + } + return w.configurer.FullStats() } From e3b40ba694a5f3b3396b09ebac86b224848901d9 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 10 Jul 2025 15:00:58 +0200 Subject: [PATCH 08/12] Update cli description of lazy connection (#4133) --- client/cmd/root.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/cmd/root.go b/client/cmd/root.go index 16e445f4d..e00a9b073 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -184,7 +184,7 @@ func init() { upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.") upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted") upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") - upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.") + upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.") debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle") debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL)) From 8632dd15f13a6954d71f09af667089bb07571a26 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Thu, 10 Jul 2025 15:21:01 +0200 Subject: [PATCH 09/12] [management] added cleanupWindow for collecting several ephemeral peers to delete (#4130) --------- Co-authored-by: Maycon Santos Co-authored-by: Pedro Costa <550684+pnmcosta@users.noreply.github.com> --- management/client/client_test.go | 6 ++ management/server/account/manager.go | 1 + management/server/dns_test.go | 2 + management/server/ephemeral.go | 37 +++++-- management/server/ephemeral_test.go | 98 ++++++++++++++++--- management/server/management_proto_test.go | 6 +- management/server/mock_server/account_mock.go | 4 + management/server/nameserver_test.go | 6 ++ management/server/peer.go | 27 +++-- management/server/peer_test.go | 10 ++ 10 files changed, 171 insertions(+), 26 deletions(-) diff --git a/management/client/client_test.go b/management/client/client_test.go index c163d1833..1847af73e 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -87,6 +87,12 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { ). Return(&types.Settings{}, nil). AnyTimes() + settingsMockManager. + EXPECT(). + GetExtraSettings(gomock.Any(), gomock.Any()). + Return(&types.ExtraSettings{}, nil). + AnyTimes() + permissionsManagerMock := permissions.NewMockManager(ctrl) permissionsManagerMock. EXPECT(). diff --git a/management/server/account/manager.go b/management/server/account/manager.go index ed17fa5ec..f8aa2756a 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -112,6 +112,7 @@ type Manager interface { GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error UpdateAccountPeers(ctx context.Context, accountID string) + BufferUpdateAccountPeers(ctx context.Context, accountID string) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error GetStore() store.Store diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 02bb042d7..31c944a25 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -216,6 +216,8 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) + // return empty extra settings for expected calls to UpdateAccountPeers + settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes() permissionsManager := permissions.NewManager(store) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 3cb9b7536..9f4348ebb 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -15,6 +15,8 @@ import ( const ( ephemeralLifeTime = 10 * time.Minute + // cleanupWindow is the time window to wait after nearest peer deadline to start the cleanup procedure. + cleanupWindow = 1 * time.Minute ) var ( @@ -41,6 +43,9 @@ type EphemeralManager struct { tailPeer *ephemeralPeer peersLock sync.Mutex timer *time.Timer + + lifeTime time.Duration + cleanupWindow time.Duration } // NewEphemeralManager instantiate new EphemeralManager @@ -48,6 +53,9 @@ func NewEphemeralManager(store store.Store, accountManager nbAccount.Manager) *E return &EphemeralManager{ store: store, accountManager: accountManager, + + lifeTime: ephemeralLifeTime, + cleanupWindow: cleanupWindow, } } @@ -60,7 +68,7 @@ func (e *EphemeralManager) LoadInitialPeers(ctx context.Context) { e.loadEphemeralPeers(ctx) if e.headPeer != nil { - e.timer = time.AfterFunc(ephemeralLifeTime, func() { + e.timer = time.AfterFunc(e.lifeTime, func() { e.cleanup(ctx) }) } @@ -113,9 +121,13 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer. return } - e.addPeer(peer.AccountID, peer.ID, newDeadLine()) + e.addPeer(peer.AccountID, peer.ID, e.newDeadLine()) if e.timer == nil { - e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() { + delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow + if delay < 0 { + delay = 0 + } + e.timer = time.AfterFunc(delay, func() { e.cleanup(ctx) }) } @@ -128,7 +140,7 @@ func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { return } - t := newDeadLine() + t := e.newDeadLine() for _, p := range peers { e.addPeer(p.AccountID, p.ID, t) } @@ -155,7 +167,11 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { } if e.headPeer != nil { - e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() { + delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow + if delay < 0 { + delay = 0 + } + e.timer = time.AfterFunc(delay, func() { e.cleanup(ctx) }) } else { @@ -164,13 +180,20 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { e.peersLock.Unlock() + bufferAccountCall := make(map[string]struct{}) + for id, p := range deletePeers { log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id) err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator) if err != nil { log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err) + } else { + bufferAccountCall[p.accountID] = struct{}{} } } + for accountID := range bufferAccountCall { + e.accountManager.BufferUpdateAccountPeers(ctx, accountID) + } } func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) { @@ -223,6 +246,6 @@ func (e *EphemeralManager) isPeerOnList(id string) bool { return false } -func newDeadLine() time.Time { - return timeNow().Add(ephemeralLifeTime) +func (e *EphemeralManager) newDeadLine() time.Time { + return timeNow().Add(e.lifeTime) } diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index 3cf6ae7f3..f71d48c58 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -3,9 +3,12 @@ package server import ( "context" "fmt" + "sync" "testing" "time" + "github.com/stretchr/testify/assert" + nbAccount "github.com/netbirdio/netbird/management/server/account" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" @@ -27,28 +30,65 @@ func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ store.LockingStren return peers, nil } -type MocAccountManager struct { +type MockAccountManager struct { + mu sync.Mutex nbAccount.Manager - store *MockStore + store *MockStore + deletePeerCalls int + bufferUpdateCalls map[string]int + wg *sync.WaitGroup } -func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error { +func (a *MockAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error { + a.mu.Lock() + defer a.mu.Unlock() + a.deletePeerCalls++ + if a.wg != nil { + a.wg.Done() + } delete(a.store.account.Peers, peerID) - return nil //nolint:nil + return nil } -func (a MocAccountManager) GetStore() store.Store { +func (a *MockAccountManager) GetDeletePeerCalls() int { + a.mu.Lock() + defer a.mu.Unlock() + return a.deletePeerCalls +} + +func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { + a.mu.Lock() + defer a.mu.Unlock() + if a.bufferUpdateCalls == nil { + a.bufferUpdateCalls = make(map[string]int) + } + a.bufferUpdateCalls[accountID]++ +} + +func (a *MockAccountManager) GetBufferUpdateCalls(accountID string) int { + a.mu.Lock() + defer a.mu.Unlock() + if a.bufferUpdateCalls == nil { + return 0 + } + return a.bufferUpdateCalls[accountID] +} + +func (a *MockAccountManager) GetStore() store.Store { return a.store } func TestNewManager(t *testing.T) { + t.Cleanup(func() { + timeNow = time.Now + }) startTime := time.Now() timeNow = func() time.Time { return startTime } store := &MockStore{} - am := MocAccountManager{ + am := MockAccountManager{ store: store, } @@ -56,7 +96,7 @@ func TestNewManager(t *testing.T) { numberOfEphemeralPeers := 3 seedPeers(store, numberOfPeers, numberOfEphemeralPeers) - mgr := NewEphemeralManager(store, am) + mgr := NewEphemeralManager(store, &am) mgr.loadEphemeralPeers(context.Background()) startTime = startTime.Add(ephemeralLifeTime + 1) mgr.cleanup(context.Background()) @@ -67,13 +107,16 @@ func TestNewManager(t *testing.T) { } func TestNewManagerPeerConnected(t *testing.T) { + t.Cleanup(func() { + timeNow = time.Now + }) startTime := time.Now() timeNow = func() time.Time { return startTime } store := &MockStore{} - am := MocAccountManager{ + am := MockAccountManager{ store: store, } @@ -81,7 +124,7 @@ func TestNewManagerPeerConnected(t *testing.T) { numberOfEphemeralPeers := 3 seedPeers(store, numberOfPeers, numberOfEphemeralPeers) - mgr := NewEphemeralManager(store, am) + mgr := NewEphemeralManager(store, &am) mgr.loadEphemeralPeers(context.Background()) mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"]) @@ -95,13 +138,16 @@ func TestNewManagerPeerConnected(t *testing.T) { } func TestNewManagerPeerDisconnected(t *testing.T) { + t.Cleanup(func() { + timeNow = time.Now + }) startTime := time.Now() timeNow = func() time.Time { return startTime } store := &MockStore{} - am := MocAccountManager{ + am := MockAccountManager{ store: store, } @@ -109,7 +155,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) { numberOfEphemeralPeers := 3 seedPeers(store, numberOfPeers, numberOfEphemeralPeers) - mgr := NewEphemeralManager(store, am) + mgr := NewEphemeralManager(store, &am) mgr.loadEphemeralPeers(context.Background()) for _, v := range store.account.Peers { mgr.OnPeerConnected(context.Background(), v) @@ -126,6 +172,36 @@ func TestNewManagerPeerDisconnected(t *testing.T) { } } +func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) { + const ( + ephemeralPeers = 10 + testLifeTime = 1 * time.Second + testCleanupWindow = 100 * time.Millisecond + ) + mockStore := &MockStore{} + mockAM := &MockAccountManager{ + store: mockStore, + } + mockAM.wg = &sync.WaitGroup{} + mockAM.wg.Add(ephemeralPeers) + mgr := NewEphemeralManager(mockStore, mockAM) + mgr.lifeTime = testLifeTime + mgr.cleanupWindow = testCleanupWindow + + account := newAccountWithId(context.Background(), "account", "", "", false) + mockStore.account = account + for i := range ephemeralPeers { + p := &nbpeer.Peer{ID: fmt.Sprintf("peer-%d", i), AccountID: account.Id, Ephemeral: true} + mockStore.account.Peers[p.ID] = p + time.Sleep(testCleanupWindow / ephemeralPeers) + mgr.OnPeerDisconnected(context.Background(), p) + } + mockAM.wg.Wait() + assert.Len(t, mockStore.account.Peers, 0, "all ephemeral peers should be cleaned up after the lifetime") + assert.Equal(t, 1, mockAM.GetBufferUpdateCalls(account.Id), "buffer update should be called once") + assert.Equal(t, ephemeralPeers, mockAM.GetDeletePeerCalls(), "should have deleted all peers") +} + func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) { store.account = newAccountWithId(context.Background(), "my account", "", "", false) diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 337890ef9..57c00ed9f 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -440,7 +440,11 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config) GetSettings(gomock.Any(), gomock.Any(), gomock.Any()). AnyTimes(). Return(&types.Settings{}, nil) - + settingsMockManager. + EXPECT(). + GetExtraSettings(gomock.Any(), gomock.Any()). + Return(&types.ExtraSettings{}, nil). + AnyTimes() permissionsManager := permissions.NewManager(store) accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted", diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 8837f9f50..4004f1b57 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -126,6 +126,10 @@ func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID // do nothing } +func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { + // do nothing +} + func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { if am.DeleteSetupKeyFunc != nil { return am.DeleteSetupKeyFunc(ctx, accountID, userID, keyID) diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 75d1e7972..8fada742c 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -778,6 +778,12 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) + settingsMockManager. + EXPECT(). + GetExtraSettings(gomock.Any(), gomock.Any()). + Return(&types.ExtraSettings{}, nil). + AnyTimes() + permissionsManager := permissions.NewManager(store) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } diff --git a/management/server/peer.go b/management/server/peer.go index 44156e534..a60513b38 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -375,7 +375,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer storeEvent() } - if updateAccountPeers { + if updateAccountPeers && userID != activity.SystemInitiator { am.BufferUpdateAccountPeers(ctx, accountID) } @@ -1177,6 +1177,19 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account globalStart := time.Now() + hasPeersConnected := false + for _, peer := range account.Peers { + if am.peersUpdateManager.HasChannel(peer.ID) { + hasPeersConnected = true + break + } + + } + + if !hasPeersConnected { + return + } + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get validate peers: %v", err) @@ -1198,6 +1211,12 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account return } + extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err) + return + } + for _, peer := range account.Peers { if !am.peersUpdateManager.HasChannel(peer.ID) { log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) @@ -1232,12 +1251,6 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account } am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start)) - extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err) - return - } - start = time.Now() update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting) am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start)) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 31439d670..07ec5037b 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1344,6 +1344,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) + settingsMockManager. + EXPECT(). + GetExtraSettings(gomock.Any(), gomock.Any()). + Return(&types.ExtraSettings{}, nil). + AnyTimes() permissionsManager := permissions.NewManager(s) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) @@ -1556,6 +1561,11 @@ func Test_LoginPeer(t *testing.T) { ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) + settingsMockManager. + EXPECT(). + GetExtraSettings(gomock.Any(), gomock.Any()). + Return(&types.ExtraSettings{}, nil). + AnyTimes() permissionsManager := permissions.NewManager(s) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) From a7ea881900b53bdafc80ccb7ad42e96eda308ef7 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 10 Jul 2025 16:13:53 +0200 Subject: [PATCH 10/12] [client] Add rotated logs flag for debug bundle generation (#4100) --- client/cmd/debug.go | 46 +++++--- client/cmd/root.go | 10 -- client/internal/debug/debug.go | 83 +++++++++----- client/proto/daemon.pb.go | 13 ++- client/proto/daemon.proto | 1 + client/proto/daemon_grpc.pb.go | 192 ++++++++++++++------------------- client/proto/generate.sh | 2 +- client/server/debug.go | 1 + go.mod | 2 +- go.sum | 4 +- util/log.go | 2 +- 11 files changed, 188 insertions(+), 168 deletions(-) diff --git a/client/cmd/debug.go b/client/cmd/debug.go index 385bd95f5..4036bb8f6 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -17,10 +17,18 @@ import ( "github.com/netbirdio/netbird/client/server" nbstatus "github.com/netbirdio/netbird/client/status" mgmProto "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/upload-server/types" ) const errCloseConnection = "Failed to close connection: %v" +var ( + logFileCount uint32 + systemInfoFlag bool + uploadBundleFlag bool + uploadBundleURLFlag string +) + var debugCmd = &cobra.Command{ Use: "debug", Short: "Debugging commands", @@ -88,12 +96,13 @@ func debugBundle(cmd *cobra.Command, _ []string) error { client := proto.NewDaemonServiceClient(conn) request := &proto.DebugBundleRequest{ - Anonymize: anonymizeFlag, - Status: getStatusOutput(cmd, anonymizeFlag), - SystemInfo: debugSystemInfoFlag, + Anonymize: anonymizeFlag, + Status: getStatusOutput(cmd, anonymizeFlag), + SystemInfo: systemInfoFlag, + LogFileCount: logFileCount, } - if debugUploadBundle { - request.UploadURL = debugUploadBundleURL + if uploadBundleFlag { + request.UploadURL = uploadBundleURLFlag } resp, err := client.DebugBundle(cmd.Context(), request) if err != nil { @@ -105,7 +114,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error { return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason()) } - if debugUploadBundle { + if uploadBundleFlag { cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey()) } @@ -223,12 +232,13 @@ func runForDuration(cmd *cobra.Command, args []string) error { headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration) statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag)) request := &proto.DebugBundleRequest{ - Anonymize: anonymizeFlag, - Status: statusOutput, - SystemInfo: debugSystemInfoFlag, + Anonymize: anonymizeFlag, + Status: statusOutput, + SystemInfo: systemInfoFlag, + LogFileCount: logFileCount, } - if debugUploadBundle { - request.UploadURL = debugUploadBundleURL + if uploadBundleFlag { + request.UploadURL = uploadBundleURLFlag } resp, err := client.DebugBundle(cmd.Context(), request) if err != nil { @@ -255,7 +265,7 @@ func runForDuration(cmd *cobra.Command, args []string) error { return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason()) } - if debugUploadBundle { + if uploadBundleFlag { cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey()) } @@ -375,3 +385,15 @@ func generateDebugBundle(config *internal.Config, recorder *peer.Status, connect } log.Infof("Generated debug bundle from SIGUSR1 at: %s", path) } + +func init() { + debugBundleCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle") + debugBundleCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle") + debugBundleCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server") + debugBundleCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle") + + forCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle") + forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle") + forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server") + forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle") +} diff --git a/client/cmd/root.go b/client/cmd/root.go index e00a9b073..fa4bd4d42 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -22,7 +22,6 @@ import ( "google.golang.org/grpc/credentials/insecure" "github.com/netbirdio/netbird/client/internal" - "github.com/netbirdio/netbird/upload-server/types" ) const ( @@ -38,10 +37,7 @@ const ( serverSSHAllowedFlag = "allow-server-ssh" extraIFaceBlackListFlag = "extra-iface-blacklist" dnsRouteIntervalFlag = "dns-router-interval" - systemInfoFlag = "system-info" enableLazyConnectionFlag = "enable-lazy-connection" - uploadBundle = "upload-bundle" - uploadBundleURL = "upload-bundle-url" ) var ( @@ -75,10 +71,7 @@ var ( autoConnectDisabled bool extraIFaceBlackList []string anonymizeFlag bool - debugSystemInfoFlag bool dnsRouteInterval time.Duration - debugUploadBundle bool - debugUploadBundleURL string lazyConnEnabled bool rootCmd = &cobra.Command{ @@ -186,9 +179,6 @@ func init() { upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.") - debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle") - debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL)) - debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle") } // SetupCloseHandler handles SIGTERM signal and exits with success diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index dfed47f05..6455b3aaf 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -167,6 +167,7 @@ type BundleGenerator struct { anonymize bool clientStatus string includeSystemInfo bool + logFileCount uint32 archive *zip.Writer } @@ -175,6 +176,7 @@ type BundleConfig struct { Anonymize bool ClientStatus string IncludeSystemInfo bool + LogFileCount uint32 } type GeneratorDependencies struct { @@ -185,6 +187,12 @@ type GeneratorDependencies struct { } func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator { + // Default to 1 log file for backward compatibility when 0 is provided + logFileCount := cfg.LogFileCount + if logFileCount == 0 { + logFileCount = 1 + } + return &BundleGenerator{ anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()), @@ -196,6 +204,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen anonymize: cfg.Anonymize, clientStatus: cfg.ClientStatus, includeSystemInfo: cfg.IncludeSystemInfo, + logFileCount: logFileCount, } } @@ -561,32 +570,8 @@ func (g *BundleGenerator) addLogfile() error { return fmt.Errorf("add client log file to zip: %w", err) } - // add latest rotated log file - pattern := filepath.Join(logDir, "client-*.log.gz") - files, err := filepath.Glob(pattern) - if err != nil { - log.Warnf("failed to glob rotated logs: %v", err) - } else if len(files) > 0 { - // pick the file with the latest ModTime - sort.Slice(files, func(i, j int) bool { - fi, err := os.Stat(files[i]) - if err != nil { - log.Warnf("failed to stat rotated log %s: %v", files[i], err) - return false - } - fj, err := os.Stat(files[j]) - if err != nil { - log.Warnf("failed to stat rotated log %s: %v", files[j], err) - return false - } - return fi.ModTime().Before(fj.ModTime()) - }) - latest := files[len(files)-1] - name := filepath.Base(latest) - if err := g.addSingleLogFileGz(latest, name); err != nil { - log.Warnf("failed to add rotated log %s: %v", name, err) - } - } + // add rotated log files based on logFileCount + g.addRotatedLogFiles(logDir) stdErrLogPath := filepath.Join(logDir, errorLogFile) stdoutLogPath := filepath.Join(logDir, stdoutLogFile) @@ -670,6 +655,52 @@ func (g *BundleGenerator) addSingleLogFileGz(logPath, targetName string) error { return nil } +// addRotatedLogFiles adds rotated log files to the bundle based on logFileCount +func (g *BundleGenerator) addRotatedLogFiles(logDir string) { + if g.logFileCount == 0 { + return + } + + pattern := filepath.Join(logDir, "client-*.log.gz") + files, err := filepath.Glob(pattern) + if err != nil { + log.Warnf("failed to glob rotated logs: %v", err) + return + } + + if len(files) == 0 { + return + } + + // sort files by modification time (newest first) + sort.Slice(files, func(i, j int) bool { + fi, err := os.Stat(files[i]) + if err != nil { + log.Warnf("failed to stat rotated log %s: %v", files[i], err) + return false + } + fj, err := os.Stat(files[j]) + if err != nil { + log.Warnf("failed to stat rotated log %s: %v", files[j], err) + return false + } + return fi.ModTime().After(fj.ModTime()) + }) + + // include up to logFileCount rotated files + maxFiles := int(g.logFileCount) + if maxFiles > len(files) { + maxFiles = len(files) + } + + for i := 0; i < maxFiles; i++ { + name := filepath.Base(files[i]) + if err := g.addSingleLogFileGz(files[i], name); err != nil { + log.Warnf("failed to add rotated log %s: %v", name, err) + } + } +} + func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error { header := &zip.FileHeader{ Name: filename, diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 202dc6f89..26e58d183 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -2290,6 +2290,7 @@ type DebugBundleRequest struct { Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"` UploadURL string `protobuf:"bytes,4,opt,name=uploadURL,proto3" json:"uploadURL,omitempty"` + LogFileCount uint32 `protobuf:"varint,5,opt,name=logFileCount,proto3" json:"logFileCount,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -2352,6 +2353,13 @@ func (x *DebugBundleRequest) GetUploadURL() string { return "" } +func (x *DebugBundleRequest) GetLogFileCount() uint32 { + if x != nil { + return x.LogFileCount + } + return 0 +} + type DebugBundleResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` @@ -3746,14 +3754,15 @@ const file_daemon_proto_rawDesc = "" + "\x12translatedHostname\x18\x04 \x01(\tR\x12translatedHostname\x128\n" + "\x0etranslatedPort\x18\x05 \x01(\v2\x10.daemon.PortInfoR\x0etranslatedPort\"G\n" + "\x17ForwardingRulesResponse\x12,\n" + - "\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\x88\x01\n" + + "\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\xac\x01\n" + "\x12DebugBundleRequest\x12\x1c\n" + "\tanonymize\x18\x01 \x01(\bR\tanonymize\x12\x16\n" + "\x06status\x18\x02 \x01(\tR\x06status\x12\x1e\n" + "\n" + "systemInfo\x18\x03 \x01(\bR\n" + "systemInfo\x12\x1c\n" + - "\tuploadURL\x18\x04 \x01(\tR\tuploadURL\"}\n" + + "\tuploadURL\x18\x04 \x01(\tR\tuploadURL\x12\"\n" + + "\flogFileCount\x18\x05 \x01(\rR\flogFileCount\"}\n" + "\x13DebugBundleResponse\x12\x12\n" + "\x04path\x18\x01 \x01(\tR\x04path\x12 \n" + "\vuploadedKey\x18\x02 \x01(\tR\vuploadedKey\x120\n" + diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index f488e69e7..462555c82 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -356,6 +356,7 @@ message DebugBundleRequest { string status = 2; bool systemInfo = 3; string uploadURL = 4; + uint32 logFileCount = 5; } message DebugBundleResponse { diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index e0612a6d1..6251f7c52 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -1,8 +1,4 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. -// versions: -// - protoc-gen-go-grpc v1.5.1 -// - protoc v5.29.3 -// source: daemon.proto package proto @@ -15,31 +11,8 @@ import ( // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.64.0 or later. -const _ = grpc.SupportPackageIsVersion9 - -const ( - DaemonService_Login_FullMethodName = "/daemon.DaemonService/Login" - DaemonService_WaitSSOLogin_FullMethodName = "/daemon.DaemonService/WaitSSOLogin" - DaemonService_Up_FullMethodName = "/daemon.DaemonService/Up" - DaemonService_Status_FullMethodName = "/daemon.DaemonService/Status" - DaemonService_Down_FullMethodName = "/daemon.DaemonService/Down" - DaemonService_GetConfig_FullMethodName = "/daemon.DaemonService/GetConfig" - DaemonService_ListNetworks_FullMethodName = "/daemon.DaemonService/ListNetworks" - DaemonService_SelectNetworks_FullMethodName = "/daemon.DaemonService/SelectNetworks" - DaemonService_DeselectNetworks_FullMethodName = "/daemon.DaemonService/DeselectNetworks" - DaemonService_ForwardingRules_FullMethodName = "/daemon.DaemonService/ForwardingRules" - DaemonService_DebugBundle_FullMethodName = "/daemon.DaemonService/DebugBundle" - DaemonService_GetLogLevel_FullMethodName = "/daemon.DaemonService/GetLogLevel" - DaemonService_SetLogLevel_FullMethodName = "/daemon.DaemonService/SetLogLevel" - DaemonService_ListStates_FullMethodName = "/daemon.DaemonService/ListStates" - DaemonService_CleanState_FullMethodName = "/daemon.DaemonService/CleanState" - DaemonService_DeleteState_FullMethodName = "/daemon.DaemonService/DeleteState" - DaemonService_SetNetworkMapPersistence_FullMethodName = "/daemon.DaemonService/SetNetworkMapPersistence" - DaemonService_TracePacket_FullMethodName = "/daemon.DaemonService/TracePacket" - DaemonService_SubscribeEvents_FullMethodName = "/daemon.DaemonService/SubscribeEvents" - DaemonService_GetEvents_FullMethodName = "/daemon.DaemonService/GetEvents" -) +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 // DaemonServiceClient is the client API for DaemonService service. // @@ -80,7 +53,7 @@ type DaemonServiceClient interface { // SetNetworkMapPersistence enables or disables network map persistence SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) - SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[SystemEvent], error) + SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error) GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error) } @@ -93,9 +66,8 @@ func NewDaemonServiceClient(cc grpc.ClientConnInterface) DaemonServiceClient { } func (c *daemonServiceClient) Login(ctx context.Context, in *LoginRequest, opts ...grpc.CallOption) (*LoginResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(LoginResponse) - err := c.cc.Invoke(ctx, DaemonService_Login_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/Login", in, out, opts...) if err != nil { return nil, err } @@ -103,9 +75,8 @@ func (c *daemonServiceClient) Login(ctx context.Context, in *LoginRequest, opts } func (c *daemonServiceClient) WaitSSOLogin(ctx context.Context, in *WaitSSOLoginRequest, opts ...grpc.CallOption) (*WaitSSOLoginResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(WaitSSOLoginResponse) - err := c.cc.Invoke(ctx, DaemonService_WaitSSOLogin_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/WaitSSOLogin", in, out, opts...) if err != nil { return nil, err } @@ -113,9 +84,8 @@ func (c *daemonServiceClient) WaitSSOLogin(ctx context.Context, in *WaitSSOLogin } func (c *daemonServiceClient) Up(ctx context.Context, in *UpRequest, opts ...grpc.CallOption) (*UpResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(UpResponse) - err := c.cc.Invoke(ctx, DaemonService_Up_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/Up", in, out, opts...) if err != nil { return nil, err } @@ -123,9 +93,8 @@ func (c *daemonServiceClient) Up(ctx context.Context, in *UpRequest, opts ...grp } func (c *daemonServiceClient) Status(ctx context.Context, in *StatusRequest, opts ...grpc.CallOption) (*StatusResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(StatusResponse) - err := c.cc.Invoke(ctx, DaemonService_Status_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/Status", in, out, opts...) if err != nil { return nil, err } @@ -133,9 +102,8 @@ func (c *daemonServiceClient) Status(ctx context.Context, in *StatusRequest, opt } func (c *daemonServiceClient) Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DownResponse) - err := c.cc.Invoke(ctx, DaemonService_Down_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/Down", in, out, opts...) if err != nil { return nil, err } @@ -143,9 +111,8 @@ func (c *daemonServiceClient) Down(ctx context.Context, in *DownRequest, opts .. } func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetConfigResponse) - err := c.cc.Invoke(ctx, DaemonService_GetConfig_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetConfig", in, out, opts...) if err != nil { return nil, err } @@ -153,9 +120,8 @@ func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigReques } func (c *daemonServiceClient) ListNetworks(ctx context.Context, in *ListNetworksRequest, opts ...grpc.CallOption) (*ListNetworksResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ListNetworksResponse) - err := c.cc.Invoke(ctx, DaemonService_ListNetworks_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListNetworks", in, out, opts...) if err != nil { return nil, err } @@ -163,9 +129,8 @@ func (c *daemonServiceClient) ListNetworks(ctx context.Context, in *ListNetworks } func (c *daemonServiceClient) SelectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SelectNetworksResponse) - err := c.cc.Invoke(ctx, DaemonService_SelectNetworks_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectNetworks", in, out, opts...) if err != nil { return nil, err } @@ -173,9 +138,8 @@ func (c *daemonServiceClient) SelectNetworks(ctx context.Context, in *SelectNetw } func (c *daemonServiceClient) DeselectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SelectNetworksResponse) - err := c.cc.Invoke(ctx, DaemonService_DeselectNetworks_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectNetworks", in, out, opts...) if err != nil { return nil, err } @@ -183,9 +147,8 @@ func (c *daemonServiceClient) DeselectNetworks(ctx context.Context, in *SelectNe } func (c *daemonServiceClient) ForwardingRules(ctx context.Context, in *EmptyRequest, opts ...grpc.CallOption) (*ForwardingRulesResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ForwardingRulesResponse) - err := c.cc.Invoke(ctx, DaemonService_ForwardingRules_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/ForwardingRules", in, out, opts...) if err != nil { return nil, err } @@ -193,9 +156,8 @@ func (c *daemonServiceClient) ForwardingRules(ctx context.Context, in *EmptyRequ } func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DebugBundleResponse) - err := c.cc.Invoke(ctx, DaemonService_DebugBundle_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/DebugBundle", in, out, opts...) if err != nil { return nil, err } @@ -203,9 +165,8 @@ func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRe } func (c *daemonServiceClient) GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetLogLevelResponse) - err := c.cc.Invoke(ctx, DaemonService_GetLogLevel_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetLogLevel", in, out, opts...) if err != nil { return nil, err } @@ -213,9 +174,8 @@ func (c *daemonServiceClient) GetLogLevel(ctx context.Context, in *GetLogLevelRe } func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SetLogLevelResponse) - err := c.cc.Invoke(ctx, DaemonService_SetLogLevel_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetLogLevel", in, out, opts...) if err != nil { return nil, err } @@ -223,9 +183,8 @@ func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRe } func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ListStatesResponse) - err := c.cc.Invoke(ctx, DaemonService_ListStates_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListStates", in, out, opts...) if err != nil { return nil, err } @@ -233,9 +192,8 @@ func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequ } func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(CleanStateResponse) - err := c.cc.Invoke(ctx, DaemonService_CleanState_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/CleanState", in, out, opts...) if err != nil { return nil, err } @@ -243,9 +201,8 @@ func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequ } func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DeleteStateResponse) - err := c.cc.Invoke(ctx, DaemonService_DeleteState_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeleteState", in, out, opts...) if err != nil { return nil, err } @@ -253,9 +210,8 @@ func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRe } func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SetNetworkMapPersistenceResponse) - err := c.cc.Invoke(ctx, DaemonService_SetNetworkMapPersistence_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetNetworkMapPersistence", in, out, opts...) if err != nil { return nil, err } @@ -263,22 +219,20 @@ func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in * } func (c *daemonServiceClient) TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(TracePacketResponse) - err := c.cc.Invoke(ctx, DaemonService_TracePacket_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/TracePacket", in, out, opts...) if err != nil { return nil, err } return out, nil } -func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[SystemEvent], error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], DaemonService_SubscribeEvents_FullMethodName, cOpts...) +func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error) { + stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], "/daemon.DaemonService/SubscribeEvents", opts...) if err != nil { return nil, err } - x := &grpc.GenericClientStream[SubscribeRequest, SystemEvent]{ClientStream: stream} + x := &daemonServiceSubscribeEventsClient{stream} if err := x.ClientStream.SendMsg(in); err != nil { return nil, err } @@ -288,13 +242,26 @@ func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *Subscribe return x, nil } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type DaemonService_SubscribeEventsClient = grpc.ServerStreamingClient[SystemEvent] +type DaemonService_SubscribeEventsClient interface { + Recv() (*SystemEvent, error) + grpc.ClientStream +} + +type daemonServiceSubscribeEventsClient struct { + grpc.ClientStream +} + +func (x *daemonServiceSubscribeEventsClient) Recv() (*SystemEvent, error) { + m := new(SystemEvent) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} func (c *daemonServiceClient) GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetEventsResponse) - err := c.cc.Invoke(ctx, DaemonService_GetEvents_FullMethodName, in, out, cOpts...) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetEvents", in, out, opts...) if err != nil { return nil, err } @@ -303,7 +270,7 @@ func (c *daemonServiceClient) GetEvents(ctx context.Context, in *GetEventsReques // DaemonServiceServer is the server API for DaemonService service. // All implementations must embed UnimplementedDaemonServiceServer -// for forward compatibility. +// for forward compatibility type DaemonServiceServer interface { // Login uses setup key to prepare configuration for the daemon. Login(context.Context, *LoginRequest) (*LoginResponse, error) @@ -340,17 +307,14 @@ type DaemonServiceServer interface { // SetNetworkMapPersistence enables or disables network map persistence SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) - SubscribeEvents(*SubscribeRequest, grpc.ServerStreamingServer[SystemEvent]) error + SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error) mustEmbedUnimplementedDaemonServiceServer() } -// UnimplementedDaemonServiceServer must be embedded to have -// forward compatible implementations. -// -// NOTE: this should be embedded by value instead of pointer to avoid a nil -// pointer dereference when methods are called. -type UnimplementedDaemonServiceServer struct{} +// UnimplementedDaemonServiceServer must be embedded to have forward compatible implementations. +type UnimplementedDaemonServiceServer struct { +} func (UnimplementedDaemonServiceServer) Login(context.Context, *LoginRequest) (*LoginResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method Login not implemented") @@ -406,14 +370,13 @@ func (UnimplementedDaemonServiceServer) SetNetworkMapPersistence(context.Context func (UnimplementedDaemonServiceServer) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method TracePacket not implemented") } -func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, grpc.ServerStreamingServer[SystemEvent]) error { +func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error { return status.Errorf(codes.Unimplemented, "method SubscribeEvents not implemented") } func (UnimplementedDaemonServiceServer) GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetEvents not implemented") } func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} -func (UnimplementedDaemonServiceServer) testEmbeddedByValue() {} // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. // Use of this interface is not recommended, as added methods to DaemonServiceServer will @@ -423,13 +386,6 @@ type UnsafeDaemonServiceServer interface { } func RegisterDaemonServiceServer(s grpc.ServiceRegistrar, srv DaemonServiceServer) { - // If the following call pancis, it indicates UnimplementedDaemonServiceServer was - // embedded by pointer and is nil. This will cause panics if an - // unimplemented method is ever invoked, so we test this at initialization - // time to prevent it from happening at runtime later due to I/O. - if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { - t.testEmbeddedByValue() - } s.RegisterService(&DaemonService_ServiceDesc, srv) } @@ -443,7 +399,7 @@ func _DaemonService_Login_Handler(srv interface{}, ctx context.Context, dec func } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_Login_FullMethodName, + FullMethod: "/daemon.DaemonService/Login", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).Login(ctx, req.(*LoginRequest)) @@ -461,7 +417,7 @@ func _DaemonService_WaitSSOLogin_Handler(srv interface{}, ctx context.Context, d } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_WaitSSOLogin_FullMethodName, + FullMethod: "/daemon.DaemonService/WaitSSOLogin", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).WaitSSOLogin(ctx, req.(*WaitSSOLoginRequest)) @@ -479,7 +435,7 @@ func _DaemonService_Up_Handler(srv interface{}, ctx context.Context, dec func(in } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_Up_FullMethodName, + FullMethod: "/daemon.DaemonService/Up", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).Up(ctx, req.(*UpRequest)) @@ -497,7 +453,7 @@ func _DaemonService_Status_Handler(srv interface{}, ctx context.Context, dec fun } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_Status_FullMethodName, + FullMethod: "/daemon.DaemonService/Status", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).Status(ctx, req.(*StatusRequest)) @@ -515,7 +471,7 @@ func _DaemonService_Down_Handler(srv interface{}, ctx context.Context, dec func( } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_Down_FullMethodName, + FullMethod: "/daemon.DaemonService/Down", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).Down(ctx, req.(*DownRequest)) @@ -533,7 +489,7 @@ func _DaemonService_GetConfig_Handler(srv interface{}, ctx context.Context, dec } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_GetConfig_FullMethodName, + FullMethod: "/daemon.DaemonService/GetConfig", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).GetConfig(ctx, req.(*GetConfigRequest)) @@ -551,7 +507,7 @@ func _DaemonService_ListNetworks_Handler(srv interface{}, ctx context.Context, d } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_ListNetworks_FullMethodName, + FullMethod: "/daemon.DaemonService/ListNetworks", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).ListNetworks(ctx, req.(*ListNetworksRequest)) @@ -569,7 +525,7 @@ func _DaemonService_SelectNetworks_Handler(srv interface{}, ctx context.Context, } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_SelectNetworks_FullMethodName, + FullMethod: "/daemon.DaemonService/SelectNetworks", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).SelectNetworks(ctx, req.(*SelectNetworksRequest)) @@ -587,7 +543,7 @@ func _DaemonService_DeselectNetworks_Handler(srv interface{}, ctx context.Contex } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_DeselectNetworks_FullMethodName, + FullMethod: "/daemon.DaemonService/DeselectNetworks", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).DeselectNetworks(ctx, req.(*SelectNetworksRequest)) @@ -605,7 +561,7 @@ func _DaemonService_ForwardingRules_Handler(srv interface{}, ctx context.Context } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_ForwardingRules_FullMethodName, + FullMethod: "/daemon.DaemonService/ForwardingRules", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).ForwardingRules(ctx, req.(*EmptyRequest)) @@ -623,7 +579,7 @@ func _DaemonService_DebugBundle_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_DebugBundle_FullMethodName, + FullMethod: "/daemon.DaemonService/DebugBundle", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).DebugBundle(ctx, req.(*DebugBundleRequest)) @@ -641,7 +597,7 @@ func _DaemonService_GetLogLevel_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_GetLogLevel_FullMethodName, + FullMethod: "/daemon.DaemonService/GetLogLevel", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).GetLogLevel(ctx, req.(*GetLogLevelRequest)) @@ -659,7 +615,7 @@ func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_SetLogLevel_FullMethodName, + FullMethod: "/daemon.DaemonService/SetLogLevel", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).SetLogLevel(ctx, req.(*SetLogLevelRequest)) @@ -677,7 +633,7 @@ func _DaemonService_ListStates_Handler(srv interface{}, ctx context.Context, dec } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_ListStates_FullMethodName, + FullMethod: "/daemon.DaemonService/ListStates", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).ListStates(ctx, req.(*ListStatesRequest)) @@ -695,7 +651,7 @@ func _DaemonService_CleanState_Handler(srv interface{}, ctx context.Context, dec } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_CleanState_FullMethodName, + FullMethod: "/daemon.DaemonService/CleanState", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).CleanState(ctx, req.(*CleanStateRequest)) @@ -713,7 +669,7 @@ func _DaemonService_DeleteState_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_DeleteState_FullMethodName, + FullMethod: "/daemon.DaemonService/DeleteState", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).DeleteState(ctx, req.(*DeleteStateRequest)) @@ -731,7 +687,7 @@ func _DaemonService_SetNetworkMapPersistence_Handler(srv interface{}, ctx contex } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_SetNetworkMapPersistence_FullMethodName, + FullMethod: "/daemon.DaemonService/SetNetworkMapPersistence", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).SetNetworkMapPersistence(ctx, req.(*SetNetworkMapPersistenceRequest)) @@ -749,7 +705,7 @@ func _DaemonService_TracePacket_Handler(srv interface{}, ctx context.Context, de } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_TracePacket_FullMethodName, + FullMethod: "/daemon.DaemonService/TracePacket", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).TracePacket(ctx, req.(*TracePacketRequest)) @@ -762,11 +718,21 @@ func _DaemonService_SubscribeEvents_Handler(srv interface{}, stream grpc.ServerS if err := stream.RecvMsg(m); err != nil { return err } - return srv.(DaemonServiceServer).SubscribeEvents(m, &grpc.GenericServerStream[SubscribeRequest, SystemEvent]{ServerStream: stream}) + return srv.(DaemonServiceServer).SubscribeEvents(m, &daemonServiceSubscribeEventsServer{stream}) } -// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. -type DaemonService_SubscribeEventsServer = grpc.ServerStreamingServer[SystemEvent] +type DaemonService_SubscribeEventsServer interface { + Send(*SystemEvent) error + grpc.ServerStream +} + +type daemonServiceSubscribeEventsServer struct { + grpc.ServerStream +} + +func (x *daemonServiceSubscribeEventsServer) Send(m *SystemEvent) error { + return x.ServerStream.SendMsg(m) +} func _DaemonService_GetEvents_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetEventsRequest) @@ -778,7 +744,7 @@ func _DaemonService_GetEvents_Handler(srv interface{}, ctx context.Context, dec } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: DaemonService_GetEvents_FullMethodName, + FullMethod: "/daemon.DaemonService/GetEvents", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(DaemonServiceServer).GetEvents(ctx, req.(*GetEventsRequest)) diff --git a/client/proto/generate.sh b/client/proto/generate.sh index 52fe23d7f..f9a2c3750 100755 --- a/client/proto/generate.sh +++ b/client/proto/generate.sh @@ -11,7 +11,7 @@ fi old_pwd=$(pwd) script_path=$(dirname $(realpath "$0")) cd "$script_path" -go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26 +go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6 go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1 protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional cd "$old_pwd" \ No newline at end of file diff --git a/client/server/debug.go b/client/server/debug.go index 7de3e8609..412602b00 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -42,6 +42,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( Anonymize: req.GetAnonymize(), ClientStatus: req.GetStatus(), IncludeSystemInfo: req.GetSystemInfo(), + LogFileCount: req.GetLogFileCount(), }, ) diff --git a/go.mod b/go.mod index a12058278..4a9727373 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,7 @@ require ( golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 google.golang.org/grpc v1.64.1 - google.golang.org/protobuf v1.36.5 + google.golang.org/protobuf v1.36.6 gopkg.in/natefinch/lumberjack.v2 v2.0.0 ) diff --git a/go.sum b/go.sum index 6ce503dd1..a622f203f 100644 --- a/go.sum +++ b/go.sum @@ -1164,8 +1164,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= -google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/util/log.go b/util/log.go index 59a064366..53d2b0684 100644 --- a/util/log.go +++ b/util/log.go @@ -14,7 +14,7 @@ import ( "github.com/netbirdio/netbird/formatter" ) -const defaultLogSize = 5 +const defaultLogSize = 15 // InitLog parses and sets log-level input func InitLog(logLevel string, logPath string) error { From 2b9f3319803e74b81f16b5216ab99e70acad24ea Mon Sep 17 00:00:00 2001 From: Pedro Maia Costa <550684+pnmcosta@users.noreply.github.com> Date: Fri, 11 Jul 2025 10:29:10 +0100 Subject: [PATCH 11/12] always suffix ephemeral peer name (#4138) --- management/server/peer.go | 43 +++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index a60513b38..21a9579fc 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -236,11 +236,23 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user if peer.Name != update.Name { var newLabel string - newLabel, err = getPeerIPDNSLabel(ctx, transaction, peer.IP, accountID, update.Name) + + newLabel, err = nbdns.GetParsedDomainLabel(update.Name) if err != nil { - return fmt.Errorf("failed to get free DNS label: %w", err) + newLabel = "" + } else { + _, err := transaction.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, update.Name) + if err == nil { + newLabel = "" + } } + if newLabel == "" { + newLabel, err = getPeerIPDNSLabel(peer.IP, update.Name) + if err != nil { + return fmt.Errorf("failed to get free DNS label: %w", err) + } + } peer.Name = update.Name peer.DNSLabel = newLabel peerLabelChanged = true @@ -472,6 +484,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s var groupsToAdd []string var allowExtraDNSLabels bool var accountID string + var isEphemeral bool if addedByUser { user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { @@ -501,7 +514,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s setupKeyName = sk.Name allowExtraDNSLabels = sk.AllowExtraDNSLabels accountID = sk.AccountID - + isEphemeral = sk.Ephemeral if !sk.AllowExtraDNSLabels && len(peer.ExtraDNSLabels) > 0 { return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key doesn't allow extra DNS labels") } @@ -573,11 +586,17 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } var freeLabel string - freeLabel, err = getPeerIPDNSLabel(ctx, am.Store, freeIP, accountID, peer.Meta.Hostname) - if err != nil { - return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err) + if isEphemeral || attempt > 1 { + freeLabel, err = getPeerIPDNSLabel(freeIP, peer.Meta.Hostname) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err) + } + } else { + freeLabel, err = nbdns.GetParsedDomainLabel(peer.Meta.Hostname) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err) + } } - newPeer.DNSLabel = freeLabel newPeer.IP = freeIP @@ -647,7 +666,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s if isUniqueConstraintError(err) { unlock() unlock = nil - log.WithContext(ctx).Debugf("Failed to add peer in attempt %d, retrying: %v", attempt, err) + log.WithContext(ctx).WithFields(log.Fields{"dns_label": freeLabel, "ip": freeIP}).Tracef("Failed to add peer in attempt %d, retrying: %v", attempt, err) continue } @@ -681,7 +700,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) } -func getPeerIPDNSLabel(ctx context.Context, tx store.Store, ip net.IP, accountID, peerHostName string) (string, error) { +func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) { ip = ip.To4() dnsName, err := nbdns.GetParsedDomainLabel(peerHostName) @@ -689,12 +708,6 @@ func getPeerIPDNSLabel(ctx context.Context, tx store.Store, ip net.IP, accountID return "", fmt.Errorf("failed to parse peer host name %s: %w", peerHostName, err) } - _, err = tx.GetPeerIdByLabel(ctx, store.LockingStrengthNone, accountID, dnsName) - if err != nil { - //nolint:nilerr - return dnsName, nil - } - return fmt.Sprintf("%s-%d-%d", dnsName, ip[2], ip[3]), nil } From a76c8eafb46cda5e44de0f3784160a34e8cdb4a2 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Fri, 11 Jul 2025 11:37:14 +0200 Subject: [PATCH 12/12] [management] sync calls to UpdateAccountPeers from BufferUpdateAccountPeers (#4137) --------- Co-authored-by: Maycon Santos Co-authored-by: Pedro Costa <550684+pnmcosta@users.noreply.github.com> --- management/server/mock_server/account_mock.go | 10 +- management/server/peer.go | 34 ++++- management/server/peer_test.go | 130 ++++++++++++++++++ 3 files changed, 166 insertions(+), 8 deletions(-) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 4004f1b57..b1ec66286 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -120,14 +120,20 @@ type MockAccountManager struct { GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error) UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) + UpdateAccountPeersFunc func(ctx context.Context, accountID string) + BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) } func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { - // do nothing + if am.UpdateAccountPeersFunc != nil { + am.UpdateAccountPeersFunc(ctx, accountID) + } } func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { - // do nothing + if am.BufferUpdateAccountPeersFunc != nil { + am.BufferUpdateAccountPeersFunc(ctx, accountID) + } } func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { diff --git a/management/server/peer.go b/management/server/peer.go index 21a9579fc..c6ade83c0 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -9,6 +9,7 @@ import ( "slices" "strings" "sync" + "sync/atomic" "time" "github.com/rs/xid" @@ -1280,18 +1281,39 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account } } -func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { - mu, _ := am.accountUpdateLocks.LoadOrStore(accountID, &sync.Mutex{}) - lock := mu.(*sync.Mutex) +type bufferUpdate struct { + mu sync.Mutex + next *time.Timer + update atomic.Bool +} - if !lock.TryLock() { +func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { + bufUpd, _ := am.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{}) + b := bufUpd.(*bufferUpdate) + + if !b.mu.TryLock() { + b.update.Store(true) return } + if b.next != nil { + b.next.Stop() + } + go func() { - time.Sleep(time.Duration(am.updateAccountPeersBufferInterval.Load())) - lock.Unlock() + defer b.mu.Unlock() am.UpdateAccountPeers(ctx, accountID) + if !b.update.Load() { + return + } + b.update.Store(false) + if b.next == nil { + b.next = time.AfterFunc(time.Duration(am.updateAccountPeersBufferInterval.Load()), func() { + am.UpdateAccountPeers(ctx, accountID) + }) + return + } + b.next.Reset(time.Duration(am.updateAccountPeersBufferInterval.Load())) }() } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 07ec5037b..d41020514 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -13,6 +13,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "testing" "time" @@ -25,6 +26,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/status" @@ -2251,3 +2253,131 @@ func Test_AddPeer(t *testing.T) { assert.Equal(t, totalPeers, maps.Values(account.SetupKeys)[0].UsedTimes) assert.Equal(t, uint64(totalPeers), account.Network.Serial) } + +func TestBufferUpdateAccountPeers(t *testing.T) { + const ( + peersCount = 1000 + updateAccountInterval = 50 * time.Millisecond + ) + + var ( + deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32 + uapLastRun, dpLastRun atomic.Int64 + + totalNewRuns, totalOldRuns int + ) + + uap := func(ctx context.Context, accountID string) { + updatePeersDeleted.Store(deletedPeers.Load()) + updatePeersRuns.Add(1) + uapLastRun.Store(time.Now().UnixMilli()) + time.Sleep(100 * time.Millisecond) + } + + t.Run("new approach", func(t *testing.T) { + updatePeersRuns.Store(0) + updatePeersDeleted.Store(0) + deletedPeers.Store(0) + + var mustore sync.Map + bufupd := func(ctx context.Context, accountID string) { + mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{}) + b := mu.(*bufferUpdate) + + if !b.mu.TryLock() { + b.update.Store(true) + return + } + + if b.next != nil { + b.next.Stop() + } + + go func() { + defer b.mu.Unlock() + uap(ctx, accountID) + if !b.update.Load() { + return + } + b.update.Store(false) + b.next = time.AfterFunc(updateAccountInterval, func() { + uap(ctx, accountID) + }) + }() + } + dp := func(ctx context.Context, accountID, peerID, userID string) error { + deletedPeers.Add(1) + dpLastRun.Store(time.Now().UnixMilli()) + time.Sleep(10 * time.Millisecond) + bufupd(ctx, accountID) + return nil + } + + am := mock_server.MockAccountManager{ + UpdateAccountPeersFunc: uap, + BufferUpdateAccountPeersFunc: bufupd, + DeletePeerFunc: dp, + } + empty := "" + for range peersCount { + //nolint + am.DeletePeer(context.Background(), empty, empty, empty) + } + time.Sleep(100 * time.Millisecond) + + assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") + assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") + assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") + + totalNewRuns = int(updatePeersRuns.Load()) + }) + + t.Run("old approach", func(t *testing.T) { + updatePeersRuns.Store(0) + updatePeersDeleted.Store(0) + deletedPeers.Store(0) + + var mustore sync.Map + bufupd := func(ctx context.Context, accountID string) { + mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{}) + b := mu.(*sync.Mutex) + + if !b.TryLock() { + return + } + + go func() { + time.Sleep(updateAccountInterval) + b.Unlock() + uap(ctx, accountID) + }() + } + dp := func(ctx context.Context, accountID, peerID, userID string) error { + deletedPeers.Add(1) + dpLastRun.Store(time.Now().UnixMilli()) + time.Sleep(10 * time.Millisecond) + bufupd(ctx, accountID) + return nil + } + + am := mock_server.MockAccountManager{ + UpdateAccountPeersFunc: uap, + BufferUpdateAccountPeersFunc: bufupd, + DeletePeerFunc: dp, + } + empty := "" + for range peersCount { + //nolint + am.DeletePeer(context.Background(), empty, empty, empty) + } + time.Sleep(100 * time.Millisecond) + + assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") + assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") + assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") + + totalOldRuns = int(updatePeersRuns.Load()) + }) + assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) + t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) +}