mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-03 23:56:38 +00:00
Merge remote-tracking branch 'origin/main' into refactor/permissions-manager
# Conflicts: # management/internals/modules/reverseproxy/domain/manager/manager.go # management/internals/modules/reverseproxy/service/manager/api.go # management/internals/server/modules.go # management/server/http/testing/testing_tools/channel/channel.go
This commit is contained in:
@@ -27,8 +27,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||
"github.com/netbirdio/netbird/client/internal/updater"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
@@ -44,13 +44,13 @@ import (
|
||||
)
|
||||
|
||||
type ConnectClient struct {
|
||||
ctx context.Context
|
||||
config *profilemanager.Config
|
||||
statusRecorder *peer.Status
|
||||
doInitialAutoUpdate bool
|
||||
ctx context.Context
|
||||
config *profilemanager.Config
|
||||
statusRecorder *peer.Status
|
||||
|
||||
engine *Engine
|
||||
engineMutex sync.Mutex
|
||||
engine *Engine
|
||||
engineMutex sync.Mutex
|
||||
updateManager *updater.Manager
|
||||
|
||||
persistSyncResponse bool
|
||||
}
|
||||
@@ -59,17 +59,19 @@ func NewConnectClient(
|
||||
ctx context.Context,
|
||||
config *profilemanager.Config,
|
||||
statusRecorder *peer.Status,
|
||||
doInitalAutoUpdate bool,
|
||||
) *ConnectClient {
|
||||
return &ConnectClient{
|
||||
ctx: ctx,
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
doInitialAutoUpdate: doInitalAutoUpdate,
|
||||
engineMutex: sync.Mutex{},
|
||||
ctx: ctx,
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
engineMutex: sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
|
||||
c.updateManager = um
|
||||
}
|
||||
|
||||
// Run with main logic.
|
||||
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
|
||||
return c.run(MobileDependency{}, runningChan, logPath)
|
||||
@@ -187,14 +189,13 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
stateManager := statemanager.New(path)
|
||||
stateManager.RegisterState(&sshconfig.ShutdownState{})
|
||||
|
||||
updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager)
|
||||
if err == nil {
|
||||
updateManager.CheckUpdateSuccess(c.ctx)
|
||||
if c.updateManager != nil {
|
||||
c.updateManager.CheckUpdateSuccess(c.ctx)
|
||||
}
|
||||
|
||||
inst := installer.New()
|
||||
if err := inst.CleanUpInstallerFiles(); err != nil {
|
||||
log.Errorf("failed to clean up temporary installer file: %v", err)
|
||||
}
|
||||
inst := installer.New()
|
||||
if err := inst.CleanUpInstallerFiles(); err != nil {
|
||||
log.Errorf("failed to clean up temporary installer file: %v", err)
|
||||
}
|
||||
|
||||
defer c.statusRecorder.ClientStop()
|
||||
@@ -308,7 +309,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
checks := loginResp.GetChecks()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
|
||||
engine := NewEngine(engineCtx, cancel, engineConfig, EngineServices{
|
||||
SignalClient: signalClient,
|
||||
MgmClient: mgmClient,
|
||||
RelayManager: relayManager,
|
||||
StatusRecorder: c.statusRecorder,
|
||||
Checks: checks,
|
||||
StateManager: stateManager,
|
||||
UpdateManager: c.updateManager,
|
||||
}, mobileDependency)
|
||||
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||
c.engine = engine
|
||||
c.engineMutex.Unlock()
|
||||
@@ -318,15 +327,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
|
||||
// AutoUpdate will be true when the user click on "Connect" menu on the UI
|
||||
if c.doInitialAutoUpdate {
|
||||
log.Infof("start engine by ui, run auto-update check")
|
||||
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
|
||||
c.doInitialAutoUpdate = false
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
@@ -77,7 +77,7 @@ func (d *Resolver) ID() types.HandlerID {
|
||||
return "local-resolver"
|
||||
}
|
||||
|
||||
func (d *Resolver) ProbeAvailability() {}
|
||||
func (d *Resolver) ProbeAvailability(context.Context) {}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
@@ -104,12 +104,16 @@ type DefaultServer struct {
|
||||
|
||||
statusRecorder *peer.Status
|
||||
stateManager *statemanager.Manager
|
||||
|
||||
probeMu sync.Mutex
|
||||
probeCancel context.CancelFunc
|
||||
probeWg sync.WaitGroup
|
||||
}
|
||||
|
||||
type handlerWithStop interface {
|
||||
dns.Handler
|
||||
Stop()
|
||||
ProbeAvailability()
|
||||
ProbeAvailability(context.Context)
|
||||
ID() types.HandlerID
|
||||
}
|
||||
|
||||
@@ -362,7 +366,13 @@ func (s *DefaultServer) DnsIP() netip.Addr {
|
||||
|
||||
// Stop stops the server
|
||||
func (s *DefaultServer) Stop() {
|
||||
s.probeMu.Lock()
|
||||
if s.probeCancel != nil {
|
||||
s.probeCancel()
|
||||
}
|
||||
s.ctxCancel()
|
||||
s.probeMu.Unlock()
|
||||
s.probeWg.Wait()
|
||||
s.shutdownWg.Wait()
|
||||
|
||||
s.mux.Lock()
|
||||
@@ -479,7 +489,8 @@ func (s *DefaultServer) SearchDomains() []string {
|
||||
}
|
||||
|
||||
// ProbeAvailability tests each upstream group's servers for availability
|
||||
// and deactivates the group if no server responds
|
||||
// and deactivates the group if no server responds.
|
||||
// If a previous probe is still running, it will be cancelled before starting a new one.
|
||||
func (s *DefaultServer) ProbeAvailability() {
|
||||
if val := os.Getenv(envSkipDNSProbe); val != "" {
|
||||
skipProbe, err := strconv.ParseBool(val)
|
||||
@@ -492,15 +503,52 @@ func (s *DefaultServer) ProbeAvailability() {
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, mux := range s.dnsMuxMap {
|
||||
wg.Add(1)
|
||||
go func(mux handlerWithStop) {
|
||||
defer wg.Done()
|
||||
mux.ProbeAvailability()
|
||||
}(mux.handler)
|
||||
s.probeMu.Lock()
|
||||
|
||||
// don't start probes on a stopped server
|
||||
if s.ctx.Err() != nil {
|
||||
s.probeMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// cancel any running probe
|
||||
if s.probeCancel != nil {
|
||||
s.probeCancel()
|
||||
s.probeCancel = nil
|
||||
}
|
||||
|
||||
// wait for the previous probe goroutines to finish while holding
|
||||
// the mutex so no other caller can start a new probe concurrently
|
||||
s.probeWg.Wait()
|
||||
|
||||
// start a new probe
|
||||
probeCtx, probeCancel := context.WithCancel(s.ctx)
|
||||
s.probeCancel = probeCancel
|
||||
|
||||
s.probeWg.Add(1)
|
||||
defer s.probeWg.Done()
|
||||
|
||||
// Snapshot handlers under s.mux to avoid racing with updateMux/dnsMuxMap writers.
|
||||
s.mux.Lock()
|
||||
handlers := make([]handlerWithStop, 0, len(s.dnsMuxMap))
|
||||
for _, mux := range s.dnsMuxMap {
|
||||
handlers = append(handlers, mux.handler)
|
||||
}
|
||||
s.mux.Unlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, handler := range handlers {
|
||||
wg.Add(1)
|
||||
go func(h handlerWithStop) {
|
||||
defer wg.Done()
|
||||
h.ProbeAvailability(probeCtx)
|
||||
}(handler)
|
||||
}
|
||||
|
||||
s.probeMu.Unlock()
|
||||
|
||||
wg.Wait()
|
||||
probeCancel()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
||||
|
||||
@@ -1065,7 +1065,7 @@ type mockHandler struct {
|
||||
|
||||
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||
func (m *mockHandler) Stop() {}
|
||||
func (m *mockHandler) ProbeAvailability() {}
|
||||
func (m *mockHandler) ProbeAvailability(context.Context) {}
|
||||
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
|
||||
|
||||
type mockService struct{}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -69,7 +70,7 @@ func (s *serviceViaListener) Listen() error {
|
||||
return fmt.Errorf("eval listen address: %w", err)
|
||||
}
|
||||
s.listenIP = s.listenIP.Unmap()
|
||||
s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
|
||||
s.server.Addr = net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
|
||||
log.Debugf("starting dns on %s", s.server.Addr)
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
@@ -186,7 +187,7 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) {
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool {
|
||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||
addrString := net.JoinHostPort(ip.String(), strconv.Itoa(port))
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
|
||||
@@ -65,6 +65,7 @@ type upstreamResolverBase struct {
|
||||
mutex sync.Mutex
|
||||
reactivatePeriod time.Duration
|
||||
upstreamTimeout time.Duration
|
||||
wg sync.WaitGroup
|
||||
|
||||
deactivate func(error)
|
||||
reactivate func()
|
||||
@@ -115,6 +116,11 @@ func (u *upstreamResolverBase) MatchSubdomains() bool {
|
||||
func (u *upstreamResolverBase) Stop() {
|
||||
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
||||
u.cancel()
|
||||
|
||||
u.mutex.Lock()
|
||||
u.wg.Wait()
|
||||
u.mutex.Unlock()
|
||||
|
||||
}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
@@ -260,16 +266,10 @@ func formatFailures(failures []upstreamFailure) string {
|
||||
|
||||
// ProbeAvailability tests all upstream servers simultaneously and
|
||||
// disables the resolver if none work
|
||||
func (u *upstreamResolverBase) ProbeAvailability() {
|
||||
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
|
||||
u.mutex.Lock()
|
||||
defer u.mutex.Unlock()
|
||||
|
||||
select {
|
||||
case <-u.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// avoid probe if upstreams could resolve at least one query
|
||||
if u.successCount.Load() > 0 {
|
||||
return
|
||||
@@ -279,31 +279,39 @@ func (u *upstreamResolverBase) ProbeAvailability() {
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
|
||||
var errors *multierror.Error
|
||||
var errs *multierror.Error
|
||||
for _, upstream := range u.upstreamServers {
|
||||
upstream := upstream
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
go func(upstream netip.AddrPort) {
|
||||
defer wg.Done()
|
||||
err := u.testNameserver(upstream, 500*time.Millisecond)
|
||||
err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond)
|
||||
if err != nil {
|
||||
errors = multierror.Append(errors, err)
|
||||
mu.Lock()
|
||||
errs = multierror.Append(errs, err)
|
||||
mu.Unlock()
|
||||
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
success = true
|
||||
}()
|
||||
mu.Unlock()
|
||||
}(upstream)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-u.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// didn't find a working upstream server, let's disable and try later
|
||||
if !success {
|
||||
u.disable(errors.ErrorOrNil())
|
||||
u.disable(errs.ErrorOrNil())
|
||||
|
||||
if u.statusRecorder == nil {
|
||||
return
|
||||
@@ -339,7 +347,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
||||
}
|
||||
|
||||
for _, upstream := range u.upstreamServers {
|
||||
if err := u.testNameserver(upstream, probeTimeout); err != nil {
|
||||
if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil {
|
||||
log.Tracef("upstream check for %s: %s", upstream, err)
|
||||
} else {
|
||||
// at least one upstream server is available, stop probing
|
||||
@@ -364,7 +372,9 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
||||
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
|
||||
u.successCount.Add(1)
|
||||
u.reactivate()
|
||||
u.mutex.Lock()
|
||||
u.disabled = false
|
||||
u.mutex.Unlock()
|
||||
}
|
||||
|
||||
// isTimeout returns true if the given error is a network timeout error.
|
||||
@@ -387,7 +397,11 @@ func (u *upstreamResolverBase) disable(err error) {
|
||||
u.successCount.Store(0)
|
||||
u.deactivate(err)
|
||||
u.disabled = true
|
||||
go u.waitUntilResponse()
|
||||
u.wg.Add(1)
|
||||
go func() {
|
||||
defer u.wg.Done()
|
||||
u.waitUntilResponse()
|
||||
}()
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) upstreamServersString() string {
|
||||
@@ -398,13 +412,18 @@ func (u *upstreamResolverBase) upstreamServersString() string {
|
||||
return strings.Join(servers, ", ")
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout time.Duration) error {
|
||||
ctx, cancel := context.WithTimeout(u.ctx, timeout)
|
||||
func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error {
|
||||
mergedCtx, cancel := context.WithTimeout(baseCtx, timeout)
|
||||
defer cancel()
|
||||
|
||||
if externalCtx != nil {
|
||||
stop2 := context.AfterFunc(externalCtx, cancel)
|
||||
defer stop2()
|
||||
}
|
||||
|
||||
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
|
||||
|
||||
_, _, err := u.upstreamClient.exchange(ctx, server.String(), r)
|
||||
_, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -188,7 +188,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
reactivated = true
|
||||
}
|
||||
|
||||
resolver.ProbeAvailability()
|
||||
resolver.ProbeAvailability(context.TODO())
|
||||
|
||||
if !failed {
|
||||
t.Errorf("expected that resolving was deactivated")
|
||||
|
||||
@@ -51,7 +51,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updater"
|
||||
"github.com/netbirdio/netbird/client/jobexec"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -79,7 +79,6 @@ const (
|
||||
|
||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||
|
||||
// EngineConfig is a config for the Engine
|
||||
type EngineConfig struct {
|
||||
WgPort int
|
||||
WgIfaceName string
|
||||
@@ -141,6 +140,17 @@ type EngineConfig struct {
|
||||
LogPath string
|
||||
}
|
||||
|
||||
// EngineServices holds the external service dependencies required by the Engine.
|
||||
type EngineServices struct {
|
||||
SignalClient signal.Client
|
||||
MgmClient mgm.Client
|
||||
RelayManager *relayClient.Manager
|
||||
StatusRecorder *peer.Status
|
||||
Checks []*mgmProto.Checks
|
||||
StateManager *statemanager.Manager
|
||||
UpdateManager *updater.Manager
|
||||
}
|
||||
|
||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||
type Engine struct {
|
||||
// signal is a Signal Service client
|
||||
@@ -209,7 +219,7 @@ type Engine struct {
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// auto-update
|
||||
updateManager *updatemanager.Manager
|
||||
updateManager *updater.Manager
|
||||
|
||||
// WireGuard interface monitor
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
@@ -239,22 +249,17 @@ type localIpUpdater interface {
|
||||
func NewEngine(
|
||||
clientCtx context.Context,
|
||||
clientCancel context.CancelFunc,
|
||||
signalClient signal.Client,
|
||||
mgmClient mgm.Client,
|
||||
relayManager *relayClient.Manager,
|
||||
config *EngineConfig,
|
||||
services EngineServices,
|
||||
mobileDep MobileDependency,
|
||||
statusRecorder *peer.Status,
|
||||
checks []*mgmProto.Checks,
|
||||
stateManager *statemanager.Manager,
|
||||
) *Engine {
|
||||
engine := &Engine{
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
signal: signalClient,
|
||||
signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
|
||||
mgmClient: mgmClient,
|
||||
relayManager: relayManager,
|
||||
signal: services.SignalClient,
|
||||
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
||||
mgmClient: services.MgmClient,
|
||||
relayManager: services.RelayManager,
|
||||
peerStore: peerstore.NewConnStore(),
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
config: config,
|
||||
@@ -262,11 +267,12 @@ func NewEngine(
|
||||
STUNs: []*stun.URI{},
|
||||
TURNs: []*stun.URI{},
|
||||
networkSerial: 0,
|
||||
statusRecorder: statusRecorder,
|
||||
stateManager: stateManager,
|
||||
checks: checks,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
stateManager: services.StateManager,
|
||||
checks: services.Checks,
|
||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||
jobExecutor: jobexec.NewExecutor(),
|
||||
updateManager: services.UpdateManager,
|
||||
}
|
||||
|
||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||
@@ -309,7 +315,7 @@ func (e *Engine) Stop() error {
|
||||
}
|
||||
|
||||
if e.updateManager != nil {
|
||||
e.updateManager.Stop()
|
||||
e.updateManager.SetDownloadOnly()
|
||||
}
|
||||
|
||||
log.Info("cleaning up status recorder states")
|
||||
@@ -559,13 +565,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
e.handleAutoUpdateVersion(autoUpdateSettings, true)
|
||||
}
|
||||
|
||||
func (e *Engine) createFirewall() error {
|
||||
if e.config.DisableFirewall {
|
||||
log.Infof("firewall is disabled")
|
||||
@@ -793,39 +792,22 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) {
|
||||
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
|
||||
if e.updateManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if autoUpdateSettings == nil {
|
||||
return
|
||||
}
|
||||
|
||||
disabled := autoUpdateSettings.Version == disableAutoUpdate
|
||||
|
||||
// stop and cleanup if disabled
|
||||
if e.updateManager != nil && disabled {
|
||||
log.Infof("auto-update is disabled, stopping update manager")
|
||||
e.updateManager.Stop()
|
||||
e.updateManager = nil
|
||||
if autoUpdateSettings.Version == disableAutoUpdate {
|
||||
log.Infof("auto-update is disabled")
|
||||
e.updateManager.SetDownloadOnly()
|
||||
return
|
||||
}
|
||||
|
||||
// Skip check unless AlwaysUpdate is enabled or this is the initial check at startup
|
||||
if !autoUpdateSettings.AlwaysUpdate && !initialCheck {
|
||||
log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check")
|
||||
return
|
||||
}
|
||||
|
||||
// Start manager if needed
|
||||
if e.updateManager == nil {
|
||||
log.Infof("starting auto-update manager")
|
||||
updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
e.updateManager = updateManager
|
||||
e.updateManager.Start(e.ctx)
|
||||
}
|
||||
log.Infof("handling auto-update version: %s", autoUpdateSettings.Version)
|
||||
e.updateManager.SetVersion(autoUpdateSettings.Version)
|
||||
e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate)
|
||||
}
|
||||
|
||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
@@ -842,7 +824,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
}
|
||||
|
||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false)
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
||||
}
|
||||
|
||||
if update.GetNetbirdConfig() != nil {
|
||||
@@ -1315,8 +1297,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
|
||||
// Test received (upstream) servers for availability right away instead of upon usage.
|
||||
// If no server of a server group responds this will disable the respective handler and retry later.
|
||||
e.dnsServer.ProbeAvailability()
|
||||
|
||||
go e.dnsServer.ProbeAvailability()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -251,9 +251,6 @@ func TestEngine_SSH(t *testing.T) {
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(
|
||||
ctx, cancel,
|
||||
&signal.MockClient{},
|
||||
&mgmt.MockClient{},
|
||||
relayMgr,
|
||||
&EngineConfig{
|
||||
WgIfaceName: "utun101",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
@@ -263,10 +260,13 @@ func TestEngine_SSH(t *testing.T) {
|
||||
MTU: iface.DefaultMTU,
|
||||
SSHKey: sshKey,
|
||||
},
|
||||
EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
},
|
||||
MobileDependency{},
|
||||
peer.NewRecorder("https://mgm"),
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
@@ -428,13 +428,18 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
WgIfaceName: "utun102",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{})
|
||||
|
||||
wgIface := &MockWGIface{
|
||||
NameFunc: func() string { return "utun102" },
|
||||
@@ -647,13 +652,18 @@ func TestEngine_Sync(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
WgIfaceName: "utun103",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{})
|
||||
engine.ctx = ctx
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
@@ -812,13 +822,18 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
WgIfaceName: wgIfaceName,
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{})
|
||||
engine.ctx = ctx
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
@@ -1014,13 +1029,18 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
WgIfaceName: wgIfaceName,
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{})
|
||||
engine.ctx = ctx
|
||||
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
@@ -1546,7 +1566,12 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil
|
||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||
SignalClient: signalClient,
|
||||
MgmClient: mgmtClient,
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{}), nil
|
||||
e.ctx = ctx
|
||||
return e, err
|
||||
}
|
||||
|
||||
@@ -12,9 +12,10 @@ const renewTimeout = 10 * time.Second
|
||||
|
||||
// Response holds the response from exposing a service.
|
||||
type Response struct {
|
||||
ServiceName string
|
||||
ServiceURL string
|
||||
Domain string
|
||||
ServiceName string
|
||||
ServiceURL string
|
||||
Domain string
|
||||
PortAutoAssigned bool
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
@@ -25,6 +26,7 @@ type Request struct {
|
||||
Pin string
|
||||
Password string
|
||||
UserGroups []string
|
||||
ListenPort uint16
|
||||
}
|
||||
|
||||
type ManagementClient interface {
|
||||
|
||||
@@ -15,6 +15,7 @@ func NewRequest(req *daemonProto.ExposeServiceRequest) *Request {
|
||||
UserGroups: req.UserGroups,
|
||||
Domain: req.Domain,
|
||||
NamePrefix: req.NamePrefix,
|
||||
ListenPort: uint16(req.ListenPort),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,13 +28,15 @@ func toClientExposeRequest(req Request) mgm.ExposeRequest {
|
||||
Pin: req.Pin,
|
||||
Password: req.Password,
|
||||
UserGroups: req.UserGroups,
|
||||
ListenPort: req.ListenPort,
|
||||
}
|
||||
}
|
||||
|
||||
func fromClientExposeResponse(response *mgm.ExposeResponse) *Response {
|
||||
return &Response{
|
||||
ServiceName: response.ServiceName,
|
||||
Domain: response.Domain,
|
||||
ServiceURL: response.ServiceURL,
|
||||
ServiceName: response.ServiceName,
|
||||
Domain: response.Domain,
|
||||
ServiceURL: response.ServiceURL,
|
||||
PortAutoAssigned: response.PortAutoAssigned,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,9 @@ package client
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -564,7 +566,7 @@ func HandlerFromRoute(params common.HandlerParams) RouteHandler {
|
||||
return dnsinterceptor.New(params)
|
||||
case handlerTypeDynamic:
|
||||
dns := nbdns.NewServiceViaMemory(params.WgInterface)
|
||||
dnsAddr := fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())
|
||||
dnsAddr := net.JoinHostPort(dns.RuntimeIP().String(), strconv.Itoa(dns.RuntimePort()))
|
||||
return dynamic.NewRoute(params, dnsAddr)
|
||||
default:
|
||||
return static.NewRoute(params)
|
||||
|
||||
@@ -4,8 +4,10 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -249,7 +251,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), uint16(d.forwarderPort.Load()))
|
||||
upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
|
||||
@@ -1,214 +0,0 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package updatemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
v "github.com/hashicorp/go-version"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
type versionUpdateMock struct {
|
||||
latestVersion *v.Version
|
||||
onUpdate func()
|
||||
}
|
||||
|
||||
func (v versionUpdateMock) StopWatch() {}
|
||||
|
||||
func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) {
|
||||
v.onUpdate = updateFn
|
||||
}
|
||||
|
||||
func (v versionUpdateMock) LatestVersion() *v.Version {
|
||||
return v.latestVersion
|
||||
}
|
||||
|
||||
func (v versionUpdateMock) StartFetcher() {}
|
||||
|
||||
func Test_LatestVersion(t *testing.T) {
|
||||
testMatrix := []struct {
|
||||
name string
|
||||
daemonVersion string
|
||||
initialLatestVersion *v.Version
|
||||
latestVersion *v.Version
|
||||
shouldUpdateInit bool
|
||||
shouldUpdateLater bool
|
||||
}{
|
||||
{
|
||||
name: "Should only trigger update once due to time between triggers being < 5 Minutes",
|
||||
daemonVersion: "1.0.0",
|
||||
initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
|
||||
latestVersion: v.Must(v.NewSemver("1.0.2")),
|
||||
shouldUpdateInit: true,
|
||||
shouldUpdateLater: false,
|
||||
},
|
||||
{
|
||||
name: "Shouldn't update initially, but should update as soon as latest version is fetched",
|
||||
daemonVersion: "1.0.0",
|
||||
initialLatestVersion: nil,
|
||||
latestVersion: v.Must(v.NewSemver("1.0.1")),
|
||||
shouldUpdateInit: false,
|
||||
shouldUpdateLater: true,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, c := range testMatrix {
|
||||
mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
|
||||
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
|
||||
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
|
||||
m.update = mockUpdate
|
||||
|
||||
targetVersionChan := make(chan string, 1)
|
||||
|
||||
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
|
||||
targetVersionChan <- targetVersion
|
||||
return nil
|
||||
}
|
||||
m.currentVersion = c.daemonVersion
|
||||
m.Start(context.Background())
|
||||
m.SetVersion("latest")
|
||||
var triggeredInit bool
|
||||
select {
|
||||
case targetVersion := <-targetVersionChan:
|
||||
if targetVersion != c.initialLatestVersion.String() {
|
||||
t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion)
|
||||
}
|
||||
triggeredInit = true
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
triggeredInit = false
|
||||
}
|
||||
if triggeredInit != c.shouldUpdateInit {
|
||||
t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
|
||||
}
|
||||
|
||||
mockUpdate.latestVersion = c.latestVersion
|
||||
mockUpdate.onUpdate()
|
||||
|
||||
var triggeredLater bool
|
||||
select {
|
||||
case targetVersion := <-targetVersionChan:
|
||||
if targetVersion != c.latestVersion.String() {
|
||||
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
|
||||
}
|
||||
triggeredLater = true
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
triggeredLater = false
|
||||
}
|
||||
if triggeredLater != c.shouldUpdateLater {
|
||||
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
|
||||
}
|
||||
|
||||
m.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleUpdate(t *testing.T) {
|
||||
testMatrix := []struct {
|
||||
name string
|
||||
daemonVersion string
|
||||
latestVersion *v.Version
|
||||
expectedVersion string
|
||||
shouldUpdate bool
|
||||
}{
|
||||
{
|
||||
name: "Update to a specific version should update regardless of if latestVersion is available yet",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "0.56.0",
|
||||
shouldUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "Update to specific version should not update if version matches",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "0.55.0",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Update to specific version should not update if current version is newer",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "0.54.0",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Update to latest version should update if latest is newer",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: v.Must(v.NewSemver("0.56.0")),
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "Update to latest version should not update if latest == current",
|
||||
daemonVersion: "0.56.0",
|
||||
latestVersion: v.Must(v.NewSemver("0.56.0")),
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Should not update if daemon version is invalid",
|
||||
daemonVersion: "development",
|
||||
latestVersion: v.Must(v.NewSemver("1.0.0")),
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Should not update if expecting latest and latest version is unavailable",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Should not update if expected version is invalid",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "development",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
}
|
||||
for idx, c := range testMatrix {
|
||||
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
|
||||
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
|
||||
m.update = &versionUpdateMock{latestVersion: c.latestVersion}
|
||||
targetVersionChan := make(chan string, 1)
|
||||
|
||||
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
|
||||
targetVersionChan <- targetVersion
|
||||
return nil
|
||||
}
|
||||
|
||||
m.currentVersion = c.daemonVersion
|
||||
m.Start(context.Background())
|
||||
m.SetVersion(c.expectedVersion)
|
||||
|
||||
var updateTriggered bool
|
||||
select {
|
||||
case targetVersion := <-targetVersionChan:
|
||||
if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() {
|
||||
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
|
||||
} else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion {
|
||||
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion)
|
||||
}
|
||||
updateTriggered = true
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
updateTriggered = false
|
||||
}
|
||||
|
||||
if updateTriggered != c.shouldUpdate {
|
||||
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered)
|
||||
}
|
||||
m.Stop()
|
||||
}
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
//go:build !windows && !darwin
|
||||
|
||||
package updatemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// Manager is a no-op stub for unsupported platforms
|
||||
type Manager struct{}
|
||||
|
||||
// NewManager returns a no-op manager for unsupported platforms
|
||||
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
|
||||
return nil, fmt.Errorf("update manager is not supported on this platform")
|
||||
}
|
||||
|
||||
// CheckUpdateSuccess is a no-op on unsupported platforms
|
||||
func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
// Start is a no-op on unsupported platforms
|
||||
func (m *Manager) Start(ctx context.Context) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
// SetVersion is a no-op on unsupported platforms
|
||||
func (m *Manager) SetVersion(expectedVersion string) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
// Stop is a no-op on unsupported platforms
|
||||
func (m *Manager) Stop() {
|
||||
// no-op
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
// Package updatemanager provides automatic update management for the NetBird client.
|
||||
// Package updater provides automatic update management for the NetBird client.
|
||||
// It monitors for new versions, handles update triggers from management server directives,
|
||||
// and orchestrates the download and installation of client updates.
|
||||
//
|
||||
@@ -32,4 +32,4 @@
|
||||
//
|
||||
// This enables verification of successful updates and appropriate user notification
|
||||
// after the client restarts with the new version.
|
||||
package updatemanager
|
||||
package updater
|
||||
@@ -16,8 +16,8 @@ import (
|
||||
goversion "github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/downloader"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||
)
|
||||
|
||||
type Installer struct {
|
||||
@@ -203,7 +203,10 @@ func (rh *ResultHandler) write(result Result) error {
|
||||
|
||||
func (rh *ResultHandler) cleanup() error {
|
||||
err := os.Remove(rh.resultFile)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
log.Debugf("delete installer result file: %s", rh.resultFile)
|
||||
@@ -1,12 +1,9 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package updatemanager
|
||||
package updater
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -15,7 +12,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
@@ -41,6 +38,9 @@ type Manager struct {
|
||||
statusRecorder *peer.Status
|
||||
stateManager *statemanager.Manager
|
||||
|
||||
downloadOnly bool // true when no enforcement from management; notifies UI to download latest
|
||||
forceUpdate bool // true when management sets AlwaysUpdate; skips UI interaction and installs directly
|
||||
|
||||
lastTrigger time.Time
|
||||
mgmUpdateChan chan struct{}
|
||||
updateChannel chan struct{}
|
||||
@@ -53,37 +53,38 @@ type Manager struct {
|
||||
expectedVersion *v.Version
|
||||
updateToLatestVersion bool
|
||||
|
||||
// updateMutex protect update and expectedVersion fields
|
||||
pendingVersion *v.Version
|
||||
|
||||
// updateMutex protects update, expectedVersion, updateToLatestVersion,
|
||||
// downloadOnly, forceUpdate, pendingVersion, and lastTrigger fields
|
||||
updateMutex sync.Mutex
|
||||
|
||||
triggerUpdateFn func(context.Context, string) error
|
||||
// installMutex and installing guard against concurrent installation attempts
|
||||
installMutex sync.Mutex
|
||||
installing bool
|
||||
|
||||
// protect to start the service multiple times
|
||||
mu sync.Mutex
|
||||
|
||||
autoUpdateSupported func() bool
|
||||
}
|
||||
|
||||
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
|
||||
if runtime.GOOS == "darwin" {
|
||||
isBrew := !installer.TypeOfInstaller(context.Background()).Downloadable()
|
||||
if isBrew {
|
||||
log.Warnf("auto-update disabled on Home Brew installation")
|
||||
return nil, fmt.Errorf("auto-update not supported on Home Brew installation yet")
|
||||
}
|
||||
}
|
||||
return newManager(statusRecorder, stateManager)
|
||||
}
|
||||
|
||||
func newManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
|
||||
// NewManager creates a new update manager. The manager is single-use: once Stop() is called, it cannot be restarted.
|
||||
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) *Manager {
|
||||
manager := &Manager{
|
||||
statusRecorder: statusRecorder,
|
||||
stateManager: stateManager,
|
||||
mgmUpdateChan: make(chan struct{}, 1),
|
||||
updateChannel: make(chan struct{}, 1),
|
||||
currentVersion: version.NetbirdVersion(),
|
||||
update: version.NewUpdate("nb/client"),
|
||||
statusRecorder: statusRecorder,
|
||||
stateManager: stateManager,
|
||||
mgmUpdateChan: make(chan struct{}, 1),
|
||||
updateChannel: make(chan struct{}, 1),
|
||||
currentVersion: version.NetbirdVersion(),
|
||||
update: version.NewUpdate("nb/client"),
|
||||
downloadOnly: true,
|
||||
autoUpdateSupported: isAutoUpdateSupported,
|
||||
}
|
||||
manager.triggerUpdateFn = manager.triggerUpdate
|
||||
|
||||
stateManager.RegisterState(&UpdateState{})
|
||||
|
||||
return manager, nil
|
||||
return manager
|
||||
}
|
||||
|
||||
// CheckUpdateSuccess checks if the update was successful and send a notification.
|
||||
@@ -124,8 +125,10 @@ func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (m *Manager) Start(ctx context.Context) {
|
||||
log.Infof("starting update manager")
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.cancel != nil {
|
||||
log.Errorf("Manager already started")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -142,13 +145,32 @@ func (m *Manager) Start(ctx context.Context) {
|
||||
m.cancel = cancel
|
||||
|
||||
m.wg.Add(1)
|
||||
go m.updateLoop(ctx)
|
||||
go func() {
|
||||
defer m.wg.Done()
|
||||
m.updateLoop(ctx)
|
||||
}()
|
||||
}
|
||||
|
||||
func (m *Manager) SetVersion(expectedVersion string) {
|
||||
log.Infof("set expected agent version for upgrade: %s", expectedVersion)
|
||||
if m.cancel == nil {
|
||||
log.Errorf("manager not started")
|
||||
func (m *Manager) SetDownloadOnly() {
|
||||
m.updateMutex.Lock()
|
||||
m.downloadOnly = true
|
||||
m.forceUpdate = false
|
||||
m.expectedVersion = nil
|
||||
m.updateToLatestVersion = false
|
||||
m.lastTrigger = time.Time{}
|
||||
m.updateMutex.Unlock()
|
||||
|
||||
select {
|
||||
case m.mgmUpdateChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) SetVersion(expectedVersion string, forceUpdate bool) {
|
||||
log.Infof("expected version changed to %s, force update: %t", expectedVersion, forceUpdate)
|
||||
|
||||
if !m.autoUpdateSupported() {
|
||||
log.Warnf("auto-update not supported on this platform")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -159,6 +181,7 @@ func (m *Manager) SetVersion(expectedVersion string) {
|
||||
log.Errorf("empty expected version provided")
|
||||
m.expectedVersion = nil
|
||||
m.updateToLatestVersion = false
|
||||
m.downloadOnly = true
|
||||
return
|
||||
}
|
||||
|
||||
@@ -178,12 +201,97 @@ func (m *Manager) SetVersion(expectedVersion string) {
|
||||
m.updateToLatestVersion = false
|
||||
}
|
||||
|
||||
m.lastTrigger = time.Time{}
|
||||
m.downloadOnly = false
|
||||
m.forceUpdate = forceUpdate
|
||||
|
||||
select {
|
||||
case m.mgmUpdateChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Install triggers the installation of the pending version. It is called when the user clicks the install button in the UI.
|
||||
func (m *Manager) Install(ctx context.Context) error {
|
||||
if !m.autoUpdateSupported() {
|
||||
return fmt.Errorf("auto-update not supported on this platform")
|
||||
}
|
||||
|
||||
m.updateMutex.Lock()
|
||||
pending := m.pendingVersion
|
||||
m.updateMutex.Unlock()
|
||||
|
||||
if pending == nil {
|
||||
return fmt.Errorf("no pending version to install")
|
||||
}
|
||||
|
||||
return m.tryInstall(ctx, pending)
|
||||
}
|
||||
|
||||
// tryInstall ensures only one installation runs at a time. Concurrent callers
|
||||
// receive an error immediately rather than queuing behind a running install.
|
||||
func (m *Manager) tryInstall(ctx context.Context, targetVersion *v.Version) error {
|
||||
m.installMutex.Lock()
|
||||
if m.installing {
|
||||
m.installMutex.Unlock()
|
||||
return fmt.Errorf("installation already in progress")
|
||||
}
|
||||
m.installing = true
|
||||
m.installMutex.Unlock()
|
||||
|
||||
defer func() {
|
||||
m.installMutex.Lock()
|
||||
m.installing = false
|
||||
m.installMutex.Unlock()
|
||||
}()
|
||||
|
||||
return m.install(ctx, targetVersion)
|
||||
}
|
||||
|
||||
// NotifyUI re-publishes the current update state to a newly connected UI client.
|
||||
// Only needed for download-only mode where the latest version is already cached
|
||||
// NotifyUI re-publishes the current update state so a newly connected UI gets the info.
|
||||
func (m *Manager) NotifyUI() {
|
||||
m.updateMutex.Lock()
|
||||
if m.update == nil {
|
||||
m.updateMutex.Unlock()
|
||||
return
|
||||
}
|
||||
downloadOnly := m.downloadOnly
|
||||
pendingVersion := m.pendingVersion
|
||||
latestVersion := m.update.LatestVersion()
|
||||
m.updateMutex.Unlock()
|
||||
|
||||
if downloadOnly {
|
||||
if latestVersion == nil {
|
||||
return
|
||||
}
|
||||
currentVersion, err := v.NewVersion(m.currentVersion)
|
||||
if err != nil || currentVersion.GreaterThanOrEqual(latestVersion) {
|
||||
return
|
||||
}
|
||||
m.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_INFO,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"New version available",
|
||||
"",
|
||||
map[string]string{"new_version_available": latestVersion.String()},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if pendingVersion != nil {
|
||||
m.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_INFO,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"New version available",
|
||||
"",
|
||||
map[string]string{"new_version_available": pendingVersion.String(), "enforced": "true"},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop is not used at the moment because it fully depends on the daemon. In a future refactor it may make sense to use it.
|
||||
func (m *Manager) Stop() {
|
||||
if m.cancel == nil {
|
||||
return
|
||||
@@ -214,8 +322,6 @@ func (m *Manager) onContextCancel() {
|
||||
}
|
||||
|
||||
func (m *Manager) updateLoop(ctx context.Context) {
|
||||
defer m.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -239,55 +345,89 @@ func (m *Manager) handleUpdate(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
expectedVersion := m.expectedVersion
|
||||
useLatest := m.updateToLatestVersion
|
||||
downloadOnly := m.downloadOnly
|
||||
forceUpdate := m.forceUpdate
|
||||
curLatestVersion := m.update.LatestVersion()
|
||||
m.updateMutex.Unlock()
|
||||
|
||||
switch {
|
||||
// Resolve "latest" to actual version
|
||||
case useLatest:
|
||||
// Download-only mode or resolve "latest" to actual version
|
||||
case downloadOnly, m.updateToLatestVersion:
|
||||
if curLatestVersion == nil {
|
||||
log.Tracef("latest version not fetched yet")
|
||||
m.updateMutex.Unlock()
|
||||
return
|
||||
}
|
||||
updateVersion = curLatestVersion
|
||||
// Update to specific version
|
||||
case expectedVersion != nil:
|
||||
updateVersion = expectedVersion
|
||||
// Install to specific version
|
||||
case m.expectedVersion != nil:
|
||||
updateVersion = m.expectedVersion
|
||||
default:
|
||||
log.Debugf("no expected version information set")
|
||||
m.updateMutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("checking update option, current version: %s, target version: %s", m.currentVersion, updateVersion)
|
||||
if !m.shouldUpdate(updateVersion) {
|
||||
if !m.shouldUpdate(updateVersion, forceUpdate) {
|
||||
m.updateMutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
m.lastTrigger = time.Now()
|
||||
log.Infof("Auto-update triggered, current version: %s, target version: %s", m.currentVersion, updateVersion)
|
||||
m.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_CRITICAL,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"Automatically updating client",
|
||||
"Your client version is older than auto-update version set in Management, updating client now.",
|
||||
nil,
|
||||
)
|
||||
log.Infof("new version available: %s", updateVersion)
|
||||
|
||||
if !downloadOnly && !forceUpdate {
|
||||
m.pendingVersion = updateVersion
|
||||
}
|
||||
m.updateMutex.Unlock()
|
||||
|
||||
if downloadOnly {
|
||||
m.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_INFO,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"New version available",
|
||||
"",
|
||||
map[string]string{"new_version_available": updateVersion.String()},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if forceUpdate {
|
||||
if err := m.tryInstall(ctx, updateVersion); err != nil {
|
||||
log.Errorf("force update failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
m.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_INFO,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"New version available",
|
||||
"",
|
||||
map[string]string{"new_version_available": updateVersion.String(), "enforced": "true"},
|
||||
)
|
||||
}
|
||||
|
||||
func (m *Manager) install(ctx context.Context, pendingVersion *v.Version) error {
|
||||
m.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_CRITICAL,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"Updating client",
|
||||
"Installing update now.",
|
||||
nil,
|
||||
)
|
||||
m.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_CRITICAL,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"",
|
||||
"",
|
||||
map[string]string{"progress_window": "show", "version": updateVersion.String()},
|
||||
map[string]string{"progress_window": "show", "version": pendingVersion.String()},
|
||||
)
|
||||
|
||||
updateState := UpdateState{
|
||||
PreUpdateVersion: m.currentVersion,
|
||||
TargetVersion: updateVersion.String(),
|
||||
TargetVersion: pendingVersion.String(),
|
||||
}
|
||||
|
||||
if err := m.stateManager.UpdateState(updateState); err != nil {
|
||||
log.Warnf("failed to update state: %v", err)
|
||||
} else {
|
||||
@@ -296,8 +436,9 @@ func (m *Manager) handleUpdate(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.triggerUpdateFn(ctx, updateVersion.String()); err != nil {
|
||||
log.Errorf("Error triggering auto-update: %v", err)
|
||||
inst := installer.New()
|
||||
if err := inst.RunInstallation(ctx, pendingVersion.String()); err != nil {
|
||||
log.Errorf("error triggering update: %v", err)
|
||||
m.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_ERROR,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
@@ -305,7 +446,9 @@ func (m *Manager) handleUpdate(ctx context.Context) {
|
||||
fmt.Sprintf("Auto-update failed: %v", err),
|
||||
nil,
|
||||
)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadAndDeleteUpdateState loads the update state, deletes it from storage, and returns it.
|
||||
@@ -339,7 +482,7 @@ func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, e
|
||||
return updateState, nil
|
||||
}
|
||||
|
||||
func (m *Manager) shouldUpdate(updateVersion *v.Version) bool {
|
||||
func (m *Manager) shouldUpdate(updateVersion *v.Version, forceUpdate bool) bool {
|
||||
if m.currentVersion == developmentVersion {
|
||||
log.Debugf("skipping auto-update, running development version")
|
||||
return false
|
||||
@@ -354,8 +497,8 @@ func (m *Manager) shouldUpdate(updateVersion *v.Version) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if time.Since(m.lastTrigger) < 5*time.Minute {
|
||||
log.Debugf("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger))
|
||||
if forceUpdate && time.Since(m.lastTrigger) < 3*time.Minute {
|
||||
log.Infof("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger))
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -367,8 +510,3 @@ func (m *Manager) lastResultErrReason() string {
|
||||
result := installer.NewResultHandler(inst.TempDir())
|
||||
return result.GetErrorResultReason()
|
||||
}
|
||||
|
||||
func (m *Manager) triggerUpdate(ctx context.Context, targetVersion string) error {
|
||||
inst := installer.New()
|
||||
return inst.RunInstallation(ctx, targetVersion)
|
||||
}
|
||||
111
client/internal/updater/manager_linux_test.go
Normal file
111
client/internal/updater/manager_linux_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
//go:build !windows && !darwin
|
||||
|
||||
package updater
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
v "github.com/hashicorp/go-version"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// On Linux, only Mode 1 (downloadOnly) is supported.
|
||||
// SetVersion is a no-op because auto-update installation is not supported.
|
||||
|
||||
func Test_LatestVersion_Linux(t *testing.T) {
|
||||
testMatrix := []struct {
|
||||
name string
|
||||
daemonVersion string
|
||||
initialLatestVersion *v.Version
|
||||
latestVersion *v.Version
|
||||
shouldUpdateInit bool
|
||||
shouldUpdateLater bool
|
||||
}{
|
||||
{
|
||||
name: "Should notify again when a newer version arrives even within 5 minutes",
|
||||
daemonVersion: "1.0.0",
|
||||
initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
|
||||
latestVersion: v.Must(v.NewSemver("1.0.2")),
|
||||
shouldUpdateInit: true,
|
||||
shouldUpdateLater: true,
|
||||
},
|
||||
{
|
||||
name: "Shouldn't notify initially, but should notify as soon as latest version is fetched",
|
||||
daemonVersion: "1.0.0",
|
||||
initialLatestVersion: nil,
|
||||
latestVersion: v.Must(v.NewSemver("1.0.1")),
|
||||
shouldUpdateInit: false,
|
||||
shouldUpdateLater: true,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, c := range testMatrix {
|
||||
mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
|
||||
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
|
||||
recorder := peer.NewRecorder("")
|
||||
sub := recorder.SubscribeToEvents()
|
||||
defer recorder.UnsubscribeFromEvents(sub)
|
||||
|
||||
m := NewManager(recorder, statemanager.New(tmpFile))
|
||||
m.update = mockUpdate
|
||||
m.currentVersion = c.daemonVersion
|
||||
m.Start(context.Background())
|
||||
m.SetDownloadOnly()
|
||||
|
||||
ver, enforced := waitForUpdateEvent(sub, 500*time.Millisecond)
|
||||
triggeredInit := ver != ""
|
||||
if enforced {
|
||||
t.Errorf("%s: Linux Mode 1 must never have enforced metadata", c.name)
|
||||
}
|
||||
if triggeredInit != c.shouldUpdateInit {
|
||||
t.Errorf("%s: Initial notify mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
|
||||
}
|
||||
if triggeredInit && c.initialLatestVersion != nil && ver != c.initialLatestVersion.String() {
|
||||
t.Errorf("%s: Initial version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), ver)
|
||||
}
|
||||
|
||||
mockUpdate.latestVersion = c.latestVersion
|
||||
mockUpdate.onUpdate()
|
||||
|
||||
ver, enforced = waitForUpdateEvent(sub, 500*time.Millisecond)
|
||||
triggeredLater := ver != ""
|
||||
if enforced {
|
||||
t.Errorf("%s: Linux Mode 1 must never have enforced metadata", c.name)
|
||||
}
|
||||
if triggeredLater != c.shouldUpdateLater {
|
||||
t.Errorf("%s: Later notify mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
|
||||
}
|
||||
if triggeredLater && c.latestVersion != nil && ver != c.latestVersion.String() {
|
||||
t.Errorf("%s: Later version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), ver)
|
||||
}
|
||||
|
||||
m.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SetVersion_NoOp_Linux(t *testing.T) {
|
||||
// On Linux, SetVersion should be a no-op — no events fired
|
||||
tmpFile := path.Join(t.TempDir(), "update-test-noop.json")
|
||||
recorder := peer.NewRecorder("")
|
||||
sub := recorder.SubscribeToEvents()
|
||||
defer recorder.UnsubscribeFromEvents(sub)
|
||||
|
||||
m := NewManager(recorder, statemanager.New(tmpFile))
|
||||
m.update = &versionUpdateMock{latestVersion: v.Must(v.NewSemver("1.0.1"))}
|
||||
m.currentVersion = "1.0.0"
|
||||
m.Start(context.Background())
|
||||
m.SetVersion("1.0.1", false)
|
||||
|
||||
ver, _ := waitForUpdateEvent(sub, 500*time.Millisecond)
|
||||
if ver != "" {
|
||||
t.Errorf("SetVersion should be a no-op on Linux, but got event with version %s", ver)
|
||||
}
|
||||
|
||||
m.Stop()
|
||||
}
|
||||
227
client/internal/updater/manager_test.go
Normal file
227
client/internal/updater/manager_test.go
Normal file
@@ -0,0 +1,227 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package updater
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
v "github.com/hashicorp/go-version"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
func Test_LatestVersion(t *testing.T) {
|
||||
testMatrix := []struct {
|
||||
name string
|
||||
daemonVersion string
|
||||
initialLatestVersion *v.Version
|
||||
latestVersion *v.Version
|
||||
shouldUpdateInit bool
|
||||
shouldUpdateLater bool
|
||||
}{
|
||||
{
|
||||
name: "Should notify again when a newer version arrives even within 5 minutes",
|
||||
daemonVersion: "1.0.0",
|
||||
initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
|
||||
latestVersion: v.Must(v.NewSemver("1.0.2")),
|
||||
shouldUpdateInit: true,
|
||||
shouldUpdateLater: true,
|
||||
},
|
||||
{
|
||||
name: "Shouldn't update initially, but should update as soon as latest version is fetched",
|
||||
daemonVersion: "1.0.0",
|
||||
initialLatestVersion: nil,
|
||||
latestVersion: v.Must(v.NewSemver("1.0.1")),
|
||||
shouldUpdateInit: false,
|
||||
shouldUpdateLater: true,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, c := range testMatrix {
|
||||
mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
|
||||
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
|
||||
recorder := peer.NewRecorder("")
|
||||
sub := recorder.SubscribeToEvents()
|
||||
defer recorder.UnsubscribeFromEvents(sub)
|
||||
|
||||
m := NewManager(recorder, statemanager.New(tmpFile))
|
||||
m.update = mockUpdate
|
||||
m.currentVersion = c.daemonVersion
|
||||
m.autoUpdateSupported = func() bool { return true }
|
||||
m.Start(context.Background())
|
||||
m.SetVersion("latest", false)
|
||||
|
||||
ver, _ := waitForUpdateEvent(sub, 500*time.Millisecond)
|
||||
triggeredInit := ver != ""
|
||||
if triggeredInit != c.shouldUpdateInit {
|
||||
t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
|
||||
}
|
||||
if triggeredInit && c.initialLatestVersion != nil && ver != c.initialLatestVersion.String() {
|
||||
t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), ver)
|
||||
}
|
||||
|
||||
mockUpdate.latestVersion = c.latestVersion
|
||||
mockUpdate.onUpdate()
|
||||
|
||||
ver, _ = waitForUpdateEvent(sub, 500*time.Millisecond)
|
||||
triggeredLater := ver != ""
|
||||
if triggeredLater != c.shouldUpdateLater {
|
||||
t.Errorf("%s: Later update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
|
||||
}
|
||||
if triggeredLater && c.latestVersion != nil && ver != c.latestVersion.String() {
|
||||
t.Errorf("%s: Later update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), ver)
|
||||
}
|
||||
|
||||
m.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleUpdate(t *testing.T) {
|
||||
testMatrix := []struct {
|
||||
name string
|
||||
daemonVersion string
|
||||
latestVersion *v.Version
|
||||
expectedVersion string
|
||||
shouldUpdate bool
|
||||
}{
|
||||
{
|
||||
name: "Install to a specific version should update regardless of if latestVersion is available yet",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "0.56.0",
|
||||
shouldUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "Install to specific version should not update if version matches",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "0.55.0",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Install to specific version should not update if current version is newer",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "0.54.0",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Install to latest version should update if latest is newer",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: v.Must(v.NewSemver("0.56.0")),
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "Install to latest version should not update if latest == current",
|
||||
daemonVersion: "0.56.0",
|
||||
latestVersion: v.Must(v.NewSemver("0.56.0")),
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Should not update if daemon version is invalid",
|
||||
daemonVersion: "development",
|
||||
latestVersion: v.Must(v.NewSemver("1.0.0")),
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Should not update if expecting latest and latest version is unavailable",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Should not update if expected version is invalid",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "development",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
}
|
||||
for idx, c := range testMatrix {
|
||||
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
|
||||
recorder := peer.NewRecorder("")
|
||||
sub := recorder.SubscribeToEvents()
|
||||
defer recorder.UnsubscribeFromEvents(sub)
|
||||
|
||||
m := NewManager(recorder, statemanager.New(tmpFile))
|
||||
m.update = &versionUpdateMock{latestVersion: c.latestVersion}
|
||||
m.currentVersion = c.daemonVersion
|
||||
m.autoUpdateSupported = func() bool { return true }
|
||||
m.Start(context.Background())
|
||||
m.SetVersion(c.expectedVersion, false)
|
||||
|
||||
ver, _ := waitForUpdateEvent(sub, 500*time.Millisecond)
|
||||
updateTriggered := ver != ""
|
||||
|
||||
if updateTriggered {
|
||||
if c.expectedVersion == "latest" && c.latestVersion != nil && ver != c.latestVersion.String() {
|
||||
t.Errorf("%s: Version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), ver)
|
||||
} else if c.expectedVersion != "latest" && c.expectedVersion != "development" && ver != c.expectedVersion {
|
||||
t.Errorf("%s: Version mismatch, expected %v, got %v", c.name, c.expectedVersion, ver)
|
||||
}
|
||||
}
|
||||
|
||||
if updateTriggered != c.shouldUpdate {
|
||||
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered)
|
||||
}
|
||||
m.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func Test_EnforcedMetadata(t *testing.T) {
|
||||
// Mode 1 (downloadOnly): no enforced metadata
|
||||
tmpFile := path.Join(t.TempDir(), "update-test-mode1.json")
|
||||
recorder := peer.NewRecorder("")
|
||||
sub := recorder.SubscribeToEvents()
|
||||
defer recorder.UnsubscribeFromEvents(sub)
|
||||
|
||||
m := NewManager(recorder, statemanager.New(tmpFile))
|
||||
m.update = &versionUpdateMock{latestVersion: v.Must(v.NewSemver("1.0.1"))}
|
||||
m.currentVersion = "1.0.0"
|
||||
m.Start(context.Background())
|
||||
m.SetDownloadOnly()
|
||||
|
||||
ver, enforced := waitForUpdateEvent(sub, 500*time.Millisecond)
|
||||
if ver == "" {
|
||||
t.Fatal("Mode 1: expected new_version_available event")
|
||||
}
|
||||
if enforced {
|
||||
t.Error("Mode 1: expected no enforced metadata")
|
||||
}
|
||||
m.Stop()
|
||||
|
||||
// Mode 2 (enforced, forceUpdate=false): enforced metadata present, no auto-install
|
||||
tmpFile2 := path.Join(t.TempDir(), "update-test-mode2.json")
|
||||
recorder2 := peer.NewRecorder("")
|
||||
sub2 := recorder2.SubscribeToEvents()
|
||||
defer recorder2.UnsubscribeFromEvents(sub2)
|
||||
|
||||
m2 := NewManager(recorder2, statemanager.New(tmpFile2))
|
||||
m2.update = &versionUpdateMock{latestVersion: nil}
|
||||
m2.currentVersion = "1.0.0"
|
||||
m2.autoUpdateSupported = func() bool { return true }
|
||||
m2.Start(context.Background())
|
||||
m2.SetVersion("1.0.1", false)
|
||||
|
||||
ver, enforced2 := waitForUpdateEvent(sub2, 500*time.Millisecond)
|
||||
if ver == "" {
|
||||
t.Fatal("Mode 2: expected new_version_available event")
|
||||
}
|
||||
if !enforced2 {
|
||||
t.Error("Mode 2: expected enforced metadata")
|
||||
}
|
||||
m2.Stop()
|
||||
}
|
||||
|
||||
// ensure the proto import is used
|
||||
var _ = cProto.SystemEvent_INFO
|
||||
56
client/internal/updater/manager_test_helpers_test.go
Normal file
56
client/internal/updater/manager_test_helpers_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package updater
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
v "github.com/hashicorp/go-version"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
type versionUpdateMock struct {
|
||||
latestVersion *v.Version
|
||||
onUpdate func()
|
||||
}
|
||||
|
||||
func (m versionUpdateMock) StopWatch() {}
|
||||
|
||||
func (m versionUpdateMock) SetDaemonVersion(newVersion string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *versionUpdateMock) SetOnUpdateListener(updateFn func()) {
|
||||
m.onUpdate = updateFn
|
||||
}
|
||||
|
||||
func (m versionUpdateMock) LatestVersion() *v.Version {
|
||||
return m.latestVersion
|
||||
}
|
||||
|
||||
func (m versionUpdateMock) StartFetcher() {}
|
||||
|
||||
// waitForUpdateEvent waits for a new_version_available event, returns the version string or "" on timeout.
|
||||
func waitForUpdateEvent(sub *peer.EventSubscription, timeout time.Duration) (version string, enforced bool) {
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-sub.Events():
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if val, ok := event.Metadata["new_version_available"]; ok {
|
||||
enforced := false
|
||||
if raw, ok := event.Metadata["enforced"]; ok {
|
||||
if parsed, err := strconv.ParseBool(raw); err == nil {
|
||||
enforced = parsed
|
||||
}
|
||||
}
|
||||
return val, enforced
|
||||
}
|
||||
case <-timer.C:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/downloader"
|
||||
)
|
||||
|
||||
const (
|
||||
22
client/internal/updater/supported_darwin.go
Normal file
22
client/internal/updater/supported_darwin.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package updater
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||
)
|
||||
|
||||
func isAutoUpdateSupported() bool {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
isBrew := !installer.TypeOfInstaller(ctx).Downloadable()
|
||||
if isBrew {
|
||||
log.Warnf("auto-update disabled on Homebrew installation")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
7
client/internal/updater/supported_other.go
Normal file
7
client/internal/updater/supported_other.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !windows && !darwin
|
||||
|
||||
package updater
|
||||
|
||||
func isAutoUpdateSupported() bool {
|
||||
return false
|
||||
}
|
||||
5
client/internal/updater/supported_windows.go
Normal file
5
client/internal/updater/supported_windows.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package updater
|
||||
|
||||
func isAutoUpdateSupported() bool {
|
||||
return true
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package updatemanager
|
||||
package updater
|
||||
|
||||
import v "github.com/hashicorp/go-version"
|
||||
|
||||
Reference in New Issue
Block a user