From 84501a3f56728c8f08ae1f38ba8233da6f4e3281 Mon Sep 17 00:00:00 2001 From: M Essam Hamed Date: Mon, 25 Aug 2025 20:49:44 +0300 Subject: [PATCH] Fix deadlock issues --- client/internal/engine.go | 10 +- client/internal/updatemanager/manager.go | 207 ++++++++++++++--------- version/update.go | 32 ++-- 3 files changed, 146 insertions(+), 103 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index 0643817c8..75f1d1012 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -244,7 +244,6 @@ func NewEngine( statusRecorder: statusRecorder, checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), - updateManager: updatemanager.NewUpdateManager(clientCtx, statusRecorder), } sm := profilemanager.NewServiceManager("") @@ -705,7 +704,14 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() - if update.GetAutoUpdateVersion() != "skip" { + if e.updateManager == nil && update.GetAutoUpdateVersion() != "disabled" { + e.updateManager = updatemanager.NewUpdateManager(e.statusRecorder) + e.updateManager.Start(e.ctx) + } else if e.updateManager != nil && update.GetAutoUpdateVersion() == "disabled" { + e.updateManager.Stop() + e.updateManager = nil + } + if e.updateManager != nil { e.updateManager.SetVersion(update.GetAutoUpdateVersion()) } if update.GetNetbirdConfig() != nil { diff --git a/client/internal/updatemanager/manager.go b/client/internal/updatemanager/manager.go index 140ddf6dd..b537f0df1 100644 --- a/client/internal/updatemanager/manager.go +++ b/client/internal/updatemanager/manager.go @@ -21,129 +21,174 @@ import ( ) const ( - latestVersion = "latest" - disableAutoUpdate = "disabled" - unknownVersion = "Unknown" + latestVersion = "latest" ) type UpdateManager struct { - ctx context.Context - cancel context.CancelFunc - version string - latestVersion string - update *version.Update lastTrigger time.Time statusRecorder *peer.Status - mutex sync.Mutex - updateChannel chan string - doneChannel chan struct{} + mgmUpdateChan chan struct{} + updateChannel chan struct{} + wg sync.WaitGroup + + cancel context.CancelFunc + update *version.Update + + expectedVersion string + expectedVersionMutex sync.Mutex } -func NewUpdateManager(ctx context.Context, statusRecorder *peer.Status) *UpdateManager { - update := version.NewUpdate("nb/client") - ctx, cancel := context.WithCancel(ctx) +func NewUpdateManager(statusRecorder *peer.Status) *UpdateManager { manager := &UpdateManager{ - update: update, - lastTrigger: time.Now().Add(-10 * time.Minute), statusRecorder: statusRecorder, - ctx: ctx, - cancel: cancel, - version: disableAutoUpdate, - latestVersion: unknownVersion, - updateChannel: make(chan string, 4), - doneChannel: make(chan struct{}), + mgmUpdateChan: make(chan struct{}, 1), + updateChannel: make(chan struct{}, 1), } - update.SetDaemonVersion(version.NetbirdVersion()) - update.SetOnUpdateChannel(manager.updateChannel) - go manager.UpdateLoop() return manager } +func (u *UpdateManager) Start(ctx context.Context) { + if u.cancel != nil { + log.Errorf("UpdateManager already started") + return + } + + u.update = version.NewUpdate("nb/client") + u.update.SetDaemonVersion(version.NetbirdVersion()) + u.update.SetOnUpdateListener(func() { + select { + case u.updateChannel <- struct{}{}: + default: + } + }) + + ctx, cancel := context.WithCancel(ctx) + u.cancel = cancel + + u.wg.Add(1) + go u.updateLoop(ctx) +} + func (u *UpdateManager) SetVersion(v string) { - u.mutex.Lock() - if u.version != v { - log.Tracef("Auto-update version set to %s", v) - u.version = v - u.mutex.Unlock() - u.updateChannel <- unknownVersion - } else { - u.mutex.Unlock() + if u.cancel == nil { + log.Errorf("UpdateManager not started") + return + } + + u.expectedVersionMutex.Lock() + defer u.expectedVersionMutex.Unlock() + if u.expectedVersion == v { + return + } + + u.expectedVersion = v + + select { + case u.mgmUpdateChan <- struct{}{}: + default: } } func (u *UpdateManager) Stop() { + if u.cancel == nil { + return + } + u.cancel() - u.mutex.Lock() - defer u.mutex.Unlock() if u.update != nil { u.update.StopWatch() u.update = nil } - <-u.doneChannel + + u.wg.Wait() } -func (u *UpdateManager) UpdateLoop() { +func (u *UpdateManager) updateLoop(ctx context.Context) { + defer u.wg.Done() + for { select { - case <-u.ctx.Done(): - u.doneChannel <- struct{}{} + case <-ctx.Done(): return - case latestVersion := <-u.updateChannel: - u.mutex.Lock() - if latestVersion != unknownVersion { - u.latestVersion = latestVersion - } - u.mutex.Unlock() - ctx, cancel := context.WithDeadline(u.ctx, time.Now().Add(time.Minute)) - u.CheckForUpdates(ctx) - cancel() + case <-u.mgmUpdateChan: + case <-u.updateChannel: } + + u.handleUpdate(ctx) } } -func (u *UpdateManager) CheckForUpdates(ctx context.Context) { - if u.version == disableAutoUpdate { - log.Trace("Skipped checking for updates, auto-update is disabled") - return - } - currentVersionString := version.NetbirdVersion() - updateVersionString := u.version - if updateVersionString == latestVersion || updateVersionString == "" { - if u.latestVersion == unknownVersion { +func (u *UpdateManager) handleUpdate(ctx context.Context) { + var updateVersion *v.Version + + u.expectedVersionMutex.Lock() + expectedVersion := u.expectedVersion + u.expectedVersionMutex.Unlock() + + // Resolve "latest" to actual version + if expectedVersion == latestVersion { + if !u.isVersionAvailable() { log.Tracef("Latest version not fetched yet") return } - updateVersionString = u.latestVersion + updateVersion = u.update.LatestVersion() + } else { + var err error + updateVersion, err = v.NewSemver(expectedVersion) + if err != nil { + log.Errorf("Failed to parse latest version: %v", err) + return + } } + + if !u.shouldUpdate(updateVersion) { + return + } + + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(time.Minute)) + defer cancel() + + u.lastTrigger = time.Now() + log.Debugf("Auto-update triggered, current version: %s, target version: %s", version.NetbirdVersion(), updateVersion) + u.statusRecorder.PublishEvent( + cProto.SystemEvent_INFO, + cProto.SystemEvent_SYSTEM, + "Automatically updating client", + "Your client version is older than auto-update version set in Management, updating client now.", + nil, + ) + + err := u.triggerUpdate(ctx, updateVersion.String()) + if err != nil { + log.Errorf("Error triggering auto-update: %v", err) + } +} + +func (u *UpdateManager) shouldUpdate(updateVersion *v.Version) bool { + currentVersionString := version.NetbirdVersion() currentVersion, err := v.NewVersion(currentVersionString) if err != nil { log.Errorf("Error checking for update, error parsing version `%s`: %v", currentVersionString, err) - return + return false } - updateVersion, err := v.NewVersion(updateVersionString) - if err != nil { - log.Errorf("Error checking for update, error parsing version `%s`: %v", updateVersionString, err) - return + if currentVersion.GreaterThanOrEqual(updateVersion) { + log.Debugf("Current version (%s) is equal to or higher than auto-update version (%s)", currentVersionString, updateVersion) + return false } - if currentVersion.LessThan(updateVersion) { - if u.lastTrigger.Add(5 * time.Minute).Before(time.Now()) { - u.lastTrigger = time.Now() - log.Debugf("Auto-update triggered, current version: %s, target version: %s", currentVersionString, updateVersionString) - u.statusRecorder.PublishEvent( - cProto.SystemEvent_INFO, - cProto.SystemEvent_SYSTEM, - "Automatically updating client", - "Your client version is older than auto-update version set in Management, updating client now.", - nil, - ) - err = u.triggerUpdate(ctx, updateVersionString) - if err != nil { - log.Errorf("Error triggering auto-update: %v", err) - } - } - } else { - log.Debugf("Current version (%s) is equal to or higher than auto-update version (%s)", currentVersionString, updateVersionString) + + if time.Since(u.lastTrigger) < 5*time.Minute { + log.Tracef("No need to update") + return false } + + return true +} + +func (u *UpdateManager) isVersionAvailable() bool { + if u.update.LatestVersion() == nil { + return false + } + return true } func downloadFileToTemporaryDir(ctx context.Context, fileURL string) (string, error) { //nolint:unused diff --git a/version/update.go b/version/update.go index 138568418..16f9d5770 100644 --- a/version/update.go +++ b/version/update.go @@ -31,7 +31,6 @@ type Update struct { fetchDone chan struct{} onUpdateListener func() - onUpdateChannel chan string listenerLock sync.Mutex } @@ -42,14 +41,11 @@ func NewUpdate(httpAgent string) *Update { currentVersion, _ = goversion.NewVersion("0.0.0") } - latestAvailable, _ := goversion.NewVersion("0.0.0") - u := &Update{ - httpAgent: httpAgent, - latestAvailable: latestAvailable, - uiVersion: currentVersion, - fetchTicker: time.NewTicker(fetchPeriod), - fetchDone: make(chan struct{}), + httpAgent: httpAgent, + uiVersion: currentVersion, + fetchTicker: time.NewTicker(fetchPeriod), + fetchDone: make(chan struct{}), } go u.startFetcher() return u @@ -95,15 +91,10 @@ func (u *Update) SetOnUpdateListener(updateFn func()) { } } -func (u *Update) SetOnUpdateChannel(updateChannel chan string) { - u.listenerLock.Lock() - defer u.listenerLock.Unlock() - u.onUpdateChannel = updateChannel - if u.isUpdateAvailable() { - u.versionsLock.Lock() - defer u.versionsLock.Unlock() - u.onUpdateChannel <- u.latestAvailable.String() - } +func (u *Update) LatestVersion() *goversion.Version { + u.versionsLock.Lock() + defer u.versionsLock.Unlock() + return u.latestAvailable } func (u *Update) startFetcher() { @@ -181,9 +172,6 @@ func (u *Update) checkUpdate() bool { u.listenerLock.Lock() defer u.listenerLock.Unlock() - if u.onUpdateChannel != nil { - u.onUpdateChannel <- u.latestAvailable.String() - } if u.onUpdateListener == nil { return true } @@ -196,6 +184,10 @@ func (u *Update) isUpdateAvailable() bool { u.versionsLock.Lock() defer u.versionsLock.Unlock() + if u.latestAvailable == nil { + return false + } + if u.latestAvailable.GreaterThan(u.uiVersion) { return true }