diff --git a/client/firewall/create.go b/client/firewall/create.go index 7b265e1d1..6f25bae46 100644 --- a/client/firewall/create.go +++ b/client/firewall/create.go @@ -15,7 +15,7 @@ import ( ) // NewFirewall creates a firewall manager instance -func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, _ statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { if !iface.IsUserspaceBind() { return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) } diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go index aa2f0d4d1..c9a2d6531 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/create_linux.go @@ -34,7 +34,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" // FWType is the type for the firewall type type FWType int -func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, stateManager statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { // on the linux system we try to user nftables or iptables // in any case, because we need to allow netbird interface traffic // so we use AllowNetbird traffic from these firewall managers @@ -51,7 +51,7 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger) } -func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) { +func createNativeFirewall(iface IFaceMapper, stateManager statemanager.Manager, routes bool) (firewall.Manager, error) { fm, err := createFW(iface) if err != nil { return nil, fmt.Errorf("create firewall: %s", err) diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index 183417327..6d4c87532 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -36,7 +36,7 @@ type aclManager struct { optionalEntries map[string][]entry ipsetStore *ipsetStore - stateManager *statemanager.Manager + stateManager statemanager.Manager } func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) { @@ -55,7 +55,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*acl return m, nil } -func (m *aclManager) init(stateManager *statemanager.Manager) error { +func (m *aclManager) init(stateManager statemanager.Manager) error { m.stateManager = stateManager m.seedInitialEntries() diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 652ab1b3e..c634db7dc 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -60,7 +60,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) { return m, nil } -func (m *Manager) Init(stateManager *statemanager.Manager) error { +func (m *Manager) Init(stateManager statemanager.Manager) error { state := &ShutdownState{ InterfaceState: &InterfaceState{ NameStr: m.wgIface.Name(), @@ -167,7 +167,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error { } // Reset firewall to the default state -func (m *Manager) Close(stateManager *statemanager.Manager) error { +func (m *Manager) Close(stateManager statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index eae9f7e25..56429a117 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -76,7 +76,7 @@ type router struct { wgIface iFaceMapper legacyManagement bool - stateManager *statemanager.Manager + stateManager statemanager.Manager ipFwdState *ipfwdstate.IPForwardingState } @@ -104,7 +104,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, return r, nil } -func (r *router) init(stateManager *statemanager.Manager) error { +func (r *router) init(stateManager statemanager.Manager) error { r.stateManager = stateManager if err := r.cleanUpDefaultForwardRules(); err != nil { diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 1d71051ef..210ae7bc6 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -55,7 +55,7 @@ const ( // It declares methods which handle actions required by the // Netbird client for ACL and routing functionality type Manager interface { - Init(stateManager *statemanager.Manager) error + Init(stateManager statemanager.Manager) error // AllowNetbird allows netbird interface traffic AllowNetbird() error @@ -103,7 +103,7 @@ type Manager interface { SetLegacyManagement(legacy bool) error // Close closes the firewall manager - Close(stateManager *statemanager.Manager) error + Close(stateManager statemanager.Manager) error // Flush the changes to firewall controller Flush() error diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index a5809471c..36e4aed2a 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -67,7 +67,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) { } // Init nftables firewall manager -func (m *Manager) Init(stateManager *statemanager.Manager) error { +func (m *Manager) Init(stateManager statemanager.Manager) error { workTable, err := m.createWorkTable() if err != nil { return fmt.Errorf("create work table: %w", err) @@ -243,7 +243,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error { } // Reset firewall to the default state -func (m *Manager) Close(stateManager *statemanager.Manager) error { +func (m *Manager) Close(stateManager statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index 5fe698aa9..b24b27bd2 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -13,7 +13,7 @@ import ( ) // Reset firewall to the default state -func (m *Manager) Close(stateManager *statemanager.Manager) error { +func (m *Manager) Close(stateManager statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index f63792fec..d601f3af8 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -23,7 +23,7 @@ const ( ) // Reset firewall to the default state -func (m *Manager) Close(*statemanager.Manager) error { +func (m *Manager) Close(statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 92da1b240..75ebd5e97 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -322,7 +322,7 @@ func (m *Manager) initForwarder() error { return nil } -func (m *Manager) Init(*statemanager.Manager) error { +func (m *Manager) Init(statemanager.Manager) error { return nil } diff --git a/client/internal/dns/file_repair_unix.go b/client/internal/dns/file_repair_unix.go index 9a9218fa1..6e77e7ac3 100644 --- a/client/internal/dns/file_repair_unix.go +++ b/client/internal/dns/file_repair_unix.go @@ -22,7 +22,7 @@ var ( } ) -type repairConfFn func([]string, string, *resolvConf, *statemanager.Manager) error +type repairConfFn func([]string, string, *resolvConf, statemanager.Manager) error type repair struct { operationFile string @@ -42,7 +42,7 @@ func newRepair(operationFile string, updateFn repairConfFn) *repair { } } -func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string, stateManager *statemanager.Manager) { +func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string, stateManager statemanager.Manager) { if f.inotify != nil { return } diff --git a/client/internal/dns/file_repair_unix_test.go b/client/internal/dns/file_repair_unix_test.go index e948557b6..dc1f2541f 100644 --- a/client/internal/dns/file_repair_unix_test.go +++ b/client/internal/dns/file_repair_unix_test.go @@ -105,7 +105,7 @@ nameserver 8.8.8.8`, var changed bool ctx, cancel := context.WithTimeout(context.Background(), time.Second) - updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error { + updateFn := func([]string, string, *resolvConf, statemanager.Manager) error { changed = true cancel() return nil @@ -152,7 +152,7 @@ searchdomain netbird.cloud something` var changed bool ctx, cancel := context.WithTimeout(context.Background(), time.Second) - updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error { + updateFn := func([]string, string, *resolvConf, statemanager.Manager) error { changed = true cancel() return nil diff --git a/client/internal/dns/file_unix.go b/client/internal/dns/file_unix.go index 1f4ddb67c..33f01d431 100644 --- a/client/internal/dns/file_unix.go +++ b/client/internal/dns/file_unix.go @@ -48,7 +48,7 @@ func (f *fileConfigurator) supportCustomPort() bool { return false } -func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { +func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager statemanager.Manager) error { backupFileExist := f.isBackupFileExist() if !config.RouteAll { if backupFileExist { @@ -86,7 +86,7 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st return nil } -func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf, stateManager *statemanager.Manager) error { +func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf, stateManager statemanager.Manager) error { searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains) nameServers := generateNsList(nbNameserverIP, cfg) diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index 25e9ff7e5..0a92fcb6b 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -17,7 +17,7 @@ const ( ) type hostManager interface { - applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error + applyDNSConfig(config HostDNSConfig, stateManager statemanager.Manager) error restoreHostDNS() error supportCustomPort() bool string() string @@ -43,14 +43,14 @@ type DomainConfig struct { } type mockHostConfigurator struct { - applyDNSConfigFunc func(config HostDNSConfig, stateManager *statemanager.Manager) error + applyDNSConfigFunc func(config HostDNSConfig, stateManager statemanager.Manager) error restoreHostDNSFunc func() error supportCustomPortFunc func() bool restoreUncleanShutdownDNSFunc func(*netip.Addr) error stringFunc func() string } -func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { +func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig, stateManager statemanager.Manager) error { if m.applyDNSConfigFunc != nil { return m.applyDNSConfigFunc(config, stateManager) } @@ -80,7 +80,7 @@ func (m *mockHostConfigurator) string() string { func newNoopHostMocker() hostManager { return &mockHostConfigurator{ - applyDNSConfigFunc: func(config HostDNSConfig, stateManager *statemanager.Manager) error { return nil }, + applyDNSConfigFunc: func(config HostDNSConfig, stateManager statemanager.Manager) error { return nil }, restoreHostDNSFunc: func() error { return nil }, supportCustomPortFunc: func() bool { return true }, restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil }, @@ -122,7 +122,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD type noopHostConfigurator struct{} -func (n noopHostConfigurator) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error { +func (n noopHostConfigurator) applyDNSConfig(HostDNSConfig, statemanager.Manager) error { return nil } diff --git a/client/internal/dns/host_android.go b/client/internal/dns/host_android.go index dfa3e5712..24cb57246 100644 --- a/client/internal/dns/host_android.go +++ b/client/internal/dns/host_android.go @@ -11,7 +11,7 @@ func newHostManager() (*androidHostManager, error) { return &androidHostManager{}, nil } -func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error { +func (a androidHostManager) applyDNSConfig(HostDNSConfig, statemanager.Manager) error { return nil } diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index f727f68b5..76acd9497 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -5,6 +5,7 @@ package dns import ( "bufio" "bytes" + "errors" "fmt" "io" "net" @@ -49,7 +50,7 @@ func (s *systemConfigurator) supportCustomPort() bool { return true } -func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { +func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager statemanager.Manager) error { var err error if err := stateManager.UpdateState(&ShutdownState{}); err != nil { @@ -200,8 +201,12 @@ func (s *systemConfigurator) recordSystemDNSSettings(force bool) error { func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { primaryServiceKey, _, err := s.getPrimaryService() if err != nil || primaryServiceKey == "" { + if err == nil { + err = errors.New("primary service key not found") + } return SystemDNSSettings{}, fmt.Errorf("couldn't find the primary service key: %w", err) } + dnsServiceKey := getKeyWithInput(primaryServiceStateKeyFormat, primaryServiceKey) line := buildCommandLine("show", dnsServiceKey, "") stdinCommands := wrapCommand(line) @@ -379,7 +384,7 @@ func buildWriteStateOperation(operation, state, commands string) string { return fmt.Sprintf("d.init\n%s %s\n%s\nset %s\n", operation, state, commands, state) } -func runSystemConfigCommand(command string) ([]byte, error) { +var runSystemConfigCommand = func(command string) ([]byte, error) { cmd := exec.Command(scutilPath) cmd.Stdin = strings.NewReader(command) out, err := cmd.Output() diff --git a/client/internal/dns/host_darwin_test.go b/client/internal/dns/host_darwin_test.go new file mode 100644 index 000000000..b3654f21f --- /dev/null +++ b/client/internal/dns/host_darwin_test.go @@ -0,0 +1,210 @@ +package dns + +import ( + "errors" + "os/exec" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "go.uber.org/mock/gomock" + + "github.com/netbirdio/netbird/client/internal/statemanager/mocks" +) + +// MockCommander to mock exec.Command +type MockCommander struct { + mock.Mock +} + +func (m *MockCommander) Command(name string, arg ...string) *exec.Cmd { + args := m.Called(name, arg) + return args.Get(0).(*exec.Cmd) +} + +func TestNewHostManager(t *testing.T) { + tests := []struct { + name string + wantErr bool + }{ + { + name: "successful creation", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := newHostManager() + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.NotNil(t, got) + assert.NotNil(t, got.createdKeys) + }) + } +} + +func TestApplyDNSConfig(t *testing.T) { + type mockSetup struct { + stateManagerError error + commandOutput []byte + commandError error + } + + tests := []struct { + name string + config HostDNSConfig + mockSetup mockSetup + wantErr bool + }{ + { + name: "successful apply with search domains", + config: HostDNSConfig{ + RouteAll: true, + Domains: []DomainConfig{ + {Domain: "example.com", MatchOnly: false}, + {Domain: "test.com", MatchOnly: true}, + }, + ServerIP: "1.1.1.1", + ServerPort: 53, + }, + mockSetup: mockSetup{ + stateManagerError: nil, + commandOutput: []byte(` + PrimaryService : ABC123 + Router : 192.168.1.1 + DomainName : example.com + SearchDomains : { + 0 : test.com + } + ServerAddresses : { + 0 : 1.1.1.1 + } + `), + commandError: nil, + }, + wantErr: false, + }, + { + name: "state manager error", + config: HostDNSConfig{ + ServerIP: "1.1.1.1", + }, + mockSetup: mockSetup{ + stateManagerError: errors.New("state error"), + }, + wantErr: false, // Function does not return an error, it only logs it. + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup mocks + s := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + } + + ctrl := gomock.NewController(t) + defer ctrl.Finish() // Ensures all expectations are met + + mockState := mocks.NewMockManager(ctrl) + mockCmd := new(MockCommander) + + // Mock UpdateState + mockState.EXPECT().UpdateState(gomock.Any()).Return(tt.mockSetup.stateManagerError).AnyTimes() + + // Mock all expected command executions + // mockCmd.On("Command", dscacheutilPath, "-flushcache").Return(&exec.Cmd{}).Once() + // mockCmd.On("Command", "killall", "-HUP", "mDNSResponder").Return(&exec.Cmd{}).Once() + // mockCmd.On("Command", scutilPath).Return(&exec.Cmd{}).Once() // For runSystemConfigCommand + + // Mock `runSystemConfigCommand` + originalRunCommand := runSystemConfigCommand + runSystemConfigCommand = func(command string) ([]byte, error) { + return tt.mockSetup.commandOutput, tt.mockSetup.commandError + } + defer func() { runSystemConfigCommand = originalRunCommand }() + + err := s.applyDNSConfig(tt.config, mockState) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + mockCmd.AssertExpectations(t) // Ensure Command() is called + }) + } +} + +func TestGetSystemDNSSettings(t *testing.T) { + tests := []struct { + name string + commandOutput []byte + commandError error + wantSettings SystemDNSSettings + wantErr bool + }{ + { + name: "successful retrieval", + commandOutput: []byte(` +PrimaryService : ABC123 +Router : 192.168.1.1 +--- +DomainName : example.com +SearchDomains : { + 0 : test.com +} +ServerAddresses : { + 0 : 1.1.1.1 +} +`), + wantSettings: SystemDNSSettings{ + Domains: []string{"example.com", "test.com"}, + ServerIP: "1.1.1.1", + ServerPort: 53, + }, + wantErr: false, + }, + { + name: "command error", + commandError: errors.New("command failed"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + } + + originalRunCommand := runSystemConfigCommand + runSystemConfigCommand = func(command string) ([]byte, error) { + return tt.commandOutput, tt.commandError + } + defer func() { runSystemConfigCommand = originalRunCommand }() + + got, err := s.getSystemDNSSettings() + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.wantSettings, got) + }) + } +} + +func TestSupportCustomPort(t *testing.T) { + s := &systemConfigurator{} + assert.True(t, s.supportCustomPort()) +} + +func TestString(t *testing.T) { + s := &systemConfigurator{} + assert.Equal(t, "scutil", s.string()) +} diff --git a/client/internal/dns/host_ios.go b/client/internal/dns/host_ios.go index 1c0ac63e9..b7a92c05f 100644 --- a/client/internal/dns/host_ios.go +++ b/client/internal/dns/host_ios.go @@ -20,7 +20,7 @@ func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) { }, nil } -func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error { +func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ statemanager.Manager) error { jsonData, err := json.Marshal(config) if err != nil { return fmt.Errorf("marshal: %w", err) diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index dceb24420..1adadd609 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -74,7 +74,7 @@ func (r *registryConfigurator) supportCustomPort() bool { return false } -func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { +func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager statemanager.Manager) error { if config.RouteAll { if err := r.addDNSSetupForAll(config.ServerIP); err != nil { return fmt.Errorf("add dns setup: %w", err) diff --git a/client/internal/dns/network_manager_unix.go b/client/internal/dns/network_manager_unix.go index 10b4e6a6e..087994489 100644 --- a/client/internal/dns/network_manager_unix.go +++ b/client/internal/dns/network_manager_unix.go @@ -103,7 +103,7 @@ func (n *networkManagerDbusConfigurator) supportCustomPort() bool { return false } -func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { +func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager statemanager.Manager) error { connSettings, configVersion, err := n.getAppliedConnectionSettings() if err != nil { return fmt.Errorf("retrieving the applied connection settings, error: %w", err) diff --git a/client/internal/dns/resolvconf_unix.go b/client/internal/dns/resolvconf_unix.go index 54c4c75bf..cea2708cd 100644 --- a/client/internal/dns/resolvconf_unix.go +++ b/client/internal/dns/resolvconf_unix.go @@ -84,7 +84,7 @@ func (r *resolvconf) supportCustomPort() bool { return false } -func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { +func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager statemanager.Manager) error { var err error if !config.RouteAll { err = r.restoreHostDNS() diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index bc87012f2..fa688d254 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -75,7 +75,7 @@ type DefaultServer struct { iosDnsManager IosDnsManager statusRecorder *peer.Status - stateManager *statemanager.Manager + stateManager statemanager.Manager } type handlerWithStop interface { @@ -99,7 +99,7 @@ func NewDefaultServer( wgInterface WGIface, customAddress string, statusRecorder *peer.Status, - stateManager *statemanager.Manager, + stateManager statemanager.Manager, disableSys bool, ) (*DefaultServer, error) { var addrPort *netip.AddrPort @@ -161,7 +161,7 @@ func newDefaultServer( wgInterface WGIface, dnsService service, statusRecorder *peer.Status, - stateManager *statemanager.Manager, + stateManager statemanager.Manager, disableSys bool, ) *DefaultServer { ctx, stop := context.WithCancel(ctx) diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index c7eeb7870..677296d53 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -647,7 +647,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { } var domainsUpdate string - hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error { + hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager statemanager.Manager) error { domains := []string{} for _, item := range config.Domains { if item.Disabled { diff --git a/client/internal/dns/systemd_linux.go b/client/internal/dns/systemd_linux.go index a87cc73e5..de8ef9c66 100644 --- a/client/internal/dns/systemd_linux.go +++ b/client/internal/dns/systemd_linux.go @@ -87,7 +87,7 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool { return true } -func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { +func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager statemanager.Manager) error { parsedIP, err := netip.ParseAddr(config.ServerIP) if err != nil { return fmt.Errorf("unable to parse ip address, error: %w", err) diff --git a/client/internal/dns/unclean_shutdown_unix.go b/client/internal/dns/unclean_shutdown_unix.go index fcf60c694..0bafec1c4 100644 --- a/client/internal/dns/unclean_shutdown_unix.go +++ b/client/internal/dns/unclean_shutdown_unix.go @@ -35,7 +35,7 @@ func (s *ShutdownState) Cleanup() error { } // TODO: move file contents to state manager -func createUncleanShutdownIndicator(sourcePath string, dnsAddressStr string, stateManager *statemanager.Manager) error { +func createUncleanShutdownIndicator(sourcePath string, dnsAddressStr string, stateManager statemanager.Manager) error { dnsAddress, err := netip.ParseAddr(dnsAddressStr) if err != nil { return fmt.Errorf("parse dns address %s: %w", dnsAddressStr, err) diff --git a/client/internal/engine.go b/client/internal/engine.go index babea2131..39d1c0692 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -184,7 +184,7 @@ type Engine struct { checks []*mgmProto.Checks relayManager *relayClient.Manager - stateManager *statemanager.Manager + stateManager statemanager.Manager srWatcher *guard.SRWatcher // Network map persistence diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index ae0d1d220..11fcaa447 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -45,7 +45,7 @@ type Manager interface { SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string EnableServerRouter(firewall firewall.Manager) error - Stop(stateManager *statemanager.Manager) + Stop(stateManager statemanager.Manager) } type ManagerConfig struct { @@ -56,7 +56,7 @@ type ManagerConfig struct { StatusRecorder *peer.Status RelayManager *relayClient.Manager InitialRoutes []*route.Route - StateManager *statemanager.Manager + StateManager statemanager.Manager DNSServer dns.Server PeerStore *peerstore.Store DisableClientRoutes bool @@ -80,7 +80,7 @@ type DefaultManager struct { routeRefCounter *refcounter.RouteRefCounter allowedIPsRefCounter *refcounter.AllowedIPsRefCounter dnsRouteInterval time.Duration - stateManager *statemanager.Manager + stateManager statemanager.Manager // clientRoutes is the most recent list of clientRoutes received from the Management Service clientRoutes route.HAMap dnsServer dns.Server @@ -234,7 +234,7 @@ func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { } // Stop stops the manager watchers and clean firewall rules -func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { +func (m *DefaultManager) Stop(stateManager statemanager.Manager) { m.stop() if m.serverRouter != nil { m.serverRouter.cleanUp() diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 64fdffceb..ab821b7be 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -19,7 +19,7 @@ type MockManager struct { GetRouteSelectorFunc func() *routeselector.RouteSelector GetClientRoutesFunc func() route.HAMap GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route - StopFunc func(manager *statemanager.Manager) + StopFunc func(manager statemanager.Manager) } func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { @@ -83,7 +83,7 @@ func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error { } // Stop mock implementation of Stop from Manager interface -func (m *MockManager) Stop(stateManager *statemanager.Manager) { +func (m *MockManager) Stop(stateManager statemanager.Manager) { if m.StopFunc != nil { m.StopFunc(stateManager) } diff --git a/client/internal/routemanager/systemops/systemops_android.go b/client/internal/routemanager/systemops/systemops_android.go index ca8aea3fb..648730f1c 100644 --- a/client/internal/routemanager/systemops/systemops_android.go +++ b/client/internal/routemanager/systemops/systemops_android.go @@ -13,11 +13,11 @@ import ( nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) SetupRouting([]net.IP, statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { return nil, nil, nil } -func (r *SysOps) CleanupRouting(*statemanager.Manager) error { +func (r *SysOps) CleanupRouting(statemanager.Manager) error { return nil } diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index eaef01815..f62c70036 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -32,7 +32,7 @@ var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) var ErrRoutingIsSeparate = errors.New("routing is separate") -func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { stateManager.RegisterState(&ShutdownState{}) initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) @@ -80,13 +80,13 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana } // updateState updates state on every change so it will be persisted regularly -func (r *SysOps) updateState(stateManager *statemanager.Manager) { +func (r *SysOps) updateState(stateManager statemanager.Manager) { if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil { log.Errorf("failed to update state: %v", err) } } -func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error { +func (r *SysOps) cleanupRefCounter(stateManager statemanager.Manager) error { if r.refCounter == nil { return nil } @@ -337,7 +337,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) return r.removeFromRouteTable(prefix, nextHop) } -func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { prefix, err := util.GetPrefixFromIP(ip) if err != nil { diff --git a/client/internal/routemanager/systemops/systemops_ios.go b/client/internal/routemanager/systemops/systemops_ios.go index bf06f3739..4b7c7c19a 100644 --- a/client/internal/routemanager/systemops/systemops_ios.go +++ b/client/internal/routemanager/systemops/systemops_ios.go @@ -13,14 +13,14 @@ import ( nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) SetupRouting([]net.IP, statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { r.mu.Lock() defer r.mu.Unlock() r.prefixes = make(map[netip.Prefix]struct{}) return nil, nil, nil } -func (r *SysOps) CleanupRouting(*statemanager.Manager) error { +func (r *SysOps) CleanupRouting(statemanager.Manager) error { r.mu.Lock() defer r.mu.Unlock() diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index d724cb1a7..1feae3d7f 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -72,7 +72,7 @@ func getSetupRules() []ruleParams { // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // This table is where a default route or other specific routes received from the management server are configured, // enabling VPN connectivity. -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { if !nbnet.AdvancedRouting() { log.Infof("Using legacy routing setup") return r.setupRefCounter(initAddresses, stateManager) @@ -110,7 +110,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager // CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. // It systematically removes the three rules and any associated routing table entries to ensure a clean state. // The function uses error aggregation to report any errors encountered during the cleanup process. -func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { +func (r *SysOps) CleanupRouting(stateManager statemanager.Manager) error { if !nbnet.AdvancedRouting() { return r.cleanupRefCounter(stateManager) } diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index 0f8f2a341..d356496fc 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -17,11 +17,11 @@ import ( nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { +func (r *SysOps) CleanupRouting(stateManager statemanager.Manager) error { return r.cleanupRefCounter(stateManager) } diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index ad325e123..2ef7419e0 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -131,11 +131,11 @@ const ( RouteDeleted ) -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { return r.setupRefCounter(initAddresses, stateManager) } -func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error { +func (r *SysOps) CleanupRouting(stateManager statemanager.Manager) error { return r.cleanupRefCounter(stateManager) } diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 29f962ad2..69fd33a05 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -49,8 +49,25 @@ func (r *RawState) MarshalJSON() ([]byte, error) { return r.data, nil } +// Manager is the interface that exposes the persistence and state management methods. +type Manager interface { + Start() + Stop(ctx context.Context) error + RegisterState(state State) + GetState(state State) State + UpdateState(state State) error + DeleteState(state State) error + DeleteStateByName(stateName string) error + DeleteAllStates() (int, error) + PersistState(ctx context.Context) error + LoadState(state State) error + CleanupStateByName(name string) error + PerformCleanup() error + GetSavedStateNames() ([]string, error) +} + // Manager handles the persistence and management of various states -type Manager struct { +type managerImpl struct { mu sync.Mutex cancel context.CancelFunc done chan struct{} @@ -65,8 +82,8 @@ type Manager struct { } // New creates a new Manager instance -func New(filePath string) *Manager { - return &Manager{ +func New(filePath string) Manager { + return &managerImpl{ filePath: filePath, states: make(map[string]State), dirty: make(map[string]struct{}), @@ -75,7 +92,7 @@ func New(filePath string) *Manager { } // Start starts the state manager periodic save routine -func (m *Manager) Start() { +func (m *managerImpl) Start() { if m == nil { return } @@ -90,7 +107,7 @@ func (m *Manager) Start() { go m.periodicStateSave(ctx) } -func (m *Manager) Stop(ctx context.Context) error { +func (m *managerImpl) Stop(ctx context.Context) error { if m == nil { return nil } @@ -114,7 +131,7 @@ func (m *Manager) Stop(ctx context.Context) error { // RegisterState registers a state with the manager but doesn't attempt to persist it. // Pass an uninitialized state to register it. -func (m *Manager) RegisterState(state State) { +func (m *managerImpl) RegisterState(state State) { if m == nil { return } @@ -128,7 +145,7 @@ func (m *Manager) RegisterState(state State) { } // GetState returns the state for the given type -func (m *Manager) GetState(state State) State { +func (m *managerImpl) GetState(state State) State { if m == nil { return nil } @@ -141,7 +158,7 @@ func (m *Manager) GetState(state State) State { // UpdateState updates the state in the manager and marks it as dirty for the next save. // The state will be replaced with the new one. -func (m *Manager) UpdateState(state State) error { +func (m *managerImpl) UpdateState(state State) error { if m == nil { return nil } @@ -151,7 +168,7 @@ func (m *Manager) UpdateState(state State) error { // DeleteState removes the state from the manager and marks it as dirty for the next save. // Pass an uninitialized state to delete it. -func (m *Manager) DeleteState(state State) error { +func (m *managerImpl) DeleteState(state State) error { if m == nil { return nil } @@ -159,7 +176,7 @@ func (m *Manager) DeleteState(state State) error { return m.setState(state.Name(), nil) } -func (m *Manager) setState(name string, state State) error { +func (m *managerImpl) setState(name string, state State) error { m.mu.Lock() defer m.mu.Unlock() @@ -175,7 +192,7 @@ func (m *Manager) setState(name string, state State) error { // DeleteStateByName handles deletion of states without cleanup. // It doesn't require the state to be registered. -func (m *Manager) DeleteStateByName(stateName string) error { +func (m *managerImpl) DeleteStateByName(stateName string) error { if m == nil { return nil } @@ -203,7 +220,7 @@ func (m *Manager) DeleteStateByName(stateName string) error { } // DeleteAllStates removes all states. -func (m *Manager) DeleteAllStates() (int, error) { +func (m *managerImpl) DeleteAllStates() (int, error) { if m == nil { return 0, nil } @@ -230,7 +247,7 @@ func (m *Manager) DeleteAllStates() (int, error) { return count, nil } -func (m *Manager) periodicStateSave(ctx context.Context) { +func (m *managerImpl) periodicStateSave(ctx context.Context) { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() defer close(m.done) @@ -248,7 +265,7 @@ func (m *Manager) periodicStateSave(ctx context.Context) { } // PersistState persists the states that have been updated since the last save. -func (m *Manager) PersistState(ctx context.Context) error { +func (m *managerImpl) PersistState(ctx context.Context) error { if m == nil { return nil } @@ -291,7 +308,7 @@ func (m *Manager) PersistState(ctx context.Context) error { } // loadStateFile reads and unmarshals the state file into a map of raw JSON messages -func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage, error) { +func (m *managerImpl) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage, error) { data, err := os.ReadFile(m.filePath) if err != nil { if errors.Is(err, fs.ErrNotExist) { @@ -311,7 +328,7 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage, } // handleCorruptedState creates a backup of a corrupted state file by moving it -func (m *Manager) handleCorruptedState(deleteCorrupt bool) { +func (m *managerImpl) handleCorruptedState(deleteCorrupt bool) { if !deleteCorrupt { return } @@ -327,7 +344,7 @@ func (m *Manager) handleCorruptedState(deleteCorrupt bool) { } // loadSingleRawState unmarshals a raw state into a concrete state object -func (m *Manager) loadSingleRawState(name string, rawState json.RawMessage) (State, error) { +func (m *managerImpl) loadSingleRawState(name string, rawState json.RawMessage) (State, error) { stateType, ok := m.stateTypes[name] if !ok { return nil, fmt.Errorf(errStateNotRegistered, name) @@ -346,7 +363,7 @@ func (m *Manager) loadSingleRawState(name string, rawState json.RawMessage) (Sta } // LoadState loads a specific state from the state file -func (m *Manager) LoadState(state State) error { +func (m *managerImpl) LoadState(state State) error { if m == nil { return nil } @@ -383,7 +400,7 @@ func (m *Manager) LoadState(state State) error { // cleanupSingleState handles the cleanup of a specific state and returns any error. // The caller must hold the mutex. -func (m *Manager) cleanupSingleState(name string, rawState json.RawMessage) error { +func (m *managerImpl) cleanupSingleState(name string, rawState json.RawMessage) error { // For unregistered states, preserve the raw JSON if _, registered := m.stateTypes[name]; !registered { m.states[name] = &RawState{data: rawState} @@ -424,7 +441,7 @@ func (m *Manager) cleanupSingleState(name string, rawState json.RawMessage) erro // CleanupStateByName loads and cleans up a specific state by name if it implements CleanableState. // Returns an error if the state doesn't exist, isn't registered, or cleanup fails. -func (m *Manager) CleanupStateByName(name string) error { +func (m *managerImpl) CleanupStateByName(name string) error { if m == nil { return nil } @@ -461,7 +478,7 @@ func (m *Manager) CleanupStateByName(name string) error { // PerformCleanup retrieves all states from the state file and calls Cleanup on registered states that support it. // Unregistered states are preserved in their original state. -func (m *Manager) PerformCleanup() error { +func (m *managerImpl) PerformCleanup() error { if m == nil { return nil } @@ -491,7 +508,7 @@ func (m *Manager) PerformCleanup() error { } // GetSavedStateNames returns all state names that are currently saved in the state file. -func (m *Manager) GetSavedStateNames() ([]string, error) { +func (m *managerImpl) GetSavedStateNames() ([]string, error) { if m == nil { return nil, nil } diff --git a/client/internal/statemanager/mocks/manager.go b/client/internal/statemanager/mocks/manager.go new file mode 100644 index 000000000..226c79a0c --- /dev/null +++ b/client/internal/statemanager/mocks/manager.go @@ -0,0 +1,312 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: client/internal/statemanager/manager.go +// +// Generated by this command: +// +// mockgen -source client/internal/statemanager/manager.go -destination lient/internal/statemanager/manager_mock.go Manager +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + statemanager "github.com/netbirdio/netbird/client/internal/statemanager" + gomock "go.uber.org/mock/gomock" +) + +// MockState is a mock of State interface. +type MockState struct { + ctrl *gomock.Controller + recorder *MockStateMockRecorder + isgomock struct{} +} + +// MockStateMockRecorder is the mock recorder for MockState. +type MockStateMockRecorder struct { + mock *MockState +} + +// NewMockState creates a new mock instance. +func NewMockState(ctrl *gomock.Controller) *MockState { + mock := &MockState{ctrl: ctrl} + mock.recorder = &MockStateMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockState) EXPECT() *MockStateMockRecorder { + return m.recorder +} + +// Name mocks base method. +func (m *MockState) Name() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Name") + ret0, _ := ret[0].(string) + return ret0 +} + +// Name indicates an expected call of Name. +func (mr *MockStateMockRecorder) Name() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockState)(nil).Name)) +} + +// MockCleanableState is a mock of CleanableState interface. +type MockCleanableState struct { + ctrl *gomock.Controller + recorder *MockCleanableStateMockRecorder + isgomock struct{} +} + +// MockCleanableStateMockRecorder is the mock recorder for MockCleanableState. +type MockCleanableStateMockRecorder struct { + mock *MockCleanableState +} + +// NewMockCleanableState creates a new mock instance. +func NewMockCleanableState(ctrl *gomock.Controller) *MockCleanableState { + mock := &MockCleanableState{ctrl: ctrl} + mock.recorder = &MockCleanableStateMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockCleanableState) EXPECT() *MockCleanableStateMockRecorder { + return m.recorder +} + +// Cleanup mocks base method. +func (m *MockCleanableState) Cleanup() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Cleanup") + ret0, _ := ret[0].(error) + return ret0 +} + +// Cleanup indicates an expected call of Cleanup. +func (mr *MockCleanableStateMockRecorder) Cleanup() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cleanup", reflect.TypeOf((*MockCleanableState)(nil).Cleanup)) +} + +// Name mocks base method. +func (m *MockCleanableState) Name() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Name") + ret0, _ := ret[0].(string) + return ret0 +} + +// Name indicates an expected call of Name. +func (mr *MockCleanableStateMockRecorder) Name() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockCleanableState)(nil).Name)) +} + +// MockManager is a mock of Manager interface. +type MockManager struct { + ctrl *gomock.Controller + recorder *MockManagerMockRecorder + isgomock struct{} +} + +// MockManagerMockRecorder is the mock recorder for MockManager. +type MockManagerMockRecorder struct { + mock *MockManager +} + +// NewMockManager creates a new mock instance. +func NewMockManager(ctrl *gomock.Controller) *MockManager { + mock := &MockManager{ctrl: ctrl} + mock.recorder = &MockManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockManager) EXPECT() *MockManagerMockRecorder { + return m.recorder +} + +// CleanupStateByName mocks base method. +func (m *MockManager) CleanupStateByName(name string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CleanupStateByName", name) + ret0, _ := ret[0].(error) + return ret0 +} + +// CleanupStateByName indicates an expected call of CleanupStateByName. +func (mr *MockManagerMockRecorder) CleanupStateByName(name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStateByName", reflect.TypeOf((*MockManager)(nil).CleanupStateByName), name) +} + +// DeleteAllStates mocks base method. +func (m *MockManager) DeleteAllStates() (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteAllStates") + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteAllStates indicates an expected call of DeleteAllStates. +func (mr *MockManagerMockRecorder) DeleteAllStates() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllStates", reflect.TypeOf((*MockManager)(nil).DeleteAllStates)) +} + +// DeleteState mocks base method. +func (m *MockManager) DeleteState(state statemanager.State) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteState", state) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteState indicates an expected call of DeleteState. +func (mr *MockManagerMockRecorder) DeleteState(state any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteState", reflect.TypeOf((*MockManager)(nil).DeleteState), state) +} + +// DeleteStateByName mocks base method. +func (m *MockManager) DeleteStateByName(stateName string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteStateByName", stateName) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteStateByName indicates an expected call of DeleteStateByName. +func (mr *MockManagerMockRecorder) DeleteStateByName(stateName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteStateByName", reflect.TypeOf((*MockManager)(nil).DeleteStateByName), stateName) +} + +// GetSavedStateNames mocks base method. +func (m *MockManager) GetSavedStateNames() ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSavedStateNames") + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSavedStateNames indicates an expected call of GetSavedStateNames. +func (mr *MockManagerMockRecorder) GetSavedStateNames() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSavedStateNames", reflect.TypeOf((*MockManager)(nil).GetSavedStateNames)) +} + +// GetState mocks base method. +func (m *MockManager) GetState(state statemanager.State) statemanager.State { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetState", state) + ret0, _ := ret[0].(statemanager.State) + return ret0 +} + +// GetState indicates an expected call of GetState. +func (mr *MockManagerMockRecorder) GetState(state any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetState", reflect.TypeOf((*MockManager)(nil).GetState), state) +} + +// LoadState mocks base method. +func (m *MockManager) LoadState(state statemanager.State) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LoadState", state) + ret0, _ := ret[0].(error) + return ret0 +} + +// LoadState indicates an expected call of LoadState. +func (mr *MockManagerMockRecorder) LoadState(state any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoadState", reflect.TypeOf((*MockManager)(nil).LoadState), state) +} + +// PerformCleanup mocks base method. +func (m *MockManager) PerformCleanup() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PerformCleanup") + ret0, _ := ret[0].(error) + return ret0 +} + +// PerformCleanup indicates an expected call of PerformCleanup. +func (mr *MockManagerMockRecorder) PerformCleanup() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PerformCleanup", reflect.TypeOf((*MockManager)(nil).PerformCleanup)) +} + +// PersistState mocks base method. +func (m *MockManager) PersistState(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PersistState", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// PersistState indicates an expected call of PersistState. +func (mr *MockManagerMockRecorder) PersistState(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PersistState", reflect.TypeOf((*MockManager)(nil).PersistState), ctx) +} + +// RegisterState mocks base method. +func (m *MockManager) RegisterState(state statemanager.State) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RegisterState", state) +} + +// RegisterState indicates an expected call of RegisterState. +func (mr *MockManagerMockRecorder) RegisterState(state any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterState", reflect.TypeOf((*MockManager)(nil).RegisterState), state) +} + +// Start mocks base method. +func (m *MockManager) Start() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Start") +} + +// Start indicates an expected call of Start. +func (mr *MockManagerMockRecorder) Start() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockManager)(nil).Start)) +} + +// Stop mocks base method. +func (m *MockManager) Stop(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stop", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Stop indicates an expected call of Stop. +func (mr *MockManagerMockRecorder) Stop(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockManager)(nil).Stop), ctx) +} + +// UpdateState mocks base method. +func (m *MockManager) UpdateState(state statemanager.State) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateState", state) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateState indicates an expected call of UpdateState. +func (mr *MockManagerMockRecorder) UpdateState(state any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateState", reflect.TypeOf((*MockManager)(nil).UpdateState), state) +} diff --git a/client/server/state_generic.go b/client/server/state_generic.go index e6c7bdd44..871913cf9 100644 --- a/client/server/state_generic.go +++ b/client/server/state_generic.go @@ -8,7 +8,7 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -func registerStates(mgr *statemanager.Manager) { +func registerStates(mgr statemanager.Manager) { mgr.RegisterState(&dns.ShutdownState{}) mgr.RegisterState(&systemops.ShutdownState{}) } diff --git a/client/server/state_linux.go b/client/server/state_linux.go index 087628907..bca38f90d 100644 --- a/client/server/state_linux.go +++ b/client/server/state_linux.go @@ -10,7 +10,7 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -func registerStates(mgr *statemanager.Manager) { +func registerStates(mgr statemanager.Manager) { mgr.RegisterState(&dns.ShutdownState{}) mgr.RegisterState(&systemops.ShutdownState{}) mgr.RegisterState(&nftables.ShutdownState{})