Merge branch 'main' into drop-dns-probes

This commit is contained in:
Viktor Liu
2026-05-06 11:10:03 +02:00
213 changed files with 12041 additions and 7636 deletions

View File

@@ -331,6 +331,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.statusRecorder.MarkSignalConnected()
relayURLs, token := parseRelayInfo(loginResp)
if override, ok := peer.OverrideRelayURLs(); ok {
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
relayURLs = override
}
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)

View File

@@ -61,6 +61,7 @@ allocs.prof: Allocations profiling information.
threadcreate.prof: Thread creation profiling information.
cpu.prof: CPU profiling information.
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
capture.pcap: Packet capture in pcap format. Only present when capture was running during bundle collection. Omitted from anonymized bundles because it contains raw decrypted packet data.
Anonymization Process
@@ -234,6 +235,7 @@ type BundleGenerator struct {
logPath string
tempDir string
cpuProfile []byte
capturePath string
refreshStatus func() // Optional callback to refresh status before bundle generation
clientMetrics MetricsExporter
@@ -257,7 +259,8 @@ type GeneratorDependencies struct {
LogPath string
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
CPUProfile []byte
RefreshStatus func() // Optional callback to refresh status before bundle generation
CapturePath string
RefreshStatus func()
ClientMetrics MetricsExporter
}
@@ -277,6 +280,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
logPath: deps.LogPath,
tempDir: deps.TempDir,
cpuProfile: deps.CPUProfile,
capturePath: deps.CapturePath,
refreshStatus: deps.RefreshStatus,
clientMetrics: deps.ClientMetrics,
@@ -346,6 +350,10 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add CPU profile to debug bundle: %v", err)
}
if err := g.addCaptureFile(); err != nil {
log.Errorf("failed to add capture file to debug bundle: %v", err)
}
if err := g.addStackTrace(); err != nil {
log.Errorf("failed to add stack trace to debug bundle: %v", err)
}
@@ -669,6 +677,29 @@ func (g *BundleGenerator) addCPUProfile() error {
return nil
}
func (g *BundleGenerator) addCaptureFile() error {
if g.capturePath == "" {
return nil
}
if g.anonymize {
log.Info("skipping capture file in anonymized bundle (contains raw packet data)")
return nil
}
f, err := os.Open(g.capturePath)
if err != nil {
return fmt.Errorf("open capture file: %w", err)
}
defer f.Close()
if err := g.addFileToZip(f, "capture.pcap"); err != nil {
return fmt.Errorf("add capture file to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) addStackTrace() error {
buf := make([]byte, 5242880) // 5 MB buffer
n := runtime.Stack(buf, true)

View File

@@ -28,6 +28,7 @@ import (
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
@@ -68,6 +69,7 @@ import (
signal "github.com/netbirdio/netbird/shared/signal/client"
sProto "github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/capture"
)
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@@ -218,6 +220,8 @@ type Engine struct {
portForwardManager *portforward.Manager
srWatcher *guard.SRWatcher
afpacketCapture *capture.AFPacketCapture
// Sync response persistence (protected by syncRespMux)
syncRespMux sync.RWMutex
persistSyncResponse bool
@@ -935,7 +939,12 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
return fmt.Errorf("update relay token: %w", err)
}
e.relayManager.UpdateServerURLs(update.Urls)
urls := update.Urls
if override, ok := peer.OverrideRelayURLs(); ok {
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
urls = override
}
e.relayManager.UpdateServerURLs(urls)
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
// We can ignore all errors because the guard will manage the reconnection retries.
@@ -1686,6 +1695,11 @@ func (e *Engine) parseNATExternalIPMappings() []string {
}
func (e *Engine) close() {
if e.afpacketCapture != nil {
e.afpacketCapture.Stop()
e.afpacketCapture = nil
}
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
if e.wgInterface != nil {
@@ -2151,6 +2165,62 @@ func (e *Engine) Address() (netip.Addr, error) {
return e.wgInterface.Address().IP, nil
}
// SetCapture sets or clears packet capture on the WireGuard device.
// On userspace WireGuard, it taps the FilteredDevice directly.
// On kernel WireGuard (Linux), it falls back to AF_PACKET raw socket capture.
// Pass nil to disable capture.
func (e *Engine) SetCapture(pc device.PacketCapture) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
intf := e.wgInterface
if intf == nil {
return errors.New("wireguard interface not initialized")
}
if e.afpacketCapture != nil {
e.afpacketCapture.Stop()
e.afpacketCapture = nil
}
dev := intf.GetDevice()
if dev != nil {
dev.SetCapture(pc)
e.setForwarderCapture(pc)
return nil
}
// Kernel mode: no FilteredDevice. Use AF_PACKET on Linux.
if pc == nil {
return nil
}
sess, ok := pc.(*capture.Session)
if !ok {
return errors.New("filtered device not available and AF_PACKET requires *capture.Session")
}
afc := capture.NewAFPacketCapture(intf.Name(), sess)
if err := afc.Start(); err != nil {
return fmt.Errorf("start AF_PACKET capture on %s: %w", intf.Name(), err)
}
e.afpacketCapture = afc
return nil
}
// setForwarderCapture propagates capture to the USP filter's forwarder endpoint.
// This captures outbound response packets that bypass the FilteredDevice in netstack mode.
func (e *Engine) setForwarderCapture(pc device.PacketCapture) {
if e.firewall == nil {
return
}
type forwarderCapturer interface {
SetPacketCapture(pc forwarder.PacketCapture)
}
if fc, ok := e.firewall.(forwarderCapturer); ok {
fc.SetPacketCapture(pc)
}
}
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
if e.firewall == nil {
log.Warn("firewall is disabled, not updating forwarding rules")
@@ -2372,6 +2442,8 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
}
}
relayIP := decodeRelayIP(msg.GetBody().GetRelayServerIP())
offerAnswer := peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
@@ -2382,7 +2454,23 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
RelaySrvIP: relayIP,
SessionID: sessionID,
}
return &offerAnswer, nil
}
// decodeRelayIP decodes the proto relayServerIP bytes (4 or 16) into a
// netip.Addr. Returns the zero value for empty input and logs a warning
// for malformed payloads.
func decodeRelayIP(b []byte) netip.Addr {
if len(b) == 0 {
return netip.Addr{}
}
ip, ok := netip.AddrFromSlice(b)
if !ok {
log.Warnf("invalid relayServerIP in signal message (%d bytes), ignoring", len(b))
return netip.Addr{}
}
return ip.Unmap()
}

View File

@@ -1671,7 +1671,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil {
return nil, "", err
}

View File

@@ -3,7 +3,6 @@ package activity
import (
"net"
"net/netip"
"runtime"
"testing"
"time"
@@ -18,10 +17,6 @@ import (
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
func isBindListenerPlatform() bool {
return runtime.GOOS == "windows" || runtime.GOOS == "js"
}
// mockEndpointManager implements device.EndpointManager for testing
type mockEndpointManager struct {
endpoints map[netip.Addr]net.Conn
@@ -181,10 +176,6 @@ func TestBindListener_Close(t *testing.T) {
}
func TestManager_BindMode(t *testing.T) {
if !isBindListenerPlatform() {
t.Skip("BindListener only used on Windows/JS platforms")
}
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
@@ -226,10 +217,6 @@ func TestManager_BindMode(t *testing.T) {
}
func TestManager_BindMode_MultiplePeers(t *testing.T) {
if !isBindListenerPlatform() {
t.Skip("BindListener only used on Windows/JS platforms")
}
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}

View File

@@ -4,14 +4,12 @@ import (
"errors"
"net"
"net/netip"
"runtime"
"sync"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
@@ -75,16 +73,6 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
return NewUDPListener(m.wgIface, peerCfg)
}
// BindListener is used on Windows, JS, and netstack platforms:
// - JS: Cannot listen to UDP sockets
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
// gateway points to, preventing them from reaching the loopback interface.
// - Netstack: Allows multiple instances on the same host without port conflicts.
// BindListener bypasses these issues by passing data directly through the bind.
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
return NewUDPListener(m.wgIface, peerCfg)
}
provider, ok := m.wgIface.(bindProvider)
if !ok {
return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider")

View File

@@ -6,7 +6,6 @@ import (
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/activity"
@@ -91,8 +90,8 @@ func (m *Manager) UpdateRouteHAMap(haMap route.HAMap) {
m.routesMu.Lock()
defer m.routesMu.Unlock()
maps.Clear(m.peerToHAGroups)
maps.Clear(m.haGroupToPeers)
clear(m.peerToHAGroups)
clear(m.haGroupToPeers)
for haUniqueID, routes := range haMap {
var peers []string

View File

@@ -3,8 +3,6 @@ package store
import (
"sync"
"golang.org/x/exp/maps"
"github.com/google/uuid"
"github.com/netbirdio/netbird/client/internal/netflow/types"
@@ -30,7 +28,7 @@ func (m *Memory) StoreEvent(event *types.Event) {
func (m *Memory) Close() {
m.mux.Lock()
defer m.mux.Unlock()
maps.Clear(m.events)
clear(m.events)
}
func (m *Memory) GetEvents() []*types.Event {

View File

@@ -7,7 +7,8 @@ import (
)
const (
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
EnvKeyNBHomeRelayServers = "NB_HOME_RELAY_SERVERS"
)
func IsForceRelayed() bool {
@@ -16,3 +17,28 @@ func IsForceRelayed() bool {
}
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
}
// OverrideRelayURLs returns the relay server URL list set in
// NB_HOME_RELAY_SERVERS (comma-separated) and a boolean indicating whether
// the override is active. When the env var is unset, the boolean is false
// and the caller should keep the list received from the management server.
// Intended for lab/debug scenarios where a peer must pin to a specific home
// relay regardless of what management offers.
func OverrideRelayURLs() ([]string, bool) {
raw := os.Getenv(EnvKeyNBHomeRelayServers)
if raw == "" {
return nil, false
}
parts := strings.Split(raw, ",")
urls := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
urls = append(urls, p)
}
}
if len(urls) == 0 {
return nil, false
}
return urls, true
}

View File

@@ -3,6 +3,7 @@ package peer
import (
"context"
"errors"
"net/netip"
"sync"
"sync/atomic"
@@ -40,6 +41,10 @@ type OfferAnswer struct {
// relay server address
RelaySrvAddress string
// RelaySrvIP is the IP the remote peer is connected to on its
// relay server. Used as a dial target if DNS for RelaySrvAddress
// fails. Zero value if the peer did not advertise an IP.
RelaySrvIP netip.Addr
// SessionID is the unique identifier of the session, used to discard old messages
SessionID *ICESessionID
}
@@ -217,8 +222,9 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer {
answer.SessionID = &sid
}
if addr, err := h.relay.RelayInstanceAddress(); err == nil {
if addr, ip, err := h.relay.RelayInstanceAddress(); err == nil {
answer.RelaySrvAddress = addr
answer.RelaySrvIP = ip
}
return answer

View File

@@ -8,6 +8,7 @@ import (
type mocListener struct {
lastState int
wg sync.WaitGroup
peersWg sync.WaitGroup
peers int
}
@@ -33,6 +34,7 @@ func (l *mocListener) OnAddressChanged(host, addr string) {
}
func (l *mocListener) OnPeersListChanged(size int) {
l.peers = size
l.peersWg.Done()
}
func (l *mocListener) setWaiter() {
@@ -43,6 +45,14 @@ func (l *mocListener) wait() {
l.wg.Wait()
}
func (l *mocListener) setPeersWaiter() {
l.peersWg.Add(1)
}
func (l *mocListener) waitPeers() {
l.peersWg.Wait()
}
func Test_notifier_serverState(t *testing.T) {
type scenario struct {
@@ -72,11 +82,13 @@ func Test_notifier_serverState(t *testing.T) {
func Test_notifier_SetListener(t *testing.T) {
listener := &mocListener{}
listener.setWaiter()
listener.setPeersWaiter()
n := newNotifier()
n.lastNotification = stateConnecting
n.setListener(listener)
listener.wait()
listener.waitPeers()
if listener.lastState != n.lastNotification {
t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification)
}
@@ -85,9 +97,14 @@ func Test_notifier_SetListener(t *testing.T) {
func Test_notifier_RemoveListener(t *testing.T) {
listener := &mocListener{}
listener.setWaiter()
listener.setPeersWaiter()
n := newNotifier()
n.lastNotification = stateConnecting
n.setListener(listener)
// setListener replays cached state on a goroutine; wait for both the state
// and peers callbacks to finish so we don't race on listener.peers.
listener.wait()
listener.waitPeers()
n.removeListener()
n.peerListChanged(1)

View File

@@ -54,19 +54,19 @@ func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string,
log.Warnf("failed to get session ID bytes: %v", err)
}
}
msg, err := signal.MarshalCredential(
s.wgPrivateKey,
offerAnswer.WgListenPort,
remoteKey,
&signal.Credential{
msg, err := signal.MarshalCredential(s.wgPrivateKey, remoteKey, signal.CredentialPayload{
Type: bodyType,
WgListenPort: offerAnswer.WgListenPort,
Credential: &signal.Credential{
UFrag: offerAnswer.IceCredentials.UFrag,
Pwd: offerAnswer.IceCredentials.Pwd,
},
bodyType,
offerAnswer.RosenpassPubKey,
offerAnswer.RosenpassAddr,
offerAnswer.RelaySrvAddress,
sessionIDBytes)
RosenpassPubKey: offerAnswer.RosenpassPubKey,
RosenpassAddr: offerAnswer.RosenpassAddr,
RelaySrvAddress: offerAnswer.RelaySrvAddress,
RelaySrvIP: offerAnswer.RelaySrvIP,
SessionID: sessionIDBytes,
})
if err != nil {
return err
}

View File

@@ -320,10 +320,10 @@ func (d *Status) RemovePeer(peerPubKey string) error {
// UpdatePeerState updates peer status
func (d *Status) UpdatePeerState(receivedState State) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey]
if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist")
}
@@ -343,23 +343,29 @@ func (d *Status) UpdatePeerState(receivedState State) error {
d.peers[receivedState.PubKey] = peerState
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
d.notifyPeerListChanged()
}
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
// when we close the connection we will not notify the router manager
if receivedState.ConnStatus == StatusIdle {
d.notifyPeerStateChangeListeners(receivedState.PubKey)
notifyRouter := receivedState.ConnStatus == StatusIdle
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
numPeers := d.numOfPeers()
d.mux.Unlock()
if notifyList {
d.notifier.peerListChanged(numPeers)
}
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
}
return nil
}
func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[peer]
if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist")
}
@@ -371,17 +377,20 @@ func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.R
d.routeIDLookup.AddRemoteRouteID(resourceId, pref)
}
numPeers := d.numOfPeers()
d.mux.Unlock()
// todo: consider to make sense of this notification or not
d.notifyPeerListChanged()
d.notifier.peerListChanged(numPeers)
return nil
}
func (d *Status) RemovePeerStateRoute(peer string, route string) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[peer]
if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist")
}
@@ -393,8 +402,11 @@ func (d *Status) RemovePeerStateRoute(peer string, route string) error {
d.routeIDLookup.RemoveRemoteRouteID(pref)
}
numPeers := d.numOfPeers()
d.mux.Unlock()
// todo: consider to make sense of this notification or not
d.notifyPeerListChanged()
d.notifier.peerListChanged(numPeers)
return nil
}
@@ -410,10 +422,10 @@ func (d *Status) CheckRoutes(ip netip.Addr) ([]byte, bool) {
func (d *Status) UpdatePeerICEState(receivedState State) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey]
if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist")
}
@@ -431,22 +443,28 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
d.peers[receivedState.PubKey] = peerState
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
d.notifyPeerListChanged()
}
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
numPeers := d.numOfPeers()
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.mux.Unlock()
if notifyList {
d.notifier.peerListChanged(numPeers)
}
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
}
return nil
}
func (d *Status) UpdatePeerRelayedState(receivedState State) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey]
if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist")
}
@@ -461,22 +479,28 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
d.peers[receivedState.PubKey] = peerState
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
d.notifyPeerListChanged()
}
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
numPeers := d.numOfPeers()
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.mux.Unlock()
if notifyList {
d.notifier.peerListChanged(numPeers)
}
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
}
return nil
}
func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey]
if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist")
}
@@ -490,22 +514,28 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
d.peers[receivedState.PubKey] = peerState
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
d.notifyPeerListChanged()
}
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
numPeers := d.numOfPeers()
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.mux.Unlock()
if notifyList {
d.notifier.peerListChanged(numPeers)
}
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
}
return nil
}
func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey]
if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist")
}
@@ -522,12 +552,18 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
d.peers[receivedState.PubKey] = peerState
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
d.notifyPeerListChanged()
}
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
numPeers := d.numOfPeers()
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.mux.Unlock()
if notifyList {
d.notifier.peerListChanged(numPeers)
}
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
}
return nil
}
@@ -594,17 +630,33 @@ func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) erro
// FinishPeerListModifications this event invoke the notification
func (d *Status) FinishPeerListModifications() {
d.mux.Lock()
defer d.mux.Unlock()
if !d.peerListChangedForNotification {
d.mux.Unlock()
return
}
d.peerListChangedForNotification = false
d.notifyPeerListChanged()
numPeers := d.numOfPeers()
// snapshot per-peer router state to deliver after the lock is released
type routerDispatch struct {
peerID string
snapshot map[string]RouterState
}
dispatches := make([]routerDispatch, 0, len(d.peers))
for key := range d.peers {
d.notifyPeerStateChangeListeners(key)
snapshot := d.snapshotRouterPeersLocked(key, true)
if snapshot != nil {
dispatches = append(dispatches, routerDispatch{peerID: key, snapshot: snapshot})
}
}
d.mux.Unlock()
d.notifier.peerListChanged(numPeers)
for _, rd := range dispatches {
d.dispatchRouterPeers(rd.peerID, rd.snapshot)
}
}
@@ -655,10 +707,12 @@ func (d *Status) GetLocalPeerState() LocalPeerState {
// UpdateLocalPeerState updates local peer status
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
d.mux.Lock()
defer d.mux.Unlock()
d.localPeer = localPeerState
d.notifyAddressChanged()
fqdn := d.localPeer.FQDN
ip := d.localPeer.IP
d.mux.Unlock()
d.notifier.localAddressChanged(fqdn, ip)
}
// AddLocalPeerStateRoute adds a route to the local peer state
@@ -721,30 +775,36 @@ func (d *Status) CleanLocalPeerStateRoutes() {
// CleanLocalPeerState cleans local peer status
func (d *Status) CleanLocalPeerState() {
d.mux.Lock()
defer d.mux.Unlock()
d.localPeer = LocalPeerState{}
d.notifyAddressChanged()
fqdn := d.localPeer.FQDN
ip := d.localPeer.IP
d.mux.Unlock()
d.notifier.localAddressChanged(fqdn, ip)
}
// MarkManagementDisconnected sets ManagementState to disconnected
func (d *Status) MarkManagementDisconnected(err error) {
d.mux.Lock()
defer d.mux.Unlock()
defer d.onConnectionChanged()
d.managementState = false
d.managementError = err
mgm := d.managementState
sig := d.signalState
d.mux.Unlock()
d.notifier.updateServerStates(mgm, sig)
}
// MarkManagementConnected sets ManagementState to connected
func (d *Status) MarkManagementConnected() {
d.mux.Lock()
defer d.mux.Unlock()
defer d.onConnectionChanged()
d.managementState = true
d.managementError = nil
mgm := d.managementState
sig := d.signalState
d.mux.Unlock()
d.notifier.updateServerStates(mgm, sig)
}
// UpdateSignalAddress update the address of the signal server
@@ -778,21 +838,25 @@ func (d *Status) UpdateLazyConnection(enabled bool) {
// MarkSignalDisconnected sets SignalState to disconnected
func (d *Status) MarkSignalDisconnected(err error) {
d.mux.Lock()
defer d.mux.Unlock()
defer d.onConnectionChanged()
d.signalState = false
d.signalError = err
mgm := d.managementState
sig := d.signalState
d.mux.Unlock()
d.notifier.updateServerStates(mgm, sig)
}
// MarkSignalConnected sets SignalState to connected
func (d *Status) MarkSignalConnected() {
d.mux.Lock()
defer d.mux.Unlock()
defer d.onConnectionChanged()
d.signalState = true
d.signalError = nil
mgm := d.managementState
sig := d.signalState
d.mux.Unlock()
d.notifier.updateServerStates(mgm, sig)
}
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
@@ -919,7 +983,7 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
// if the server connection is not established then we will use the general address
// in case of connection we will use the instance specific address
instanceAddr, err := d.relayMgr.RelayInstanceAddress()
instanceAddr, _, err := d.relayMgr.RelayInstanceAddress()
if err != nil {
// TODO add their status
for _, r := range d.relayMgr.ServerURLs() {
@@ -1012,18 +1076,17 @@ func (d *Status) RemoveConnectionListener() {
d.notifier.removeListener()
}
func (d *Status) onConnectionChanged() {
d.notifier.updateServerStates(d.managementState, d.signalState)
}
// notifyPeerStateChangeListeners notifies route manager about the change in peer state
func (d *Status) notifyPeerStateChangeListeners(peerID string) {
subs, ok := d.changeNotify[peerID]
if !ok {
return
// snapshotRouterPeersLocked builds the RouterState map for a peer's subscribers.
// Caller MUST hold d.mux. Returns nil when there are no subscribers for peerID
// or when notify is false. The snapshot is consumed later by dispatchRouterPeers
// outside the lock so the channel send cannot stall any d.mux holder.
func (d *Status) snapshotRouterPeersLocked(peerID string, notify bool) map[string]RouterState {
if !notify {
return nil
}
if _, ok := d.changeNotify[peerID]; !ok {
return nil
}
// collect the relevant data for router peers
routerPeers := make(map[string]RouterState, len(d.changeNotify))
for pid := range d.changeNotify {
s, ok := d.peers[pid]
@@ -1031,13 +1094,35 @@ func (d *Status) notifyPeerStateChangeListeners(peerID string) {
log.Warnf("router peer not found in peers list: %s", pid)
continue
}
routerPeers[pid] = RouterState{
Status: s.ConnStatus,
Relayed: s.Relayed,
Latency: s.Latency,
}
}
return routerPeers
}
// dispatchRouterPeers delivers a previously snapshotted router-state map to
// the peer's subscribers. Caller MUST NOT hold d.mux. The method takes a
// fresh, short read of d.changeNotify under the lock to grab subscriber
// channels, then sends outside the lock so a slow consumer cannot block other
// d.mux holders. The send itself stays blocking (only short-circuited by the
// subscriber's context) so peer state transitions are not silently dropped.
func (d *Status) dispatchRouterPeers(peerID string, routerPeers map[string]RouterState) {
if routerPeers == nil {
return
}
d.mux.Lock()
subsMap, ok := d.changeNotify[peerID]
subs := make([]*StatusChangeSubscription, 0, len(subsMap))
if ok {
for _, sub := range subsMap {
subs = append(subs, sub)
}
}
d.mux.Unlock()
for _, sub := range subs {
select {
@@ -1047,14 +1132,6 @@ func (d *Status) notifyPeerStateChangeListeners(peerID string) {
}
}
func (d *Status) notifyPeerListChanged() {
d.notifier.peerListChanged(d.numOfPeers())
}
func (d *Status) notifyAddressChanged() {
d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP)
}
func (d *Status) numOfPeers() int {
return len(d.peers) + len(d.offlinePeers)
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"net"
"net/netip"
"sync"
"sync/atomic"
@@ -53,15 +54,19 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.relaySupportedOnRemotePeer.Store(true)
// the relayManager will return with error in case if the connection has lost with relay server
currentRelayAddress, err := w.relayManager.RelayInstanceAddress()
currentRelayAddress, _, err := w.relayManager.RelayInstanceAddress()
if err != nil {
w.log.Errorf("failed to handle new offer: %s", err)
return
}
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
var serverIP netip.Addr
if srv == remoteOfferAnswer.RelaySrvAddress {
serverIP = remoteOfferAnswer.RelaySrvIP
}
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key)
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key, serverIP)
if err != nil {
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
w.log.Debugf("handled offer by reusing existing relay connection")
@@ -90,7 +95,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
})
}
func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
func (w *WorkerRelay) RelayInstanceAddress() (string, netip.Addr, error) {
return w.relayManager.RelayInstanceAddress()
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net"
"net/netip"
"runtime"
"sync"
"time"
@@ -177,7 +178,12 @@ func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) {
return nil, nil, err
}
_, gateway, localIP, err = router.Route(net.IPv4zero)
dst := net.IPv4zero
if runtime.GOOS == "linux" {
// go-netroute v0.4.0 rejects unspecified destinations client-side on Linux.
dst = net.IPv4(0, 0, 0, 1)
}
_, gateway, localIP, err = router.Route(dst)
if err != nil {
return nil, nil, err
}
@@ -196,7 +202,12 @@ func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) {
return nil, nil, err
}
_, gateway, localIP, err = router.Route(net.IPv6zero)
dst := net.IPv6zero
if runtime.GOOS == "linux" {
// ::2
dst = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
}
_, gateway, localIP, err = router.Route(dst)
if err != nil {
return nil, nil, err
}

View File

@@ -89,8 +89,16 @@ func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec)
}
reused := false
if err := r.addScopedDefault(unspec, nexthop); err != nil {
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
if !errors.Is(err, unix.EEXIST) {
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
}
// macOS installs its own RTF_IFSCOPE defaults for primary service
// selection on multi-NIC setups, so a route on this ifindex can
// already exist before we try. Binding to it via IP[V6]_BOUND_IF
// still produces the scoped lookup we need.
reused = true
}
af := unix.AF_INET
@@ -102,7 +110,11 @@ func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
if nexthop.IP.IsValid() {
via = nexthop.IP.String()
}
log.Infof("installed scoped default route via %s on %s for %s", via, nexthop.Intf.Name, afOf(unspec))
verb := "installed"
if reused {
verb = "reused existing"
}
log.Infof("%s scoped default route via %s on %s for %s", verb, via, nexthop.Intf.Name, afOf(unspec))
return true, nil
}

View File

@@ -342,6 +342,22 @@ func GetNextHop(ip netip.Addr) (Nexthop, error) {
if err != nil {
return Nexthop{}, fmt.Errorf("new netroute: %w", err)
}
// go-netroute v0.4.0 rejects unspecified destinations on Linux with a hard
// client-side check. Substitute the lowest non-loopback address so the
// lookup falls through to the default route (::1 / 127.0.0.1 would match
// loopback, ::/0.0.0.0 are unspec). BSD/Windows pass the query straight to
// the kernel and need no substitution.
if runtime.GOOS == "linux" && ip.IsUnspecified() {
if ip.Is6() {
// ::2
ip = netip.AddrFrom16([16]byte{15: 2})
} else {
// 0.0.0.1
ip = netip.AddrFrom4([4]byte{0, 0, 0, 1})
}
}
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
if err != nil {
log.Debugf("Failed to get route for %s: %v", ip, err)

View File

@@ -354,9 +354,13 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
require.NoError(t, err, "Should be able to get IPv4 default route")
t.Logf("Initial IPv4 next hop: %s", initialNextHopV4)
if testCase.prefix.Addr().Is6() && !testCase.expectError {
ensureIPv6DefaultRoute(t)
}
initialNextHopV6, err := GetNextHop(netip.IPv6Unspecified())
if testCase.prefix.Addr().Is6() &&
(errors.Is(err, vars.ErrRouteNotFound) || initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun")) {
initialNextHopV6.Intf != nil && strings.HasPrefix(initialNextHopV6.Intf.Name, "utun") {
t.Skip("Skipping test as no ipv6 default route is available")
}
if err != nil && !errors.Is(err, vars.ErrRouteNotFound) {

View File

@@ -0,0 +1,30 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package systemops
import (
"bytes"
"os/exec"
"testing"
)
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
// interface so route lookups for global IPv6 prefixes resolve in environments
// without v6 connectivity. If a default already exists it is left alone.
func ensureIPv6DefaultRoute(t *testing.T) {
t.Helper()
out, err := exec.Command("route", "-6", "add", "default", "-iface", "lo0").CombinedOutput()
if err != nil {
// Existing default; nothing to install or clean up.
if bytes.Contains(out, []byte("route already in table")) {
return
}
t.Skipf("install IPv6 fallback default route: %v: %s", err, out)
}
t.Cleanup(func() {
if out, err := exec.Command("route", "-6", "delete", "default").CombinedOutput(); err != nil {
t.Logf("delete IPv6 fallback default route: %v: %s", err, out)
}
})
}

View File

@@ -0,0 +1,41 @@
//go:build linux && !android
package systemops
import (
"errors"
"net"
"syscall"
"testing"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
)
// ensureIPv6DefaultRoute installs a low-preference IPv6 default route via the
// loopback interface so route lookups for global IPv6 prefixes resolve in
// environments without v6 connectivity. Any pre-existing default route wins
// because of its lower metric.
func ensureIPv6DefaultRoute(t *testing.T) {
t.Helper()
lo, err := netlink.LinkByName("lo")
require.NoError(t, err, "find loopback interface")
route := &netlink.Route{
Dst: &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)},
LinkIndex: lo.Attrs().Index,
Priority: 1 << 20,
}
if err := netlink.RouteAdd(route); err != nil {
if errors.Is(err, syscall.EEXIST) {
return
}
t.Skipf("install IPv6 fallback default route: %v", err)
}
t.Cleanup(func() {
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) {
t.Logf("delete IPv6 fallback default route: %v", err)
}
})
}

View File

@@ -0,0 +1,34 @@
//go:build windows
package systemops
import (
"bytes"
"os/exec"
"testing"
)
const loopbackIfaceWindows = "Loopback Pseudo-Interface 1"
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
// interface so route lookups for global IPv6 prefixes resolve in environments
// without v6 connectivity. If a default already exists it is left alone.
func ensureIPv6DefaultRoute(t *testing.T) {
t.Helper()
script := `New-NetRoute -DestinationPrefix "::/0" -InterfaceAlias "` + loopbackIfaceWindows + `" -RouteMetric 9999 -PolicyStore ActiveStore -ErrorAction Stop`
out, err := exec.Command("powershell", "-Command", script).CombinedOutput()
if err != nil {
// Existing default; nothing to install or clean up.
if bytes.Contains(out, []byte("already exists")) {
return
}
t.Skipf("install IPv6 fallback default route: %v: %s", err, out)
}
t.Cleanup(func() {
script := `Remove-NetRoute -DestinationPrefix "::/0" -InterfaceAlias "` + loopbackIfaceWindows + `" -Confirm:$false -ErrorAction Stop`
if out, err := exec.Command("powershell", "-Command", script).CombinedOutput(); err != nil {
t.Logf("delete IPv6 fallback default route: %v: %s", err, out)
}
})
}

View File

@@ -7,7 +7,6 @@ import (
"sync"
"github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/route"
@@ -44,8 +43,8 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{}
}
maps.Clear(rs.deselectedRoutes)
maps.Clear(rs.selectedRoutes)
clear(rs.deselectedRoutes)
clear(rs.selectedRoutes)
for _, r := range allRoutes {
rs.deselectedRoutes[r] = struct{}{}
}
@@ -78,8 +77,8 @@ func (rs *RouteSelector) SelectAllRoutes() {
if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{}
}
maps.Clear(rs.deselectedRoutes)
maps.Clear(rs.selectedRoutes)
clear(rs.deselectedRoutes)
clear(rs.selectedRoutes)
}
// DeselectRoutes removes specific routes from the selection.
@@ -116,8 +115,8 @@ func (rs *RouteSelector) DeselectAllRoutes() {
if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{}
}
maps.Clear(rs.deselectedRoutes)
maps.Clear(rs.selectedRoutes)
clear(rs.deselectedRoutes)
clear(rs.selectedRoutes)
}
// IsSelected checks if a specific route is selected.

View File

@@ -2,217 +2,358 @@
package sleep
/*
#cgo LDFLAGS: -framework IOKit -framework CoreFoundation
#include <IOKit/pwr_mgt/IOPMLib.h>
#include <IOKit/IOMessage.h>
#include <CoreFoundation/CoreFoundation.h>
extern void sleepCallbackBridge();
extern void poweredOnCallbackBridge();
extern void suspendedCallbackBridge();
extern void resumedCallbackBridge();
// C global variables for IOKit state
static IONotificationPortRef g_notifyPortRef = NULL;
static io_object_t g_notifierObject = 0;
static io_object_t g_generalInterestNotifier = 0;
static io_connect_t g_rootPort = 0;
static CFRunLoopRef g_runLoop = NULL;
static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) {
switch (messageType) {
case kIOMessageSystemWillSleep:
sleepCallbackBridge();
IOAllowPowerChange(g_rootPort, (long)messageArgument);
break;
case kIOMessageSystemHasPoweredOn:
poweredOnCallbackBridge();
break;
case kIOMessageServiceIsSuspended:
suspendedCallbackBridge();
break;
case kIOMessageServiceIsResumed:
resumedCallbackBridge();
break;
default:
break;
}
}
static void registerNotifications() {
g_rootPort = IORegisterForSystemPower(
NULL,
&g_notifyPortRef,
(IOServiceInterestCallback)sleepCallback,
&g_notifierObject
);
if (g_rootPort == 0) {
return;
}
CFRunLoopAddSource(CFRunLoopGetCurrent(),
IONotificationPortGetRunLoopSource(g_notifyPortRef),
kCFRunLoopCommonModes);
g_runLoop = CFRunLoopGetCurrent();
CFRunLoopRun();
}
static void unregisterNotifications() {
CFRunLoopRemoveSource(g_runLoop,
IONotificationPortGetRunLoopSource(g_notifyPortRef),
kCFRunLoopCommonModes);
IODeregisterForSystemPower(&g_notifierObject);
IOServiceClose(g_rootPort);
IONotificationPortDestroy(g_notifyPortRef);
CFRunLoopStop(g_runLoop);
g_notifyPortRef = NULL;
g_notifierObject = 0;
g_rootPort = 0;
g_runLoop = NULL;
}
*/
import "C"
import (
"context"
"fmt"
"runtime"
"sync"
"time"
"unsafe"
"github.com/ebitengine/purego"
log "github.com/sirupsen/logrus"
)
var (
serviceRegistry = make(map[*Detector]struct{})
serviceRegistryMu sync.Mutex
// IOKit message types from IOKit/IOMessage.h.
const (
kIOMessageCanSystemSleep uintptr = 0xe0000270
kIOMessageSystemWillSleep uintptr = 0xe0000280
kIOMessageSystemHasPoweredOn uintptr = 0xe0000300
)
//export sleepCallbackBridge
func sleepCallbackBridge() {
log.Info("sleepCallbackBridge event triggered")
var (
ioKit iokitFuncs
cf cfFuncs
cfCommonModes uintptr
serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
libInitOnce sync.Once
libInitErr error
for svc := range serviceRegistry {
svc.triggerCallback(EventTypeSleep)
}
// callbackThunk is the single C-callable trampoline registered with IOKit.
callbackThunk uintptr
serviceRegistry = make(map[*Detector]struct{})
serviceRegistryMu sync.Mutex
session *runLoopSession
// lifecycleMu serializes Register/Deregister so a new registration can't
// start a second runloop while a previous teardown is still pending.
lifecycleMu sync.Mutex
)
// iokitFuncs holds IOKit symbols resolved once at init.
type iokitFuncs struct {
IORegisterForSystemPower func(refcon uintptr, portRef *uintptr, callback uintptr, notifier *uintptr) uintptr
IODeregisterForSystemPower func(notifier *uintptr) int32
IOAllowPowerChange func(kernelPort uintptr, notificationID uintptr) int32
IOServiceClose func(connect uintptr) int32
IONotificationPortGetRunLoopSource func(port uintptr) uintptr
IONotificationPortDestroy func(port uintptr)
}
//export resumedCallbackBridge
func resumedCallbackBridge() {
log.Info("resumedCallbackBridge event triggered")
// cfFuncs holds CoreFoundation symbols resolved once at init.
type cfFuncs struct {
CFRunLoopGetCurrent func() uintptr
CFRunLoopRun func()
CFRunLoopStop func(rl uintptr)
CFRunLoopAddSource func(rl, source, mode uintptr)
CFRunLoopRemoveSource func(rl, source, mode uintptr)
}
//export suspendedCallbackBridge
func suspendedCallbackBridge() {
log.Info("suspendedCallbackBridge event triggered")
// runLoopSession bundles the handles owned by one CFRunLoop lifetime. A nil
// session means no runloop is active and the next Register must start one.
type runLoopSession struct {
rl uintptr
port uintptr
notifier uintptr
rp uintptr
}
//export poweredOnCallbackBridge
func poweredOnCallbackBridge() {
log.Info("poweredOnCallbackBridge event triggered")
serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
for svc := range serviceRegistry {
svc.triggerCallback(EventTypeWakeUp)
}
// detectorSnapshot pins a detector's callback and done channel so dispatch
// runs with values valid at snapshot time, even if a concurrent
// Deregister/Register rewrites the detector's fields.
type detectorSnapshot struct {
detector *Detector
callback func(event EventType)
done <-chan struct{}
}
// Detector delivers sleep and wake events to a registered callback.
type Detector struct {
callback func(event EventType)
ctx context.Context
cancel context.CancelFunc
}
func NewDetector() (*Detector, error) {
return &Detector{}, nil
done chan struct{}
}
// Register installs callback for power events. The first registration starts
// the CFRunLoop on a dedicated OS-locked thread and blocks until IOKit
// registration succeeds or fails; subsequent registrations just add to the
// dispatch set.
func (d *Detector) Register(callback func(event EventType)) error {
serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
lifecycleMu.Lock()
defer lifecycleMu.Unlock()
serviceRegistryMu.Lock()
if _, exists := serviceRegistry[d]; exists {
serviceRegistryMu.Unlock()
return fmt.Errorf("detector service already registered")
}
d.callback = callback
d.done = make(chan struct{})
serviceRegistry[d] = struct{}{}
needSetup := session == nil
serviceRegistryMu.Unlock()
d.ctx, d.cancel = context.WithCancel(context.Background())
if len(serviceRegistry) > 0 {
serviceRegistry[d] = struct{}{}
if !needSetup {
return nil
}
serviceRegistry[d] = struct{}{}
// CFRunLoop must run on a single fixed OS thread
go func() {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
C.registerNotifications()
}()
errCh := make(chan error, 1)
go runRunLoop(errCh)
if err := <-errCh; err != nil {
serviceRegistryMu.Lock()
delete(serviceRegistry, d)
close(d.done)
d.done = nil
serviceRegistryMu.Unlock()
return err
}
log.Info("sleep detection service started on macOS")
return nil
}
// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down
// and the runloop is stopped and cleaned up.
// Deregister removes the detector. When the last detector leaves, IOKit
// notifications are torn down and the runloop is stopped.
func (d *Detector) Deregister() error {
lifecycleMu.Lock()
defer lifecycleMu.Unlock()
serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
_, exists := serviceRegistry[d]
if !exists {
if _, exists := serviceRegistry[d]; !exists {
serviceRegistryMu.Unlock()
return nil
}
// cancel and remove this detector
d.cancel()
close(d.done)
delete(serviceRegistry, d)
// If other Detectors still exist, leave IOKit running
if len(serviceRegistry) > 0 {
serviceRegistryMu.Unlock()
return nil
}
sess := session
serviceRegistryMu.Unlock()
log.Info("sleep detection service stopping (deregister)")
// Deregister IOKit notifications, stop runloop, and free resources
C.unregisterNotifications()
if sess == nil {
return nil
}
if sess.rl != 0 && sess.port != 0 {
source := ioKit.IONotificationPortGetRunLoopSource(sess.port)
cf.CFRunLoopRemoveSource(sess.rl, source, cfCommonModes)
}
if sess.notifier != 0 {
n := sess.notifier
ioKit.IODeregisterForSystemPower(&n)
}
// Clear session only after IODeregisterForSystemPower returns so any
// in-flight powerCallback can still look up session.rp to ack sleep.
serviceRegistryMu.Lock()
session = nil
serviceRegistryMu.Unlock()
if sess.rp != 0 {
ioKit.IOServiceClose(sess.rp)
}
if sess.port != 0 {
ioKit.IONotificationPortDestroy(sess.port)
}
if sess.rl != 0 {
cf.CFRunLoopStop(sess.rl)
}
return nil
}
func (d *Detector) triggerCallback(event EventType) {
doneChan := make(chan struct{})
func (d *Detector) triggerCallback(event EventType, cb func(event EventType), done <-chan struct{}) {
if cb == nil || done == nil {
return
}
select {
case <-done:
return
default:
}
doneChan := make(chan struct{})
timeout := time.NewTimer(500 * time.Millisecond)
defer timeout.Stop()
cb := d.callback
go func(callback func(event EventType)) {
go func() {
defer close(doneChan)
defer func() {
if r := recover(); r != nil {
log.Errorf("panic in sleep callback: %v", r)
}
}()
log.Info("sleep detection event fired")
callback(event)
close(doneChan)
}(cb)
cb(event)
}()
select {
case <-doneChan:
case <-d.ctx.Done():
case <-done:
case <-timeout.C:
log.Warnf("sleep callback timed out")
log.Warn("sleep callback timed out")
}
}
// NewDetector initializes IOKit/CoreFoundation bindings and returns a Detector.
func NewDetector() (*Detector, error) {
if err := initLibs(); err != nil {
return nil, err
}
return &Detector{}, nil
}
func initLibs() error {
libInitOnce.Do(func() {
iokit, err := purego.Dlopen("/System/Library/Frameworks/IOKit.framework/IOKit", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
libInitErr = fmt.Errorf("dlopen IOKit: %w", err)
return
}
cfLib, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
libInitErr = fmt.Errorf("dlopen CoreFoundation: %w", err)
return
}
purego.RegisterLibFunc(&ioKit.IORegisterForSystemPower, iokit, "IORegisterForSystemPower")
purego.RegisterLibFunc(&ioKit.IODeregisterForSystemPower, iokit, "IODeregisterForSystemPower")
purego.RegisterLibFunc(&ioKit.IOAllowPowerChange, iokit, "IOAllowPowerChange")
purego.RegisterLibFunc(&ioKit.IOServiceClose, iokit, "IOServiceClose")
purego.RegisterLibFunc(&ioKit.IONotificationPortGetRunLoopSource, iokit, "IONotificationPortGetRunLoopSource")
purego.RegisterLibFunc(&ioKit.IONotificationPortDestroy, iokit, "IONotificationPortDestroy")
purego.RegisterLibFunc(&cf.CFRunLoopGetCurrent, cfLib, "CFRunLoopGetCurrent")
purego.RegisterLibFunc(&cf.CFRunLoopRun, cfLib, "CFRunLoopRun")
purego.RegisterLibFunc(&cf.CFRunLoopStop, cfLib, "CFRunLoopStop")
purego.RegisterLibFunc(&cf.CFRunLoopAddSource, cfLib, "CFRunLoopAddSource")
purego.RegisterLibFunc(&cf.CFRunLoopRemoveSource, cfLib, "CFRunLoopRemoveSource")
modeAddr, err := purego.Dlsym(cfLib, "kCFRunLoopCommonModes")
if err != nil {
libInitErr = fmt.Errorf("dlsym kCFRunLoopCommonModes: %w", err)
return
}
// Launder the uintptr-to-pointer conversion through a Go variable so
// go vet's unsafeptr analyzer doesn't flag a system-library global.
cfCommonModes = **(**uintptr)(unsafe.Pointer(&modeAddr))
// NewCallback slots are a finite, non-reclaimable resource, so register
// a single thunk that dispatches to the current Detector set.
callbackThunk = purego.NewCallback(powerCallback)
})
return libInitErr
}
// powerCallback is the IOServiceInterestCallback trampoline, invoked on the
// runloop thread. A Go panic crossing the purego boundary has undefined
// behavior, so contain it here.
func powerCallback(refcon, service, messageType, messageArgument uintptr) uintptr {
defer func() {
if r := recover(); r != nil {
log.Errorf("panic in sleep powerCallback: %v", r)
}
}()
switch messageType {
case kIOMessageCanSystemSleep:
// Not acknowledging forces a 30s IOKit timeout before idle sleep.
allowPowerChange(messageArgument)
case kIOMessageSystemWillSleep:
dispatchEvent(EventTypeSleep)
allowPowerChange(messageArgument)
case kIOMessageSystemHasPoweredOn:
dispatchEvent(EventTypeWakeUp)
}
return 0
}
func allowPowerChange(messageArgument uintptr) {
serviceRegistryMu.Lock()
var port uintptr
if session != nil {
port = session.rp
}
serviceRegistryMu.Unlock()
if port != 0 {
ioKit.IOAllowPowerChange(port, messageArgument)
}
}
func dispatchEvent(event EventType) {
serviceRegistryMu.Lock()
snaps := make([]detectorSnapshot, 0, len(serviceRegistry))
for d := range serviceRegistry {
snaps = append(snaps, detectorSnapshot{
detector: d,
callback: d.callback,
done: d.done,
})
}
serviceRegistryMu.Unlock()
for _, s := range snaps {
s.detector.triggerCallback(event, s.callback, s.done)
}
}
// runRunLoop owns the OS-locked thread that CFRunLoop is pinned to. Setup
// result is reported on errCh so Register can surface failures synchronously.
func runRunLoop(errCh chan<- error) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
sess, err := setupSession()
if err == nil {
serviceRegistryMu.Lock()
session = sess
serviceRegistryMu.Unlock()
}
errCh <- err
if err != nil {
return
}
defer func() {
if r := recover(); r != nil {
log.Errorf("panic in sleep runloop: %v", r)
}
}()
cf.CFRunLoopRun()
}
// setupSession performs the IOKit registration on the current thread. Panics
// are converted to errors so runRunLoop never leaves errCh unsent.
func setupSession() (s *runLoopSession, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic during runloop setup: %v", r)
}
}()
var portRef, notifier uintptr
rp := ioKit.IORegisterForSystemPower(0, &portRef, callbackThunk, &notifier)
if rp == 0 {
return nil, fmt.Errorf("IORegisterForSystemPower returned zero")
}
rl := cf.CFRunLoopGetCurrent()
source := ioKit.IONotificationPortGetRunLoopSource(portRef)
cf.CFRunLoopAddSource(rl, source, cfCommonModes)
return &runLoopSession{rl: rl, port: portRef, notifier: notifier, rp: rp}, nil
}

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"net"
"runtime"
"time"
log "github.com/sirupsen/logrus"
@@ -28,6 +27,10 @@ func NewWGIfaceMonitor() *WGIfaceMonitor {
// Start begins monitoring the WireGuard interface.
// It relies on the provided context cancellation to stop.
//
// On Linux the watcher is event-driven (RTNLGRP_LINK netlink subscription)
// to avoid the allocation churn of repeatedly dumping the kernel link
// table; on other platforms it falls back to a low-frequency poll.
func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) {
defer close(m.done)
@@ -56,31 +59,7 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex)
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Infof("Interface monitor: stopped for %s", ifaceName)
return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err())
case <-ticker.C:
currentIndex, err := getInterfaceIndex(ifaceName)
if err != nil {
// Interface was deleted
log.Infof("Interface monitor: %s deleted", ifaceName)
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
}
// Check if interface index changed (interface was recreated)
if currentIndex != expectedIndex {
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
ifaceName, expectedIndex, currentIndex)
return true, nil
}
}
}
return watchInterface(ctx, ifaceName, expectedIndex)
}
// getInterfaceIndex returns the index of a network interface by name.

View File

@@ -0,0 +1,134 @@
//go:build linux
package internal
import (
"context"
"fmt"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
)
// watchInterface uses an RTNLGRP_LINK netlink subscription to detect
// deletion or recreation of the WireGuard interface.
//
// The previous implementation polled net.InterfaceByName every 2 s, which
// on Linux issues syscall.NetlinkRIB(RTM_GETLINK, ...) and dumps the
// entire kernel link table on every call. On hosts with many veth
// interfaces (containers, bridges) the resulting allocation churn was on
// the order of ~1 GB/day from this single ticker, which on small ARM
// hosts manifested as a slow RSS climb (see netbirdio/netbird#3678).
//
// The event-driven version below allocates only when the kernel actually
// publishes a link event for the tracked interface — typically zero
// allocations between events.
func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) {
done := make(chan struct{})
defer close(done)
// Buffer the channel to absorb event bursts (e.g. when many veth
// pairs are created/destroyed at once by container runtimes).
linkChan := make(chan netlink.LinkUpdate, 32)
if err := netlink.LinkSubscribe(linkChan, done); err != nil {
// Return shouldRestart=true so the engine recovers monitoring
// via triggerClientRestart instead of silently losing it for
// the rest of the process lifetime.
return true, fmt.Errorf("subscribe to link updates: %w", err)
}
// Race window: the interface could have been deleted (or recreated)
// between the initial getInterfaceIndex() in Start and LinkSubscribe
// completing its handshake with the kernel. Re-check explicitly so we
// do not block forever waiting for an event that already fired.
if currentIndex, err := getInterfaceIndex(ifaceName); err != nil {
log.Infof("Interface monitor: %s deleted before subscription completed", ifaceName)
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
} else if currentIndex != expectedIndex {
log.Infof("Interface monitor: %s recreated (index changed from %d to %d) before subscription completed",
ifaceName, expectedIndex, currentIndex)
return true, nil
}
for {
select {
case <-ctx.Done():
log.Infof("Interface monitor: stopped for %s", ifaceName)
return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err())
case update, ok := <-linkChan:
if !ok {
// The vishvananda/netlink subscription goroutine closes
// the channel on receive errors. Signal the engine to
// restart so monitoring is re-established instead of
// silently ending.
log.Warnf("Interface monitor: link subscription channel closed unexpectedly for %s", ifaceName)
return true, fmt.Errorf("link subscription channel closed unexpectedly")
}
if restart, err := inspectLinkEvent(update, ifaceName, expectedIndex); restart {
return true, err
}
}
}
}
// inspectLinkEvent classifies a single netlink link update against the
// tracked WireGuard interface. It returns (true, err) when the engine
// should restart monitoring; (false, nil) means the event is unrelated
// and the caller should keep waiting.
//
// The error component, when non-nil, describes the kernel-side reason
// (deletion or rename); the recreation case returns (true, nil) since
// no error condition is reported.
func inspectLinkEvent(update netlink.LinkUpdate, ifaceName string, expectedIndex int) (bool, error) {
eventIndex := int(update.Index)
eventName := ""
if attrs := update.Attrs(); attrs != nil {
eventName = attrs.Name
}
switch update.Header.Type {
case syscall.RTM_DELLINK:
return inspectDelLink(eventIndex, ifaceName, expectedIndex)
case syscall.RTM_NEWLINK:
return inspectNewLink(eventIndex, eventName, ifaceName, expectedIndex)
}
return false, nil
}
// inspectDelLink reports a restart when an RTM_DELLINK arrives for the
// tracked interface index.
func inspectDelLink(eventIndex int, ifaceName string, expectedIndex int) (bool, error) {
if eventIndex != expectedIndex {
return false, nil
}
log.Infof("Interface monitor: %s deleted", ifaceName)
return true, fmt.Errorf("interface %s deleted", ifaceName)
}
// inspectNewLink reports a restart when an RTM_NEWLINK either:
//
// 1. Introduces a link with our name at a different index (recreation
// after a delete), or
//
// 2. Reports a link still at our index but with a different name
// (in-place rename). The previous polling implementation caught
// this implicitly because net.InterfaceByName(ifaceName) would
// start failing; the event-driven version has to test it.
//
// Same name + same index is just a flag/state change on the existing
// interface and is ignored.
func inspectNewLink(eventIndex int, eventName, ifaceName string, expectedIndex int) (bool, error) {
if eventName == ifaceName && eventIndex != expectedIndex {
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
ifaceName, expectedIndex, eventIndex)
return true, nil
}
if eventIndex == expectedIndex && eventName != "" && eventName != ifaceName {
log.Infof("Interface monitor: %s renamed to %s (index %d), restarting engine",
ifaceName, eventName, expectedIndex)
return true, fmt.Errorf("interface %s renamed to %s", ifaceName, eventName)
}
return false, nil
}

View File

@@ -0,0 +1,56 @@
//go:build !linux
package internal
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
)
// watchInterface polls net.InterfaceByName at a fixed interval to detect
// deletion or recreation of the WireGuard interface.
//
// This is the fallback used on non-Linux desktop and server platforms
// (darwin, windows, freebsd). It is also compiled on android and ios so
// the package builds on every supported GOOS, but it is never reached
// at runtime there because Start() in wg_iface_monitor.go exits early
// on mobile platforms.
//
// The Linux build (see wg_iface_monitor_linux.go) uses an event-driven
// RTNLGRP_LINK netlink subscription instead, because on Linux
// net.InterfaceByName issues syscall.NetlinkRIB(RTM_GETLINK, ...) which
// dumps the entire kernel link table on every call and produces
// significant allocation churn (netbirdio/netbird#3678).
//
// Windows is also reported in #3678 as affected by RSS climb. A future
// follow-up could implement an event-driven watcher there using
// NotifyIpInterfaceChange from iphlpapi.
func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Infof("Interface monitor: stopped for %s", ifaceName)
return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err())
case <-ticker.C:
currentIndex, err := getInterfaceIndex(ifaceName)
if err != nil {
// Interface was deleted
log.Infof("Interface monitor: %s deleted", ifaceName)
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
}
// Check if interface index changed (interface was recreated)
if currentIndex != expectedIndex {
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
ifaceName, expectedIndex, currentIndex)
return true, nil
}
}
}
}