Merge remote-tracking branch 'origin/main' into refactor/permissions-manager

# Conflicts:
#	management/internals/modules/reverseproxy/domain/manager/manager.go
#	management/internals/modules/reverseproxy/service/manager/api.go
#	management/internals/server/modules.go
#	management/server/http/testing/testing_tools/channel/channel.go
This commit is contained in:
pascal
2026-03-17 12:38:08 +01:00
244 changed files with 17304 additions and 3509 deletions

View File

@@ -27,8 +27,8 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/updatemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
"github.com/netbirdio/netbird/client/internal/updater"
"github.com/netbirdio/netbird/client/internal/updater/installer"
nbnet "github.com/netbirdio/netbird/client/net"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh"
@@ -44,13 +44,13 @@ import (
)
type ConnectClient struct {
ctx context.Context
config *profilemanager.Config
statusRecorder *peer.Status
doInitialAutoUpdate bool
ctx context.Context
config *profilemanager.Config
statusRecorder *peer.Status
engine *Engine
engineMutex sync.Mutex
engine *Engine
engineMutex sync.Mutex
updateManager *updater.Manager
persistSyncResponse bool
}
@@ -59,17 +59,19 @@ func NewConnectClient(
ctx context.Context,
config *profilemanager.Config,
statusRecorder *peer.Status,
doInitalAutoUpdate bool,
) *ConnectClient {
return &ConnectClient{
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
doInitialAutoUpdate: doInitalAutoUpdate,
engineMutex: sync.Mutex{},
ctx: ctx,
config: config,
statusRecorder: statusRecorder,
engineMutex: sync.Mutex{},
}
}
func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
c.updateManager = um
}
// Run with main logic.
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
return c.run(MobileDependency{}, runningChan, logPath)
@@ -187,14 +189,13 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
stateManager := statemanager.New(path)
stateManager.RegisterState(&sshconfig.ShutdownState{})
updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager)
if err == nil {
updateManager.CheckUpdateSuccess(c.ctx)
if c.updateManager != nil {
c.updateManager.CheckUpdateSuccess(c.ctx)
}
inst := installer.New()
if err := inst.CleanUpInstallerFiles(); err != nil {
log.Errorf("failed to clean up temporary installer file: %v", err)
}
inst := installer.New()
if err := inst.CleanUpInstallerFiles(); err != nil {
log.Errorf("failed to clean up temporary installer file: %v", err)
}
defer c.statusRecorder.ClientStop()
@@ -308,7 +309,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
checks := loginResp.GetChecks()
c.engineMutex.Lock()
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
engine := NewEngine(engineCtx, cancel, engineConfig, EngineServices{
SignalClient: signalClient,
MgmClient: mgmClient,
RelayManager: relayManager,
StatusRecorder: c.statusRecorder,
Checks: checks,
StateManager: stateManager,
UpdateManager: c.updateManager,
}, mobileDependency)
engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engine = engine
c.engineMutex.Unlock()
@@ -318,15 +327,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return wrapErr(err)
}
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
// AutoUpdate will be true when the user click on "Connect" menu on the UI
if c.doInitialAutoUpdate {
log.Infof("start engine by ui, run auto-update check")
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
c.doInitialAutoUpdate = false
}
}
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)

View File

@@ -27,7 +27,7 @@ import (
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
"github.com/netbirdio/netbird/client/internal/updater/installer"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"

View File

@@ -77,7 +77,7 @@ func (d *Resolver) ID() types.HandlerID {
return "local-resolver"
}
func (d *Resolver) ProbeAvailability() {}
func (d *Resolver) ProbeAvailability(context.Context) {}
// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {

View File

@@ -104,12 +104,16 @@ type DefaultServer struct {
statusRecorder *peer.Status
stateManager *statemanager.Manager
probeMu sync.Mutex
probeCancel context.CancelFunc
probeWg sync.WaitGroup
}
type handlerWithStop interface {
dns.Handler
Stop()
ProbeAvailability()
ProbeAvailability(context.Context)
ID() types.HandlerID
}
@@ -362,7 +366,13 @@ func (s *DefaultServer) DnsIP() netip.Addr {
// Stop stops the server
func (s *DefaultServer) Stop() {
s.probeMu.Lock()
if s.probeCancel != nil {
s.probeCancel()
}
s.ctxCancel()
s.probeMu.Unlock()
s.probeWg.Wait()
s.shutdownWg.Wait()
s.mux.Lock()
@@ -479,7 +489,8 @@ func (s *DefaultServer) SearchDomains() []string {
}
// ProbeAvailability tests each upstream group's servers for availability
// and deactivates the group if no server responds
// and deactivates the group if no server responds.
// If a previous probe is still running, it will be cancelled before starting a new one.
func (s *DefaultServer) ProbeAvailability() {
if val := os.Getenv(envSkipDNSProbe); val != "" {
skipProbe, err := strconv.ParseBool(val)
@@ -492,15 +503,52 @@ func (s *DefaultServer) ProbeAvailability() {
}
}
var wg sync.WaitGroup
for _, mux := range s.dnsMuxMap {
wg.Add(1)
go func(mux handlerWithStop) {
defer wg.Done()
mux.ProbeAvailability()
}(mux.handler)
s.probeMu.Lock()
// don't start probes on a stopped server
if s.ctx.Err() != nil {
s.probeMu.Unlock()
return
}
// cancel any running probe
if s.probeCancel != nil {
s.probeCancel()
s.probeCancel = nil
}
// wait for the previous probe goroutines to finish while holding
// the mutex so no other caller can start a new probe concurrently
s.probeWg.Wait()
// start a new probe
probeCtx, probeCancel := context.WithCancel(s.ctx)
s.probeCancel = probeCancel
s.probeWg.Add(1)
defer s.probeWg.Done()
// Snapshot handlers under s.mux to avoid racing with updateMux/dnsMuxMap writers.
s.mux.Lock()
handlers := make([]handlerWithStop, 0, len(s.dnsMuxMap))
for _, mux := range s.dnsMuxMap {
handlers = append(handlers, mux.handler)
}
s.mux.Unlock()
var wg sync.WaitGroup
for _, handler := range handlers {
wg.Add(1)
go func(h handlerWithStop) {
defer wg.Done()
h.ProbeAvailability(probeCtx)
}(handler)
}
s.probeMu.Unlock()
wg.Wait()
probeCancel()
}
func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {

View File

@@ -1065,7 +1065,7 @@ type mockHandler struct {
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
func (m *mockHandler) Stop() {}
func (m *mockHandler) ProbeAvailability() {}
func (m *mockHandler) ProbeAvailability(context.Context) {}
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
type mockService struct{}

View File

@@ -6,6 +6,7 @@ import (
"net"
"net/netip"
"runtime"
"strconv"
"sync"
"time"
@@ -69,7 +70,7 @@ func (s *serviceViaListener) Listen() error {
return fmt.Errorf("eval listen address: %w", err)
}
s.listenIP = s.listenIP.Unmap()
s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
s.server.Addr = net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
log.Debugf("starting dns on %s", s.server.Addr)
go func() {
s.setListenerStatus(true)
@@ -186,7 +187,7 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) {
}
func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool {
addrString := fmt.Sprintf("%s:%d", ip, port)
addrString := net.JoinHostPort(ip.String(), strconv.Itoa(port))
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
probeListener, err := net.ListenUDP("udp", udpAddr)
if err != nil {

View File

@@ -65,6 +65,7 @@ type upstreamResolverBase struct {
mutex sync.Mutex
reactivatePeriod time.Duration
upstreamTimeout time.Duration
wg sync.WaitGroup
deactivate func(error)
reactivate func()
@@ -115,6 +116,11 @@ func (u *upstreamResolverBase) MatchSubdomains() bool {
func (u *upstreamResolverBase) Stop() {
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
u.cancel()
u.mutex.Lock()
u.wg.Wait()
u.mutex.Unlock()
}
// ServeDNS handles a DNS request
@@ -260,16 +266,10 @@ func formatFailures(failures []upstreamFailure) string {
// ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability() {
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
u.mutex.Lock()
defer u.mutex.Unlock()
select {
case <-u.ctx.Done():
return
default:
}
// avoid probe if upstreams could resolve at least one query
if u.successCount.Load() > 0 {
return
@@ -279,31 +279,39 @@ func (u *upstreamResolverBase) ProbeAvailability() {
var mu sync.Mutex
var wg sync.WaitGroup
var errors *multierror.Error
var errs *multierror.Error
for _, upstream := range u.upstreamServers {
upstream := upstream
wg.Add(1)
go func() {
go func(upstream netip.AddrPort) {
defer wg.Done()
err := u.testNameserver(upstream, 500*time.Millisecond)
err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond)
if err != nil {
errors = multierror.Append(errors, err)
mu.Lock()
errs = multierror.Append(errs, err)
mu.Unlock()
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
return
}
mu.Lock()
defer mu.Unlock()
success = true
}()
mu.Unlock()
}(upstream)
}
wg.Wait()
select {
case <-ctx.Done():
return
case <-u.ctx.Done():
return
default:
}
// didn't find a working upstream server, let's disable and try later
if !success {
u.disable(errors.ErrorOrNil())
u.disable(errs.ErrorOrNil())
if u.statusRecorder == nil {
return
@@ -339,7 +347,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
}
for _, upstream := range u.upstreamServers {
if err := u.testNameserver(upstream, probeTimeout); err != nil {
if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil {
log.Tracef("upstream check for %s: %s", upstream, err)
} else {
// at least one upstream server is available, stop probing
@@ -364,7 +372,9 @@ func (u *upstreamResolverBase) waitUntilResponse() {
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
u.successCount.Add(1)
u.reactivate()
u.mutex.Lock()
u.disabled = false
u.mutex.Unlock()
}
// isTimeout returns true if the given error is a network timeout error.
@@ -387,7 +397,11 @@ func (u *upstreamResolverBase) disable(err error) {
u.successCount.Store(0)
u.deactivate(err)
u.disabled = true
go u.waitUntilResponse()
u.wg.Add(1)
go func() {
defer u.wg.Done()
u.waitUntilResponse()
}()
}
func (u *upstreamResolverBase) upstreamServersString() string {
@@ -398,13 +412,18 @@ func (u *upstreamResolverBase) upstreamServersString() string {
return strings.Join(servers, ", ")
}
func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(u.ctx, timeout)
func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error {
mergedCtx, cancel := context.WithTimeout(baseCtx, timeout)
defer cancel()
if externalCtx != nil {
stop2 := context.AfterFunc(externalCtx, cancel)
defer stop2()
}
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
_, _, err := u.upstreamClient.exchange(ctx, server.String(), r)
_, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r)
return err
}

View File

@@ -188,7 +188,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
reactivated = true
}
resolver.ProbeAvailability()
resolver.ProbeAvailability(context.TODO())
if !failed {
t.Errorf("expected that resolving was deactivated")

View File

@@ -51,7 +51,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager"
"github.com/netbirdio/netbird/client/internal/updater"
"github.com/netbirdio/netbird/client/jobexec"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
@@ -79,7 +79,6 @@ const (
var ErrResetConnection = fmt.Errorf("reset connection")
// EngineConfig is a config for the Engine
type EngineConfig struct {
WgPort int
WgIfaceName string
@@ -141,6 +140,17 @@ type EngineConfig struct {
LogPath string
}
// EngineServices holds the external service dependencies required by the Engine.
type EngineServices struct {
SignalClient signal.Client
MgmClient mgm.Client
RelayManager *relayClient.Manager
StatusRecorder *peer.Status
Checks []*mgmProto.Checks
StateManager *statemanager.Manager
UpdateManager *updater.Manager
}
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
type Engine struct {
// signal is a Signal Service client
@@ -209,7 +219,7 @@ type Engine struct {
flowManager nftypes.FlowManager
// auto-update
updateManager *updatemanager.Manager
updateManager *updater.Manager
// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
@@ -239,22 +249,17 @@ type localIpUpdater interface {
func NewEngine(
clientCtx context.Context,
clientCancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
relayManager *relayClient.Manager,
config *EngineConfig,
services EngineServices,
mobileDep MobileDependency,
statusRecorder *peer.Status,
checks []*mgmProto.Checks,
stateManager *statemanager.Manager,
) *Engine {
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
signal: signalClient,
signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
mgmClient: mgmClient,
relayManager: relayManager,
signal: services.SignalClient,
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
mgmClient: services.MgmClient,
relayManager: services.RelayManager,
peerStore: peerstore.NewConnStore(),
syncMsgMux: &sync.Mutex{},
config: config,
@@ -262,11 +267,12 @@ func NewEngine(
STUNs: []*stun.URI{},
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: statusRecorder,
stateManager: stateManager,
checks: checks,
statusRecorder: services.StatusRecorder,
stateManager: services.StateManager,
checks: services.Checks,
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(),
updateManager: services.UpdateManager,
}
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
@@ -309,7 +315,7 @@ func (e *Engine) Stop() error {
}
if e.updateManager != nil {
e.updateManager.Stop()
e.updateManager.SetDownloadOnly()
}
log.Info("cleaning up status recorder states")
@@ -559,13 +565,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return nil
}
func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.handleAutoUpdateVersion(autoUpdateSettings, true)
}
func (e *Engine) createFirewall() error {
if e.config.DisableFirewall {
log.Infof("firewall is disabled")
@@ -793,39 +792,22 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
return nil
}
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) {
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
if e.updateManager == nil {
return
}
if autoUpdateSettings == nil {
return
}
disabled := autoUpdateSettings.Version == disableAutoUpdate
// stop and cleanup if disabled
if e.updateManager != nil && disabled {
log.Infof("auto-update is disabled, stopping update manager")
e.updateManager.Stop()
e.updateManager = nil
if autoUpdateSettings.Version == disableAutoUpdate {
log.Infof("auto-update is disabled")
e.updateManager.SetDownloadOnly()
return
}
// Skip check unless AlwaysUpdate is enabled or this is the initial check at startup
if !autoUpdateSettings.AlwaysUpdate && !initialCheck {
log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check")
return
}
// Start manager if needed
if e.updateManager == nil {
log.Infof("starting auto-update manager")
updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager)
if err != nil {
return
}
e.updateManager = updateManager
e.updateManager.Start(e.ctx)
}
log.Infof("handling auto-update version: %s", autoUpdateSettings.Version)
e.updateManager.SetVersion(autoUpdateSettings.Version)
e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate)
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
@@ -842,7 +824,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
}
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false)
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
}
if update.GetNetbirdConfig() != nil {
@@ -1315,8 +1297,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
// Test received (upstream) servers for availability right away instead of upon usage.
// If no server of a server group responds this will disable the respective handler and retry later.
e.dnsServer.ProbeAvailability()
go e.dnsServer.ProbeAvailability()
return nil
}

View File

@@ -251,9 +251,6 @@ func TestEngine_SSH(t *testing.T) {
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(
ctx, cancel,
&signal.MockClient{},
&mgmt.MockClient{},
relayMgr,
&EngineConfig{
WgIfaceName: "utun101",
WgAddr: "100.64.0.1/24",
@@ -263,10 +260,13 @@ func TestEngine_SSH(t *testing.T) {
MTU: iface.DefaultMTU,
SSHKey: sshKey,
},
EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
},
MobileDependency{},
peer.NewRecorder("https://mgm"),
nil,
nil,
)
engine.dnsServer = &dns.MockServer{
@@ -428,13 +428,18 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
engine := NewEngine(ctx, cancel, &EngineConfig{
WgIfaceName: "utun102",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{})
wgIface := &MockWGIface{
NameFunc: func() string { return "utun102" },
@@ -647,13 +652,18 @@ func TestEngine_Sync(t *testing.T) {
return nil
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{
engine := NewEngine(ctx, cancel, &EngineConfig{
WgIfaceName: "utun103",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{})
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{
@@ -812,13 +822,18 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
engine := NewEngine(ctx, cancel, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{})
engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
@@ -1014,13 +1029,18 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
engine := NewEngine(ctx, cancel, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{})
engine.ctx = ctx
newNet, err := stdnet.NewNet(context.Background(), nil)
@@ -1546,7 +1566,12 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil
e, err := NewEngine(ctx, cancel, conf, EngineServices{
SignalClient: signalClient,
MgmClient: mgmtClient,
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{}), nil
e.ctx = ctx
return e, err
}

View File

@@ -12,9 +12,10 @@ const renewTimeout = 10 * time.Second
// Response holds the response from exposing a service.
type Response struct {
ServiceName string
ServiceURL string
Domain string
ServiceName string
ServiceURL string
Domain string
PortAutoAssigned bool
}
type Request struct {
@@ -25,6 +26,7 @@ type Request struct {
Pin string
Password string
UserGroups []string
ListenPort uint16
}
type ManagementClient interface {

View File

@@ -15,6 +15,7 @@ func NewRequest(req *daemonProto.ExposeServiceRequest) *Request {
UserGroups: req.UserGroups,
Domain: req.Domain,
NamePrefix: req.NamePrefix,
ListenPort: uint16(req.ListenPort),
}
}
@@ -27,13 +28,15 @@ func toClientExposeRequest(req Request) mgm.ExposeRequest {
Pin: req.Pin,
Password: req.Password,
UserGroups: req.UserGroups,
ListenPort: req.ListenPort,
}
}
func fromClientExposeResponse(response *mgm.ExposeResponse) *Response {
return &Response{
ServiceName: response.ServiceName,
Domain: response.Domain,
ServiceURL: response.ServiceURL,
ServiceName: response.ServiceName,
Domain: response.Domain,
ServiceURL: response.ServiceURL,
PortAutoAssigned: response.PortAutoAssigned,
}
}

View File

@@ -3,7 +3,9 @@ package client
import (
"context"
"fmt"
"net"
"reflect"
"strconv"
"time"
log "github.com/sirupsen/logrus"
@@ -564,7 +566,7 @@ func HandlerFromRoute(params common.HandlerParams) RouteHandler {
return dnsinterceptor.New(params)
case handlerTypeDynamic:
dns := nbdns.NewServiceViaMemory(params.WgInterface)
dnsAddr := fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())
dnsAddr := net.JoinHostPort(dns.RuntimeIP().String(), strconv.Itoa(dns.RuntimePort()))
return dynamic.NewRoute(params, dnsAddr)
default:
return static.NewRoute(params)

View File

@@ -4,8 +4,10 @@ import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
@@ -249,7 +251,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
r.MsgHdr.AuthenticatedData = true
}
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), uint16(d.forwarderPort.Load()))
upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10))
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()

View File

@@ -1,214 +0,0 @@
//go:build windows || darwin
package updatemanager
import (
"context"
"fmt"
"path"
"testing"
"time"
v "github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
type versionUpdateMock struct {
latestVersion *v.Version
onUpdate func()
}
func (v versionUpdateMock) StopWatch() {}
func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool {
return false
}
func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) {
v.onUpdate = updateFn
}
func (v versionUpdateMock) LatestVersion() *v.Version {
return v.latestVersion
}
func (v versionUpdateMock) StartFetcher() {}
func Test_LatestVersion(t *testing.T) {
testMatrix := []struct {
name string
daemonVersion string
initialLatestVersion *v.Version
latestVersion *v.Version
shouldUpdateInit bool
shouldUpdateLater bool
}{
{
name: "Should only trigger update once due to time between triggers being < 5 Minutes",
daemonVersion: "1.0.0",
initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
latestVersion: v.Must(v.NewSemver("1.0.2")),
shouldUpdateInit: true,
shouldUpdateLater: false,
},
{
name: "Shouldn't update initially, but should update as soon as latest version is fetched",
daemonVersion: "1.0.0",
initialLatestVersion: nil,
latestVersion: v.Must(v.NewSemver("1.0.1")),
shouldUpdateInit: false,
shouldUpdateLater: true,
},
}
for idx, c := range testMatrix {
mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
m.update = mockUpdate
targetVersionChan := make(chan string, 1)
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
targetVersionChan <- targetVersion
return nil
}
m.currentVersion = c.daemonVersion
m.Start(context.Background())
m.SetVersion("latest")
var triggeredInit bool
select {
case targetVersion := <-targetVersionChan:
if targetVersion != c.initialLatestVersion.String() {
t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion)
}
triggeredInit = true
case <-time.After(10 * time.Millisecond):
triggeredInit = false
}
if triggeredInit != c.shouldUpdateInit {
t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
}
mockUpdate.latestVersion = c.latestVersion
mockUpdate.onUpdate()
var triggeredLater bool
select {
case targetVersion := <-targetVersionChan:
if targetVersion != c.latestVersion.String() {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
}
triggeredLater = true
case <-time.After(10 * time.Millisecond):
triggeredLater = false
}
if triggeredLater != c.shouldUpdateLater {
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
}
m.Stop()
}
}
func Test_HandleUpdate(t *testing.T) {
testMatrix := []struct {
name string
daemonVersion string
latestVersion *v.Version
expectedVersion string
shouldUpdate bool
}{
{
name: "Update to a specific version should update regardless of if latestVersion is available yet",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.56.0",
shouldUpdate: true,
},
{
name: "Update to specific version should not update if version matches",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.55.0",
shouldUpdate: false,
},
{
name: "Update to specific version should not update if current version is newer",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.54.0",
shouldUpdate: false,
},
{
name: "Update to latest version should update if latest is newer",
daemonVersion: "0.55.0",
latestVersion: v.Must(v.NewSemver("0.56.0")),
expectedVersion: "latest",
shouldUpdate: true,
},
{
name: "Update to latest version should not update if latest == current",
daemonVersion: "0.56.0",
latestVersion: v.Must(v.NewSemver("0.56.0")),
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if daemon version is invalid",
daemonVersion: "development",
latestVersion: v.Must(v.NewSemver("1.0.0")),
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if expecting latest and latest version is unavailable",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if expected version is invalid",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "development",
shouldUpdate: false,
},
}
for idx, c := range testMatrix {
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
m.update = &versionUpdateMock{latestVersion: c.latestVersion}
targetVersionChan := make(chan string, 1)
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
targetVersionChan <- targetVersion
return nil
}
m.currentVersion = c.daemonVersion
m.Start(context.Background())
m.SetVersion(c.expectedVersion)
var updateTriggered bool
select {
case targetVersion := <-targetVersionChan:
if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
} else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion)
}
updateTriggered = true
case <-time.After(10 * time.Millisecond):
updateTriggered = false
}
if updateTriggered != c.shouldUpdate {
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered)
}
m.Stop()
}
}

View File

@@ -1,39 +0,0 @@
//go:build !windows && !darwin
package updatemanager
import (
"context"
"fmt"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// Manager is a no-op stub for unsupported platforms
type Manager struct{}
// NewManager returns a no-op manager for unsupported platforms
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
return nil, fmt.Errorf("update manager is not supported on this platform")
}
// CheckUpdateSuccess is a no-op on unsupported platforms
func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
// no-op
}
// Start is a no-op on unsupported platforms
func (m *Manager) Start(ctx context.Context) {
// no-op
}
// SetVersion is a no-op on unsupported platforms
func (m *Manager) SetVersion(expectedVersion string) {
// no-op
}
// Stop is a no-op on unsupported platforms
func (m *Manager) Stop() {
// no-op
}

View File

@@ -1,4 +1,4 @@
// Package updatemanager provides automatic update management for the NetBird client.
// Package updater provides automatic update management for the NetBird client.
// It monitors for new versions, handles update triggers from management server directives,
// and orchestrates the download and installation of client updates.
//
@@ -32,4 +32,4 @@
//
// This enables verification of successful updates and appropriate user notification
// after the client restarts with the new version.
package updatemanager
package updater

View File

@@ -16,8 +16,8 @@ import (
goversion "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
"github.com/netbirdio/netbird/client/internal/updater/downloader"
"github.com/netbirdio/netbird/client/internal/updater/reposign"
)
type Installer struct {

View File

@@ -203,7 +203,10 @@ func (rh *ResultHandler) write(result Result) error {
func (rh *ResultHandler) cleanup() error {
err := os.Remove(rh.resultFile)
if err != nil && !os.IsNotExist(err) {
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
log.Debugf("delete installer result file: %s", rh.resultFile)

View File

@@ -1,12 +1,9 @@
//go:build windows || darwin
package updatemanager
package updater
import (
"context"
"errors"
"fmt"
"runtime"
"sync"
"time"
@@ -15,7 +12,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
"github.com/netbirdio/netbird/client/internal/updater/installer"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/version"
)
@@ -41,6 +38,9 @@ type Manager struct {
statusRecorder *peer.Status
stateManager *statemanager.Manager
downloadOnly bool // true when no enforcement from management; notifies UI to download latest
forceUpdate bool // true when management sets AlwaysUpdate; skips UI interaction and installs directly
lastTrigger time.Time
mgmUpdateChan chan struct{}
updateChannel chan struct{}
@@ -53,37 +53,38 @@ type Manager struct {
expectedVersion *v.Version
updateToLatestVersion bool
// updateMutex protect update and expectedVersion fields
pendingVersion *v.Version
// updateMutex protects update, expectedVersion, updateToLatestVersion,
// downloadOnly, forceUpdate, pendingVersion, and lastTrigger fields
updateMutex sync.Mutex
triggerUpdateFn func(context.Context, string) error
// installMutex and installing guard against concurrent installation attempts
installMutex sync.Mutex
installing bool
// protect to start the service multiple times
mu sync.Mutex
autoUpdateSupported func() bool
}
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
if runtime.GOOS == "darwin" {
isBrew := !installer.TypeOfInstaller(context.Background()).Downloadable()
if isBrew {
log.Warnf("auto-update disabled on Home Brew installation")
return nil, fmt.Errorf("auto-update not supported on Home Brew installation yet")
}
}
return newManager(statusRecorder, stateManager)
}
func newManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
// NewManager creates a new update manager. The manager is single-use: once Stop() is called, it cannot be restarted.
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) *Manager {
manager := &Manager{
statusRecorder: statusRecorder,
stateManager: stateManager,
mgmUpdateChan: make(chan struct{}, 1),
updateChannel: make(chan struct{}, 1),
currentVersion: version.NetbirdVersion(),
update: version.NewUpdate("nb/client"),
statusRecorder: statusRecorder,
stateManager: stateManager,
mgmUpdateChan: make(chan struct{}, 1),
updateChannel: make(chan struct{}, 1),
currentVersion: version.NetbirdVersion(),
update: version.NewUpdate("nb/client"),
downloadOnly: true,
autoUpdateSupported: isAutoUpdateSupported,
}
manager.triggerUpdateFn = manager.triggerUpdate
stateManager.RegisterState(&UpdateState{})
return manager, nil
return manager
}
// CheckUpdateSuccess checks if the update was successful and send a notification.
@@ -124,8 +125,10 @@ func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
}
func (m *Manager) Start(ctx context.Context) {
log.Infof("starting update manager")
m.mu.Lock()
defer m.mu.Unlock()
if m.cancel != nil {
log.Errorf("Manager already started")
return
}
@@ -142,13 +145,32 @@ func (m *Manager) Start(ctx context.Context) {
m.cancel = cancel
m.wg.Add(1)
go m.updateLoop(ctx)
go func() {
defer m.wg.Done()
m.updateLoop(ctx)
}()
}
func (m *Manager) SetVersion(expectedVersion string) {
log.Infof("set expected agent version for upgrade: %s", expectedVersion)
if m.cancel == nil {
log.Errorf("manager not started")
func (m *Manager) SetDownloadOnly() {
m.updateMutex.Lock()
m.downloadOnly = true
m.forceUpdate = false
m.expectedVersion = nil
m.updateToLatestVersion = false
m.lastTrigger = time.Time{}
m.updateMutex.Unlock()
select {
case m.mgmUpdateChan <- struct{}{}:
default:
}
}
func (m *Manager) SetVersion(expectedVersion string, forceUpdate bool) {
log.Infof("expected version changed to %s, force update: %t", expectedVersion, forceUpdate)
if !m.autoUpdateSupported() {
log.Warnf("auto-update not supported on this platform")
return
}
@@ -159,6 +181,7 @@ func (m *Manager) SetVersion(expectedVersion string) {
log.Errorf("empty expected version provided")
m.expectedVersion = nil
m.updateToLatestVersion = false
m.downloadOnly = true
return
}
@@ -178,12 +201,97 @@ func (m *Manager) SetVersion(expectedVersion string) {
m.updateToLatestVersion = false
}
m.lastTrigger = time.Time{}
m.downloadOnly = false
m.forceUpdate = forceUpdate
select {
case m.mgmUpdateChan <- struct{}{}:
default:
}
}
// Install triggers the installation of the pending version. It is called when the user clicks the install button in the UI.
func (m *Manager) Install(ctx context.Context) error {
if !m.autoUpdateSupported() {
return fmt.Errorf("auto-update not supported on this platform")
}
m.updateMutex.Lock()
pending := m.pendingVersion
m.updateMutex.Unlock()
if pending == nil {
return fmt.Errorf("no pending version to install")
}
return m.tryInstall(ctx, pending)
}
// tryInstall ensures only one installation runs at a time. Concurrent callers
// receive an error immediately rather than queuing behind a running install.
func (m *Manager) tryInstall(ctx context.Context, targetVersion *v.Version) error {
m.installMutex.Lock()
if m.installing {
m.installMutex.Unlock()
return fmt.Errorf("installation already in progress")
}
m.installing = true
m.installMutex.Unlock()
defer func() {
m.installMutex.Lock()
m.installing = false
m.installMutex.Unlock()
}()
return m.install(ctx, targetVersion)
}
// NotifyUI re-publishes the current update state to a newly connected UI client.
// Only needed for download-only mode where the latest version is already cached
// NotifyUI re-publishes the current update state so a newly connected UI gets the info.
func (m *Manager) NotifyUI() {
m.updateMutex.Lock()
if m.update == nil {
m.updateMutex.Unlock()
return
}
downloadOnly := m.downloadOnly
pendingVersion := m.pendingVersion
latestVersion := m.update.LatestVersion()
m.updateMutex.Unlock()
if downloadOnly {
if latestVersion == nil {
return
}
currentVersion, err := v.NewVersion(m.currentVersion)
if err != nil || currentVersion.GreaterThanOrEqual(latestVersion) {
return
}
m.statusRecorder.PublishEvent(
cProto.SystemEvent_INFO,
cProto.SystemEvent_SYSTEM,
"New version available",
"",
map[string]string{"new_version_available": latestVersion.String()},
)
return
}
if pendingVersion != nil {
m.statusRecorder.PublishEvent(
cProto.SystemEvent_INFO,
cProto.SystemEvent_SYSTEM,
"New version available",
"",
map[string]string{"new_version_available": pendingVersion.String(), "enforced": "true"},
)
}
}
// Stop is not used at the moment because it fully depends on the daemon. In a future refactor it may make sense to use it.
func (m *Manager) Stop() {
if m.cancel == nil {
return
@@ -214,8 +322,6 @@ func (m *Manager) onContextCancel() {
}
func (m *Manager) updateLoop(ctx context.Context) {
defer m.wg.Done()
for {
select {
case <-ctx.Done():
@@ -239,55 +345,89 @@ func (m *Manager) handleUpdate(ctx context.Context) {
return
}
expectedVersion := m.expectedVersion
useLatest := m.updateToLatestVersion
downloadOnly := m.downloadOnly
forceUpdate := m.forceUpdate
curLatestVersion := m.update.LatestVersion()
m.updateMutex.Unlock()
switch {
// Resolve "latest" to actual version
case useLatest:
// Download-only mode or resolve "latest" to actual version
case downloadOnly, m.updateToLatestVersion:
if curLatestVersion == nil {
log.Tracef("latest version not fetched yet")
m.updateMutex.Unlock()
return
}
updateVersion = curLatestVersion
// Update to specific version
case expectedVersion != nil:
updateVersion = expectedVersion
// Install to specific version
case m.expectedVersion != nil:
updateVersion = m.expectedVersion
default:
log.Debugf("no expected version information set")
m.updateMutex.Unlock()
return
}
log.Debugf("checking update option, current version: %s, target version: %s", m.currentVersion, updateVersion)
if !m.shouldUpdate(updateVersion) {
if !m.shouldUpdate(updateVersion, forceUpdate) {
m.updateMutex.Unlock()
return
}
m.lastTrigger = time.Now()
log.Infof("Auto-update triggered, current version: %s, target version: %s", m.currentVersion, updateVersion)
m.statusRecorder.PublishEvent(
cProto.SystemEvent_CRITICAL,
cProto.SystemEvent_SYSTEM,
"Automatically updating client",
"Your client version is older than auto-update version set in Management, updating client now.",
nil,
)
log.Infof("new version available: %s", updateVersion)
if !downloadOnly && !forceUpdate {
m.pendingVersion = updateVersion
}
m.updateMutex.Unlock()
if downloadOnly {
m.statusRecorder.PublishEvent(
cProto.SystemEvent_INFO,
cProto.SystemEvent_SYSTEM,
"New version available",
"",
map[string]string{"new_version_available": updateVersion.String()},
)
return
}
if forceUpdate {
if err := m.tryInstall(ctx, updateVersion); err != nil {
log.Errorf("force update failed: %v", err)
}
return
}
m.statusRecorder.PublishEvent(
cProto.SystemEvent_INFO,
cProto.SystemEvent_SYSTEM,
"New version available",
"",
map[string]string{"new_version_available": updateVersion.String(), "enforced": "true"},
)
}
func (m *Manager) install(ctx context.Context, pendingVersion *v.Version) error {
m.statusRecorder.PublishEvent(
cProto.SystemEvent_CRITICAL,
cProto.SystemEvent_SYSTEM,
"Updating client",
"Installing update now.",
nil,
)
m.statusRecorder.PublishEvent(
cProto.SystemEvent_CRITICAL,
cProto.SystemEvent_SYSTEM,
"",
"",
map[string]string{"progress_window": "show", "version": updateVersion.String()},
map[string]string{"progress_window": "show", "version": pendingVersion.String()},
)
updateState := UpdateState{
PreUpdateVersion: m.currentVersion,
TargetVersion: updateVersion.String(),
TargetVersion: pendingVersion.String(),
}
if err := m.stateManager.UpdateState(updateState); err != nil {
log.Warnf("failed to update state: %v", err)
} else {
@@ -296,8 +436,9 @@ func (m *Manager) handleUpdate(ctx context.Context) {
}
}
if err := m.triggerUpdateFn(ctx, updateVersion.String()); err != nil {
log.Errorf("Error triggering auto-update: %v", err)
inst := installer.New()
if err := inst.RunInstallation(ctx, pendingVersion.String()); err != nil {
log.Errorf("error triggering update: %v", err)
m.statusRecorder.PublishEvent(
cProto.SystemEvent_ERROR,
cProto.SystemEvent_SYSTEM,
@@ -305,7 +446,9 @@ func (m *Manager) handleUpdate(ctx context.Context) {
fmt.Sprintf("Auto-update failed: %v", err),
nil,
)
return err
}
return nil
}
// loadAndDeleteUpdateState loads the update state, deletes it from storage, and returns it.
@@ -339,7 +482,7 @@ func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, e
return updateState, nil
}
func (m *Manager) shouldUpdate(updateVersion *v.Version) bool {
func (m *Manager) shouldUpdate(updateVersion *v.Version, forceUpdate bool) bool {
if m.currentVersion == developmentVersion {
log.Debugf("skipping auto-update, running development version")
return false
@@ -354,8 +497,8 @@ func (m *Manager) shouldUpdate(updateVersion *v.Version) bool {
return false
}
if time.Since(m.lastTrigger) < 5*time.Minute {
log.Debugf("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger))
if forceUpdate && time.Since(m.lastTrigger) < 3*time.Minute {
log.Infof("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger))
return false
}
@@ -367,8 +510,3 @@ func (m *Manager) lastResultErrReason() string {
result := installer.NewResultHandler(inst.TempDir())
return result.GetErrorResultReason()
}
func (m *Manager) triggerUpdate(ctx context.Context, targetVersion string) error {
inst := installer.New()
return inst.RunInstallation(ctx, targetVersion)
}

View File

@@ -0,0 +1,111 @@
//go:build !windows && !darwin
package updater
import (
"context"
"fmt"
"path"
"testing"
"time"
v "github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// On Linux, only Mode 1 (downloadOnly) is supported.
// SetVersion is a no-op because auto-update installation is not supported.
func Test_LatestVersion_Linux(t *testing.T) {
testMatrix := []struct {
name string
daemonVersion string
initialLatestVersion *v.Version
latestVersion *v.Version
shouldUpdateInit bool
shouldUpdateLater bool
}{
{
name: "Should notify again when a newer version arrives even within 5 minutes",
daemonVersion: "1.0.0",
initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
latestVersion: v.Must(v.NewSemver("1.0.2")),
shouldUpdateInit: true,
shouldUpdateLater: true,
},
{
name: "Shouldn't notify initially, but should notify as soon as latest version is fetched",
daemonVersion: "1.0.0",
initialLatestVersion: nil,
latestVersion: v.Must(v.NewSemver("1.0.1")),
shouldUpdateInit: false,
shouldUpdateLater: true,
},
}
for idx, c := range testMatrix {
mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
recorder := peer.NewRecorder("")
sub := recorder.SubscribeToEvents()
defer recorder.UnsubscribeFromEvents(sub)
m := NewManager(recorder, statemanager.New(tmpFile))
m.update = mockUpdate
m.currentVersion = c.daemonVersion
m.Start(context.Background())
m.SetDownloadOnly()
ver, enforced := waitForUpdateEvent(sub, 500*time.Millisecond)
triggeredInit := ver != ""
if enforced {
t.Errorf("%s: Linux Mode 1 must never have enforced metadata", c.name)
}
if triggeredInit != c.shouldUpdateInit {
t.Errorf("%s: Initial notify mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
}
if triggeredInit && c.initialLatestVersion != nil && ver != c.initialLatestVersion.String() {
t.Errorf("%s: Initial version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), ver)
}
mockUpdate.latestVersion = c.latestVersion
mockUpdate.onUpdate()
ver, enforced = waitForUpdateEvent(sub, 500*time.Millisecond)
triggeredLater := ver != ""
if enforced {
t.Errorf("%s: Linux Mode 1 must never have enforced metadata", c.name)
}
if triggeredLater != c.shouldUpdateLater {
t.Errorf("%s: Later notify mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
}
if triggeredLater && c.latestVersion != nil && ver != c.latestVersion.String() {
t.Errorf("%s: Later version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), ver)
}
m.Stop()
}
}
func Test_SetVersion_NoOp_Linux(t *testing.T) {
// On Linux, SetVersion should be a no-op — no events fired
tmpFile := path.Join(t.TempDir(), "update-test-noop.json")
recorder := peer.NewRecorder("")
sub := recorder.SubscribeToEvents()
defer recorder.UnsubscribeFromEvents(sub)
m := NewManager(recorder, statemanager.New(tmpFile))
m.update = &versionUpdateMock{latestVersion: v.Must(v.NewSemver("1.0.1"))}
m.currentVersion = "1.0.0"
m.Start(context.Background())
m.SetVersion("1.0.1", false)
ver, _ := waitForUpdateEvent(sub, 500*time.Millisecond)
if ver != "" {
t.Errorf("SetVersion should be a no-op on Linux, but got event with version %s", ver)
}
m.Stop()
}

View File

@@ -0,0 +1,227 @@
//go:build windows || darwin
package updater
import (
"context"
"fmt"
"path"
"testing"
"time"
v "github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
cProto "github.com/netbirdio/netbird/client/proto"
)
func Test_LatestVersion(t *testing.T) {
testMatrix := []struct {
name string
daemonVersion string
initialLatestVersion *v.Version
latestVersion *v.Version
shouldUpdateInit bool
shouldUpdateLater bool
}{
{
name: "Should notify again when a newer version arrives even within 5 minutes",
daemonVersion: "1.0.0",
initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
latestVersion: v.Must(v.NewSemver("1.0.2")),
shouldUpdateInit: true,
shouldUpdateLater: true,
},
{
name: "Shouldn't update initially, but should update as soon as latest version is fetched",
daemonVersion: "1.0.0",
initialLatestVersion: nil,
latestVersion: v.Must(v.NewSemver("1.0.1")),
shouldUpdateInit: false,
shouldUpdateLater: true,
},
}
for idx, c := range testMatrix {
mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
recorder := peer.NewRecorder("")
sub := recorder.SubscribeToEvents()
defer recorder.UnsubscribeFromEvents(sub)
m := NewManager(recorder, statemanager.New(tmpFile))
m.update = mockUpdate
m.currentVersion = c.daemonVersion
m.autoUpdateSupported = func() bool { return true }
m.Start(context.Background())
m.SetVersion("latest", false)
ver, _ := waitForUpdateEvent(sub, 500*time.Millisecond)
triggeredInit := ver != ""
if triggeredInit != c.shouldUpdateInit {
t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
}
if triggeredInit && c.initialLatestVersion != nil && ver != c.initialLatestVersion.String() {
t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), ver)
}
mockUpdate.latestVersion = c.latestVersion
mockUpdate.onUpdate()
ver, _ = waitForUpdateEvent(sub, 500*time.Millisecond)
triggeredLater := ver != ""
if triggeredLater != c.shouldUpdateLater {
t.Errorf("%s: Later update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
}
if triggeredLater && c.latestVersion != nil && ver != c.latestVersion.String() {
t.Errorf("%s: Later update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), ver)
}
m.Stop()
}
}
func Test_HandleUpdate(t *testing.T) {
testMatrix := []struct {
name string
daemonVersion string
latestVersion *v.Version
expectedVersion string
shouldUpdate bool
}{
{
name: "Install to a specific version should update regardless of if latestVersion is available yet",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.56.0",
shouldUpdate: true,
},
{
name: "Install to specific version should not update if version matches",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.55.0",
shouldUpdate: false,
},
{
name: "Install to specific version should not update if current version is newer",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.54.0",
shouldUpdate: false,
},
{
name: "Install to latest version should update if latest is newer",
daemonVersion: "0.55.0",
latestVersion: v.Must(v.NewSemver("0.56.0")),
expectedVersion: "latest",
shouldUpdate: true,
},
{
name: "Install to latest version should not update if latest == current",
daemonVersion: "0.56.0",
latestVersion: v.Must(v.NewSemver("0.56.0")),
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if daemon version is invalid",
daemonVersion: "development",
latestVersion: v.Must(v.NewSemver("1.0.0")),
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if expecting latest and latest version is unavailable",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if expected version is invalid",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "development",
shouldUpdate: false,
},
}
for idx, c := range testMatrix {
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
recorder := peer.NewRecorder("")
sub := recorder.SubscribeToEvents()
defer recorder.UnsubscribeFromEvents(sub)
m := NewManager(recorder, statemanager.New(tmpFile))
m.update = &versionUpdateMock{latestVersion: c.latestVersion}
m.currentVersion = c.daemonVersion
m.autoUpdateSupported = func() bool { return true }
m.Start(context.Background())
m.SetVersion(c.expectedVersion, false)
ver, _ := waitForUpdateEvent(sub, 500*time.Millisecond)
updateTriggered := ver != ""
if updateTriggered {
if c.expectedVersion == "latest" && c.latestVersion != nil && ver != c.latestVersion.String() {
t.Errorf("%s: Version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), ver)
} else if c.expectedVersion != "latest" && c.expectedVersion != "development" && ver != c.expectedVersion {
t.Errorf("%s: Version mismatch, expected %v, got %v", c.name, c.expectedVersion, ver)
}
}
if updateTriggered != c.shouldUpdate {
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered)
}
m.Stop()
}
}
func Test_EnforcedMetadata(t *testing.T) {
// Mode 1 (downloadOnly): no enforced metadata
tmpFile := path.Join(t.TempDir(), "update-test-mode1.json")
recorder := peer.NewRecorder("")
sub := recorder.SubscribeToEvents()
defer recorder.UnsubscribeFromEvents(sub)
m := NewManager(recorder, statemanager.New(tmpFile))
m.update = &versionUpdateMock{latestVersion: v.Must(v.NewSemver("1.0.1"))}
m.currentVersion = "1.0.0"
m.Start(context.Background())
m.SetDownloadOnly()
ver, enforced := waitForUpdateEvent(sub, 500*time.Millisecond)
if ver == "" {
t.Fatal("Mode 1: expected new_version_available event")
}
if enforced {
t.Error("Mode 1: expected no enforced metadata")
}
m.Stop()
// Mode 2 (enforced, forceUpdate=false): enforced metadata present, no auto-install
tmpFile2 := path.Join(t.TempDir(), "update-test-mode2.json")
recorder2 := peer.NewRecorder("")
sub2 := recorder2.SubscribeToEvents()
defer recorder2.UnsubscribeFromEvents(sub2)
m2 := NewManager(recorder2, statemanager.New(tmpFile2))
m2.update = &versionUpdateMock{latestVersion: nil}
m2.currentVersion = "1.0.0"
m2.autoUpdateSupported = func() bool { return true }
m2.Start(context.Background())
m2.SetVersion("1.0.1", false)
ver, enforced2 := waitForUpdateEvent(sub2, 500*time.Millisecond)
if ver == "" {
t.Fatal("Mode 2: expected new_version_available event")
}
if !enforced2 {
t.Error("Mode 2: expected enforced metadata")
}
m2.Stop()
}
// ensure the proto import is used
var _ = cProto.SystemEvent_INFO

View File

@@ -0,0 +1,56 @@
package updater
import (
"strconv"
"time"
v "github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/client/internal/peer"
)
type versionUpdateMock struct {
latestVersion *v.Version
onUpdate func()
}
func (m versionUpdateMock) StopWatch() {}
func (m versionUpdateMock) SetDaemonVersion(newVersion string) bool {
return false
}
func (m *versionUpdateMock) SetOnUpdateListener(updateFn func()) {
m.onUpdate = updateFn
}
func (m versionUpdateMock) LatestVersion() *v.Version {
return m.latestVersion
}
func (m versionUpdateMock) StartFetcher() {}
// waitForUpdateEvent waits for a new_version_available event, returns the version string or "" on timeout.
func waitForUpdateEvent(sub *peer.EventSubscription, timeout time.Duration) (version string, enforced bool) {
timer := time.NewTimer(timeout)
defer timer.Stop()
for {
select {
case event, ok := <-sub.Events():
if !ok {
return "", false
}
if val, ok := event.Metadata["new_version_available"]; ok {
enforced := false
if raw, ok := event.Metadata["enforced"]; ok {
if parsed, err := strconv.ParseBool(raw); err == nil {
enforced = parsed
}
}
return val, enforced
}
case <-timer.C:
return "", false
}
}
}

View File

@@ -10,7 +10,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
"github.com/netbirdio/netbird/client/internal/updater/downloader"
)
const (

View File

@@ -0,0 +1,22 @@
package updater
import (
"context"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/updater/installer"
)
func isAutoUpdateSupported() bool {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
isBrew := !installer.TypeOfInstaller(ctx).Downloadable()
if isBrew {
log.Warnf("auto-update disabled on Homebrew installation")
return false
}
return true
}

View File

@@ -0,0 +1,7 @@
//go:build !windows && !darwin
package updater
func isAutoUpdateSupported() bool {
return false
}

View File

@@ -0,0 +1,5 @@
package updater
func isAutoUpdateSupported() bool {
return true
}

View File

@@ -1,4 +1,4 @@
package updatemanager
package updater
import v "github.com/hashicorp/go-version"