diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index 6ca941626..9712ae42f 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -186,19 +186,9 @@ func parseCustomSSHFlags(args []string) ([]string, []string, []string) { arg := args[i] switch { case strings.HasPrefix(arg, "-L"): - if arg == "-L" && i+1 < len(args) { - localForwardFlags = append(localForwardFlags, args[i+1]) - i++ - } else if len(arg) > 2 { - localForwardFlags = append(localForwardFlags, arg[2:]) - } + localForwardFlags, i = parseForwardFlag(arg, args, i, localForwardFlags) case strings.HasPrefix(arg, "-R"): - if arg == "-R" && i+1 < len(args) { - remoteForwardFlags = append(remoteForwardFlags, args[i+1]) - i++ - } else if len(arg) > 2 { - remoteForwardFlags = append(remoteForwardFlags, arg[2:]) - } + remoteForwardFlags, i = parseForwardFlag(arg, args, i, remoteForwardFlags) default: filteredArgs = append(filteredArgs, arg) } @@ -207,6 +197,18 @@ func parseCustomSSHFlags(args []string) ([]string, []string, []string) { return filteredArgs, localForwardFlags, remoteForwardFlags } +func parseForwardFlag(arg string, args []string, i int, flags []string) ([]string, int) { + if arg == "-L" || arg == "-R" { + if i+1 < len(args) { + flags = append(flags, args[i+1]) + i++ + } + } else if len(arg) > 2 { + flags = append(flags, arg[2:]) + } + return flags, i +} + // extractGlobalFlags parses global flags that were passed before 'ssh' command func extractGlobalFlags(args []string) { sshPos := findSSHCommandPosition(args) diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go index 1dc5c72e1..defa16247 100644 --- a/client/ssh/client/client.go +++ b/client/ssh/client/client.go @@ -370,7 +370,23 @@ func createHostKeyCallbackWithDaemonAddr(addr, daemonAddr string) (ssh.HostKeyCa // verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error { - // Connect to NetBird daemon using the same logic as CLI + client, err := connectToDaemon(daemonAddr) + if err != nil { + return err + } + defer func() { + if err := client.Close(); err != nil { + log.Debugf("daemon connection close error: %v", err) + } + }() + + addresses := buildAddressList(hostname, remote) + log.Debugf("verifying SSH host key for hostname=%s, remote=%s, addresses=%v", hostname, remote.String(), addresses) + + return verifyKeyWithDaemon(client, addresses, key) +} + +func connectToDaemon(daemonAddr string) (*grpc.ClientConn, error) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -382,64 +398,70 @@ func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, ) if err != nil { log.Debugf("failed to connect to NetBird daemon at %s: %v", daemonAddr, err) - return fmt.Errorf("failed to connect to NetBird daemon: %w", err) + return nil, fmt.Errorf("failed to connect to NetBird daemon: %w", err) } - defer func() { - if err := conn.Close(); err != nil { - log.Debugf("daemon connection close error: %v", err) - } - }() + return conn, nil +} - client := proto.NewDaemonServiceClient(conn) - - // Try both hostname and IP address from remote.String() +func buildAddressList(hostname string, remote net.Addr) []string { addresses := []string{hostname} if host, _, err := net.SplitHostPort(remote.String()); err == nil { if host != hostname { addresses = append(addresses, host) } } + return addresses +} - log.Debugf("verifying SSH host key for hostname=%s, remote=%s, addresses=%v", hostname, remote.String(), addresses) +func verifyKeyWithDaemon(conn *grpc.ClientConn, addresses []string, key ssh.PublicKey) error { + client := proto.NewDaemonServiceClient(conn) for _, addr := range addresses { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - response, err := client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{ - PeerAddress: addr, - }) - cancel() - - log.Debugf("daemon query for address %s: found=%v, error=%v", addr, response != nil && response.GetFound(), err) - - if err != nil { - log.Debugf("daemon query error for %s: %v", addr, err) - continue - } - - if !response.GetFound() { - log.Debugf("SSH host key not found in daemon for address: %s", addr) - continue - } - - // Parse the stored SSH host key - storedKey, _, _, _, err := ssh.ParseAuthorizedKey(response.GetSshHostKey()) - if err != nil { - log.Debugf("failed to parse stored SSH host key for %s: %v", addr, err) - continue - } - - // Compare the keys - if key.Type() == storedKey.Type() && string(key.Marshal()) == string(storedKey.Marshal()) { - log.Debugf("SSH host key verified via NetBird daemon for %s", addr) + if err := checkAddressKey(client, addr, key); err == nil { return nil - } else { - log.Debugf("SSH host key mismatch for %s: stored type=%s, presented type=%s", addr, storedKey.Type(), key.Type()) } } - return fmt.Errorf("SSH host key not found or does not match in NetBird daemon") } +func checkAddressKey(client proto.DaemonServiceClient, addr string, key ssh.PublicKey) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + response, err := client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{ + PeerAddress: addr, + }) + log.Debugf("daemon query for address %s: found=%v, error=%v", addr, response != nil && response.GetFound(), err) + + if err != nil { + log.Debugf("daemon query error for %s: %v", addr, err) + return err + } + + if !response.GetFound() { + log.Debugf("SSH host key not found in daemon for address: %s", addr) + return fmt.Errorf("key not found") + } + + return compareKeys(response.GetSshHostKey(), key, addr) +} + +func compareKeys(storedKeyData []byte, presentedKey ssh.PublicKey, addr string) error { + storedKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData) + if err != nil { + log.Debugf("failed to parse stored SSH host key for %s: %v", addr, err) + return err + } + + if presentedKey.Type() == storedKey.Type() && string(presentedKey.Marshal()) == string(storedKey.Marshal()) { + log.Debugf("SSH host key verified via NetBird daemon for %s", addr) + return nil + } + + log.Debugf("SSH host key mismatch for %s: stored type=%s, presented type=%s", addr, storedKey.Type(), presentedKey.Type()) + return fmt.Errorf("key mismatch") +} + // getKnownHostsFiles returns paths to known_hosts files in order of preference func getKnownHostsFiles() []string { var files []string @@ -469,41 +491,49 @@ func getKnownHostsFiles() []string { // createHostKeyCallbackWithOptions creates a host key verification callback with custom options func createHostKeyCallbackWithOptions(addr string, opts DialOptions) (ssh.HostKeyCallback, error) { return func(hostname string, remote net.Addr, key ssh.PublicKey) error { - // First try to get host key from NetBird daemon (if daemon address provided) - if opts.DaemonAddr != "" { - if err := verifyHostKeyViaDaemon(hostname, remote, key, opts.DaemonAddr); err == nil { - return nil - } + if err := tryDaemonVerification(hostname, remote, key, opts.DaemonAddr); err == nil { + return nil } - - // Fallback to known_hosts files - var knownHostsFiles []string - - if opts.KnownHostsFile != "" { - knownHostsFiles = append(knownHostsFiles, opts.KnownHostsFile) - } else { - knownHostsFiles = getKnownHostsFiles() - } - - var hostKeyCallbacks []ssh.HostKeyCallback - - for _, file := range knownHostsFiles { - if callback, err := knownhosts.New(file); err == nil { - hostKeyCallbacks = append(hostKeyCallbacks, callback) - } - } - - // Try each known_hosts callback - for _, callback := range hostKeyCallbacks { - if err := callback(hostname, remote, key); err == nil { - return nil - } - } - - return fmt.Errorf("host key verification failed: key not found in NetBird daemon or any known_hosts file") + return tryKnownHostsVerification(hostname, remote, key, opts.KnownHostsFile) }, nil } +func tryDaemonVerification(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error { + if daemonAddr == "" { + return fmt.Errorf("no daemon address") + } + return verifyHostKeyViaDaemon(hostname, remote, key, daemonAddr) +} + +func tryKnownHostsVerification(hostname string, remote net.Addr, key ssh.PublicKey, knownHostsFile string) error { + knownHostsFiles := getKnownHostsFilesList(knownHostsFile) + hostKeyCallbacks := buildHostKeyCallbacks(knownHostsFiles) + + for _, callback := range hostKeyCallbacks { + if err := callback(hostname, remote, key); err == nil { + return nil + } + } + return fmt.Errorf("host key verification failed: key not found in NetBird daemon or any known_hosts file") +} + +func getKnownHostsFilesList(knownHostsFile string) []string { + if knownHostsFile != "" { + return []string{knownHostsFile} + } + return getKnownHostsFiles() +} + +func buildHostKeyCallbacks(knownHostsFiles []string) []ssh.HostKeyCallback { + var hostKeyCallbacks []ssh.HostKeyCallback + for _, file := range knownHostsFiles { + if callback, err := knownhosts.New(file); err == nil { + hostKeyCallbacks = append(hostKeyCallbacks, callback) + } + } + return hostKeyCallbacks +} + // createSSHKeyAuth creates SSH key authentication from a private key file func createSSHKeyAuth(keyFile string) (ssh.AuthMethod, error) { keyData, err := os.ReadFile(keyFile) diff --git a/client/ssh/config/manager.go b/client/ssh/config/manager.go index ab59c3d15..209d75e81 100644 --- a/client/ssh/config/manager.go +++ b/client/ssh/config/manager.go @@ -143,13 +143,7 @@ func NewManager() *Manager { func getSystemSSHPaths() (configDir, knownHostsDir string) { switch runtime.GOOS { case "windows": - // Windows OpenSSH paths - programData := os.Getenv("PROGRAMDATA") - if programData == "" { - programData = `C:\ProgramData` - } - configDir = filepath.Join(programData, "ssh", "ssh_config.d") - knownHostsDir = filepath.Join(programData, "ssh", "ssh_known_hosts.d") + configDir, knownHostsDir = getWindowsSSHPaths() default: // Unix-like systems (Linux, macOS, etc.) configDir = "/etc/ssh/ssh_config.d" @@ -158,6 +152,16 @@ func getSystemSSHPaths() (configDir, knownHostsDir string) { return configDir, knownHostsDir } +func getWindowsSSHPaths() (configDir, knownHostsDir string) { + programData := os.Getenv("PROGRAMDATA") + if programData == "" { + programData = `C:\ProgramData` + } + configDir = filepath.Join(programData, "ssh", "ssh_config.d") + knownHostsDir = filepath.Join(programData, "ssh", "ssh_known_hosts.d") + return configDir, knownHostsDir +} + // SetupSSHClientConfig creates SSH client configuration for NetBird domains func (m *Manager) SetupSSHClientConfig(domains []string) error { return m.SetupSSHClientConfigWithPeers(domains, nil) @@ -165,75 +169,95 @@ func (m *Manager) SetupSSHClientConfig(domains []string) error { // SetupSSHClientConfigWithPeers creates SSH client configuration for peer hostnames func (m *Manager) SetupSSHClientConfigWithPeers(domains []string, peerKeys []PeerHostKey) error { - peerCount := len(peerKeys) - - // Check if SSH config should be generated - if !shouldGenerateSSHConfig(peerCount) { - if isSSHConfigDisabled() { - log.Debugf("SSH config management disabled via %s", EnvDisableSSHConfig) - } else { - log.Infof("SSH config generation skipped: too many peers (%d > %d). Use %s=true to force.", - peerCount, MaxPeersForSSHConfig, EnvForceSSHConfig) - } + if !shouldGenerateSSHConfig(len(peerKeys)) { + m.logSkipReason(len(peerKeys)) return nil } - // Try to set up known_hosts for host key verification + + knownHostsPath := m.getKnownHostsPath() + sshConfig := m.buildSSHConfig(peerKeys, knownHostsPath) + return m.writeSSHConfig(sshConfig, domains) +} + +func (m *Manager) logSkipReason(peerCount int) { + if isSSHConfigDisabled() { + log.Debugf("SSH config management disabled via %s", EnvDisableSSHConfig) + } else { + log.Infof("SSH config generation skipped: too many peers (%d > %d). Use %s=true to force.", + peerCount, MaxPeersForSSHConfig, EnvForceSSHConfig) + } +} + +func (m *Manager) getKnownHostsPath() string { knownHostsPath, err := m.setupKnownHostsFile() if err != nil { log.Warnf("Failed to setup known_hosts file: %v", err) - // Continue with fallback to no verification - knownHostsPath = "/dev/null" + return "/dev/null" + } + return knownHostsPath +} + +func (m *Manager) buildSSHConfig(peerKeys []PeerHostKey, knownHostsPath string) string { + sshConfig := m.buildConfigHeader() + for _, peer := range peerKeys { + sshConfig += m.buildPeerConfig(peer, knownHostsPath) + } + return sshConfig +} + +func (m *Manager) buildConfigHeader() string { + return "# NetBird SSH client configuration\n" + + "# Generated automatically - do not edit manually\n" + + "#\n" + + "# To disable SSH config management, use:\n" + + "# netbird service reconfigure --service-env NB_DISABLE_SSH_CONFIG=true\n" + + "#\n\n" +} + +func (m *Manager) buildPeerConfig(peer PeerHostKey, knownHostsPath string) string { + hostPatterns := m.buildHostPatterns(peer) + if len(hostPatterns) == 0 { + return "" } + hostLine := strings.Join(hostPatterns, " ") + config := fmt.Sprintf("Host %s\n", hostLine) + config += " # NetBird peer-specific configuration\n" + config += " PreferredAuthentications password,publickey,keyboard-interactive\n" + config += " PasswordAuthentication yes\n" + config += " PubkeyAuthentication yes\n" + config += " BatchMode no\n" + config += m.buildHostKeyConfig(knownHostsPath) + config += " LogLevel ERROR\n\n" + return config +} + +func (m *Manager) buildHostPatterns(peer PeerHostKey) []string { + var hostPatterns []string + if peer.IP != "" { + hostPatterns = append(hostPatterns, peer.IP) + } + if peer.FQDN != "" { + hostPatterns = append(hostPatterns, peer.FQDN) + } + if peer.Hostname != "" && peer.Hostname != peer.FQDN { + hostPatterns = append(hostPatterns, peer.Hostname) + } + return hostPatterns +} + +func (m *Manager) buildHostKeyConfig(knownHostsPath string) string { + if knownHostsPath == "/dev/null" { + return " StrictHostKeyChecking no\n" + + " UserKnownHostsFile /dev/null\n" + } + return " StrictHostKeyChecking yes\n" + + fmt.Sprintf(" UserKnownHostsFile %s\n", knownHostsPath) +} + +func (m *Manager) writeSSHConfig(sshConfig string, domains []string) error { sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile) - // Build SSH client configuration - sshConfig := "# NetBird SSH client configuration\n" - sshConfig += "# Generated automatically - do not edit manually\n" - sshConfig += "#\n" - sshConfig += "# To disable SSH config management, use:\n" - sshConfig += "# netbird service reconfigure --service-env NB_DISABLE_SSH_CONFIG=true\n" - sshConfig += "#\n\n" - - // Add specific peer entries with multiple hostnames in one Host line - for _, peer := range peerKeys { - var hostPatterns []string - - // Add IP address - if peer.IP != "" { - hostPatterns = append(hostPatterns, peer.IP) - } - - // Add FQDN - if peer.FQDN != "" { - hostPatterns = append(hostPatterns, peer.FQDN) - } - - // Add short hostname if different from FQDN - if peer.Hostname != "" && peer.Hostname != peer.FQDN { - hostPatterns = append(hostPatterns, peer.Hostname) - } - - if len(hostPatterns) > 0 { - hostLine := strings.Join(hostPatterns, " ") - sshConfig += fmt.Sprintf("Host %s\n", hostLine) - sshConfig += " # NetBird peer-specific configuration\n" - sshConfig += " PreferredAuthentications password,publickey,keyboard-interactive\n" - sshConfig += " PasswordAuthentication yes\n" - sshConfig += " PubkeyAuthentication yes\n" - sshConfig += " BatchMode no\n" - if knownHostsPath == "/dev/null" { - sshConfig += " StrictHostKeyChecking no\n" - sshConfig += " UserKnownHostsFile /dev/null\n" - } else { - sshConfig += " StrictHostKeyChecking yes\n" - sshConfig += fmt.Sprintf(" UserKnownHostsFile %s\n", knownHostsPath) - } - sshConfig += " LogLevel ERROR\n\n" - } - } - - // Try to create system-wide SSH config if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil { log.Warnf("Failed to create SSH config directory %s: %v", m.sshConfigDir, err) return m.setupUserConfig(sshConfig, domains) diff --git a/client/ssh/server/compatibility_test.go b/client/ssh/server/compatibility_test.go index 552545adc..a692da264 100644 --- a/client/ssh/server/compatibility_test.go +++ b/client/ssh/server/compatibility_test.go @@ -39,7 +39,7 @@ func TestSSHServerCompatibility(t *testing.T) { require.NoError(t, err) // Generate OpenSSH-compatible keys for client - clientPrivKeyOpenSSH, clientPubKeyOpenSSH, err := generateOpenSSHKey() + clientPrivKeyOpenSSH, clientPubKeyOpenSSH, err := generateOpenSSHKey(t) require.NoError(t, err) server := New(hostKey) @@ -270,7 +270,7 @@ func isSSHClientAvailable() bool { } // generateOpenSSHKey generates an ED25519 key in OpenSSH format that the system SSH client can use. -func generateOpenSSHKey() ([]byte, []byte, error) { +func generateOpenSSHKey(t *testing.T) ([]byte, []byte, error) { // Check if ssh-keygen is available if _, err := exec.LookPath("ssh-keygen"); err != nil { // Fall back to our existing key generation and try to convert diff --git a/client/ssh/server/session_handlers.go b/client/ssh/server/session_handlers.go index f1132e7ad..06d4e5a07 100644 --- a/client/ssh/server/session_handlers.go +++ b/client/ssh/server/session_handlers.go @@ -52,17 +52,20 @@ func (s *Server) sessionHandler(session ssh.Session) { // ssh - non-Pty command execution s.handleCommand(logger, session, privilegeResult, ssh.Pty{}, nil) default: - // ssh - no Pty, no command (invalid) - if _, err := io.WriteString(session, "no command specified and Pty not requested\n"); err != nil { - logger.Debugf(errWriteSession, err) - } - if err := session.Exit(1); err != nil { - logger.Debugf(errExitSession, err) - } - logger.Infof("rejected non-Pty session without command from %s", session.RemoteAddr()) + s.rejectInvalidSession(logger, session) } } +func (s *Server) rejectInvalidSession(logger *log.Entry, session ssh.Session) { + if _, err := io.WriteString(session, "no command specified and Pty not requested\n"); err != nil { + logger.Debugf(errWriteSession, err) + } + if err := session.Exit(1); err != nil { + logger.Debugf(errExitSession, err) + } + logger.Infof("rejected non-Pty session without command from %s", session.RemoteAddr()) +} + func (s *Server) registerSession(session ssh.Session) SessionKey { sessionID := session.Context().Value(ssh.ContextKeySessionID) if sessionID == nil { diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 62b2f151f..c8d854a94 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -437,72 +437,7 @@ func (s *serviceClient) getSettingsForm() fyne.CanvasObject { ) // Create save and cancel buttons - saveButton := widget.NewButton("Save", func() { - if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey { - // validate preSharedKey if it added - if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil { - dialog.ShowError(fmt.Errorf("Invalid Pre-shared Key Value"), s.wSettings) - return - } - } - - port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64) - if err != nil { - dialog.ShowError(errors.New("Invalid interface port"), s.wSettings) - return - } - - iAdminURL := strings.TrimSpace(s.iAdminURL.Text) - iMngURL := strings.TrimSpace(s.iMngURL.Text) - - defer s.wSettings.Close() - - // Check if any settings have changed - if s.managementURL != iMngURL || s.preSharedKey != s.iPreSharedKey.Text || - s.adminURL != iAdminURL || s.RosenpassPermissive != s.sRosenpassPermissive.Checked || - s.interfaceName != s.iInterfaceName.Text || s.interfacePort != int(port) || - s.networkMonitor != s.sNetworkMonitor.Checked || - s.disableDNS != s.sDisableDNS.Checked || - s.disableClientRoutes != s.sDisableClientRoutes.Checked || - s.disableServerRoutes != s.sDisableServerRoutes.Checked || - s.blockLANAccess != s.sBlockLANAccess.Checked || - s.enableSSHRoot != s.sEnableSSHRoot.Checked || - s.enableSSHSFTP != s.sEnableSSHSFTP.Checked || - s.enableSSHLocalPortForward != s.sEnableSSHLocalPortForward.Checked || - s.enableSSHRemotePortForward != s.sEnableSSHRemotePortForward.Checked { - - s.managementURL = iMngURL - s.preSharedKey = s.iPreSharedKey.Text - s.adminURL = iAdminURL - - loginRequest := proto.LoginRequest{ - ManagementUrl: iMngURL, - AdminURL: iAdminURL, - IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd", - RosenpassPermissive: &s.sRosenpassPermissive.Checked, - InterfaceName: &s.iInterfaceName.Text, - WireguardPort: &port, - NetworkMonitor: &s.sNetworkMonitor.Checked, - DisableDns: &s.sDisableDNS.Checked, - DisableClientRoutes: &s.sDisableClientRoutes.Checked, - DisableServerRoutes: &s.sDisableServerRoutes.Checked, - BlockLanAccess: &s.sBlockLANAccess.Checked, - EnableSSHRoot: &s.sEnableSSHRoot.Checked, - EnableSSHSFTP: &s.sEnableSSHSFTP.Checked, - EnableSSHLocalPortForwarding: &s.sEnableSSHLocalPortForward.Checked, - EnableSSHRemotePortForwarding: &s.sEnableSSHRemotePortForward.Checked, - } - - if s.iPreSharedKey.Text != censoredPreSharedKey { - loginRequest.OptionalPreSharedKey = &s.iPreSharedKey.Text - } - - if err := s.restartClient(&loginRequest); err != nil { - log.Errorf("restarting client connection: %v", err) - return - } - } - }) + saveButton := widget.NewButton("Save", s.handleSaveSettings) cancelButton := widget.NewButton("Cancel", func() { s.wSettings.Close() @@ -519,6 +454,105 @@ func (s *serviceClient) getSettingsForm() fyne.CanvasObject { return container.NewBorder(nil, buttonContainer, nil, nil, tabs) } +func (s *serviceClient) handleSaveSettings() { + defer s.wSettings.Close() + + if err := s.validateSettings(); err != nil { + dialog.ShowError(err, s.wSettings) + return + } + + port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64) + if err != nil { + dialog.ShowError(errors.New("Invalid interface port"), s.wSettings) + return + } + + iAdminURL := strings.TrimSpace(s.iAdminURL.Text) + iMngURL := strings.TrimSpace(s.iMngURL.Text) + + if s.hasSettingsChanged(iMngURL, iAdminURL, int(port)) { + s.applySettings(iMngURL, iAdminURL, port) + } +} + +func (s *serviceClient) validateSettings() error { + if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey { + if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil { + return fmt.Errorf("Invalid Pre-shared Key Value") + } + } + return nil +} + +func (s *serviceClient) hasSettingsChanged(iMngURL, iAdminURL string, port int) bool { + return s.managementURL != iMngURL || + s.preSharedKey != s.iPreSharedKey.Text || + s.adminURL != iAdminURL || + s.hasInterfaceChanges(port) || + s.hasNetworkChanges() || + s.hasSSHChanges() +} + +func (s *serviceClient) hasInterfaceChanges(port int) bool { + return s.RosenpassPermissive != s.sRosenpassPermissive.Checked || + s.interfaceName != s.iInterfaceName.Text || + s.interfacePort != port +} + +func (s *serviceClient) hasNetworkChanges() bool { + return s.networkMonitor != s.sNetworkMonitor.Checked || + s.disableDNS != s.sDisableDNS.Checked || + s.disableClientRoutes != s.sDisableClientRoutes.Checked || + s.disableServerRoutes != s.sDisableServerRoutes.Checked || + s.blockLANAccess != s.sBlockLANAccess.Checked +} + +func (s *serviceClient) hasSSHChanges() bool { + return s.enableSSHRoot != s.sEnableSSHRoot.Checked || + s.enableSSHSFTP != s.sEnableSSHSFTP.Checked || + s.enableSSHLocalPortForward != s.sEnableSSHLocalPortForward.Checked || + s.enableSSHRemotePortForward != s.sEnableSSHRemotePortForward.Checked +} + +func (s *serviceClient) applySettings(iMngURL, iAdminURL string, port int64) { + s.managementURL = iMngURL + s.preSharedKey = s.iPreSharedKey.Text + s.adminURL = iAdminURL + + loginRequest := s.buildLoginRequest(iMngURL, iAdminURL, port) + + if err := s.restartClient(&loginRequest); err != nil { + log.Errorf("restarting client connection: %v", err) + } +} + +func (s *serviceClient) buildLoginRequest(iMngURL, iAdminURL string, port int64) proto.LoginRequest { + loginRequest := proto.LoginRequest{ + ManagementUrl: iMngURL, + AdminURL: iAdminURL, + IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd", + RosenpassPermissive: &s.sRosenpassPermissive.Checked, + InterfaceName: &s.iInterfaceName.Text, + WireguardPort: &port, + NetworkMonitor: &s.sNetworkMonitor.Checked, + DisableDns: &s.sDisableDNS.Checked, + DisableClientRoutes: &s.sDisableClientRoutes.Checked, + DisableServerRoutes: &s.sDisableServerRoutes.Checked, + BlockLanAccess: &s.sBlockLANAccess.Checked, + EnableSSHRoot: &s.sEnableSSHRoot.Checked, + EnableSSHSFTP: &s.sEnableSSHSFTP.Checked, + EnableSSHLocalPortForwarding: &s.sEnableSSHLocalPortForward.Checked, + EnableSSHRemotePortForwarding: &s.sEnableSSHRemotePortForward.Checked, + } + + if s.iPreSharedKey.Text != censoredPreSharedKey { + loginRequest.OptionalPreSharedKey = &s.iPreSharedKey.Text + } + + return loginRequest +} + func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil {