From 869537c9511edbaf9f22bbcbffee5a2bcbbc3b93 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:53:46 +0200 Subject: [PATCH] [client] Cleanup dns and route states on startup (#2757) --- client/firewall/nftables/state.go | 1 + client/internal/connect.go | 12 +- client/internal/dns/consts_freebsd.go | 3 +- client/internal/dns/consts_linux.go | 3 +- client/internal/dns/file_repair_unix.go | 8 +- client/internal/dns/file_repair_unix_test.go | 9 +- client/internal/dns/file_unix.go | 24 +- client/internal/dns/host.go | 19 +- client/internal/dns/host_android.go | 12 +- client/internal/dns/host_darwin.go | 18 +- client/internal/dns/host_ios.go | 11 +- client/internal/dns/host_unix.go | 28 +- client/internal/dns/host_windows.go | 20 +- client/internal/dns/network_manager_unix.go | 21 +- client/internal/dns/resolvconf_unix.go | 23 +- client/internal/dns/server.go | 40 ++- client/internal/dns/server_test.go | 10 +- client/internal/dns/server_windows.go | 2 +- client/internal/dns/systemd_freebsd.go | 2 +- client/internal/dns/systemd_linux.go | 21 +- .../internal/dns/unclean_shutdown_android.go | 5 - .../internal/dns/unclean_shutdown_darwin.go | 48 +-- client/internal/dns/unclean_shutdown_ios.go | 5 - .../internal/dns/unclean_shutdown_mobile.go | 14 + client/internal/dns/unclean_shutdown_unix.go | 83 ++--- .../internal/dns/unclean_shutdown_windows.go | 69 +--- client/internal/engine.go | 35 +- client/internal/routemanager/manager.go | 15 +- client/internal/routemanager/manager_test.go | 4 +- client/internal/routemanager/mock.go | 9 +- .../internal/routemanager/systemops/state.go | 81 +++++ .../systemops/systemops_android.go | 9 +- .../systemops/systemops_generic.go | 51 ++- .../systemops/systemops_generic_test.go | 8 +- .../routemanager/systemops/systemops_ios.go | 25 +- .../routemanager/systemops/systemops_linux.go | 13 +- .../routemanager/systemops/systemops_unix.go | 9 +- .../systemops/systemops_windows.go | 9 +- client/internal/statemanager/manager.go | 298 ++++++++++++++++++ client/internal/statemanager/path.go | 35 ++ client/ios/NetBirdSDK/client.go | 4 +- client/server/server.go | 47 +++ 42 files changed, 786 insertions(+), 377 deletions(-) create mode 100644 client/firewall/nftables/state.go delete mode 100644 client/internal/dns/unclean_shutdown_android.go delete mode 100644 client/internal/dns/unclean_shutdown_ios.go create mode 100644 client/internal/dns/unclean_shutdown_mobile.go create mode 100644 client/internal/routemanager/systemops/state.go create mode 100644 client/internal/statemanager/manager.go create mode 100644 client/internal/statemanager/path.go diff --git a/client/firewall/nftables/state.go b/client/firewall/nftables/state.go new file mode 100644 index 000000000..7027fe987 --- /dev/null +++ b/client/firewall/nftables/state.go @@ -0,0 +1 @@ +package nftables diff --git a/client/internal/connect.go b/client/internal/connect.go index 74dc1f1b5..13f10fbf1 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -117,12 +117,6 @@ func (c *ConnectClient) run( log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH) - // Check if client was not shut down in a clean way and restore DNS config if required. - // Otherwise, we might not be able to connect to the management server to retrieve new config. - if err := dns.CheckUncleanShutdown(c.config.WgIface); err != nil { - log.Errorf("checking unclean shutdown error: %s", err) - } - backOff := &backoff.ExponentialBackOff{ InitialInterval: time.Second, RandomizationFactor: 1, @@ -358,7 +352,11 @@ func (c *ConnectClient) Stop() error { if c.engine == nil { return nil } - return c.engine.Stop() + if err := c.engine.Stop(); err != nil { + return fmt.Errorf("stop engine: %w", err) + } + + return nil } func (c *ConnectClient) isContextCancelled() bool { diff --git a/client/internal/dns/consts_freebsd.go b/client/internal/dns/consts_freebsd.go index 958eca8e5..64c8fe5eb 100644 --- a/client/internal/dns/consts_freebsd.go +++ b/client/internal/dns/consts_freebsd.go @@ -1,6 +1,5 @@ package dns const ( - fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf" - fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager" + fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf" ) diff --git a/client/internal/dns/consts_linux.go b/client/internal/dns/consts_linux.go index 32456a50f..15614b0c5 100644 --- a/client/internal/dns/consts_linux.go +++ b/client/internal/dns/consts_linux.go @@ -3,6 +3,5 @@ package dns const ( - fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf" - fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager" + fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf" ) diff --git a/client/internal/dns/file_repair_unix.go b/client/internal/dns/file_repair_unix.go index ae2c33b86..9a9218fa1 100644 --- a/client/internal/dns/file_repair_unix.go +++ b/client/internal/dns/file_repair_unix.go @@ -9,6 +9,8 @@ import ( "github.com/fsnotify/fsnotify" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) var ( @@ -20,7 +22,7 @@ var ( } ) -type repairConfFn func([]string, string, *resolvConf) error +type repairConfFn func([]string, string, *resolvConf, *statemanager.Manager) error type repair struct { operationFile string @@ -40,7 +42,7 @@ func newRepair(operationFile string, updateFn repairConfFn) *repair { } } -func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string) { +func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string, stateManager *statemanager.Manager) { if f.inotify != nil { return } @@ -81,7 +83,7 @@ func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP strin log.Errorf("failed to rm inotify watch for resolv.conf: %s", err) } - err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf) + err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf, stateManager) if err != nil { log.Errorf("failed to repair resolv.conf: %v", err) } diff --git a/client/internal/dns/file_repair_unix_test.go b/client/internal/dns/file_repair_unix_test.go index 4dba79e99..e948557b6 100644 --- a/client/internal/dns/file_repair_unix_test.go +++ b/client/internal/dns/file_repair_unix_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/util" ) @@ -104,14 +105,14 @@ nameserver 8.8.8.8`, var changed bool ctx, cancel := context.WithTimeout(context.Background(), time.Second) - updateFn := func([]string, string, *resolvConf) error { + updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error { changed = true cancel() return nil } r := newRepair(operationFile, updateFn) - r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1") + r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil) err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755) if err != nil { @@ -151,14 +152,14 @@ searchdomain netbird.cloud something` var changed bool ctx, cancel := context.WithTimeout(context.Background(), time.Second) - updateFn := func([]string, string, *resolvConf) error { + updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error { changed = true cancel() return nil } r := newRepair(tmpLink, updateFn) - r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1") + r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil) err = os.WriteFile(tmpLink, []byte(modifyContent), 0755) if err != nil { diff --git a/client/internal/dns/file_unix.go b/client/internal/dns/file_unix.go index 624e089cb..02ae26e10 100644 --- a/client/internal/dns/file_unix.go +++ b/client/internal/dns/file_unix.go @@ -11,6 +11,8 @@ import ( "time" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -36,7 +38,7 @@ type fileConfigurator struct { nbNameserverIP string } -func newFileConfigurator() (hostManager, error) { +func newFileConfigurator() (*fileConfigurator, error) { fc := &fileConfigurator{} fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig) return fc, nil @@ -46,7 +48,7 @@ func (f *fileConfigurator) supportCustomPort() bool { return false } -func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { backupFileExist := f.isBackupFileExist() if !config.RouteAll { if backupFileExist { @@ -76,15 +78,15 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error { f.repair.stopWatchFileChanges() - err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf) + err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf, stateManager) if err != nil { return err } - f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP) + f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP, stateManager) return nil } -func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf) error { +func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf, stateManager *statemanager.Manager) error { searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains) nameServers := generateNsList(nbNameserverIP, cfg) @@ -107,7 +109,7 @@ func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList) // create another backup for unclean shutdown detection right after overwriting the original resolv.conf - if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, fileManager, nbNameserverIP); err != nil { + if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, nbNameserverIP, stateManager); err != nil { log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) } @@ -145,10 +147,6 @@ func (f *fileConfigurator) restore() error { return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err) } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err) - } - return os.RemoveAll(fileDefaultResolvConfBackupLocation) } @@ -176,7 +174,7 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add return restoreResolvConfFile() } - log.Info("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: not restoring") + log.Infof("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: %s (current) vs %s (stored): not restoring", currentDNSAddress, storedDNSAddress) return nil } @@ -192,10 +190,6 @@ func restoreResolvConfFile() error { return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err) } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown resolv.conf file: %s", err) - } - return nil } diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index e55a07055..e2b5f699a 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -5,14 +5,14 @@ import ( "net/netip" "strings" + "github.com/netbirdio/netbird/client/internal/statemanager" nbdns "github.com/netbirdio/netbird/dns" ) type hostManager interface { - applyDNSConfig(config HostDNSConfig) error + applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error restoreHostDNS() error supportCustomPort() bool - restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error } type SystemDNSSettings struct { @@ -35,15 +35,15 @@ type DomainConfig struct { } type mockHostConfigurator struct { - applyDNSConfigFunc func(config HostDNSConfig) error + applyDNSConfigFunc func(config HostDNSConfig, stateManager *statemanager.Manager) error restoreHostDNSFunc func() error supportCustomPortFunc func() bool restoreUncleanShutdownDNSFunc func(*netip.Addr) error } -func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { if m.applyDNSConfigFunc != nil { - return m.applyDNSConfigFunc(config) + return m.applyDNSConfigFunc(config, stateManager) } return fmt.Errorf("method applyDNSSettings is not implemented") } @@ -62,16 +62,9 @@ func (m *mockHostConfigurator) supportCustomPort() bool { return false } -func (m *mockHostConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error { - if m.restoreUncleanShutdownDNSFunc != nil { - return m.restoreUncleanShutdownDNSFunc(storedDNSAddress) - } - return fmt.Errorf("method restoreUncleanShutdownDNS is not implemented") -} - func newNoopHostMocker() hostManager { return &mockHostConfigurator{ - applyDNSConfigFunc: func(config HostDNSConfig) error { return nil }, + applyDNSConfigFunc: func(config HostDNSConfig, stateManager *statemanager.Manager) error { return nil }, restoreHostDNSFunc: func() error { return nil }, supportCustomPortFunc: func() bool { return true }, restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil }, diff --git a/client/internal/dns/host_android.go b/client/internal/dns/host_android.go index 9230cb257..5653710d7 100644 --- a/client/internal/dns/host_android.go +++ b/client/internal/dns/host_android.go @@ -1,15 +1,17 @@ package dns -import "net/netip" +import ( + "github.com/netbirdio/netbird/client/internal/statemanager" +) type androidHostManager struct { } -func newHostManager() (hostManager, error) { +func newHostManager() (*androidHostManager, error) { return &androidHostManager{}, nil } -func (a androidHostManager) applyDNSConfig(config HostDNSConfig) error { +func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error { return nil } @@ -20,7 +22,3 @@ func (a androidHostManager) restoreHostDNS() error { func (a androidHostManager) supportCustomPort() bool { return false } - -func (a androidHostManager) restoreUncleanShutdownDNS(*netip.Addr) error { - return nil -} diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index 5dee305c2..b8ba33e34 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -8,12 +8,13 @@ import ( "fmt" "io" "net" - "net/netip" "os/exec" "strconv" "strings" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -37,7 +38,7 @@ type systemConfigurator struct { systemDNSSettings SystemDNSSettings } -func newHostManager() (hostManager, error) { +func newHostManager() (*systemConfigurator, error) { return &systemConfigurator{ createdKeys: make(map[string]struct{}), }, nil @@ -47,12 +48,11 @@ func (s *systemConfigurator) supportCustomPort() bool { return true } -func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { var err error - // create a file for unclean shutdown detection - if err := createUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to create unclean shutdown file: %s", err) + if err := stateManager.UpdateState(&ShutdownState{}); err != nil { + log.Errorf("failed to update shutdown state: %s", err) } var ( @@ -123,10 +123,6 @@ func (s *systemConfigurator) restoreHostDNS() error { } } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown file: %s", err) - } - return nil } @@ -320,7 +316,7 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) { return primaryService, router, nil } -func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { +func (s *systemConfigurator) restoreUncleanShutdownDNS() error { if err := s.restoreHostDNS(); err != nil { return fmt.Errorf("restoring dns via scutil: %w", err) } diff --git a/client/internal/dns/host_ios.go b/client/internal/dns/host_ios.go index ad8b14fb8..4a0acf572 100644 --- a/client/internal/dns/host_ios.go +++ b/client/internal/dns/host_ios.go @@ -3,9 +3,10 @@ package dns import ( "encoding/json" "fmt" - "net/netip" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) type iosHostManager struct { @@ -13,13 +14,13 @@ type iosHostManager struct { config HostDNSConfig } -func newHostManager(dnsManager IosDnsManager) (hostManager, error) { +func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) { return &iosHostManager{ dnsManager: dnsManager, }, nil } -func (a iosHostManager) applyDNSConfig(config HostDNSConfig) error { +func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error { jsonData, err := json.Marshal(config) if err != nil { return fmt.Errorf("marshal: %w", err) @@ -37,7 +38,3 @@ func (a iosHostManager) restoreHostDNS() error { func (a iosHostManager) supportCustomPort() bool { return false } - -func (a iosHostManager) restoreUncleanShutdownDNS(*netip.Addr) error { - return nil -} diff --git a/client/internal/dns/host_unix.go b/client/internal/dns/host_unix.go index 72b8f6c6e..7bd4aec64 100644 --- a/client/internal/dns/host_unix.go +++ b/client/internal/dns/host_unix.go @@ -4,9 +4,9 @@ package dns import ( "bufio" - "errors" "fmt" "io" + "net/netip" "os" "strings" @@ -21,27 +21,8 @@ const ( resolvConfManager ) -var ErrUnknownOsManagerType = errors.New("unknown os manager type") - type osManagerType int -func newOsManagerType(osManager string) (osManagerType, error) { - switch osManager { - case "netbird": - return fileManager, nil - case "file": - return netbirdManager, nil - case "networkManager": - return networkManager, nil - case "systemd": - return systemdManager, nil - case "resolvconf": - return resolvConfManager, nil - default: - return 0, ErrUnknownOsManagerType - } -} - func (t osManagerType) String() string { switch t { case netbirdManager: @@ -59,6 +40,11 @@ func (t osManagerType) String() string { } } +type restoreHostManager interface { + hostManager + restoreUncleanShutdownDNS(*netip.Addr) error +} + func newHostManager(wgInterface string) (hostManager, error) { osManager, err := getOSDNSManagerType() if err != nil { @@ -69,7 +55,7 @@ func newHostManager(wgInterface string) (hostManager, error) { return newHostManagerFromType(wgInterface, osManager) } -func newHostManagerFromType(wgInterface string, osManager osManagerType) (hostManager, error) { +func newHostManagerFromType(wgInterface string, osManager osManagerType) (restoreHostManager, error) { switch osManager { case networkManager: return newNetworkManagerDbusConfigurator(wgInterface) diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index c8bf2e552..7ecca8a41 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -3,11 +3,12 @@ package dns import ( "fmt" "io" - "net/netip" "strings" log "github.com/sirupsen/logrus" "golang.org/x/sys/windows/registry" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) const ( @@ -31,7 +32,7 @@ type registryConfigurator struct { routingAll bool } -func newHostManager(wgInterface WGIface) (hostManager, error) { +func newHostManager(wgInterface WGIface) (*registryConfigurator, error) { guid, err := wgInterface.GetInterfaceGUIDString() if err != nil { return nil, err @@ -39,7 +40,7 @@ func newHostManager(wgInterface WGIface) (hostManager, error) { return newHostManagerWithGuid(guid) } -func newHostManagerWithGuid(guid string) (hostManager, error) { +func newHostManagerWithGuid(guid string) (*registryConfigurator, error) { return ®istryConfigurator{ guid: guid, }, nil @@ -49,7 +50,7 @@ func (r *registryConfigurator) supportCustomPort() bool { return false } -func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { var err error if config.RouteAll { err = r.addDNSSetupForAll(config.ServerIP) @@ -65,9 +66,8 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error { log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) } - // create a file for unclean shutdown detection - if err := createUncleanShutdownIndicator(r.guid); err != nil { - log.Errorf("failed to create unclean shutdown file: %s", err) + if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid}); err != nil { + log.Errorf("failed to update shutdown state: %s", err) } var ( @@ -160,10 +160,6 @@ func (r *registryConfigurator) restoreHostDNS() error { return fmt.Errorf("remove interface registry key: %w", err) } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown file: %s", err) - } - return nil } @@ -221,7 +217,7 @@ func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) { return regKey, nil } -func (r *registryConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { +func (r *registryConfigurator) restoreUncleanShutdownDNS() error { if err := r.restoreHostDNS(); err != nil { return fmt.Errorf("restoring dns via registry: %w", err) } diff --git a/client/internal/dns/network_manager_unix.go b/client/internal/dns/network_manager_unix.go index 184047a64..63bbead77 100644 --- a/client/internal/dns/network_manager_unix.go +++ b/client/internal/dns/network_manager_unix.go @@ -16,6 +16,7 @@ import ( "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/statemanager" nbversion "github.com/netbirdio/netbird/version" ) @@ -53,6 +54,7 @@ var supportedNetworkManagerVersionConstraints = []string{ type networkManagerDbusConfigurator struct { dbusLinkObject dbus.ObjectPath routingAll bool + ifaceName string } // the types below are based on dbus specification, each field is mapped to a dbus type @@ -77,7 +79,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() { } } -func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error) { +func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusConfigurator, error) { obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode) if err != nil { return nil, fmt.Errorf("get nm dbus: %w", err) @@ -93,6 +95,7 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error) return &networkManagerDbusConfigurator{ dbusLinkObject: dbus.ObjectPath(s), + ifaceName: wgInterface, }, nil } @@ -100,7 +103,7 @@ func (n *networkManagerDbusConfigurator) supportCustomPort() bool { return false } -func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { connSettings, configVersion, err := n.getAppliedConnectionSettings() if err != nil { return fmt.Errorf("retrieving the applied connection settings, error: %w", err) @@ -151,10 +154,12 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) er connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority) connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList) - // create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file. - // The file content itself is not important for network-manager restoration - if err := createUncleanShutdownIndicator(defaultResolvConfPath, networkManager, dnsIP.String()); err != nil { - log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) + state := &ShutdownState{ + ManagerType: networkManager, + WgIface: n.ifaceName, + } + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("failed to update shutdown state: %s", err) } log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) @@ -171,10 +176,6 @@ func (n *networkManagerDbusConfigurator) restoreHostDNS() error { return fmt.Errorf("delete connection settings: %w", err) } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err) - } - return nil } diff --git a/client/internal/dns/resolvconf_unix.go b/client/internal/dns/resolvconf_unix.go index 0c17626c7..a5d1cc8a2 100644 --- a/client/internal/dns/resolvconf_unix.go +++ b/client/internal/dns/resolvconf_unix.go @@ -9,6 +9,8 @@ import ( "os/exec" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/statemanager" ) const resolvconfCommand = "resolvconf" @@ -22,7 +24,7 @@ type resolvconf struct { } // supported "openresolv" only -func newResolvConfConfigurator(wgInterface string) (hostManager, error) { +func newResolvConfConfigurator(wgInterface string) (*resolvconf, error) { resolvConfEntries, err := parseDefaultResolvConf() if err != nil { log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err) @@ -40,7 +42,7 @@ func (r *resolvconf) supportCustomPort() bool { return false } -func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error { +func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { var err error if !config.RouteAll { err = r.restoreHostDNS() @@ -60,9 +62,12 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error { append([]string{config.ServerIP}, r.originalNameServers...), options) - // create a backup for unclean shutdown detection before the resolv.conf is changed - if err := createUncleanShutdownIndicator(defaultResolvConfPath, resolvConfManager, config.ServerIP); err != nil { - log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) + state := &ShutdownState{ + ManagerType: resolvConfManager, + WgIface: r.ifaceName, + } + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("failed to update shutdown state: %s", err) } err = r.applyConfig(buf) @@ -79,11 +84,7 @@ func (r *resolvconf) restoreHostDNS() error { cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName) _, err := cmd.Output() if err != nil { - return fmt.Errorf("removing resolvconf configuration for %s interface, error: %w", r.ifaceName, err) - } - - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err) + return fmt.Errorf("removing resolvconf configuration for %s interface: %w", r.ifaceName, err) } return nil @@ -95,7 +96,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error { cmd.Stdin = &content _, err := cmd.Output() if err != nil { - return fmt.Errorf("applying resolvconf configuration for %s interface, error: %w", r.ifaceName, err) + return fmt.Errorf("applying resolvconf configuration for %s interface: %w", r.ifaceName, err) } return nil } diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index a4651ebb5..772797fac 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -7,6 +7,7 @@ import ( "runtime" "strings" "sync" + "time" "github.com/miekg/dns" "github.com/mitchellh/hashstructure/v2" @@ -14,6 +15,7 @@ import ( "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" nbdns "github.com/netbirdio/netbird/dns" ) @@ -63,6 +65,7 @@ type DefaultServer struct { iosDnsManager IosDnsManager statusRecorder *peer.Status + stateManager *statemanager.Manager } type handlerWithStop interface { @@ -77,12 +80,7 @@ type muxUpdate struct { } // NewDefaultServer returns a new dns server -func NewDefaultServer( - ctx context.Context, - wgInterface WGIface, - customAddress string, - statusRecorder *peer.Status, -) (*DefaultServer, error) { +func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string, statusRecorder *peer.Status, stateManager *statemanager.Manager) (*DefaultServer, error) { var addrPort *netip.AddrPort if customAddress != "" { parsedAddrPort, err := netip.ParseAddrPort(customAddress) @@ -99,7 +97,7 @@ func NewDefaultServer( dnsService = newServiceViaListener(wgInterface, addrPort) } - return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder), nil + return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager), nil } // NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems @@ -112,7 +110,7 @@ func NewDefaultServerPermanentUpstream( statusRecorder *peer.Status, ) *DefaultServer { log.Debugf("host dns address list is: %v", hostsDnsList) - ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder) + ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil) ds.hostsDNSHolder.set(hostsDnsList) ds.permanent = true ds.addHostRootZone() @@ -130,12 +128,12 @@ func NewDefaultServerIos( iosDnsManager IosDnsManager, statusRecorder *peer.Status, ) *DefaultServer { - ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder) + ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil) ds.iosDnsManager = iosDnsManager return ds } -func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status) *DefaultServer { +func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer { ctx, stop := context.WithCancel(ctx) defaultServer := &DefaultServer{ ctx: ctx, @@ -147,6 +145,7 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi }, wgInterface: wgInterface, statusRecorder: statusRecorder, + stateManager: stateManager, hostsDNSHolder: newHostsDNSHolder(), } @@ -169,6 +168,7 @@ func (s *DefaultServer) Initialize() (err error) { } } + s.stateManager.RegisterState(&ShutdownState{}) s.hostManager, err = s.initialize() if err != nil { return fmt.Errorf("initialize: %w", err) @@ -191,9 +191,10 @@ func (s *DefaultServer) Stop() { s.ctxCancel() if s.hostManager != nil { - err := s.hostManager.restoreHostDNS() - if err != nil { - log.Error(err) + if err := s.hostManager.restoreHostDNS(); err != nil { + log.Error("failed to restore host DNS settings: ", err) + } else if err := s.stateManager.DeleteState(&ShutdownState{}); err != nil { + log.Errorf("failed to delete shutdown dns state: %v", err) } } @@ -318,10 +319,17 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { hostUpdate.RouteAll = false } - if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil { + if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil { log.Error(err) } + // persist dns state right away + ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) + defer cancel() + if err := s.stateManager.PersistState(ctx); err != nil { + log.Errorf("Failed to persist dns state: %v", err) + } + if s.searchDomainNotifier != nil { s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains()) } @@ -521,7 +529,7 @@ func (s *DefaultServer) upstreamCallbacks( } } - if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { + if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) } @@ -551,7 +559,7 @@ func (s *DefaultServer) upstreamCallbacks( s.currentConfig.RouteAll = true s.service.RegisterMux(nbdns.RootZone, handler) } - if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { + if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 4a5aff3ea..21f1f1b7d 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" pfmock "github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/stdnet" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/formatter" @@ -291,7 +292,7 @@ func TestUpdateDNSServer(t *testing.T) { t.Log(err) } }() - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}) + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil) if err != nil { t.Fatal(err) } @@ -400,7 +401,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { return } - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}) + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil) if err != nil { t.Errorf("create DNS server: %v", err) return @@ -495,7 +496,7 @@ func TestDNSServerStartStop(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}) + dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil) if err != nil { t.Fatalf("%v", err) } @@ -554,6 +555,7 @@ func TestDNSServerStartStop(t *testing.T) { func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { hostManager := &mockHostConfigurator{} server := DefaultServer{ + ctx: context.Background(), service: NewServiceViaMemory(&mocWGIface{}), localResolver: &localResolver{ registeredMap: make(registrationMap), @@ -570,7 +572,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { } var domainsUpdate string - hostManager.applyDNSConfigFunc = func(config HostDNSConfig) error { + hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error { domains := []string{} for _, item := range config.Domains { if item.Disabled { diff --git a/client/internal/dns/server_windows.go b/client/internal/dns/server_windows.go index 5e1494e9e..bc051d59b 100644 --- a/client/internal/dns/server_windows.go +++ b/client/internal/dns/server_windows.go @@ -1,5 +1,5 @@ package dns -func (s *DefaultServer) initialize() (manager hostManager, err error) { +func (s *DefaultServer) initialize() (hostManager, error) { return newHostManager(s.wgInterface) } diff --git a/client/internal/dns/systemd_freebsd.go b/client/internal/dns/systemd_freebsd.go index 0de805337..41c8bf019 100644 --- a/client/internal/dns/systemd_freebsd.go +++ b/client/internal/dns/systemd_freebsd.go @@ -7,7 +7,7 @@ import ( var errNotImplemented = errors.New("not implemented") -func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) { +func newSystemdDbusConfigurator(string) (restoreHostManager, error) { return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented) } diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index e2fa5b71a..a031be582 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -15,6 +15,7 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" + "github.com/netbirdio/netbird/client/internal/statemanager" nbdns "github.com/netbirdio/netbird/dns" ) @@ -38,6 +39,7 @@ const ( type systemdDbusConfigurator struct { dbusLinkObject dbus.ObjectPath routingAll bool + ifaceName string } // the types below are based on dbus specification, each field is mapped to a dbus type @@ -55,7 +57,7 @@ type systemdDbusLinkDomainsInput struct { MatchOnly bool } -func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) { +func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, error) { iface, err := net.InterfaceByName(wgInterface) if err != nil { return nil, fmt.Errorf("get interface: %w", err) @@ -77,6 +79,7 @@ func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) { return &systemdDbusConfigurator{ dbusLinkObject: dbus.ObjectPath(s), + ifaceName: wgInterface, }, nil } @@ -84,7 +87,7 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool { return true } -func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error { +func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { parsedIP, err := netip.ParseAddr(config.ServerIP) if err != nil { return fmt.Errorf("unable to parse ip address, error: %w", err) @@ -135,10 +138,12 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error { log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort) } - // create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file. - // The file content itself is not important for systemd restoration - if err := createUncleanShutdownIndicator(defaultResolvConfPath, systemdManager, parsedIP.String()); err != nil { - log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) + state := &ShutdownState{ + ManagerType: systemdManager, + WgIface: s.ifaceName, + } + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("failed to update shutdown state: %s", err) } log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) @@ -174,10 +179,6 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error { return fmt.Errorf("unable to revert link configuration, got error: %w", err) } - if err := removeUncleanShutdownIndicator(); err != nil { - log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err) - } - return s.flushCaches() } diff --git a/client/internal/dns/unclean_shutdown_android.go b/client/internal/dns/unclean_shutdown_android.go deleted file mode 100644 index 105fb00bf..000000000 --- a/client/internal/dns/unclean_shutdown_android.go +++ /dev/null @@ -1,5 +0,0 @@ -package dns - -func CheckUncleanShutdown(string) error { - return nil -} diff --git a/client/internal/dns/unclean_shutdown_darwin.go b/client/internal/dns/unclean_shutdown_darwin.go index e077ec84d..9bbdd2b56 100644 --- a/client/internal/dns/unclean_shutdown_darwin.go +++ b/client/internal/dns/unclean_shutdown_darwin.go @@ -3,57 +3,25 @@ package dns import ( - "errors" "fmt" - "io/fs" - "os" - "path/filepath" - - log "github.com/sirupsen/logrus" ) -const fileUncleanShutdownFileLocation = "/var/lib/netbird/unclean_shutdown_dns" +type ShutdownState struct { +} -func CheckUncleanShutdown(string) error { - if _, err := os.Stat(fileUncleanShutdownFileLocation); err != nil { - if errors.Is(err, fs.ErrNotExist) { - // no file -> clean shutdown - return nil - } else { - return fmt.Errorf("state: %w", err) - } - } - - log.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", fileUncleanShutdownFileLocation) +func (s *ShutdownState) Name() string { + return "dns_state" +} +func (s *ShutdownState) Cleanup() error { manager, err := newHostManager() if err != nil { return fmt.Errorf("create host manager: %w", err) } - if err := manager.restoreUncleanShutdownDNS(nil); err != nil { - return fmt.Errorf("restore unclean shutdown backup: %w", err) + if err := manager.restoreUncleanShutdownDNS(); err != nil { + return fmt.Errorf("restore unclean shutdown dns: %w", err) } return nil } - -func createUncleanShutdownIndicator() error { - dir := filepath.Dir(fileUncleanShutdownFileLocation) - if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil { - return fmt.Errorf("create dir %s: %w", dir, err) - } - - if err := os.WriteFile(fileUncleanShutdownFileLocation, nil, 0644); err != nil { //nolint:gosec - return fmt.Errorf("create %s: %w", fileUncleanShutdownFileLocation, err) - } - - return nil -} - -func removeUncleanShutdownIndicator() error { - if err := os.Remove(fileUncleanShutdownFileLocation); err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("remove %s: %w", fileUncleanShutdownFileLocation, err) - } - return nil -} diff --git a/client/internal/dns/unclean_shutdown_ios.go b/client/internal/dns/unclean_shutdown_ios.go deleted file mode 100644 index 105fb00bf..000000000 --- a/client/internal/dns/unclean_shutdown_ios.go +++ /dev/null @@ -1,5 +0,0 @@ -package dns - -func CheckUncleanShutdown(string) error { - return nil -} diff --git a/client/internal/dns/unclean_shutdown_mobile.go b/client/internal/dns/unclean_shutdown_mobile.go new file mode 100644 index 000000000..0d3a2cdbd --- /dev/null +++ b/client/internal/dns/unclean_shutdown_mobile.go @@ -0,0 +1,14 @@ +//go:build ios || android + +package dns + +type ShutdownState struct { +} + +func (s *ShutdownState) Name() string { + return "dns_state" +} + +func (s *ShutdownState) Cleanup() error { + return nil +} diff --git a/client/internal/dns/unclean_shutdown_unix.go b/client/internal/dns/unclean_shutdown_unix.go index 8a32090c3..fcf60c694 100644 --- a/client/internal/dns/unclean_shutdown_unix.go +++ b/client/internal/dns/unclean_shutdown_unix.go @@ -3,66 +3,44 @@ package dns import ( - "errors" "fmt" - "io/fs" "net/netip" "os" "path/filepath" - "strings" - log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/statemanager" ) -func CheckUncleanShutdown(wgIface string) error { - if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil { - if errors.Is(err, fs.ErrNotExist) { - // no file -> clean shutdown - return nil - } else { - return fmt.Errorf("state: %w", err) - } - } +type ShutdownState struct { + ManagerType osManagerType + DNSAddress netip.Addr + WgIface string +} - log.Warnf("detected unclean shutdown, file %s exists", fileUncleanShutdownResolvConfLocation) +func (s *ShutdownState) Name() string { + return "dns_state" +} - managerData, err := os.ReadFile(fileUncleanShutdownManagerTypeLocation) - if err != nil { - return fmt.Errorf("read %s: %w", fileUncleanShutdownManagerTypeLocation, err) - } - - managerFields := strings.Split(string(managerData), ",") - if len(managerFields) < 2 { - return errors.New("split manager data: insufficient number of fields") - } - osManagerTypeStr, dnsAddressStr := managerFields[0], managerFields[1] - - dnsAddress, err := netip.ParseAddr(dnsAddressStr) - if err != nil { - return fmt.Errorf("parse dns address %s failed: %w", dnsAddressStr, err) - } - - log.Warnf("restoring unclean shutdown dns settings via previously detected manager: %s", osManagerTypeStr) - - // determine os manager type, so we can invoke the respective restore action - osManagerType, err := newOsManagerType(osManagerTypeStr) - if err != nil { - return fmt.Errorf("detect previous host manager: %w", err) - } - - manager, err := newHostManagerFromType(wgIface, osManagerType) +func (s *ShutdownState) Cleanup() error { + manager, err := newHostManagerFromType(s.WgIface, s.ManagerType) if err != nil { return fmt.Errorf("create previous host manager: %w", err) } - if err := manager.restoreUncleanShutdownDNS(&dnsAddress); err != nil { - return fmt.Errorf("restore unclean shutdown backup: %w", err) + if err := manager.restoreUncleanShutdownDNS(&s.DNSAddress); err != nil { + return fmt.Errorf("restore unclean shutdown dns: %w", err) } return nil } -func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType, dnsAddress string) error { +// TODO: move file contents to state manager +func createUncleanShutdownIndicator(sourcePath string, dnsAddressStr string, stateManager *statemanager.Manager) error { + dnsAddress, err := netip.ParseAddr(dnsAddressStr) + if err != nil { + return fmt.Errorf("parse dns address %s: %w", dnsAddressStr, err) + } + dir := filepath.Dir(fileUncleanShutdownResolvConfLocation) if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil { return fmt.Errorf("create dir %s: %w", dir, err) @@ -72,20 +50,13 @@ func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType return fmt.Errorf("create %s: %w", sourcePath, err) } - managerData := fmt.Sprintf("%s,%s", managerType, dnsAddress) - - if err := os.WriteFile(fileUncleanShutdownManagerTypeLocation, []byte(managerData), 0644); err != nil { //nolint:gosec - return fmt.Errorf("create %s: %w", fileUncleanShutdownManagerTypeLocation, err) - } - return nil -} - -func removeUncleanShutdownIndicator() error { - if err := os.Remove(fileUncleanShutdownResolvConfLocation); err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("remove %s: %w", fileUncleanShutdownResolvConfLocation, err) - } - if err := os.Remove(fileUncleanShutdownManagerTypeLocation); err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("remove %s: %w", fileUncleanShutdownManagerTypeLocation, err) + state := &ShutdownState{ + ManagerType: fileManager, + DNSAddress: dnsAddress, } + if err := stateManager.UpdateState(state); err != nil { + return fmt.Errorf("update state: %w", err) + } + return nil } diff --git a/client/internal/dns/unclean_shutdown_windows.go b/client/internal/dns/unclean_shutdown_windows.go index 41db46768..74e40cc11 100644 --- a/client/internal/dns/unclean_shutdown_windows.go +++ b/client/internal/dns/unclean_shutdown_windows.go @@ -1,75 +1,26 @@ package dns import ( - "errors" "fmt" - "io/fs" - "os" - "path/filepath" - - "github.com/sirupsen/logrus" ) -const ( - netbirdProgramDataLocation = "Netbird" - fileUncleanShutdownFile = "unclean_shutdown_dns.txt" -) +type ShutdownState struct { + Guid string +} -func CheckUncleanShutdown(string) error { - file := getUncleanShutdownFile() +func (s *ShutdownState) Name() string { + return "dns_state" +} - if _, err := os.Stat(file); err != nil { - if errors.Is(err, fs.ErrNotExist) { - // no file -> clean shutdown - return nil - } else { - return fmt.Errorf("state: %w", err) - } - } - - logrus.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", file) - - guid, err := os.ReadFile(file) - if err != nil { - return fmt.Errorf("read %s: %w", file, err) - } - - manager, err := newHostManagerWithGuid(string(guid)) +func (s *ShutdownState) Cleanup() error { + manager, err := newHostManagerWithGuid(s.Guid) if err != nil { return fmt.Errorf("create host manager: %w", err) } - if err := manager.restoreUncleanShutdownDNS(nil); err != nil { - return fmt.Errorf("restore unclean shutdown backup: %w", err) + if err := manager.restoreUncleanShutdownDNS(); err != nil { + return fmt.Errorf("restore unclean shutdown dns: %w", err) } return nil } - -func createUncleanShutdownIndicator(guid string) error { - file := getUncleanShutdownFile() - - dir := filepath.Dir(file) - if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil { - return fmt.Errorf("create dir %s: %w", dir, err) - } - - if err := os.WriteFile(file, []byte(guid), 0600); err != nil { - return fmt.Errorf("create %s: %w", file, err) - } - - return nil -} - -func removeUncleanShutdownIndicator() error { - file := getUncleanShutdownFile() - - if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("remove %s: %w", file, err) - } - return nil -} - -func getUncleanShutdownFile() string { - return filepath.Join(os.Getenv("PROGRAMDATA"), netbirdProgramDataLocation, fileUncleanShutdownFile) -} diff --git a/client/internal/engine.go b/client/internal/engine.go index 459518de1..22dd1f584 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -23,18 +23,19 @@ import ( "github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" - - "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" + "github.com/netbirdio/netbird/client/internal/statemanager" + nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" @@ -166,6 +167,7 @@ type Engine struct { checks []*mgmProto.Checks relayManager *relayClient.Manager + stateManager *statemanager.Manager } // Peer is an instance of the Connection Peer @@ -213,7 +215,7 @@ func NewEngineWithProbes( probes *ProbeHolder, checks []*mgmProto.Checks, ) *Engine { - return &Engine{ + engine := &Engine{ clientCtx: clientCtx, clientCancel: clientCancel, signal: signalClient, @@ -232,6 +234,11 @@ func NewEngineWithProbes( probes: probes, checks: checks, } + if path := statemanager.GetDefaultStatePath(); path != "" { + engine.stateManager = statemanager.New(path) + } + + return engine } func (e *Engine) Stop() error { @@ -253,7 +260,7 @@ func (e *Engine) Stop() error { e.stopDNSServer() if e.routeManager != nil { - e.routeManager.Stop() + e.routeManager.Stop(e.stateManager) } err := e.removeAllPeers() @@ -275,6 +282,17 @@ func (e *Engine) Stop() error { e.close() log.Infof("stopped Netbird Engine") + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + if err := e.stateManager.Stop(ctx); err != nil { + return fmt.Errorf("failed to stop state manager: %w", err) + } + if err := e.stateManager.PersistState(ctx); err != nil { + log.Errorf("failed to persist state: %v", err) + } + return nil } @@ -314,6 +332,8 @@ func (e *Engine) Start() error { } } + e.stateManager.Start() + initialRoutes, dnsServer, err := e.newDnsServer() if err != nil { e.close() @@ -322,7 +342,7 @@ func (e *Engine) Start() error { e.dnsServer = dnsServer e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes) - beforePeerHook, afterPeerHook, err := e.routeManager.Init() + beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager) if err != nil { log.Errorf("Failed to initialize route manager: %s", err) } else { @@ -1219,10 +1239,11 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) { dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder) return nil, dnsServer, nil default: - dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder) + dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager) if err != nil { return nil, nil, err } + return nil, dnsServer, nil } } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index d7ddf7ae8..0a1c7dc56 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -23,6 +23,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routeselector" + "github.com/netbirdio/netbird/client/internal/statemanager" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" nbnet "github.com/netbirdio/netbird/util/net" @@ -31,14 +32,14 @@ import ( // Manager is a route manager interface type Manager interface { - Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) + Init(*statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) TriggerSelection(route.HAMap) GetRouteSelector() *routeselector.RouteSelector SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string EnableServerRouter(firewall firewall.Manager) error - Stop() + Stop(stateManager *statemanager.Manager) } // DefaultManager is the default instance of a route manager @@ -120,12 +121,12 @@ func NewManager( } // Init sets up the routing -func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { if nbnet.CustomRoutingDisabled() { return nil, nil, nil } - if err := m.sysOps.CleanupRouting(); err != nil { + if err := m.sysOps.CleanupRouting(nil); err != nil { log.Warnf("Failed cleaning up routing: %v", err) } @@ -136,7 +137,7 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) ips := resolveURLsToIPs(initialAddresses) - beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips) + beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, stateManager) if err != nil { return nil, nil, fmt.Errorf("setup routing: %w", err) } @@ -154,7 +155,7 @@ func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { } // Stop stops the manager watchers and clean firewall rules -func (m *DefaultManager) Stop() { +func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { m.stop() if m.serverRouter != nil { m.serverRouter.cleanUp() @@ -172,7 +173,7 @@ func (m *DefaultManager) Stop() { } if !nbnet.CustomRoutingDisabled() { - if err := m.sysOps.CleanupRouting(); err != nil { + if err := m.sysOps.CleanupRouting(stateManager); err != nil { log.Errorf("Error cleaning up routing: %v", err) } else { log.Info("Routing cleanup complete") diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 044a996c7..e669bc44a 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -426,10 +426,10 @@ func TestManagerUpdateRoutes(t *testing.T) { ctx := context.TODO() routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil) - _, _, err = routeManager.Init() + _, _, err = routeManager.Init(nil) require.NoError(t, err, "should init route manager") - defer routeManager.Stop() + defer routeManager.Stop(nil) if testCase.removeSrvRouter { routeManager.serverRouter = nil diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 908279c88..503185f03 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -8,6 +8,7 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/routeselector" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/util/net" ) @@ -17,10 +18,10 @@ type MockManager struct { UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) TriggerSelectionFunc func(haMap route.HAMap) GetRouteSelectorFunc func() *routeselector.RouteSelector - StopFunc func() + StopFunc func(manager *statemanager.Manager) } -func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { +func (m *MockManager) Init(*statemanager.Manager) (net.AddHookFunc, net.RemoveHookFunc, error) { return nil, nil, nil } @@ -65,8 +66,8 @@ func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error { } // Stop mock implementation of Stop from Manager interface -func (m *MockManager) Stop() { +func (m *MockManager) Stop(stateManager *statemanager.Manager) { if m.StopFunc != nil { - m.StopFunc() + m.StopFunc(stateManager) } } diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go new file mode 100644 index 000000000..269924677 --- /dev/null +++ b/client/internal/routemanager/systemops/state.go @@ -0,0 +1,81 @@ +package systemops + +import ( + "encoding/json" + "fmt" + "net/netip" + "sync" + + "github.com/hashicorp/go-multierror" + + nberrors "github.com/netbirdio/netbird/client/errors" +) + +type RouteEntry struct { + Prefix netip.Prefix `json:"prefix"` + Nexthop Nexthop `json:"nexthop"` +} + +type ShutdownState struct { + Routes map[netip.Prefix]RouteEntry `json:"routes,omitempty"` + mu sync.RWMutex +} + +func NewShutdownState() *ShutdownState { + return &ShutdownState{ + Routes: make(map[netip.Prefix]RouteEntry), + } +} + +func (s *ShutdownState) Name() string { + return "route_state" +} + +func (s *ShutdownState) Cleanup() error { + sysops := NewSysOps(nil, nil) + var merr *multierror.Error + + s.mu.RLock() + defer s.mu.RUnlock() + + for _, route := range s.Routes { + if err := sysops.removeFromRouteTable(route.Prefix, route.Nexthop); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", route.Prefix, err)) + } + } + + return nberrors.FormatErrorOrNil(merr) +} + +func (s *ShutdownState) UpdateRoute(prefix netip.Prefix, nexthop Nexthop) { + s.mu.Lock() + defer s.mu.Unlock() + + s.Routes[prefix] = RouteEntry{ + Prefix: prefix, + Nexthop: nexthop, + } +} + +func (s *ShutdownState) RemoveRoute(prefix netip.Prefix) { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.Routes, prefix) +} + +// MarshalJSON ensures that empty routes are marshaled as null +func (s *ShutdownState) MarshalJSON() ([]byte, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if len(s.Routes) == 0 { + return json.Marshal(nil) + } + + return json.Marshal(s.Routes) +} + +func (s *ShutdownState) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &s.Routes) +} diff --git a/client/internal/routemanager/systemops/systemops_android.go b/client/internal/routemanager/systemops/systemops_android.go index 5e97a4a5f..ca8aea3fb 100644 --- a/client/internal/routemanager/systemops/systemops_android.go +++ b/client/internal/routemanager/systemops/systemops_android.go @@ -9,14 +9,15 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { return nil, nil, nil } -func (r *SysOps) CleanupRouting() error { +func (r *SysOps) CleanupRouting(*statemanager.Manager) error { return nil } @@ -28,6 +29,10 @@ func (r *SysOps) RemoveVPNRoute(netip.Prefix, *net.Interface) error { return nil } +func (r *SysOps) removeFromRouteTable(netip.Prefix, Nexthop) error { + return nil +} + func EnableIPForwarding() error { log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 9258f4a4e..2b8a14ea2 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -30,7 +31,9 @@ var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) var ErrRoutingIsSeparate = errors.New("routing is separate") -func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { + stateManager.RegisterState(&ShutdownState{}) + initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) if err != nil && !errors.Is(err, vars.ErrRouteNotFound) { log.Errorf("Unable to get initial v4 default next hop: %v", err) @@ -53,9 +56,18 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn // These errors are not critical, but also we should not track and try to remove the routes either. return nexthop, refcounter.ErrIgnore } + + r.updateState(stateManager, prefix, nexthop) + return nexthop, err }, - r.removeFromRouteTable, + func(prefix netip.Prefix, nexthop Nexthop) error { + // remove from state even if we have trouble removing it from the route table + // it could be already gone + r.removeFromState(stateManager, prefix) + + return r.removeFromRouteTable(prefix, nexthop) + }, ) r.refCounter = refCounter @@ -63,7 +75,25 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn return r.setupHooks(initAddresses) } -func (r *SysOps) cleanupRefCounter() error { +func (r *SysOps) updateState(stateManager *statemanager.Manager, prefix netip.Prefix, nexthop Nexthop) { + state := getState(stateManager) + state.UpdateRoute(prefix, nexthop) + + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("failed to update state: %v", err) + } +} + +func (r *SysOps) removeFromState(stateManager *statemanager.Manager, prefix netip.Prefix) { + state := getState(stateManager) + state.RemoveRoute(prefix) + + if err := stateManager.UpdateState(state); err != nil { + log.Errorf("Failed to update state: %v", err) + } +} + +func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error { if r.refCounter == nil { return nil } @@ -76,6 +106,10 @@ func (r *SysOps) cleanupRefCounter() error { return fmt.Errorf("flush route manager: %w", err) } + if err := stateManager.DeleteState(&ShutdownState{}); err != nil { + log.Errorf("failed to delete state: %v", err) + } + return nil } @@ -506,3 +540,14 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P // Return true if the longest matching prefix is from vpnRoutes return isVpn, longestPrefix } + +func getState(stateManager *statemanager.Manager) *ShutdownState { + var shutdownState *ShutdownState + if state := stateManager.GetState(shutdownState); state != nil { + shutdownState = state.(*ShutdownState) + } else { + shutdownState = NewShutdownState() + } + + return shutdownState +} diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index ce5b6b843..5b7b13f97 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -77,10 +77,10 @@ func TestAddRemoveRoutes(t *testing.T) { r := NewSysOps(wgInterface, nil) - _, _, err = r.SetupRouting(nil) + _, _, err = r.SetupRouting(nil, nil) require.NoError(t, err) t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting()) + assert.NoError(t, r.CleanupRouting(nil)) }) index, err := net.InterfaceByName(wgInterface.Name()) @@ -403,10 +403,10 @@ func setupTestEnv(t *testing.T) { }) r := NewSysOps(wgInterface, nil) - _, _, err := r.SetupRouting(nil) + _, _, err := r.SetupRouting(nil, nil) require.NoError(t, err, "setupRouting should not return err") t.Cleanup(func() { - assert.NoError(t, r.CleanupRouting()) + assert.NoError(t, r.CleanupRouting(nil)) }) index, err := net.InterfaceByName(wgInterface.Name()) diff --git a/client/internal/routemanager/systemops/systemops_ios.go b/client/internal/routemanager/systemops/systemops_ios.go index 7cfb2b298..bf06f3739 100644 --- a/client/internal/routemanager/systemops/systemops_ios.go +++ b/client/internal/routemanager/systemops/systemops_ios.go @@ -9,17 +9,18 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting([]net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { r.mu.Lock() defer r.mu.Unlock() r.prefixes = make(map[netip.Prefix]struct{}) return nil, nil, nil } -func (r *SysOps) CleanupRouting() error { +func (r *SysOps) CleanupRouting(*statemanager.Manager) error { r.mu.Lock() defer r.mu.Unlock() @@ -46,6 +47,18 @@ func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, _ *net.Interface) error { return nil } +func (r *SysOps) notify() { + prefixes := make([]netip.Prefix, 0, len(r.prefixes)) + for prefix := range r.prefixes { + prefixes = append(prefixes, prefix) + } + r.notifier.OnNewPrefixes(prefixes) +} + +func (r *SysOps) removeFromRouteTable(netip.Prefix, Nexthop) error { + return nil +} + func EnableIPForwarding() error { log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil @@ -54,11 +67,3 @@ func EnableIPForwarding() error { func IsAddrRouted(netip.Addr, []netip.Prefix) (bool, netip.Prefix) { return false, netip.Prefix{} } - -func (r *SysOps) notify() { - prefixes := make([]netip.Prefix, 0, len(r.prefixes)) - for prefix := range r.prefixes { - prefixes = append(prefixes, prefix) - } - r.notifier.OnNewPrefixes(prefixes) -} diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 2d0c57826..0124fd95e 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -18,6 +18,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/routemanager/sysctl" "github.com/netbirdio/netbird/client/internal/routemanager/vars" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -85,10 +86,10 @@ func getSetupRules() []ruleParams { // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // This table is where a default route or other specific routes received from the management server are configured, // enabling VPN connectivity. -func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { if isLegacy() { log.Infof("Using legacy routing setup") - return r.setupRefCounter(initAddresses) + return r.setupRefCounter(initAddresses, stateManager) } if err = addRoutingTableName(); err != nil { @@ -104,7 +105,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb defer func() { if err != nil { - if cleanErr := r.CleanupRouting(); cleanErr != nil { + if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil { log.Errorf("Error cleaning up routing: %v", cleanErr) } } @@ -116,7 +117,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb if errors.Is(err, syscall.EOPNOTSUPP) { log.Warnf("Rule operations are not supported, falling back to the legacy routing setup") setIsLegacy(true) - return r.setupRefCounter(initAddresses) + return r.setupRefCounter(initAddresses, stateManager) } return nil, nil, fmt.Errorf("%s: %w", rule.description, err) } @@ -128,9 +129,9 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP) (_ nbnet.AddHookFunc, _ nb // CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. // It systematically removes the three rules and any associated routing table entries to ensure a clean state. // The function uses error aggregation to report any errors encountered during the cleanup process. -func (r *SysOps) CleanupRouting() error { +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { if isLegacy() { - return r.cleanupRefCounter() + return r.cleanupRefCounter(stateManager) } var result *multierror.Error diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index a2bbf35cf..0f8f2a341 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -13,15 +13,16 @@ import ( "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { - return r.setupRefCounter(initAddresses) +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { + return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting() error { - return r.cleanupRefCounter() +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { + return r.cleanupRefCounter(stateManager) } func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index 3f756788e..b1732a080 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -22,6 +22,7 @@ import ( "golang.org/x/sys/windows" "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/internal/statemanager" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -130,12 +131,12 @@ const ( RouteDeleted ) -func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { - return r.setupRefCounter(initAddresses) +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { + return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting() error { - return r.cleanupRefCounter() +func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { + return r.cleanupRefCounter(stateManager) } func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go new file mode 100644 index 000000000..a5a14f807 --- /dev/null +++ b/client/internal/statemanager/manager.go @@ -0,0 +1,298 @@ +package statemanager + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io/fs" + "os" + "reflect" + "sync" + "time" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + + nberrors "github.com/netbirdio/netbird/client/errors" +) + +// State interface defines the methods that all state types must implement +type State interface { + Name() string + Cleanup() error +} + +// Manager handles the persistence and management of various states +type Manager struct { + mu sync.Mutex + cancel context.CancelFunc + done chan struct{} + + filePath string + // holds the states that are registered with the manager and that are to be persisted + states map[string]State + // holds the state names that have been updated and need to be persisted with the next save + dirty map[string]struct{} + // holds the type information for each registered state + stateTypes map[string]reflect.Type +} + +// New creates a new Manager instance +func New(filePath string) *Manager { + return &Manager{ + filePath: filePath, + states: make(map[string]State), + dirty: make(map[string]struct{}), + stateTypes: make(map[string]reflect.Type), + } +} + +// Start starts the state manager periodic save routine +func (m *Manager) Start() { + if m == nil { + return + } + + m.mu.Lock() + defer m.mu.Unlock() + + var ctx context.Context + ctx, m.cancel = context.WithCancel(context.Background()) + m.done = make(chan struct{}) + + go m.periodicStateSave(ctx) +} + +func (m *Manager) Stop(ctx context.Context) error { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + if m.cancel != nil { + m.cancel() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-m.done: + return nil + } + } + + return nil +} + +// RegisterState registers a state with the manager but doesn't attempt to persist it. +// Pass an uninitialized state to register it. +func (m *Manager) RegisterState(state State) { + if m == nil { + return + } + + m.mu.Lock() + defer m.mu.Unlock() + + name := state.Name() + m.states[name] = nil + m.stateTypes[name] = reflect.TypeOf(state).Elem() +} + +// GetState returns the state for the given type +func (m *Manager) GetState(state State) State { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + return m.states[state.Name()] +} + +// UpdateState updates the state in the manager and marks it as dirty for the next save. +// The state will be replaced with the new one. +func (m *Manager) UpdateState(state State) error { + if m == nil { + return nil + } + + return m.setState(state.Name(), state) +} + +// DeleteState removes the state from the manager and marks it as dirty for the next save. +// Pass an uninitialized state to delete it. +func (m *Manager) DeleteState(state State) error { + if m == nil { + return nil + } + + return m.setState(state.Name(), nil) +} + +func (m *Manager) setState(name string, state State) error { + m.mu.Lock() + defer m.mu.Unlock() + + if _, exists := m.states[name]; !exists { + return fmt.Errorf("state %s not registered", name) + } + + m.states[name] = state + m.dirty[name] = struct{}{} + + return nil +} + +func (m *Manager) periodicStateSave(ctx context.Context) { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + defer close(m.done) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := m.PersistState(ctx); err != nil { + log.Errorf("failed to persist state: %v", err) + } + } + } +} + +// PersistState persists the states that have been updated since the last save. +func (m *Manager) PersistState(ctx context.Context) error { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + if len(m.dirty) == 0 { + return nil + } + + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + + done := make(chan error, 1) + + go func() { + data, err := json.MarshalIndent(m.states, "", " ") + if err != nil { + done <- fmt.Errorf("marshal states: %w", err) + return + } + + // nolint:gosec + if err := os.WriteFile(m.filePath, data, 0640); err != nil { + done <- fmt.Errorf("write state file: %w", err) + return + } + + done <- nil + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-done: + if err != nil { + return err + } + } + + log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty)) + + clear(m.dirty) + + return nil +} + +// loadState loads the existing state from the state file +func (m *Manager) loadState() error { + data, err := os.ReadFile(m.filePath) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + log.Debug("state file does not exist") + return nil + } + return fmt.Errorf("read state file: %w", err) + } + + var rawStates map[string]json.RawMessage + if err := json.Unmarshal(data, &rawStates); err != nil { + log.Warn("State file appears to be corrupted, attempting to delete it") + if err := os.Remove(m.filePath); err != nil { + log.Errorf("Failed to delete corrupted state file: %v", err) + } else { + log.Info("State file deleted") + } + return fmt.Errorf("unmarshal states: %w", err) + } + + var merr *multierror.Error + + for name, rawState := range rawStates { + stateType, ok := m.stateTypes[name] + if !ok { + merr = multierror.Append(merr, fmt.Errorf("unknown state type: %s", name)) + continue + } + + if string(rawState) == "null" { + continue + } + + statePtr := reflect.New(stateType).Interface().(State) + if err := json.Unmarshal(rawState, statePtr); err != nil { + merr = multierror.Append(merr, fmt.Errorf("unmarshal state %s: %w", name, err)) + continue + } + + m.states[name] = statePtr + log.Debugf("loaded state: %s", name) + } + + return nberrors.FormatErrorOrNil(merr) +} + +// PerformCleanup retrieves all states from the state file for the registered states and calls Cleanup on them. +// If the cleanup is successful, the state is marked for deletion. +func (m *Manager) PerformCleanup() error { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + if err := m.loadState(); err != nil { + log.Warnf("Failed to load state during cleanup: %v", err) + } + + var merr *multierror.Error + for name, state := range m.states { + if state == nil { + // If no state was found in the state file, we don't mark the state dirty nor return an error + continue + } + + log.Infof("client was not shut down properly, cleaning up %s", name) + if err := state.Cleanup(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err)) + } else { + // mark for deletion on cleanup success + m.states[name] = nil + m.dirty[name] = struct{}{} + } + } + + return nberrors.FormatErrorOrNil(merr) +} diff --git a/client/internal/statemanager/path.go b/client/internal/statemanager/path.go new file mode 100644 index 000000000..64c5316d8 --- /dev/null +++ b/client/internal/statemanager/path.go @@ -0,0 +1,35 @@ +package statemanager + +import ( + "os" + "path/filepath" + "runtime" + + "github.com/sirupsen/logrus" +) + +// GetDefaultStatePath returns the path to the state file based on the operating system +// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist. +func GetDefaultStatePath() string { + var path string + + switch runtime.GOOS { + case "windows": + path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json") + case "darwin", "linux": + path = "/var/lib/netbird/state.json" + case "freebsd", "openbsd", "netbsd", "dragonfly": + path = "/var/db/netbird/state.json" + // ios/android don't need state + default: + return "" + } + + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + logrus.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err) + return "" + } + + return path +} diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index dc13706bf..9d65bdbe0 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -138,12 +138,12 @@ func (c *Client) Stop() { c.ctxCancel() } -// ÏSetTraceLogLevel configure the logger to trace level +// SetTraceLogLevel configure the logger to trace level func (c *Client) SetTraceLogLevel() { log.SetLevel(log.TraceLevel) } -// getStatusDetails return with the list of the PeerInfos +// GetStatusDetails return with the list of the PeerInfos func (c *Client) GetStatusDetails() *StatusDetails { fullStatus := c.recorder.GetFullStatus() diff --git a/client/server/server.go b/client/server/server.go index 0a4c18131..342f61b88 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -11,6 +11,7 @@ import ( "time" "github.com/cenkalti/backoff/v4" + "github.com/hashicorp/go-multierror" "golang.org/x/exp/maps" "google.golang.org/protobuf/types/known/durationpb" @@ -20,7 +21,11 @@ import ( gstatus "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/auth" + "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/internal" @@ -39,6 +44,8 @@ const ( defaultMaxRetryInterval = 60 * time.Minute defaultMaxRetryTime = 14 * 24 * time.Hour defaultRetryMultiplier = 1.7 + + errRestoreResidualState = "failed to restore residual state: %v" ) // Server for service control. @@ -95,6 +102,10 @@ func (s *Server) Start() error { defer s.mutex.Unlock() state := internal.CtxGetState(s.rootCtx) + if err := restoreResidualState(s.rootCtx); err != nil { + log.Warnf(errRestoreResidualState, err) + } + // if current state contains any error, return it // in all other cases we can continue execution only if status is idle and up command was // not in the progress or already successfully established connection. @@ -292,6 +303,10 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.actCancel = cancel s.mutex.Unlock() + if err := restoreResidualState(ctx); err != nil { + log.Warnf(errRestoreResidualState, err) + } + state := internal.CtxGetState(ctx) defer func() { status, err := state.Status() @@ -549,6 +564,10 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes s.mutex.Lock() defer s.mutex.Unlock() + if err := restoreResidualState(callerCtx); err != nil { + log.Warnf(errRestoreResidualState, err) + } + state := internal.CtxGetState(s.rootCtx) // if current state contains any error, return it @@ -829,3 +848,31 @@ func sendTerminalNotification() error { return wallCmd.Wait() } + +// restoreResidulaConfig check if the client was not shut down in a clean way and restores residual if required. +// Otherwise, we might not be able to connect to the management server to retrieve new config. +func restoreResidualState(ctx context.Context) error { + path := statemanager.GetDefaultStatePath() + if path == "" { + return nil + } + + mgr := statemanager.New(path) + + var merr *multierror.Error + + // register the states we are interested in restoring + // this will also allow each subsystem to record its own state + mgr.RegisterState(&dns.ShutdownState{}) + mgr.RegisterState(&systemops.ShutdownState{}) + + if err := mgr.PerformCleanup(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err)) + } + + if err := mgr.PersistState(ctx); err != nil { + merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) + } + + return nberrors.FormatErrorOrNil(merr) +}