diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go index cb1c36e13..8897b9c7e 100644 --- a/client/ssh/proxy/proxy.go +++ b/client/ssh/proxy/proxy.go @@ -207,8 +207,6 @@ func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) { } func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) { - // Create a backend session to mirror the client's session request. - // This keeps the connection alive on the server side while port forwarding channels operate. serverSession, err := sshClient.NewSession() if err != nil { _, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err) @@ -216,10 +214,28 @@ func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *c } defer func() { _ = serverSession.Close() }() - <-session.Context().Done() + serverSession.Stdin = session + serverSession.Stdout = session + serverSession.Stderr = session.Stderr() - if err := session.Exit(0); err != nil { - log.Debugf("session exit: %v", err) + if err := serverSession.Shell(); err != nil { + log.Debugf("start shell: %v", err) + return + } + + done := make(chan error, 1) + go func() { + done <- serverSession.Wait() + }() + + select { + case <-session.Context().Done(): + return + case err := <-done: + if err != nil { + log.Debugf("shell session: %v", err) + p.handleProxyExitCode(session, err) + } } } diff --git a/client/ssh/server/command_execution.go b/client/ssh/server/command_execution.go index 7a01ce4f6..b0a85fe4b 100644 --- a/client/ssh/server/command_execution.go +++ b/client/ssh/server/command_execution.go @@ -12,8 +12,8 @@ import ( log "github.com/sirupsen/logrus" ) -// handleCommand executes an SSH command with privilege validation -func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) { +// handleExecution executes an SSH command or shell with privilege validation +func (s *Server) handleExecution(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) { hasPty := winCh != nil commandType := "command" @@ -23,7 +23,7 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command())) - execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty) + execCmd, cleanup, err := s.createCommand(logger, privilegeResult, session, hasPty) if err != nil { logger.Errorf("%s creation failed: %v", commandType, err) @@ -51,13 +51,12 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege defer cleanup() - ptyReq, _, _ := session.Pty() if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) { logger.Debugf("%s execution completed", commandType) } } -func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) { +func (s *Server) createCommand(logger *log.Entry, privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) { localUser := privilegeResult.User if localUser == nil { return nil, nil, errors.New("no user in privilege result") @@ -66,28 +65,28 @@ func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh // If PTY requested but su doesn't support --pty, skip su and use executor // This ensures PTY functionality is provided (executor runs within our allocated PTY) if hasPty && !s.suSupportsPty { - log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality") - cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty) + logger.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality") + cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty) if err != nil { return nil, nil, fmt.Errorf("create command with privileges: %w", err) } - cmd.Env = s.prepareCommandEnv(localUser, session) + cmd.Env = s.prepareCommandEnv(logger, localUser, session) return cmd, cleanup, nil } // Try su first for system integration (PAM/audit) when privileged - cmd, err := s.createSuCommand(session, localUser, hasPty) + cmd, err := s.createSuCommand(logger, session, localUser, hasPty) if err != nil || privilegeResult.UsedFallback { - log.Debugf("su command failed, falling back to executor: %v", err) - cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty) + logger.Debugf("su command failed, falling back to executor: %v", err) + cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty) if err != nil { return nil, nil, fmt.Errorf("create command with privileges: %w", err) } - cmd.Env = s.prepareCommandEnv(localUser, session) + cmd.Env = s.prepareCommandEnv(logger, localUser, session) return cmd, cleanup, nil } - cmd.Env = s.prepareCommandEnv(localUser, session) + cmd.Env = s.prepareCommandEnv(logger, localUser, session) return cmd, func() {}, nil } diff --git a/client/ssh/server/command_execution_js.go b/client/ssh/server/command_execution_js.go index 01759a337..3aeaa135c 100644 --- a/client/ssh/server/command_execution_js.go +++ b/client/ssh/server/command_execution_js.go @@ -15,17 +15,17 @@ import ( var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform") // createSuCommand is not supported on JS/WASM -func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) { +func (s *Server) createSuCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) { return nil, errNotSupported } // createExecutorCommand is not supported on JS/WASM -func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) { +func (s *Server) createExecutorCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) { return nil, nil, errNotSupported } // prepareCommandEnv is not supported on JS/WASM -func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string { +func (s *Server) prepareCommandEnv(_ *log.Entry, _ *user.User, _ ssh.Session) []string { return nil } diff --git a/client/ssh/server/command_execution_unix.go b/client/ssh/server/command_execution_unix.go index db1a9bcfe..279b89341 100644 --- a/client/ssh/server/command_execution_unix.go +++ b/client/ssh/server/command_execution_unix.go @@ -10,6 +10,7 @@ import ( "os" "os/exec" "os/user" + "path/filepath" "runtime" "strings" "sync" @@ -99,40 +100,52 @@ func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool { return isUtilLinux } -// createSuCommand creates a command using su -l -c for privilege switching -func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) { +// createSuCommand creates a command using su - for privilege switching. +func (s *Server) createSuCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) { + if err := validateUsername(localUser.Username); err != nil { + return nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err) + } + suPath, err := exec.LookPath("su") if err != nil { return nil, fmt.Errorf("su command not available: %w", err) } - command := session.RawCommand() - if command == "" { - return nil, fmt.Errorf("no command specified for su execution") - } - - args := []string{"-l"} + args := []string{"-"} if hasPty && s.suSupportsPty { args = append(args, "--pty") } - args = append(args, localUser.Username, "-c", command) + args = append(args, localUser.Username) + command := session.RawCommand() + if command != "" { + args = append(args, "-c", command) + } + + logger.Debugf("creating su command: %s %v", suPath, args) cmd := exec.CommandContext(session.Context(), suPath, args...) cmd.Dir = localUser.HomeDir return cmd, nil } -// getShellCommandArgs returns the shell command and arguments for executing a command string +// getShellCommandArgs returns the shell command and arguments for executing a command string. func (s *Server) getShellCommandArgs(shell, cmdString string) []string { if cmdString == "" { - return []string{shell, "-l"} + return []string{shell} } - return []string{shell, "-l", "-c", cmdString} + return []string{shell, "-c", cmdString} +} + +// createShellCommand creates an exec.Cmd configured as a login shell by setting argv[0] to "-shellname". +func (s *Server) createShellCommand(ctx context.Context, shell string, args []string) *exec.Cmd { + cmd := exec.CommandContext(ctx, shell, args[1:]...) + cmd.Args[0] = "-" + filepath.Base(shell) + return cmd } // prepareCommandEnv prepares environment variables for command execution on Unix -func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string { +func (s *Server) prepareCommandEnv(_ *log.Entry, localUser *user.User, session ssh.Session) []string { env := prepareUserEnv(localUser, getUserShell(localUser.Uid)) env = append(env, prepareSSHEnv(session)...) for _, v := range session.Environ() { @@ -154,7 +167,7 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, e return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh) } -func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { +func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session) if err != nil { logger.Errorf("Pty command creation failed: %v", err) @@ -244,11 +257,6 @@ func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *pty }() go func() { - defer func() { - if err := session.Close(); err != nil && !errors.Is(err, io.EOF) { - logger.Debugf("session close error: %v", err) - } - }() if _, err := io.Copy(session, ptmx); err != nil { if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) { logger.Warnf("Pty output copy error: %v", err) @@ -268,7 +276,7 @@ func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, ex case <-ctx.Done(): s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done) case err := <-done: - s.handlePtyCommandCompletion(logger, session, err) + s.handlePtyCommandCompletion(logger, session, ptyMgr, err) } } @@ -296,17 +304,20 @@ func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Ses } } -func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) { +func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, err error) { if err != nil { logger.Debugf("Pty command execution failed: %v", err) s.handleSessionExit(session, err, logger) - return + } else { + logger.Debugf("Pty command completed successfully") + if err := session.Exit(0); err != nil { + logSessionExitError(logger, err) + } } - // Normal completion - logger.Debugf("Pty command completed successfully") - if err := session.Exit(0); err != nil { - logSessionExitError(logger, err) + // Close PTY to unblock io.Copy goroutines + if err := ptyMgr.Close(); err != nil { + logger.Debugf("Pty close after completion: %v", err) } } diff --git a/client/ssh/server/command_execution_windows.go b/client/ssh/server/command_execution_windows.go index 998796871..e1ba777f6 100644 --- a/client/ssh/server/command_execution_windows.go +++ b/client/ssh/server/command_execution_windows.go @@ -20,32 +20,32 @@ import ( // getUserEnvironment retrieves the Windows environment for the target user. // Follows OpenSSH's resilient approach with graceful degradation on failures. -func (s *Server) getUserEnvironment(username, domain string) ([]string, error) { - userToken, err := s.getUserToken(username, domain) +func (s *Server) getUserEnvironment(logger *log.Entry, username, domain string) ([]string, error) { + userToken, err := s.getUserToken(logger, username, domain) if err != nil { return nil, fmt.Errorf("get user token: %w", err) } defer func() { if err := windows.CloseHandle(userToken); err != nil { - log.Debugf("close user token: %v", err) + logger.Debugf("close user token: %v", err) } }() - return s.getUserEnvironmentWithToken(userToken, username, domain) + return s.getUserEnvironmentWithToken(logger, userToken, username, domain) } // getUserEnvironmentWithToken retrieves the Windows environment using an existing token. -func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) { +func (s *Server) getUserEnvironmentWithToken(logger *log.Entry, userToken windows.Handle, username, domain string) ([]string, error) { userProfile, err := s.loadUserProfile(userToken, username, domain) if err != nil { - log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err) + logger.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err) userProfile = fmt.Sprintf("C:\\Users\\%s", username) } envMap := make(map[string]string) if err := s.loadSystemEnvironment(envMap); err != nil { - log.Debugf("failed to load system environment from registry: %v", err) + logger.Debugf("failed to load system environment from registry: %v", err) } s.setUserEnvironmentVariables(envMap, userProfile, username, domain) @@ -59,8 +59,8 @@ func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, } // getUserToken creates a user token for the specified user. -func (s *Server) getUserToken(username, domain string) (windows.Handle, error) { - privilegeDropper := NewPrivilegeDropper() +func (s *Server) getUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) { + privilegeDropper := NewPrivilegeDropper(WithLogger(logger)) token, err := privilegeDropper.createToken(username, domain) if err != nil { return 0, fmt.Errorf("generate S4U user token: %w", err) @@ -242,9 +242,9 @@ func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfi } // prepareCommandEnv prepares environment variables for command execution on Windows -func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string { +func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, session ssh.Session) []string { username, domain := s.parseUsername(localUser.Username) - userEnv, err := s.getUserEnvironment(username, domain) + userEnv, err := s.getUserEnvironment(logger, username, domain) if err != nil { log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err) env := prepareUserEnv(localUser, getUserShell(localUser.Uid)) @@ -267,22 +267,16 @@ func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) [] return env } -func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { +func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool { if privilegeResult.User == nil { logger.Errorf("no user in privilege result") return false } - cmd := session.Command() shell := getUserShell(privilegeResult.User.Uid) + logger.Infof("starting interactive shell: %s", shell) - if len(cmd) == 0 { - logger.Infof("starting interactive shell: %s", shell) - } else { - logger.Infof("executing command: %s", safeLogCommand(cmd)) - } - - s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd) + s.executeCommandWithPty(logger, session, nil, privilegeResult, ptyReq, nil) return true } @@ -294,11 +288,6 @@ func (s *Server) getShellCommandArgs(shell, cmdString string) []string { return []string{shell, "-Command", cmdString} } -func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) { - logger.Info("starting interactive shell") - s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand()) -} - type PtyExecutionRequest struct { Shell string Command string @@ -308,25 +297,25 @@ type PtyExecutionRequest struct { Domain string } -func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error { - log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d", +func executePtyCommandWithUserToken(logger *log.Entry, session ssh.Session, req PtyExecutionRequest) error { + logger.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d", req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height) - privilegeDropper := NewPrivilegeDropper() + privilegeDropper := NewPrivilegeDropper(WithLogger(logger)) userToken, err := privilegeDropper.createToken(req.Username, req.Domain) if err != nil { return fmt.Errorf("create user token: %w", err) } defer func() { if err := windows.CloseHandle(userToken); err != nil { - log.Debugf("close user token: %v", err) + logger.Debugf("close user token: %v", err) } }() server := &Server{} - userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain) + userEnv, err := server.getUserEnvironmentWithToken(logger, userToken, req.Username, req.Domain) if err != nil { - log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err) + logger.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err) userEnv = os.Environ() } @@ -348,8 +337,8 @@ func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, re Environment: userEnv, } - log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir) - return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig) + logger.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir) + return winpty.ExecutePtyWithUserToken(session, ptyConfig, userConfig) } func getUserHomeFromEnv(env []string) string { @@ -371,10 +360,8 @@ func (s *Server) killProcessGroup(cmd *exec.Cmd) { return } - logger := log.WithField("pid", cmd.Process.Pid) - if err := cmd.Process.Kill(); err != nil { - logger.Debugf("kill process failed: %v", err) + log.Debugf("kill process %d failed: %v", cmd.Process.Pid, err) } } @@ -389,21 +376,7 @@ func (s *Server) detectUtilLinuxLogin(context.Context) bool { } // executeCommandWithPty executes a command with PTY allocation on Windows using ConPty -func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { - command := session.RawCommand() - if command == "" { - logger.Error("no command specified for PTY execution") - if err := session.Exit(1); err != nil { - logSessionExitError(logger, err) - } - return false - } - - return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command) -} - -// executeConPtyCommand executes a command using ConPty (common for interactive and command execution) -func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool { +func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _ *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool { localUser := privilegeResult.User if localUser == nil { logger.Errorf("no user in privilege result") @@ -415,14 +388,14 @@ func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, pr req := PtyExecutionRequest{ Shell: shell, - Command: command, + Command: session.RawCommand(), Width: ptyReq.Window.Width, Height: ptyReq.Window.Height, Username: username, Domain: domain, } - if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil { + if err := executePtyCommandWithUserToken(logger, session, req); err != nil { logger.Errorf("ConPty execution failed: %v", err) if err := session.Exit(1); err != nil { logSessionExitError(logger, err) diff --git a/client/ssh/server/compatibility_test.go b/client/ssh/server/compatibility_test.go index 34ffccfd2..7fe2d6c5e 100644 --- a/client/ssh/server/compatibility_test.go +++ b/client/ssh/server/compatibility_test.go @@ -4,12 +4,15 @@ import ( "context" "crypto/ed25519" "crypto/rand" + "errors" "fmt" "io" "net" "os" "os/exec" + "path/filepath" "runtime" + "slices" "strings" "testing" "time" @@ -23,25 +26,67 @@ import ( "github.com/netbirdio/netbird/client/ssh/testutil" ) -// TestMain handles package-level setup and cleanup func TestMain(m *testing.M) { - // Guard against infinite recursion when test binary is called as "netbird ssh exec" - // This happens when running tests as non-privileged user with fallback + // On platforms where su doesn't support --pty (macOS, FreeBSD, Windows), the SSH server + // spawns an executor subprocess via os.Executable(). During tests, this invokes the test + // binary with "ssh exec" args. We handle that here to properly execute commands and + // propagate exit codes. if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" { - // Just exit with error to break the recursion - fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n") - os.Exit(1) + runTestExecutor() + return } - // Run tests code := m.Run() - - // Cleanup any created test users testutil.CleanupTestUsers() - os.Exit(code) } +// runTestExecutor emulates the netbird executor for tests. +// Parses --shell and --cmd args, runs the command, and exits with the correct code. +func runTestExecutor() { + if os.Getenv("_NETBIRD_TEST_EXECUTOR") != "" { + fmt.Fprintf(os.Stderr, "executor recursion detected\n") + os.Exit(1) + } + os.Setenv("_NETBIRD_TEST_EXECUTOR", "1") + + shell := "/bin/sh" + var command string + for i := 3; i < len(os.Args); i++ { + switch os.Args[i] { + case "--shell": + if i+1 < len(os.Args) { + shell = os.Args[i+1] + i++ + } + case "--cmd": + if i+1 < len(os.Args) { + command = os.Args[i+1] + i++ + } + } + } + + var cmd *exec.Cmd + if command == "" { + cmd = exec.Command(shell) + } else { + cmd = exec.Command(shell, "-c", command) + } + cmd.Args[0] = "-" + filepath.Base(shell) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + os.Exit(exitErr.ExitCode()) + } + os.Exit(1) + } + os.Exit(0) +} + // TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client func TestSSHServerCompatibility(t *testing.T) { if testing.Short() { @@ -405,6 +450,171 @@ func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) { return createTempKeyFileFromBytes(t, privateKey) } +// TestSSHPtyModes tests different PTY allocation modes (-T, -t, -tt flags) +// This ensures our implementation matches OpenSSH behavior for: +// - ssh host command (no PTY - default when no TTY) +// - ssh -T host command (explicit no PTY) +// - ssh -t host command (force PTY) +// - ssh -T host (no PTY shell - our implementation) +func TestSSHPtyModes(t *testing.T) { + if testing.Short() { + t.Skip("Skipping SSH PTY mode tests in short mode") + } + + if !isSSHClientAvailable() { + t.Skip("SSH client not available on this system") + } + + if runtime.GOOS == "windows" && testutil.IsCI() { + t.Skip("Skipping Windows SSH PTY tests in CI due to S4U authentication issues") + } + + hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + require.NoError(t, err) + + clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t) + require.NoError(t, err) + + serverConfig := &Config{ + HostKeyPEM: hostKey, + JWT: nil, + } + server := New(serverConfig) + server.SetAllowRootLogin(true) + + serverAddr := StartTestServer(t, server) + defer func() { + err := server.Stop() + require.NoError(t, err) + }() + + clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH) + defer cleanupKey() + + host, portStr, err := net.SplitHostPort(serverAddr) + require.NoError(t, err) + + username := testutil.GetTestUsername(t) + + baseArgs := []string{ + "-i", clientKeyFile, + "-p", portStr, + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + "-o", "BatchMode=yes", + } + + t.Run("command_default_no_pty", func(t *testing.T) { + args := append(slices.Clone(baseArgs), fmt.Sprintf("%s@%s", username, host), "echo", "no_pty_default") + cmd := exec.Command("ssh", args...) + + output, err := cmd.CombinedOutput() + require.NoError(t, err, "Command (default no PTY) failed: %s", output) + assert.Contains(t, string(output), "no_pty_default") + }) + + t.Run("command_explicit_no_pty", func(t *testing.T) { + args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "echo", "explicit_no_pty") + cmd := exec.Command("ssh", args...) + + output, err := cmd.CombinedOutput() + require.NoError(t, err, "Command (-T explicit no PTY) failed: %s", output) + assert.Contains(t, string(output), "explicit_no_pty") + }) + + t.Run("command_force_pty", func(t *testing.T) { + args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "echo", "force_pty") + cmd := exec.Command("ssh", args...) + + output, err := cmd.CombinedOutput() + require.NoError(t, err, "Command (-tt force PTY) failed: %s", output) + assert.Contains(t, string(output), "force_pty") + }) + + t.Run("shell_explicit_no_pty", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host)) + cmd := exec.CommandContext(ctx, "ssh", args...) + + stdin, err := cmd.StdinPipe() + require.NoError(t, err) + + stdout, err := cmd.StdoutPipe() + require.NoError(t, err) + + require.NoError(t, cmd.Start(), "Shell (-T no PTY) start failed") + + go func() { + defer stdin.Close() + time.Sleep(100 * time.Millisecond) + _, err := stdin.Write([]byte("echo shell_no_pty_test\n")) + assert.NoError(t, err, "write echo command") + time.Sleep(100 * time.Millisecond) + _, err = stdin.Write([]byte("exit 0\n")) + assert.NoError(t, err, "write exit command") + }() + + output, _ := io.ReadAll(stdout) + err = cmd.Wait() + + require.NoError(t, err, "Shell (-T no PTY) failed: %s", output) + assert.Contains(t, string(output), "shell_no_pty_test") + }) + + t.Run("exit_code_preserved_no_pty", func(t *testing.T) { + args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "exit", "42") + cmd := exec.Command("ssh", args...) + + err := cmd.Run() + require.Error(t, err, "Command should exit with non-zero") + + var exitErr *exec.ExitError + require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err) + assert.Equal(t, 42, exitErr.ExitCode(), "Exit code should be preserved with -T") + }) + + t.Run("exit_code_preserved_with_pty", func(t *testing.T) { + args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "sh -c 'exit 43'") + cmd := exec.Command("ssh", args...) + + err := cmd.Run() + require.Error(t, err, "PTY command should exit with non-zero") + + var exitErr *exec.ExitError + require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err) + assert.Equal(t, 43, exitErr.ExitCode(), "Exit code should be preserved with -tt") + }) + + t.Run("stderr_works_no_pty", func(t *testing.T) { + args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), + "sh -c 'echo stdout_msg; echo stderr_msg >&2'") + cmd := exec.Command("ssh", args...) + + var stdout, stderr strings.Builder + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + require.NoError(t, cmd.Run(), "stderr test failed") + assert.Contains(t, stdout.String(), "stdout_msg", "stdout should have stdout_msg") + assert.Contains(t, stderr.String(), "stderr_msg", "stderr should have stderr_msg") + assert.NotContains(t, stdout.String(), "stderr_msg", "stdout should NOT have stderr_msg") + }) + + t.Run("stderr_merged_with_pty", func(t *testing.T) { + args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), + "sh -c 'echo stdout_msg; echo stderr_msg >&2'") + cmd := exec.Command("ssh", args...) + + output, err := cmd.CombinedOutput() + require.NoError(t, err, "PTY stderr test failed: %s", output) + assert.Contains(t, string(output), "stdout_msg") + assert.Contains(t, string(output), "stderr_msg") + }) +} + // TestSSHServerFeatureCompatibility tests specific SSH features for compatibility func TestSSHServerFeatureCompatibility(t *testing.T) { if testing.Short() { diff --git a/client/ssh/server/executor_unix.go b/client/ssh/server/executor_unix.go index 8adc824ef..ee0b0ff78 100644 --- a/client/ssh/server/executor_unix.go +++ b/client/ssh/server/executor_unix.go @@ -8,6 +8,7 @@ import ( "fmt" "os" "os/exec" + "path/filepath" "runtime" "strings" "syscall" @@ -35,11 +36,35 @@ type ExecutorConfig struct { } // PrivilegeDropper handles secure privilege dropping in child processes -type PrivilegeDropper struct{} +type PrivilegeDropper struct { + logger *log.Entry +} + +// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper +type PrivilegeDropperOption func(*PrivilegeDropper) // NewPrivilegeDropper creates a new privilege dropper -func NewPrivilegeDropper() *PrivilegeDropper { - return &PrivilegeDropper{} +func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper { + pd := &PrivilegeDropper{} + for _, opt := range opts { + opt(pd) + } + return pd +} + +// WithLogger sets the logger for the PrivilegeDropper +func WithLogger(logger *log.Entry) PrivilegeDropperOption { + return func(pd *PrivilegeDropper) { + pd.logger = logger + } +} + +// log returns the logger, falling back to standard logger if none set +func (pd *PrivilegeDropper) log() *log.Entry { + if pd.logger != nil { + return pd.logger + } + return log.NewEntry(log.StandardLogger()) } // CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping @@ -83,7 +108,7 @@ func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config Ex break } } - log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs) + pd.log().Tracef("creating executor command: %s %v", netbirdPath, safeArgs) return exec.CommandContext(ctx, netbirdPath, args...), nil } @@ -206,17 +231,22 @@ func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config var execCmd *exec.Cmd if config.Command == "" { - os.Exit(ExitCodeSuccess) + execCmd = exec.CommandContext(ctx, config.Shell) + } else { + execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command) } - - execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command) + execCmd.Args[0] = "-" + filepath.Base(config.Shell) execCmd.Stdin = os.Stdin execCmd.Stdout = os.Stdout execCmd.Stderr = os.Stderr - cmdParts := strings.Fields(config.Command) - safeCmd := safeLogCommand(cmdParts) - log.Tracef("executing %s -c %s", execCmd.Path, safeCmd) + if config.Command == "" { + log.Tracef("executing login shell: %s", execCmd.Path) + } else { + cmdParts := strings.Fields(config.Command) + safeCmd := safeLogCommand(cmdParts) + log.Tracef("executing %s -c %s", execCmd.Path, safeCmd) + } if err := execCmd.Run(); err != nil { var exitError *exec.ExitError if errors.As(err, &exitError) { diff --git a/client/ssh/server/executor_windows.go b/client/ssh/server/executor_windows.go index d3504e056..51c995ec3 100644 --- a/client/ssh/server/executor_windows.go +++ b/client/ssh/server/executor_windows.go @@ -28,22 +28,45 @@ const ( ) type WindowsExecutorConfig struct { - Username string - Domain string - WorkingDir string - Shell string - Command string - Args []string - Interactive bool - Pty bool - PtyWidth int - PtyHeight int + Username string + Domain string + WorkingDir string + Shell string + Command string + Args []string + Pty bool + PtyWidth int + PtyHeight int } -type PrivilegeDropper struct{} +type PrivilegeDropper struct { + logger *log.Entry +} -func NewPrivilegeDropper() *PrivilegeDropper { - return &PrivilegeDropper{} +// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper +type PrivilegeDropperOption func(*PrivilegeDropper) + +func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper { + pd := &PrivilegeDropper{} + for _, opt := range opts { + opt(pd) + } + return pd +} + +// WithLogger sets the logger for the PrivilegeDropper +func WithLogger(logger *log.Entry) PrivilegeDropperOption { + return func(pd *PrivilegeDropper) { + pd.logger = logger + } +} + +// log returns the logger, falling back to standard logger if none set +func (pd *PrivilegeDropper) log() *log.Entry { + if pd.logger != nil { + return pd.logger + } + return log.NewEntry(log.StandardLogger()) } var ( @@ -56,7 +79,6 @@ const ( // Common error messages commandFlag = "-Command" - closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials convertUsernameError = "convert username to UTF16: %w" convertDomainError = "convert domain to UTF16: %w" ) @@ -80,7 +102,7 @@ func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, co shellArgs = []string{shell} } - log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs) + pd.log().Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs) cmd, token, err := pd.CreateWindowsProcessAsUser( ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir) @@ -180,10 +202,10 @@ func newLsaString(s string) lsaString { // generateS4UUserToken creates a Windows token using S4U authentication // This is the exact approach OpenSSH for Windows uses for public key authentication -func generateS4UUserToken(username, domain string) (windows.Handle, error) { +func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) { userCpn := buildUserCpn(username, domain) - pd := NewPrivilegeDropper() + pd := NewPrivilegeDropper(WithLogger(logger)) isDomainUser := !pd.isLocalUser(domain) lsaHandle, err := initializeLsaConnection() @@ -197,12 +219,12 @@ func generateS4UUserToken(username, domain string) (windows.Handle, error) { return 0, err } - logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser) + logonInfo, logonInfoSize, err := prepareS4ULogonStructure(logger, username, domain, isDomainUser) if err != nil { return 0, err } - return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser) + return performS4ULogon(logger, lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser) } // buildUserCpn constructs the user principal name @@ -310,21 +332,21 @@ func lookupPrincipalName(username, domain string) (string, error) { } // prepareS4ULogonStructure creates the appropriate S4U logon structure -func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) { +func prepareS4ULogonStructure(logger *log.Entry, username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) { if isDomainUser { - return prepareDomainS4ULogon(username, domain) + return prepareDomainS4ULogon(logger, username, domain) } - return prepareLocalS4ULogon(username) + return prepareLocalS4ULogon(logger, username) } // prepareDomainS4ULogon creates S4U logon structure for domain users -func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) { +func prepareDomainS4ULogon(logger *log.Entry, username, domain string) (unsafe.Pointer, uintptr, error) { upn, err := lookupPrincipalName(username, domain) if err != nil { return nil, 0, fmt.Errorf("lookup principal name: %w", err) } - log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn) + logger.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn) upnUtf16, err := windows.UTF16FromString(upn) if err != nil { @@ -357,8 +379,8 @@ func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, er } // prepareLocalS4ULogon creates S4U logon structure for local users -func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) { - log.Debugf("using Msv1_0S4ULogon for local user: %s", username) +func prepareLocalS4ULogon(logger *log.Entry, username string) (unsafe.Pointer, uintptr, error) { + logger.Debugf("using Msv1_0S4ULogon for local user: %s", username) usernameUtf16, err := windows.UTF16FromString(username) if err != nil { @@ -406,11 +428,11 @@ func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) { } // performS4ULogon executes the S4U logon operation -func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) { +func performS4ULogon(logger *log.Entry, lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) { var tokenSource tokenSource copy(tokenSource.SourceName[:], "netbird") if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 { - log.Debugf("AllocateLocallyUniqueId failed") + logger.Debugf("AllocateLocallyUniqueId failed") } originName := newLsaString("netbird") @@ -441,7 +463,7 @@ func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo u if profile != 0 { if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess { - log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret) + logger.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret) } } @@ -449,7 +471,7 @@ func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo u return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus) } - log.Debugf("created S4U %s token for user %s", + logger.Debugf("created S4U %s token for user %s", map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn) return token, nil } @@ -497,8 +519,8 @@ func (pd *PrivilegeDropper) isLocalUser(domain string) bool { // authenticateLocalUser handles authentication for local users func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) { - log.Debugf("using S4U authentication for local user %s", fullUsername) - token, err := generateS4UUserToken(username, ".") + pd.log().Debugf("using S4U authentication for local user %s", fullUsername) + token, err := generateS4UUserToken(pd.log(), username, ".") if err != nil { return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err) } @@ -507,12 +529,12 @@ func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) // authenticateDomainUser handles authentication for domain users func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) { - log.Debugf("using S4U authentication for domain user %s", fullUsername) - token, err := generateS4UUserToken(username, domain) + pd.log().Debugf("using S4U authentication for domain user %s", fullUsername) + token, err := generateS4UUserToken(pd.log(), username, domain) if err != nil { return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err) } - log.Debugf("Successfully created S4U token for domain user %s", fullUsername) + pd.log().Debugf("successfully created S4U token for domain user %s", fullUsername) return token, nil } @@ -526,7 +548,7 @@ func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, exec defer func() { if err := windows.CloseHandle(token); err != nil { - log.Debugf("close impersonation token: %v", err) + pd.log().Debugf("close impersonation token: %v", err) } }() @@ -564,7 +586,7 @@ func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceTo return cmd, primaryToken, nil } -// createSuCommand creates a command using su -l -c for privilege switching (Windows stub) -func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) { +// createSuCommand creates a command using su - for privilege switching (Windows stub). +func (s *Server) createSuCommand(*log.Entry, ssh.Session, *user.User, bool) (*exec.Cmd, error) { return nil, fmt.Errorf("su command not available on Windows") } diff --git a/client/ssh/server/port_forwarding.go b/client/ssh/server/port_forwarding.go index c60cf4f58..e16ff5d46 100644 --- a/client/ssh/server/port_forwarding.go +++ b/client/ssh/server/port_forwarding.go @@ -271,13 +271,6 @@ func (s *Server) isRemotePortForwardingAllowed() bool { return s.allowRemotePortForwarding } -// isPortForwardingEnabled checks if any port forwarding (local or remote) is enabled -func (s *Server) isPortForwardingEnabled() bool { - s.mu.RLock() - defer s.mu.RUnlock() - return s.allowLocalPortForwarding || s.allowRemotePortForwarding -} - // parseTcpipForwardRequest parses the SSH request payload func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) { var payload tcpipForwardMsg diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index e897bbade..1ddb60f8e 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -335,7 +335,7 @@ func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) { sessions = append(sessions, info) } - // Add authenticated connections without sessions (e.g., -N/-T or port-forwarding only) + // Add authenticated connections without sessions (e.g., -N or port-forwarding only) for key, connState := range s.connections { remoteAddr := string(key) if reportedAddrs[remoteAddr] { diff --git a/client/ssh/server/server_config_test.go b/client/ssh/server/server_config_test.go index d85d85a51..f70e29963 100644 --- a/client/ssh/server/server_config_test.go +++ b/client/ssh/server/server_config_test.go @@ -483,12 +483,11 @@ func TestServer_IsPrivilegedUser(t *testing.T) { } } -func TestServer_PortForwardingOnlySession(t *testing.T) { - // Test that sessions without PTY and command are allowed when port forwarding is enabled +func TestServer_NonPtyShellSession(t *testing.T) { + // Test that non-PTY shell sessions (ssh -T) work regardless of port forwarding settings. currentUser, err := user.Current() require.NoError(t, err, "Should be able to get current user") - // Generate host key for server hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) require.NoError(t, err) @@ -496,36 +495,26 @@ func TestServer_PortForwardingOnlySession(t *testing.T) { name string allowLocalForwarding bool allowRemoteForwarding bool - expectAllowed bool - description string }{ { - name: "session_allowed_with_local_forwarding", + name: "shell_with_local_forwarding_enabled", allowLocalForwarding: true, allowRemoteForwarding: false, - expectAllowed: true, - description: "Port-forwarding-only session should be allowed when local forwarding is enabled", }, { - name: "session_allowed_with_remote_forwarding", + name: "shell_with_remote_forwarding_enabled", allowLocalForwarding: false, allowRemoteForwarding: true, - expectAllowed: true, - description: "Port-forwarding-only session should be allowed when remote forwarding is enabled", }, { - name: "session_allowed_with_both", + name: "shell_with_both_forwarding_enabled", allowLocalForwarding: true, allowRemoteForwarding: true, - expectAllowed: true, - description: "Port-forwarding-only session should be allowed when both forwarding types enabled", }, { - name: "session_denied_without_forwarding", + name: "shell_with_forwarding_disabled", allowLocalForwarding: false, allowRemoteForwarding: false, - expectAllowed: false, - description: "Port-forwarding-only session should be denied when all forwarding is disabled", }, } @@ -545,7 +534,6 @@ func TestServer_PortForwardingOnlySession(t *testing.T) { _ = server.Stop() }() - // Connect to the server without requesting PTY or command ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -557,20 +545,10 @@ func TestServer_PortForwardingOnlySession(t *testing.T) { _ = client.Close() }() - // Execute a command without PTY - this simulates ssh -T with no command - // The server should either allow it (port forwarding enabled) or reject it - output, err := client.ExecuteCommand(ctx, "") - if tt.expectAllowed { - // When allowed, the session stays open until cancelled - // ExecuteCommand with empty command should return without error - assert.NoError(t, err, "Session should be allowed when port forwarding is enabled") - assert.NotContains(t, output, "port forwarding is disabled", - "Output should not contain port forwarding disabled message") - } else if err != nil { - // When denied, we expect an error message about port forwarding being disabled - assert.Contains(t, err.Error(), "port forwarding is disabled", - "Should get port forwarding disabled message") - } + // Execute without PTY and no command - simulates ssh -T (shell without PTY) + // Should always succeed regardless of port forwarding settings + _, err = client.ExecuteCommand(ctx, "") + assert.NoError(t, err, "Non-PTY shell session should be allowed") }) } } diff --git a/client/ssh/server/server_test.go b/client/ssh/server/server_test.go index 661068539..89fab717f 100644 --- a/client/ssh/server/server_test.go +++ b/client/ssh/server/server_test.go @@ -405,12 +405,14 @@ func TestSSHServer_WindowsShellHandling(t *testing.T) { assert.Equal(t, "-Command", args[1]) assert.Equal(t, "echo test", args[2]) } else { - // Test Unix shell behavior args := server.getShellCommandArgs("/bin/sh", "echo test") assert.Equal(t, "/bin/sh", args[0]) - assert.Equal(t, "-l", args[1]) - assert.Equal(t, "-c", args[2]) - assert.Equal(t, "echo test", args[3]) + assert.Equal(t, "-c", args[1]) + assert.Equal(t, "echo test", args[2]) + + args = server.getShellCommandArgs("/bin/sh", "") + assert.Equal(t, "/bin/sh", args[0]) + assert.Len(t, args, 1) } } diff --git a/client/ssh/server/session_handlers.go b/client/ssh/server/session_handlers.go index 3fd578064..f12a75961 100644 --- a/client/ssh/server/session_handlers.go +++ b/client/ssh/server/session_handlers.go @@ -62,54 +62,12 @@ func (s *Server) sessionHandler(session ssh.Session) { ptyReq, winCh, isPty := session.Pty() hasCommand := len(session.Command()) > 0 - switch { - case isPty && hasCommand: - // ssh -t - Pty command execution - s.handleCommand(logger, session, privilegeResult, winCh) - case isPty: - // ssh - Pty interactive session (login) - s.handlePty(logger, session, privilegeResult, ptyReq, winCh) - case hasCommand: - // ssh - non-Pty command execution - s.handleCommand(logger, session, privilegeResult, nil) - default: - // ssh -T (or ssh -N) - no PTY, no command - s.handleNonInteractiveSession(logger, session) - } -} - -// handleNonInteractiveSession handles sessions that have no PTY and no command. -// These are typically used for port forwarding (ssh -L/-R) or tunneling (ssh -N). -func (s *Server) handleNonInteractiveSession(logger *log.Entry, session ssh.Session) { - s.updateSessionType(session, cmdNonInteractive) - - if !s.isPortForwardingEnabled() { - if _, err := io.WriteString(session, "port forwarding is disabled on this server\n"); err != nil { - logger.Debugf(errWriteSession, err) - } - if err := session.Exit(1); err != nil { - logSessionExitError(logger, err) - } - logger.Infof("rejected non-interactive session: port forwarding disabled") - return - } - - <-session.Context().Done() - - if err := session.Exit(0); err != nil { - logSessionExitError(logger, err) - } -} - -func (s *Server) updateSessionType(session ssh.Session, sessionType string) { - s.mu.Lock() - defer s.mu.Unlock() - - for _, state := range s.sessions { - if state.session == session { - state.sessionType = sessionType - return - } + if isPty && !hasCommand { + // ssh - PTY interactive session (login) + s.handlePtyLogin(logger, session, privilegeResult, ptyReq, winCh) + } else { + // ssh , ssh -t , ssh -T - command or shell execution + s.handleExecution(logger, session, privilegeResult, ptyReq, winCh) } } diff --git a/client/ssh/server/session_handlers_js.go b/client/ssh/server/session_handlers_js.go index c35e4da0b..4a6cf3d92 100644 --- a/client/ssh/server/session_handlers_js.go +++ b/client/ssh/server/session_handlers_js.go @@ -9,8 +9,8 @@ import ( log "github.com/sirupsen/logrus" ) -// handlePty is not supported on JS/WASM -func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool { +// handlePtyLogin is not supported on JS/WASM +func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool { errorMsg := "PTY sessions are not supported on WASM/JS platform\n" if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil { logger.Debugf(errWriteSession, err) diff --git a/client/ssh/server/userswitching_unix.go b/client/ssh/server/userswitching_unix.go index bc1557419..d80b77042 100644 --- a/client/ssh/server/userswitching_unix.go +++ b/client/ssh/server/userswitching_unix.go @@ -181,8 +181,8 @@ func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) { // createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping. // Returns the command and a cleanup function (no-op on Unix). -func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) { - log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty) +func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) { + logger.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty) if err := validateUsername(localUser.Username); err != nil { return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err) @@ -192,7 +192,7 @@ func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User if err != nil { return nil, nil, fmt.Errorf("parse user credentials: %w", err) } - privilegeDropper := NewPrivilegeDropper() + privilegeDropper := NewPrivilegeDropper(WithLogger(logger)) config := ExecutorConfig{ UID: uid, GID: gid, @@ -233,7 +233,7 @@ func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.Use shell := getUserShell(localUser.Uid) args := s.getShellCommandArgs(shell, session.RawCommand()) - cmd := exec.CommandContext(session.Context(), args[0], args[1:]...) + cmd := s.createShellCommand(session.Context(), shell, args) cmd.Dir = localUser.HomeDir cmd.Env = s.preparePtyEnv(localUser, ptyReq, session) diff --git a/client/ssh/server/userswitching_windows.go b/client/ssh/server/userswitching_windows.go index 5a5f75fa4..260e1301e 100644 --- a/client/ssh/server/userswitching_windows.go +++ b/client/ssh/server/userswitching_windows.go @@ -88,20 +88,20 @@ func validateUsernameFormat(username string) error { // createExecutorCommand creates a command using Windows executor for privilege dropping. // Returns the command and a cleanup function that must be called after starting the process. -func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) { - log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty) +func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) { + logger.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty) username, _ := s.parseUsername(localUser.Username) if err := validateUsername(username); err != nil { return nil, nil, fmt.Errorf("invalid username %q: %w", username, err) } - return s.createUserSwitchCommand(localUser, session, hasPty) + return s.createUserSwitchCommand(logger, session, localUser) } // createUserSwitchCommand creates a command with Windows user switching. // Returns the command and a cleanup function that must be called after starting the process. -func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, func(), error) { +func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session, localUser *user.User) (*exec.Cmd, func(), error) { username, domain := s.parseUsername(localUser.Username) shell := getUserShell(localUser.Uid) @@ -113,15 +113,14 @@ func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Sessi } config := WindowsExecutorConfig{ - Username: username, - Domain: domain, - WorkingDir: localUser.HomeDir, - Shell: shell, - Command: command, - Interactive: interactive || (rawCmd == ""), + Username: username, + Domain: domain, + WorkingDir: localUser.HomeDir, + Shell: shell, + Command: command, } - dropper := NewPrivilegeDropper() + dropper := NewPrivilegeDropper(WithLogger(logger)) cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config) if err != nil { return nil, nil, err @@ -130,7 +129,7 @@ func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Sessi cleanup := func() { if token != 0 { if err := windows.CloseHandle(windows.Handle(token)); err != nil { - log.Debugf("close primary token: %v", err) + logger.Debugf("close primary token: %v", err) } } } diff --git a/client/ssh/server/winpty/conpty.go b/client/ssh/server/winpty/conpty.go index 0f3659ffe..c08ccfd05 100644 --- a/client/ssh/server/winpty/conpty.go +++ b/client/ssh/server/winpty/conpty.go @@ -56,7 +56,7 @@ var ( ) // ExecutePtyWithUserToken executes a command with ConPty using user token. -func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error { +func ExecutePtyWithUserToken(session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error { args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command) commandLine := buildCommandLine(args) @@ -64,7 +64,7 @@ func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig Pty: ptyConfig, User: userConfig, Session: session, - Context: ctx, + Context: session.Context(), } return executeConPtyWithConfig(commandLine, config)