diff --git a/client/internal/updatemanager/manager.go b/client/internal/updatemanager/manager.go index f5a6d7aff..48db70c43 100644 --- a/client/internal/updatemanager/manager.go +++ b/client/internal/updatemanager/manager.go @@ -24,15 +24,25 @@ const ( latestVersion = "latest" ) +type UpdateInterface interface { + StopWatch() + SetDaemonVersion(newVersion string) bool + SetOnUpdateListener(updateFn func()) + LatestVersion() *v.Version + StartFetcher() +} + type UpdateManager struct { lastTrigger time.Time statusRecorder *peer.Status mgmUpdateChan chan struct{} updateChannel chan struct{} wg sync.WaitGroup + currentVersion string + updateFunc func(ctx context.Context, targetVersion string) error cancel context.CancelFunc - update *version.Update + update UpdateInterface expectedVersion *v.Version updateToLatestVersion bool @@ -44,18 +54,26 @@ func NewUpdateManager(statusRecorder *peer.Status) *UpdateManager { statusRecorder: statusRecorder, mgmUpdateChan: make(chan struct{}, 1), updateChannel: make(chan struct{}, 1), + currentVersion: version.NetbirdVersion(), + updateFunc: triggerUpdate, + update: version.NewUpdate("nb/client"), } return manager } +func (u *UpdateManager) WithCustomVersionUpdate(versionUpdate UpdateInterface) *UpdateManager { + u.update = versionUpdate + return u +} + func (u *UpdateManager) Start(ctx context.Context) { if u.cancel != nil { log.Errorf("UpdateManager already started") return } - u.update = version.NewUpdate("nb/client") - u.update.SetDaemonVersion(version.NetbirdVersion()) + go u.update.StartFetcher() + u.update.SetDaemonVersion(u.currentVersion) u.update.SetOnUpdateListener(func() { select { case u.updateChannel <- struct{}{}: @@ -162,7 +180,7 @@ func (u *UpdateManager) handleUpdate(ctx context.Context) { defer cancel() u.lastTrigger = time.Now() - log.Debugf("Auto-update triggered, current version: %s, target version: %s", version.NetbirdVersion(), updateVersion) + log.Debugf("Auto-update triggered, current version: %s, target version: %s", u.currentVersion, updateVersion) u.statusRecorder.PublishEvent( cProto.SystemEvent_INFO, cProto.SystemEvent_SYSTEM, @@ -171,21 +189,20 @@ func (u *UpdateManager) handleUpdate(ctx context.Context) { nil, ) - err := u.triggerUpdate(ctx, updateVersion.String()) + err := u.updateFunc(ctx, updateVersion.String()) if err != nil { log.Errorf("Error triggering auto-update: %v", err) } } func (u *UpdateManager) shouldUpdate(updateVersion *v.Version) bool { - currentVersionString := version.NetbirdVersion() - currentVersion, err := v.NewVersion(currentVersionString) + currentVersion, err := v.NewVersion(u.currentVersion) if err != nil { - log.Errorf("Error checking for update, error parsing version `%s`: %v", currentVersionString, err) + log.Errorf("Error checking for update, error parsing version `%s`: %v", u.currentVersion, err) return false } if currentVersion.GreaterThanOrEqual(updateVersion) { - log.Debugf("Current version (%s) is equal to or higher than auto-update version (%s)", currentVersionString, updateVersion) + log.Debugf("Current version (%s) is equal to or higher than auto-update version (%s)", u.currentVersion, updateVersion) return false } diff --git a/client/internal/updatemanager/manager_test.go b/client/internal/updatemanager/manager_test.go new file mode 100644 index 000000000..7cd66e9d6 --- /dev/null +++ b/client/internal/updatemanager/manager_test.go @@ -0,0 +1,203 @@ +package updatemanager + +import ( + "context" + v "github.com/hashicorp/go-version" + "github.com/netbirdio/netbird/client/internal/peer" + "testing" + "time" +) + +type versionUpdateMock struct { + latestVersion *v.Version + onUpdate func() +} + +func (v versionUpdateMock) StopWatch() {} + +func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool { + return false +} + +func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) { + v.onUpdate = updateFn +} + +func (v versionUpdateMock) LatestVersion() *v.Version { + return v.latestVersion +} + +func (v versionUpdateMock) StartFetcher() {} + +func Test_LatestVersion(t *testing.T) { + testMatrix := []struct { + name string + daemonVersion string + initialLatestVersion *v.Version + latestVersion *v.Version + shouldUpdateInit bool + shouldUpdateLater bool + }{ + { + name: "Should only trigger update once due to time between triggers being < 5 Minutes", + daemonVersion: "1.0.0", + initialLatestVersion: v.Must(v.NewSemver("1.0.1")), + latestVersion: v.Must(v.NewSemver("1.0.2")), + shouldUpdateInit: true, + shouldUpdateLater: false, + }, + { + name: "Shouldn't update initially, but should update as soon as latest version is fetched", + daemonVersion: "1.0.0", + initialLatestVersion: nil, + latestVersion: v.Must(v.NewSemver("1.0.1")), + shouldUpdateInit: false, + shouldUpdateLater: true, + }, + } + + for _, c := range testMatrix { + mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion} + m := NewUpdateManager(peer.NewRecorder("")).WithCustomVersionUpdate(mockUpdate) + + targetVersionChan := make(chan string, 1) + + m.updateFunc = func(ctx context.Context, targetVersion string) error { + targetVersionChan <- targetVersion + return nil + } + m.currentVersion = c.daemonVersion + m.Start(t.Context()) + m.SetVersion("latest") + var triggeredInit bool + select { + case targetVersion := <-targetVersionChan: + if targetVersion != c.initialLatestVersion.String() { + t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion) + } + triggeredInit = true + case <-time.After(10 * time.Millisecond): + triggeredInit = false + } + if triggeredInit != c.shouldUpdateInit { + t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit) + } + + mockUpdate.latestVersion = c.latestVersion + mockUpdate.onUpdate() + + var triggeredLater bool + select { + case targetVersion := <-targetVersionChan: + if targetVersion != c.latestVersion.String() { + t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion) + } + triggeredLater = true + case <-time.After(10 * time.Millisecond): + triggeredLater = false + } + if triggeredLater != c.shouldUpdateLater { + t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater) + } + + m.Stop() + } +} + +func Test_HandleUpdate(t *testing.T) { + testMatrix := []struct { + name string + daemonVersion string + latestVersion *v.Version + expectedVersion string + shouldUpdate bool + }{ + { + name: "Update to a specific version should update regardless of if latestVersion is available yet", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "0.56.0", + shouldUpdate: true, + }, + { + name: "Update to specific version should not update if version matches", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "0.55.0", + shouldUpdate: false, + }, + { + name: "Update to specific version should not update if current version is newer", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "0.54.0", + shouldUpdate: false, + }, + { + name: "Update to latest version should update if latest is newer", + daemonVersion: "0.55.0", + latestVersion: v.Must(v.NewSemver("0.56.0")), + expectedVersion: "latest", + shouldUpdate: true, + }, + { + name: "Update to latest version should not update if latest == current", + daemonVersion: "0.56.0", + latestVersion: v.Must(v.NewSemver("0.56.0")), + expectedVersion: "latest", + shouldUpdate: false, + }, + { + name: "Should not update if daemon version is invalid", + daemonVersion: "development", + latestVersion: v.Must(v.NewSemver("1.0.0")), + expectedVersion: "latest", + shouldUpdate: false, + }, + { + name: "Should not update if expecting latest and latest version is unavailable", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "latest", + shouldUpdate: false, + }, + { + name: "Should not update if expected version is invalid", + daemonVersion: "0.55.0", + latestVersion: nil, + expectedVersion: "development", + shouldUpdate: false, + }, + } + for _, c := range testMatrix { + m := NewUpdateManager(peer.NewRecorder("")).WithCustomVersionUpdate(&versionUpdateMock{latestVersion: c.latestVersion}) + targetVersionChan := make(chan string, 1) + + m.updateFunc = func(ctx context.Context, targetVersion string) error { + targetVersionChan <- targetVersion + return nil + } + + m.currentVersion = c.daemonVersion + m.Start(t.Context()) + m.SetVersion(c.expectedVersion) + + var updateTriggered bool + select { + case targetVersion := <-targetVersionChan: + if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() { + t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion) + } else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion { + t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion) + } + updateTriggered = true + case <-time.After(10 * time.Millisecond): + updateTriggered = false + } + + if updateTriggered != c.shouldUpdate { + t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered) + } + m.Stop() + } +} diff --git a/client/internal/updatemanager/update_darwin.go b/client/internal/updatemanager/update_darwin.go index 80eac0c42..fc91b9cdf 100644 --- a/client/internal/updatemanager/update_darwin.go +++ b/client/internal/updatemanager/update_darwin.go @@ -16,13 +16,13 @@ const ( pkgDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg" ) -func (u *UpdateManager) triggerUpdate(ctx context.Context, targetVersion string) error { +func triggerUpdate(ctx context.Context, targetVersion string) error { cmd := exec.CommandContext(ctx, "pkgutil", "--pkg-info", "io.netbird.client") outBytes, err := cmd.Output() if err != nil && cmd.ProcessState.ExitCode() == 1 { // Not installed using pkg file, thus installed using Homebrew - return u.updateHomeBrew(ctx) + return updateHomeBrew(ctx) } // Installed using pkg file path, err := downloadFileToTemporaryDir(ctx, urlWithVersionArch(pkgDownloadURL, targetVersion)) @@ -49,7 +49,7 @@ func (u *UpdateManager) triggerUpdate(ctx context.Context, targetVersion string) return err } -func (u *UpdateManager) updateHomeBrew(ctx context.Context) error { +func updateHomeBrew(ctx context.Context) error { // Homebrew must be run as a non-root user // To find out which user installed NetBird using HomeBrew we can check the owner of our brew tap directory fileInfo, err := os.Stat("/opt/homebrew/Library/Taps/netbirdio/homebrew-tap/") diff --git a/client/internal/updatemanager/update_freebsd.go b/client/internal/updatemanager/update_freebsd.go index f987897f1..18b3dea24 100644 --- a/client/internal/updatemanager/update_freebsd.go +++ b/client/internal/updatemanager/update_freebsd.go @@ -4,7 +4,7 @@ package updatemanager import "context" -func (u *UpdateManager) triggerUpdate(ctx context.Context, targetVersion string) error { +func triggerUpdate(ctx context.Context, targetVersion string) error { // TODO: Implement return nil } diff --git a/client/internal/updatemanager/update_linux.go b/client/internal/updatemanager/update_linux.go index 64caa0f77..7ff116a16 100644 --- a/client/internal/updatemanager/update_linux.go +++ b/client/internal/updatemanager/update_linux.go @@ -4,7 +4,7 @@ package updatemanager import "context" -func (u *UpdateManager) triggerUpdate(ctx context.Context, targetVersion string) error { +func triggerUpdate(ctx context.Context, targetVersion string) error { // TODO: Implement return nil } diff --git a/client/internal/updatemanager/update_windows.go b/client/internal/updatemanager/update_windows.go index c2b77a805..98516f8bb 100644 --- a/client/internal/updatemanager/update_windows.go +++ b/client/internal/updatemanager/update_windows.go @@ -40,7 +40,7 @@ func installationMethod() string { return "EXE" } -func (u *UpdateManager) updateMSI(ctx context.Context, targetVersion string) error { +func updateMSI(ctx context.Context, targetVersion string) error { path, err := downloadFileToTemporaryDir(ctx, urlWithVersionArch(msiDownloadURL, targetVersion)) if err != nil { return err @@ -50,7 +50,7 @@ func (u *UpdateManager) updateMSI(ctx context.Context, targetVersion string) err return err } -func (u *UpdateManager) updateEXE(ctx context.Context, targetVersion string) error { +func updateEXE(ctx context.Context, targetVersion string) error { path, err := downloadFileToTemporaryDir(ctx, urlWithVersionArch(exeDownloadURL, targetVersion)) if err != nil { return err @@ -65,12 +65,12 @@ func (u *UpdateManager) updateEXE(ctx context.Context, targetVersion string) err return err } -func (u *UpdateManager) triggerUpdate(ctx context.Context, targetVersion string) error { +func triggerUpdate(ctx context.Context, targetVersion string) error { switch installationMethod() { case "EXE": - return u.updateEXE(ctx, targetVersion) + return updateEXE(ctx, targetVersion) case "MSI": - return u.updateMSI(ctx, targetVersion) + return updateMSI(ctx, targetVersion) default: return fmt.Errorf("unsupported installation method: %s", installationMethod()) } diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 2403b5d05..52d3ffee8 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -330,7 +330,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient { showAdvancedSettings: args.showSettings, showNetworks: args.showNetworks, - update: version.NewUpdate("nb/client-ui"), + update: version.NewUpdateAndStart("nb/client-ui"), } s.eventHandler = newEventHandler(s) @@ -529,7 +529,7 @@ func (s *serviceClient) getSettingsForm() *widget.Form { var req proto.SetConfigRequest req.ProfileName = activeProf.Name req.Username = currUser.Username - + if iMngURL != "" { req.ManagementUrl = iMngURL } diff --git a/management/internals/server/server.go b/management/internals/server/server.go index e868c2529..cfe4f32e1 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -180,7 +180,7 @@ func (s *BaseServer) Start(ctx context.Context) error { log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String()) s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled) - s.update = version.NewUpdate("nb/management") + s.update = version.NewUpdateAndStart("nb/management") s.update.SetDaemonVersion(version.NetbirdVersion()) s.update.SetOnUpdateListener(func() { log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion()) diff --git a/version/update.go b/version/update.go index 16f9d5770..a324d97fe 100644 --- a/version/update.go +++ b/version/update.go @@ -42,17 +42,27 @@ func NewUpdate(httpAgent string) *Update { } u := &Update{ - httpAgent: httpAgent, - uiVersion: currentVersion, - fetchTicker: time.NewTicker(fetchPeriod), - fetchDone: make(chan struct{}), + httpAgent: httpAgent, + uiVersion: currentVersion, + fetchDone: make(chan struct{}), } - go u.startFetcher() + + return u +} + +func NewUpdateAndStart(httpAgent string) *Update { + u := NewUpdate(httpAgent) + go u.StartFetcher() + return u } // StopWatch stop the version info fetch loop func (u *Update) StopWatch() { + if u.fetchTicker == nil { + return + } + u.fetchTicker.Stop() select { @@ -97,7 +107,12 @@ func (u *Update) LatestVersion() *goversion.Version { return u.latestAvailable } -func (u *Update) startFetcher() { +func (u *Update) StartFetcher() { + if u.fetchTicker != nil { + return + } + u.fetchTicker = time.NewTicker(fetchPeriod) + if changed := u.fetchVersion(); changed { u.checkUpdate() } diff --git a/version/update_test.go b/version/update_test.go index a733714cf..d5d60800e 100644 --- a/version/update_test.go +++ b/version/update_test.go @@ -23,7 +23,7 @@ func TestNewUpdate(t *testing.T) { wg.Add(1) onUpdate := false - u := NewUpdate(httpAgent) + u := NewUpdateAndStart(httpAgent) defer u.StopWatch() u.SetOnUpdateListener(func() { onUpdate = true @@ -48,7 +48,7 @@ func TestDoNotUpdate(t *testing.T) { wg.Add(1) onUpdate := false - u := NewUpdate(httpAgent) + u := NewUpdateAndStart(httpAgent) defer u.StopWatch() u.SetOnUpdateListener(func() { onUpdate = true @@ -73,7 +73,7 @@ func TestDaemonUpdate(t *testing.T) { wg.Add(1) onUpdate := false - u := NewUpdate(httpAgent) + u := NewUpdateAndStart(httpAgent) defer u.StopWatch() u.SetOnUpdateListener(func() { onUpdate = true