Compare commits

...

4 Commits

Author SHA1 Message Date
Zoltan Papp
d092827fca [client] Fix pointer comparisons in profile config apply
apply() compared several *bool/*int ConfigInput fields against the
Config fields by pointer identity instead of by value, so any
non-nil input always looked "changed" and triggered a spurious log
line plus an unconditional config rewrite even when the value was
unchanged.
2026-07-01 11:02:57 +02:00
Viktor Liu
4ef65294e9 [client] Reinject captured first packet on lazy connection activation (#6572) 2026-06-30 11:22:25 +02:00
Bethuel Mmbaga
5b5f11740a [misc] Require on-premise EULA acceptance in enterprise scripts (#6596) 2026-06-30 11:34:23 +03:00
Riccardo Manfrin
3de889d529 [client] bound system info / posture-check gathering with a timeout to prevent sync-loop freeze (#6512)
* Wraps syestem info / posture checks into a goroutine with timeout

e.checks = checks is set before doing the SyncMeta,
so if it fails next time isCheckEquals compares true and bypasses
the update. This is to avoid another repeating the 15 seconds hang.
The checks will be synced on reconnect or posture checks changes
push from mgmt.

* Propagate context to OS calls that can leverage its cancellation / timeout

* Distinguish timeout from cancellation in logs

* Dont log twice

* Block on timeout failure and reapply the exclude_ips

* Refactor for complexity
2026-06-30 08:18:51 +02:00
27 changed files with 403 additions and 89 deletions

View File

@@ -136,6 +136,11 @@ func (p *ProxyBind) CloseConn() error {
return p.close()
}
// InjectPacket is a no-op for the userspace proxy: first-packet reinjection is kernel-only.
func (p *ProxyBind) InjectPacket(_ []byte) error {
return nil
}
func (p *ProxyBind) close() error {
if p.remoteConn == nil {
return nil

View File

@@ -219,6 +219,17 @@ func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Unlock()
}
// InjectPacket writes b to the remote peer over the underlying transport.
func (p *ProxyWrapper) InjectPacket(b []byte) error {
if p.remoteConn == nil {
return errors.New("proxy not started")
}
if _, err := p.remoteConn.Write(b); err != nil {
return err
}
return nil
}
// CloseConn close the remoteConn and automatically remove the conn instance from the map
func (p *ProxyWrapper) CloseConn() error {
if p.cancel == nil {

View File

@@ -18,4 +18,9 @@ type Proxy interface {
RedirectAs(endpoint *net.UDPAddr)
CloseConn() error
SetDisconnectListener(disconnected func())
// InjectPacket writes a raw packet directly to the remote peer over the underlying transport,
// bypassing WireGuard. Used to replay the captured lazyconn handshake initiation. Only the
// kernel-mode proxies act on it; the userspace proxy is a no-op since reinjection is kernel-only.
InjectPacket(b []byte) error
}

View File

@@ -147,6 +147,17 @@ func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
p.sendPkg = p.srcFakerConn.SendPkg
}
// InjectPacket writes b to the remote peer over the underlying transport.
func (p *WGUDPProxy) InjectPacket(b []byte) error {
if p.remoteConn == nil {
return errors.New("proxy not started")
}
if _, err := p.remoteConn.Write(b); err != nil {
return err
}
return nil
}
// CloseConn close the localConn
func (p *WGUDPProxy) CloseConn() error {
if p.cancel == nil {

View File

@@ -82,6 +82,12 @@ const (
PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms
disableAutoUpdate = "disabled"
// systemInfoTimeout bounds how long the sync loop waits for system info / posture
// check gathering. The gathering runs uncancellable system calls (process scan,
// exec, os.Stat); without this bound a single stuck call freezes handleSync, and
// thus syncMsgMux, for as long as the call hangs (observed multi-minute freezes).
systemInfoTimeout = 15 * time.Second
)
var ErrResetConnection = fmt.Errorf("reset connection")
@@ -1084,11 +1090,22 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
}
e.checks = checks
info, err := system.GetInfoWithChecks(e.ctx, checks, e.overlayAddresses()...)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
info, ok := system.GetInfoWithChecksTimeout(e.ctx, systemInfoTimeout, checks, e.overlayAddresses()...)
if !ok {
// Gathering timed out; skip the meta sync this cycle rather than blocking the
// sync loop (and syncMsgMux) on a stuck system call. A later sync will retry.
return nil
}
e.applyInfoFlags(info)
if err := e.mgmClient.SyncMeta(info); err != nil {
return fmt.Errorf("could not sync meta: error %s", err)
}
return nil
}
// applyInfoFlags sets the engine's config-derived feature flags on the gathered system info.
func (e *Engine) applyInfoFlags(info *system.Info) {
info.SetFlags(
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
@@ -1107,12 +1124,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
)
if err := e.mgmClient.SyncMeta(info); err != nil {
log.Errorf("could not sync meta: error %s", err)
return err
}
return nil
}
// overlayAddresses returns our own WireGuard overlay address (v4 and v6) so it
@@ -1272,31 +1283,15 @@ func (e *Engine) receiveManagementEvents() {
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
info, err := system.GetInfoWithChecks(e.ctx, e.checks, e.overlayAddresses()...)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info, ok := system.GetInfoWithChecksTimeout(e.ctx, systemInfoTimeout, e.checks, e.overlayAddresses()...)
if !ok {
// Gathering timed out; connect the stream with base info so management
// connectivity still comes up rather than blocking here.
info = system.GetInfo(e.ctx)
}
info.SetFlags(
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
e.config.DisableFirewall,
e.config.BlockLANAccess,
e.config.BlockInbound,
e.config.DisableIPv6,
e.config.LazyConnectionEnabled,
e.config.EnableSSHRoot,
e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
)
e.applyInfoFlags(info)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
err := e.mgmClient.Sync(e.ctx, info, e.handleSync)
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client

View File

@@ -178,6 +178,10 @@ func (m *MockWGIface) LastActivities() map[string]monotime.Time {
return nil
}
func (m *MockWGIface) MTU() uint16 {
return 1280
}
func (m *MockWGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
return nil
}

View File

@@ -44,4 +44,5 @@ type wgIfaceBase interface {
FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
MTU() uint16
}

View File

@@ -124,6 +124,11 @@ func (d *BindListener) ReadPackets() {
d.done.Done()
}
// CapturedPacket is unused in userspace bind mode: first-packet reinjection is kernel-only.
func (d *BindListener) CapturedPacket() []byte {
return nil
}
// Close stops the listener and cleans up resources.
func (d *BindListener) Close() {
d.peerCfg.Log.Infof("closing activity listener (LazyConn)")

View File

@@ -45,10 +45,6 @@ type MockWGIfaceBind struct {
endpointMgr *mockEndpointManager
}
func (m *MockWGIfaceBind) RemovePeer(string) error {
return nil
}
func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
return nil
}
@@ -68,6 +64,10 @@ func (m *MockWGIfaceBind) GetBind() device.EndpointManager {
return m.endpointMgr
}
func (m *MockWGIfaceBind) MTU() uint16 {
return 1280
}
func TestBindListener_Creation(t *testing.T) {
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
@@ -207,8 +207,9 @@ func TestManager_BindMode(t *testing.T) {
require.NoError(t, err)
select {
case peerConnID := <-mgr.OnActivityChan:
assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match")
case ev := <-mgr.OnActivityChan:
assert.Equal(t, cfg.PeerConnID, ev.PeerConnID, "Received peer connection ID should match")
assert.Nil(t, ev.FirstPacket, "Bind mode does not capture packets: reinjection is kernel-only")
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity notification")
}
@@ -266,8 +267,8 @@ func TestManager_BindMode_MultiplePeers(t *testing.T) {
receivedPeers := make(map[peerid.ConnID]bool)
for i := 0; i < 2; i++ {
select {
case peerConnID := <-mgr.OnActivityChan:
receivedPeers[peerConnID] = true
case ev := <-mgr.OnActivityChan:
receivedPeers[ev.PeerConnID] = true
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity notifications")
}

View File

@@ -3,11 +3,13 @@ package activity
import (
"fmt"
"net"
"slices"
"sync"
"sync/atomic"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
@@ -20,6 +22,8 @@ type UDPListener struct {
done sync.Mutex
isClosed atomic.Bool
capturedPacket []byte
}
// NewUDPListener creates a listener that detects activity via UDP socket reads.
@@ -46,9 +50,13 @@ func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener,
}
// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed.
// The first packet that triggers activity is captured so it can be reinjected through the real
// transport once it is established. Without this, kernel WireGuard's handshake initiation would be
// dropped and WG would only retry after REKEY_TIMEOUT.
func (d *UDPListener) ReadPackets() {
for {
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
buf := make([]byte, int(d.wgIface.MTU())+bufsize.WGBufferOverhead)
n, remoteAddr, err := d.conn.ReadFromUDP(buf)
if err != nil {
if d.isClosed.Load() {
d.peerCfg.Log.Infof("exit from activity listener")
@@ -62,20 +70,24 @@ func (d *UDPListener) ReadPackets() {
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
continue
}
d.peerCfg.Log.Infof("activity detected")
d.capturedPacket = slices.Clone(buf[:n])
d.peerCfg.Log.Infof("activity detected, captured %d bytes for reinjection", n)
break
}
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
// Ignore close error as it may return "use of closed network connection" if already closed.
// Leave the peer in place. ConfigureWGEndpoint will UpdatePeer with the real endpoint;
// removing the peer here wipes kernel WG's staged queue and drops the user packet that
// triggered activation.
_ = d.conn.Close()
d.done.Unlock()
}
// CapturedPacket returns the first packet that triggered activity, or nil if none was captured.
// Safe to call after ReadPackets returns.
func (d *UDPListener) CapturedPacket() []byte {
return d.capturedPacket
}
// Close stops the listener and cleans up resources.
func (d *UDPListener) Close() {
d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())

View File

@@ -19,17 +19,25 @@ import (
type listener interface {
ReadPackets()
Close()
CapturedPacket() []byte
}
// Event reports activity on a managed peer. FirstPacket is the bytes that triggered activation,
// captured for reinjection through the real transport.
type Event struct {
PeerConnID peerid.ConnID
FirstPacket []byte
}
type WgInterface interface {
RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
IsUserspaceBind() bool
Address() wgaddr.Address
MTU() uint16
}
type Manager struct {
OnActivityChan chan peerid.ConnID
OnActivityChan chan Event
wgIface WgInterface
@@ -41,7 +49,7 @@ type Manager struct {
func NewManager(wgIface WgInterface) *Manager {
m := &Manager{
OnActivityChan: make(chan peerid.ConnID, 1),
OnActivityChan: make(chan Event, 1),
wgIface: wgIface,
peers: make(map[peerid.ConnID]listener),
done: make(chan struct{}),
@@ -116,12 +124,12 @@ func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) {
delete(m.peers, peerConnID)
m.mu.Unlock()
m.notify(peerConnID)
m.notify(Event{PeerConnID: peerConnID, FirstPacket: l.CapturedPacket()})
}
func (m *Manager) notify(peerConnID peerid.ConnID) {
func (m *Manager) notify(ev Event) {
select {
case <-m.done:
case m.OnActivityChan <- peerConnID:
case m.OnActivityChan <- ev:
}
}

View File

@@ -1,6 +1,7 @@
package activity
import (
"bytes"
"net"
"net/netip"
"testing"
@@ -25,10 +26,6 @@ func (m *MocPeer) ConnID() peerid.ConnID {
type MocWGIface struct {
}
func (m MocWGIface) RemovePeer(string) error {
return nil
}
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
return nil
}
@@ -44,6 +41,10 @@ func (m MocWGIface) Address() wgaddr.Address {
}
}
func (m MocWGIface) MTU() uint16 {
return 1280
}
// GetPeerListener is a test helper to access listeners
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) {
m.mu.Lock()
@@ -86,11 +87,15 @@ func TestManager_MonitorPeerActivity(t *testing.T) {
}
select {
case peerConnID := <-mgr.OnActivityChan:
if peerConnID != peerCfg1.PeerConnID {
t.Fatalf("unexpected peerConnID: %v", peerConnID)
case ev := <-mgr.OnActivityChan:
if ev.PeerConnID != peerCfg1.PeerConnID {
t.Fatalf("unexpected peerConnID: %v", ev.PeerConnID)
}
if !bytes.Equal(ev.FirstPacket, []byte{0x01, 0x02, 0x03, 0x04, 0x05}) {
t.Fatalf("unexpected first packet: %v", ev.FirstPacket)
}
case <-time.After(1 * time.Second):
t.Fatal("timed out waiting for activity")
}
}

View File

@@ -130,8 +130,8 @@ func (m *Manager) Start(ctx context.Context) {
select {
case <-ctx.Done():
return
case peerConnID := <-m.activityManager.OnActivityChan:
m.onPeerActivity(peerConnID)
case ev := <-m.activityManager.OnActivityChan:
m.onPeerActivity(ev)
case peerIDs := <-m.inactivityManager.InactivePeersChan():
m.onPeerInactivityTimedOut(peerIDs)
}
@@ -513,13 +513,13 @@ func (m *Manager) checkHaGroupActivity(haGroup route.HAUniqueID, peerID string,
return false
}
func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
func (m *Manager) onPeerActivity(ev activity.Event) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
mp, ok := m.managedPeersByConnID[ev.PeerConnID]
if !ok {
log.Errorf("peer not found by conn id: %v", peerConnID)
log.Errorf("peer not found by conn id: %v", ev.PeerConnID)
return
}
@@ -536,7 +536,7 @@ func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
m.activateHAGroupPeers(mp.peerCfg)
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
m.peerStore.PeerConnOpenWithFirstPacket(m.engineCtx, mp.peerCfg.PublicKey, ev.FirstPacket)
}
func (m *Manager) onPeerInactivityTimedOut(peerIDs map[string]struct{}) {

View File

@@ -17,4 +17,5 @@ type WGIface interface {
IsUserspaceBind() bool
Address() wgaddr.Address
LastActivities() map[string]monotime.Time
MTU() uint16
}

View File

@@ -6,6 +6,7 @@ import (
"net"
"net/netip"
"runtime"
"slices"
"sync"
"time"
@@ -136,6 +137,39 @@ type Conn struct {
// Connection stage timestamps for metrics
metricsRecorder MetricsRecorder
metricsStages *MetricsStages
// pendingFirstPacket is the lazyconn-captured handshake init, replayed once the real
// transport is up.
pendingFirstPacket []byte
}
// injectPendingFirstPacket replays the captured handshake through the proxy if present, else
// directly through the ICE conn. The packet is cleared only after a successful write, so a failed
// or transport-less attempt leaves it available for a later reinjection. Caller must hold conn.mu.
func (conn *Conn) injectPendingFirstPacket(proxy wgproxy.Proxy, directConn net.Conn) {
pkt := conn.pendingFirstPacket
if len(pkt) == 0 {
return
}
switch {
case proxy != nil:
if err := proxy.InjectPacket(pkt); err != nil {
conn.Log.Debugf("failed to reinject captured first packet via proxy: %v", err)
return
}
case directConn != nil:
if _, err := directConn.Write(pkt); err != nil {
conn.Log.Debugf("failed to reinject captured first packet via direct conn: %v", err)
return
}
default:
conn.Log.Debugf("no transport available to reinject captured first packet")
return
}
conn.pendingFirstPacket = nil
conn.Log.Debugf("reinjected captured first packet (%d bytes)", len(pkt))
}
// NewConn creates a new not opened Conn to the remote peer.
@@ -172,6 +206,16 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
// be used.
func (conn *Conn) Open(engineCtx context.Context) error {
return conn.open(engineCtx, nil)
}
// OpenWithFirstPacket opens the connection like Open and stashes firstPacket to be replayed once
// the real transport is established. The packet is retained only on a successful open.
func (conn *Conn) OpenWithFirstPacket(engineCtx context.Context, firstPacket []byte) error {
return conn.open(engineCtx, firstPacket)
}
func (conn *Conn) open(engineCtx context.Context, firstPacket []byte) error {
conn.mu.Lock()
defer conn.mu.Unlock()
@@ -227,6 +271,9 @@ func (conn *Conn) Open(engineCtx context.Context) error {
defer conn.wg.Done()
conn.guard.Start(conn.ctx, conn.onGuardEvent)
}()
if len(firstPacket) > 0 {
conn.pendingFirstPacket = slices.Clone(firstPacket)
}
conn.opened = true
return nil
}
@@ -423,6 +470,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
conn.wgProxyRelay.RedirectAs(ep)
}
conn.injectPendingFirstPacket(wgProxy, iceConnInfo.RemoteConn)
conn.currentConnPriority = priority
conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo, updateTime)
@@ -546,6 +595,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
wgConfigWorkaround()
conn.injectPendingFirstPacket(wgProxy, nil)
conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = conntype.Relay
conn.statusRelay.SetConnected()

View File

@@ -88,11 +88,24 @@ func (s *Store) PeerConnOpen(ctx context.Context, pubKey string) {
if !ok {
return
}
// this can be blocked because of the connect open limiter semaphore
if err := p.Open(ctx); err != nil {
p.Log.Errorf("failed to open peer connection: %v", err)
}
}
// PeerConnOpenWithFirstPacket opens the peer connection and stashes a first packet to be
// reinjected once the real transport is established.
func (s *Store) PeerConnOpenWithFirstPacket(ctx context.Context, pubKey string, firstPacket []byte) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return
}
if err := p.OpenWithFirstPacket(ctx, firstPacket); err != nil {
p.Log.Errorf("failed to open peer connection: %v", err)
}
}
func (s *Store) PeerConnIdle(pubKey string) {

View File

@@ -386,7 +386,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.NetworkMonitor != nil && input.NetworkMonitor != config.NetworkMonitor {
if input.NetworkMonitor != nil && (config.NetworkMonitor == nil || *input.NetworkMonitor != *config.NetworkMonitor) {
log.Infof("switching Network Monitor to %t", *input.NetworkMonitor)
config.NetworkMonitor = input.NetworkMonitor
updated = true
@@ -454,7 +454,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
if input.EnableSSHRoot != nil && (config.EnableSSHRoot == nil || *input.EnableSSHRoot != *config.EnableSSHRoot) {
if *input.EnableSSHRoot {
log.Infof("enabling SSH root login")
} else {
@@ -464,7 +464,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.EnableSSHSFTP != nil && input.EnableSSHSFTP != config.EnableSSHSFTP {
if input.EnableSSHSFTP != nil && (config.EnableSSHSFTP == nil || *input.EnableSSHSFTP != *config.EnableSSHSFTP) {
if *input.EnableSSHSFTP {
log.Infof("enabling SSH SFTP subsystem")
} else {
@@ -474,7 +474,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.EnableSSHLocalPortForwarding != nil && input.EnableSSHLocalPortForwarding != config.EnableSSHLocalPortForwarding {
if input.EnableSSHLocalPortForwarding != nil && (config.EnableSSHLocalPortForwarding == nil || *input.EnableSSHLocalPortForwarding != *config.EnableSSHLocalPortForwarding) {
if *input.EnableSSHLocalPortForwarding {
log.Infof("enabling SSH local port forwarding")
} else {
@@ -484,7 +484,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.EnableSSHRemotePortForwarding != nil && input.EnableSSHRemotePortForwarding != config.EnableSSHRemotePortForwarding {
if input.EnableSSHRemotePortForwarding != nil && (config.EnableSSHRemotePortForwarding == nil || *input.EnableSSHRemotePortForwarding != *config.EnableSSHRemotePortForwarding) {
if *input.EnableSSHRemotePortForwarding {
log.Infof("enabling SSH remote port forwarding")
} else {
@@ -494,7 +494,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.DisableSSHAuth != nil && input.DisableSSHAuth != config.DisableSSHAuth {
if input.DisableSSHAuth != nil && (config.DisableSSHAuth == nil || *input.DisableSSHAuth != *config.DisableSSHAuth) {
if *input.DisableSSHAuth {
log.Infof("disabling SSH authentication")
} else {
@@ -504,7 +504,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
if input.SSHJWTCacheTTL != nil && (config.SSHJWTCacheTTL == nil || *input.SSHJWTCacheTTL != *config.SSHJWTCacheTTL) {
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
updated = true
@@ -587,7 +587,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
if input.DisableNotifications != nil && (config.DisableNotifications == nil || *input.DisableNotifications != *config.DisableNotifications) {
if *input.DisableNotifications {
log.Infof("disabling notifications")
} else {

View File

@@ -2,9 +2,11 @@ package system
import (
"context"
"errors"
"net/netip"
"slices"
"strings"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/metadata"
@@ -174,7 +176,7 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, excludeIPs .
processCheckPaths = append(processCheckPaths, check.GetFiles()...)
}
files, err := checkFileAndProcess(processCheckPaths)
files, err := checkFileAndProcess(ctx, processCheckPaths)
if err != nil {
return nil, err
}
@@ -187,3 +189,43 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, excludeIPs .
log.Debugf("all system information gathered successfully")
return info, nil
}
// GetInfoWithChecksTimeout is GetInfoWithChecks bounded by timeout. Posture-check gathering
// runs uncancellable system calls (process enumeration, os.Stat), so calling it inline can
// block the caller for as long as such a call hangs. It runs in a goroutine instead: if it
// does not return within timeout the caller gets (nil, false) and should proceed with
// degraded behavior rather than block. On a gathering error it falls back to base GetInfo.
//
// The buffered channel lets the abandoned goroutine finish and exit once its blocking call
// returns, so it does not leak beyond the duration of that call.
func GetInfoWithChecksTimeout(ctx context.Context, timeout time.Duration, checks []*proto.Checks, excludeIPs ...netip.Addr) (*Info, bool) {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
infoCh := make(chan *Info, 1)
go func() {
info, err := GetInfoWithChecks(ctx, checks, excludeIPs...)
if err != nil {
if ctx.Err() != nil {
return
}
log.Warnf("failed to get system info with checks: %v", err)
info = GetInfo(ctx)
info.removeAddresses(excludeIPs...)
}
infoCh <- info
}()
select {
case info := <-infoCh:
return info, true
case <-ctx.Done():
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
log.Warnf("gathering system info with checks timed out after %s", timeout)
} else {
// Parent context canceled (e.g. shutdown), not a timeout.
log.Warnf("gathering system info with checks canceled: %v", ctx.Err())
}
return nil, false
}
}

View File

@@ -50,7 +50,7 @@ func GetInfo(ctx context.Context) *Info {
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
return []File{}, nil
}

View File

@@ -32,7 +32,7 @@ func GetInfo(ctx context.Context) *Info {
sysName := string(bytes.Split(utsname.Sysname[:], []byte{0})[0])
machine := string(bytes.Split(utsname.Machine[:], []byte{0})[0])
release := string(bytes.Split(utsname.Release[:], []byte{0})[0])
swVersion, err := exec.Command("sw_vers", "-productVersion").Output()
swVersion, err := exec.CommandContext(ctx, "sw_vers", "-productVersion").Output()
if err != nil {
log.Warnf("got an error while retrieving macOS version with sw_vers, error: %s. Using darwin version instead.\n", err)
swVersion = []byte(release)

View File

@@ -105,7 +105,7 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
return []File{}, nil
}

View File

@@ -103,7 +103,7 @@ func collectLocationInfo(info *Info) {
}
}
func checkFileAndProcess(_ []string) ([]File, error) {
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
return []File{}, nil
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata"
@@ -35,6 +36,20 @@ func Test_CustomHostname(t *testing.T) {
assert.Equal(t, want, got.Hostname)
}
func TestGetInfoWithChecksTimeout_Success(t *testing.T) {
info, ok := GetInfoWithChecksTimeout(context.Background(), 30*time.Second, nil)
assert.True(t, ok, "expected gathering to complete within the timeout")
assert.NotNil(t, info)
}
func TestGetInfoWithChecksTimeout_Timeout(t *testing.T) {
// A 1ns budget expires before the (real) system-info gathering can finish, so the
// caller must get (nil, false) instead of blocking on the in-flight goroutine.
info, ok := GetInfoWithChecksTimeout(context.Background(), time.Nanosecond, nil)
assert.False(t, ok, "expected timeout to be reported")
assert.Nil(t, info)
}
func Test_NetAddresses(t *testing.T) {
addr, err := networkAddresses()
if err != nil {

View File

@@ -3,24 +3,30 @@
package system
import (
"context"
"os"
"slices"
"github.com/shirou/gopsutil/v3/process"
)
// getRunningProcesses returns a list of running process paths.
func getRunningProcesses() ([]string, error) {
processIDs, err := process.Pids()
// getRunningProcesses returns a list of running process paths. The context bounds the work:
// the per-PID loop bails as soon as ctx is done, and the gopsutil calls honor it where they
// can, so a stuck enumeration cannot run unbounded.
func getRunningProcesses(ctx context.Context) ([]string, error) {
processIDs, err := process.PidsWithContext(ctx)
if err != nil {
return nil, err
}
processMap := make(map[string]bool)
for _, pID := range processIDs {
if err := ctx.Err(); err != nil {
return nil, err
}
p := &process.Process{Pid: pID}
path, _ := p.Exe()
path, _ := p.ExeWithContext(ctx)
if path != "" {
processMap[path] = false
}
@@ -35,18 +41,21 @@ func getRunningProcesses() ([]string, error) {
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
func checkFileAndProcess(ctx context.Context, paths []string) ([]File, error) {
files := make([]File, len(paths))
if len(paths) == 0 {
return files, nil
}
runningProcesses, err := getRunningProcesses()
runningProcesses, err := getRunningProcesses(ctx)
if err != nil {
return nil, err
}
for i, path := range paths {
if err := ctx.Err(); err != nil {
return nil, err
}
file := File{Path: path}
_, err := os.Stat(path)

View File

@@ -1,6 +1,7 @@
package system
import (
"context"
"testing"
"github.com/shirou/gopsutil/v3/process"
@@ -9,7 +10,7 @@ import (
func Benchmark_getRunningProcesses(b *testing.B) {
b.Run("getRunningProcesses new", func(b *testing.B) {
for i := 0; i < b.N; i++ {
ps, err := getRunningProcesses()
ps, err := getRunningProcesses(context.Background())
if err != nil {
b.Fatalf("unexpected error: %v", err)
}
@@ -29,12 +30,38 @@ func Benchmark_getRunningProcesses(b *testing.B) {
}
}
})
s, _ := getRunningProcesses()
s, _ := getRunningProcesses(context.Background())
b.Logf("getRunningProcesses returned %d processes", len(s))
s, _ = getRunningProcessesOld()
b.Logf("getRunningProcessesOld returned %d processes", len(s))
}
func TestCheckFileAndProcess_ContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
// With a canceled context and non-empty paths the gathering must bail with an error
// instead of running the (potentially blocking) process scan / stat loop.
if _, err := checkFileAndProcess(ctx, []string{"/does/not/exist"}); err == nil {
t.Fatal("expected error on canceled context, got nil")
}
}
func TestCheckFileAndProcess_EmptyPaths(t *testing.T) {
// No check paths means no work to do: it must return immediately with no error,
// even on a canceled context (nothing to scan or stat).
ctx, cancel := context.WithCancel(context.Background())
cancel()
files, err := checkFileAndProcess(ctx, nil)
if err != nil {
t.Fatalf("unexpected error for empty paths: %v", err)
}
if len(files) != 0 {
t.Fatalf("expected no files, got %d", len(files))
}
}
func getRunningProcessesOld() ([]string, error) {
processes, err := process.Processes()
if err != nil {

View File

@@ -9,6 +9,8 @@ set -o pipefail
SED_STRIP_PADDING='s/=//g'
NETBIRD_EULA_URL="https://netbird.io/self-hosted-EULA"
check_docker_compose() {
if command -v docker-compose &> /dev/null; then
echo "docker-compose"
@@ -139,6 +141,43 @@ read_yes_no() {
esac
}
# Gate the install on explicit acceptance of the NetBird On-Premise EULA.
require_eula_acceptance() {
cat > /dev/stderr <<EOF
──────────────────────────────────────────────────────────────────────
NetBird On-Premise End User License Agreement
──────────────────────────────────────────────────────────────────────
NetBird's on-premise software is commercial software, licensed and not
sold. Your installation, deployment and use are governed by the NetBird
On-Premise End User License Agreement (the "EULA"). Please read the EULA
in full before continuing:
${NETBIRD_EULA_URL}
By typing "accept" and continuing the installation, you confirm that you
have read and agree to the EULA, that you are authorized to accept it on
behalf of your organization (the "Customer"), and that the Software is
used for business purposes only.
──────────────────────────────────────────────────────────────────────
EOF
if [[ "${NB_ACCEPT_EULA:-}" == "yes" ]]; then
echo "EULA accepted via NB_ACCEPT_EULA=yes." > /dev/stderr
return 0
fi
local ans=""
echo -n 'Type "accept" to agree, or anything else to abort: ' > /dev/stderr
read -r ans < /dev/tty
if [[ "$ans" != "accept" ]]; then
echo "" > /dev/stderr
echo "EULA not accepted. Aborting installation." > /dev/stderr
exit 1
fi
echo "" > /dev/stderr
}
wait_postgres() {
set +e
echo -n "Waiting for postgres to become ready"
@@ -174,6 +213,9 @@ init_environment() {
exit 1
fi
require_eula_acceptance
NETBIRD_EULA_ACCEPTED_AT=$(date -u +%Y-%m-%dT%H:%M:%SZ)
echo "NetBird Enterprise bootstrap"
echo ""
echo "Traffic flow:"
@@ -260,6 +302,11 @@ render_env() {
# Generated by getting-started-enterprise.sh
# Holds all configuration and secrets for the stack. Mode 600.
# NetBird On-Premise EULA acceptance
NETBIRD_EULA_ACCEPTED=yes
NETBIRD_EULA_ACCEPTED_AT=${NETBIRD_EULA_ACCEPTED_AT}
NETBIRD_EULA_URL=${NETBIRD_EULA_URL}
# Features (set by the script; don't edit without re-running)
NETBIRD_TRAFFIC_FLOW_ENABLED=${NETBIRD_TRAFFIC_FLOW}

View File

@@ -25,6 +25,8 @@ set -o pipefail
OVERRIDE_FILE="docker-compose.override.yml"
ENTERPRISE_CONFIG_FILE="config.yaml.enterprise"
NETBIRD_EULA_URL="https://netbird.io/self-hosted-EULA"
check_docker_compose() {
if command -v docker-compose &> /dev/null; then
echo "docker-compose"
@@ -115,6 +117,43 @@ read_yes_no() {
esac
}
# Gate the migration on explicit acceptance of the NetBird On-Premise EULA.
require_eula_acceptance() {
cat > /dev/stderr <<EOF
──────────────────────────────────────────────────────────────────────
NetBird On-Premise End User License Agreement
──────────────────────────────────────────────────────────────────────
NetBird's on-premise software is commercial software, licensed and not
sold. Your installation, deployment and use are governed by the NetBird
On-Premise End User License Agreement (the "EULA"). Please read the EULA
in full before continuing:
${NETBIRD_EULA_URL}
By typing "accept" and continuing the installation, you confirm that you
have read and agree to the EULA, that you are authorized to accept it on
behalf of your organization (the "Customer"), and that the Software is
used for business purposes only.
──────────────────────────────────────────────────────────────────────
EOF
if [[ "${NB_ACCEPT_EULA:-}" == "yes" ]]; then
echo "EULA accepted via NB_ACCEPT_EULA=yes." > /dev/stderr
return 0
fi
local ans=""
echo -n 'Type "accept" to agree, or anything else to abort: ' > /dev/stderr
read -r ans < /dev/tty
if [[ "$ans" != "accept" ]]; then
echo "" > /dev/stderr
echo "EULA not accepted. Aborting migration." > /dev/stderr
exit 1
fi
echo "" > /dev/stderr
}
# ---------------------------------------------------------------------------
# Detection — read the operator's existing compose to find service names and
# paths we need to override. Bail loudly if shape isn't recognised.
@@ -436,6 +475,9 @@ init_migration() {
echo " Network: $COMPOSE_NETWORK"
echo ""
require_eula_acceptance
NETBIRD_EULA_ACCEPTED_AT=$(date -u +%Y-%m-%dT%H:%M:%SZ)
local proceed
proceed=$(read_yes_no "Proceed with migration?" "y")
if [[ "$proceed" != "yes" ]]; then
@@ -529,6 +571,10 @@ apply_changes() {
{
echo ""
echo "# Added by migrate-to-enterprise.sh on $(date -u +%Y-%m-%dT%H:%M:%SZ)"
echo "# NetBird On-Premise EULA accepted at install time"
echo "NETBIRD_EULA_ACCEPTED=yes"
echo "NETBIRD_EULA_ACCEPTED_AT=${NETBIRD_EULA_ACCEPTED_AT}"
echo "NETBIRD_EULA_URL=${NETBIRD_EULA_URL}"
echo "NB_LICENSE_KEY=${NB_LICENSE_KEY}"
if [[ -n "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
echo "NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}"