package updatemanager import ( "context" "fmt" "io" "net/http" "os" "path/filepath" "strings" "sync" "time" v "github.com/hashicorp/go-version" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" cProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/version" ) const ( latestVersion = "latest" // this version will be ignored developmentVersion = "development" ) type UpdateInterface interface { StopWatch() SetDaemonVersion(newVersion string) bool SetOnUpdateListener(updateFn func()) LatestVersion() *v.Version StartFetcher() } type UpdateState struct { PreUpdateVersion string TargetVersion string } func (u UpdateState) Name() string { return "autoUpdate" } type UpdateManager struct { statusRecorder *peer.Status stateManager *statemanager.Manager lastTrigger time.Time mgmUpdateChan chan struct{} updateChannel chan struct{} currentVersion string update UpdateInterface wg sync.WaitGroup cancel context.CancelFunc expectedVersion *v.Version updateToLatestVersion bool // updateMutex protect update and expectedVersion fields updateMutex sync.Mutex // updateFunc is used for testing to mock the triggerUpdate behavior updateFunc func(ctx context.Context, targetVersion string) error } func NewUpdateManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) *UpdateManager { manager := &UpdateManager{ statusRecorder: statusRecorder, stateManager: stateManager, mgmUpdateChan: make(chan struct{}, 1), updateChannel: make(chan struct{}, 1), currentVersion: version.NetbirdVersion(), update: version.NewUpdate("nb/client"), } return manager } // CheckUpdateSuccess checks if the update was successful. It works without to start the update manager. func (u *UpdateManager) CheckUpdateSuccess(ctx context.Context) { u.updateStateManager(ctx) } func (u *UpdateManager) Start(ctx context.Context) { if u.cancel != nil { log.Errorf("UpdateManager already started") return } u.update.SetDaemonVersion(u.currentVersion) u.update.SetOnUpdateListener(func() { select { case u.updateChannel <- struct{}{}: default: } }) go u.update.StartFetcher() ctx, cancel := context.WithCancel(ctx) u.cancel = cancel u.wg.Add(1) go u.updateLoop(ctx) } func (u *UpdateManager) SetVersion(expectedVersion string) { log.Infof("set expected agent version for upgrade: %s", expectedVersion) if u.cancel == nil { log.Errorf("UpdateManager not started") return } u.updateMutex.Lock() defer u.updateMutex.Unlock() if expectedVersion == latestVersion { u.updateToLatestVersion = true u.expectedVersion = nil } else { expectedSemVer, err := v.NewVersion(expectedVersion) if err != nil { log.Errorf("Error parsing version: %v", err) return } if u.expectedVersion != nil && u.expectedVersion.Equal(expectedSemVer) { return } u.expectedVersion = expectedSemVer u.updateToLatestVersion = false } select { case u.mgmUpdateChan <- struct{}{}: default: } } func (u *UpdateManager) Stop() { if u.cancel == nil { return } u.cancel() u.updateMutex.Lock() if u.update != nil { u.update.StopWatch() u.update = nil } u.updateMutex.Unlock() u.wg.Wait() } func (u *UpdateManager) onContextCancel() { if u.cancel == nil { return } u.updateMutex.Lock() defer u.updateMutex.Unlock() if u.update != nil { u.update.StopWatch() u.update = nil } } func (u *UpdateManager) updateLoop(ctx context.Context) { defer u.wg.Done() for { select { case <-ctx.Done(): u.onContextCancel() return case <-u.mgmUpdateChan: case <-u.updateChannel: log.Infof("fetched new version info") } u.handleUpdate(ctx) } } func (u *UpdateManager) handleUpdate(ctx context.Context) { var updateVersion *v.Version u.updateMutex.Lock() if u.update == nil { u.updateMutex.Unlock() return } expectedVersion := u.expectedVersion useLatest := u.updateToLatestVersion curLatestVersion := u.update.LatestVersion() u.updateMutex.Unlock() switch { // Resolve "latest" to actual version case useLatest: if curLatestVersion == nil { log.Tracef("latest version not fetched yet") return } updateVersion = curLatestVersion // Update to specific version case expectedVersion != nil: updateVersion = expectedVersion default: log.Debugf("no expected version information set") return } log.Debugf("checking update option, current version: %s, target version: %s", u.currentVersion, updateVersion) if !u.shouldUpdate(updateVersion) { return } ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() u.lastTrigger = time.Now() log.Debugf("Auto-update triggered, current version: %s, target version: %s", u.currentVersion, updateVersion) u.statusRecorder.PublishEvent( cProto.SystemEvent_INFO, cProto.SystemEvent_SYSTEM, "Automatically updating client", "Your client version is older than auto-update version set in Management, updating client now.", nil, ) u.statusRecorder.PublishEvent( cProto.SystemEvent_INFO, cProto.SystemEvent_SYSTEM, "", "", map[string]string{"progress_window": "show"}, ) updateState := UpdateState{ PreUpdateVersion: u.currentVersion, TargetVersion: updateVersion.String(), } if err := u.stateManager.UpdateState(updateState); err != nil { log.Warnf("failed to update state: %v", err) } else { if err = u.stateManager.PersistState(ctx); err != nil { log.Warnf("failed to persist state: %v", err) } } if err := u.triggerUpdate(ctx, updateVersion.String()); err != nil { log.Errorf("Error triggering auto-update: %v", err) u.statusRecorder.PublishEvent( cProto.SystemEvent_ERROR, cProto.SystemEvent_SYSTEM, "Auto-update failed", fmt.Sprintf("Auto-update failed: %v", err), nil, ) u.statusRecorder.PublishEvent( cProto.SystemEvent_INFO, cProto.SystemEvent_SYSTEM, "", "", map[string]string{"progress_window": "hide"}, ) } } func (u *UpdateManager) updateStateManager(ctx context.Context) { stateType := &UpdateState{} u.stateManager.RegisterState(stateType) if err := u.stateManager.LoadState(stateType); err != nil { log.Errorf("failed to load state: %v", err) return } state := u.stateManager.GetState(stateType) if state == nil { return } updateState, ok := state.(*UpdateState) if !ok { log.Errorf("failed to cast state to UpdateState") return } log.Debugf("autoUpdate state loaded, %v", *updateState) if updateState.TargetVersion == u.currentVersion { log.Infof("published notification event") u.statusRecorder.PublishEvent( cProto.SystemEvent_INFO, cProto.SystemEvent_SYSTEM, "Auto-update completed", fmt.Sprintf("Your NetBird Client was auto-updated to version %s", u.currentVersion), nil, ) } if err := u.stateManager.DeleteState(updateState); err != nil { log.Errorf("failed to delete state: %v", err) } else if err = u.stateManager.PersistState(ctx); err != nil { log.Errorf("failed to persist state: %v", err) } } func (u *UpdateManager) shouldUpdate(updateVersion *v.Version) bool { if u.currentVersion == developmentVersion { log.Debugf("skipping auto-update, running development version") return false } currentVersion, err := v.NewVersion(u.currentVersion) if err != nil { log.Errorf("error checking for update, error parsing version `%s`: %v", u.currentVersion, err) return false } if currentVersion.GreaterThanOrEqual(updateVersion) { log.Infof("current version (%s) is equal to or higher than auto-update version (%s)", u.currentVersion, updateVersion) return false } if time.Since(u.lastTrigger) < 5*time.Minute { log.Debugf("skipping auto-update, last update was %s ago", time.Since(u.lastTrigger)) return false } return true } func downloadFileToTemporaryDir(ctx context.Context, fileURL string) (string, error) { //nolint:unused tempDir, err := os.MkdirTemp("", "netbird-installer-*") if err != nil { return "", fmt.Errorf("error creating temporary directory: %w", err) } // Clean up temp directory on error var success bool defer func() { if !success { if err := os.RemoveAll(tempDir); err != nil { log.Errorf("error cleaning up temporary directory: %v", err) } } }() fileNameParts := strings.Split(fileURL, "/") out, err := os.Create(filepath.Join(tempDir, fileNameParts[len(fileNameParts)-1])) if err != nil { return "", fmt.Errorf("error creating temporary file: %w", err) } defer func() { if err := out.Close(); err != nil { log.Errorf("error closing temporary file: %v", err) } }() req, err := http.NewRequestWithContext(ctx, "GET", fileURL, nil) if err != nil { return "", fmt.Errorf("error creating file download request: %w", err) } resp, err := http.DefaultClient.Do(req) if err != nil { return "", fmt.Errorf("error downloading file: %w", err) } defer func() { if err := resp.Body.Close(); err != nil { log.Errorf("Error closing response body: %v", err) } }() if resp.StatusCode != http.StatusOK { log.Errorf("error downloading update file, received status code: %d", resp.StatusCode) return "", fmt.Errorf("error downloading file, received status code: %d", resp.StatusCode) } _, err = io.Copy(out, resp.Body) if err != nil { return "", fmt.Errorf("error downloading file: %w", err) } log.Infof("downloaded update file to %s", out.Name()) success = true // Mark success to prevent cleanup return out.Name(), nil }