[client] Support non-PTY no-command interactive SSH sessions (#5093)

This commit is contained in:
Viktor Liu
2026-01-27 18:05:04 +08:00
committed by GitHub
parent d4f7df271a
commit 06966da012
17 changed files with 461 additions and 270 deletions

View File

@@ -207,8 +207,6 @@ func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
}
func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) {
// Create a backend session to mirror the client's session request.
// This keeps the connection alive on the server side while port forwarding channels operate.
serverSession, err := sshClient.NewSession()
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
@@ -216,10 +214,28 @@ func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *c
}
defer func() { _ = serverSession.Close() }()
<-session.Context().Done()
serverSession.Stdin = session
serverSession.Stdout = session
serverSession.Stderr = session.Stderr()
if err := session.Exit(0); err != nil {
log.Debugf("session exit: %v", err)
if err := serverSession.Shell(); err != nil {
log.Debugf("start shell: %v", err)
return
}
done := make(chan error, 1)
go func() {
done <- serverSession.Wait()
}()
select {
case <-session.Context().Done():
return
case err := <-done:
if err != nil {
log.Debugf("shell session: %v", err)
p.handleProxyExitCode(session, err)
}
}
}

View File

@@ -12,8 +12,8 @@ import (
log "github.com/sirupsen/logrus"
)
// handleCommand executes an SSH command with privilege validation
func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) {
// handleExecution executes an SSH command or shell with privilege validation
func (s *Server) handleExecution(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) {
hasPty := winCh != nil
commandType := "command"
@@ -23,7 +23,7 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty)
execCmd, cleanup, err := s.createCommand(logger, privilegeResult, session, hasPty)
if err != nil {
logger.Errorf("%s creation failed: %v", commandType, err)
@@ -51,13 +51,12 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
defer cleanup()
ptyReq, _, _ := session.Pty()
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
logger.Debugf("%s execution completed", commandType)
}
}
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
func (s *Server) createCommand(logger *log.Entry, privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
localUser := privilegeResult.User
if localUser == nil {
return nil, nil, errors.New("no user in privilege result")
@@ -66,28 +65,28 @@ func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh
// If PTY requested but su doesn't support --pty, skip su and use executor
// This ensures PTY functionality is provided (executor runs within our allocated PTY)
if hasPty && !s.suSupportsPty {
log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
logger.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
if err != nil {
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
}
cmd.Env = s.prepareCommandEnv(localUser, session)
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
return cmd, cleanup, nil
}
// Try su first for system integration (PAM/audit) when privileged
cmd, err := s.createSuCommand(session, localUser, hasPty)
cmd, err := s.createSuCommand(logger, session, localUser, hasPty)
if err != nil || privilegeResult.UsedFallback {
log.Debugf("su command failed, falling back to executor: %v", err)
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
logger.Debugf("su command failed, falling back to executor: %v", err)
cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
if err != nil {
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
}
cmd.Env = s.prepareCommandEnv(localUser, session)
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
return cmd, cleanup, nil
}
cmd.Env = s.prepareCommandEnv(localUser, session)
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
return cmd, func() {}, nil
}

View File

@@ -15,17 +15,17 @@ import (
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
// createSuCommand is not supported on JS/WASM
func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
func (s *Server) createSuCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
return nil, errNotSupported
}
// createExecutorCommand is not supported on JS/WASM
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
func (s *Server) createExecutorCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
return nil, nil, errNotSupported
}
// prepareCommandEnv is not supported on JS/WASM
func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string {
func (s *Server) prepareCommandEnv(_ *log.Entry, _ *user.User, _ ssh.Session) []string {
return nil
}

View File

@@ -10,6 +10,7 @@ import (
"os"
"os/exec"
"os/user"
"path/filepath"
"runtime"
"strings"
"sync"
@@ -99,40 +100,52 @@ func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool {
return isUtilLinux
}
// createSuCommand creates a command using su -l -c for privilege switching
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
// createSuCommand creates a command using su - for privilege switching.
func (s *Server) createSuCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
if err := validateUsername(localUser.Username); err != nil {
return nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
}
suPath, err := exec.LookPath("su")
if err != nil {
return nil, fmt.Errorf("su command not available: %w", err)
}
command := session.RawCommand()
if command == "" {
return nil, fmt.Errorf("no command specified for su execution")
}
args := []string{"-l"}
args := []string{"-"}
if hasPty && s.suSupportsPty {
args = append(args, "--pty")
}
args = append(args, localUser.Username, "-c", command)
args = append(args, localUser.Username)
command := session.RawCommand()
if command != "" {
args = append(args, "-c", command)
}
logger.Debugf("creating su command: %s %v", suPath, args)
cmd := exec.CommandContext(session.Context(), suPath, args...)
cmd.Dir = localUser.HomeDir
return cmd, nil
}
// getShellCommandArgs returns the shell command and arguments for executing a command string
// getShellCommandArgs returns the shell command and arguments for executing a command string.
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
if cmdString == "" {
return []string{shell, "-l"}
return []string{shell}
}
return []string{shell, "-l", "-c", cmdString}
return []string{shell, "-c", cmdString}
}
// createShellCommand creates an exec.Cmd configured as a login shell by setting argv[0] to "-shellname".
func (s *Server) createShellCommand(ctx context.Context, shell string, args []string) *exec.Cmd {
cmd := exec.CommandContext(ctx, shell, args[1:]...)
cmd.Args[0] = "-" + filepath.Base(shell)
return cmd
}
// prepareCommandEnv prepares environment variables for command execution on Unix
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
func (s *Server) prepareCommandEnv(_ *log.Entry, localUser *user.User, session ssh.Session) []string {
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
@@ -154,7 +167,7 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, e
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
}
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
if err != nil {
logger.Errorf("Pty command creation failed: %v", err)
@@ -244,11 +257,6 @@ func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *pty
}()
go func() {
defer func() {
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
logger.Debugf("session close error: %v", err)
}
}()
if _, err := io.Copy(session, ptmx); err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
logger.Warnf("Pty output copy error: %v", err)
@@ -268,7 +276,7 @@ func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, ex
case <-ctx.Done():
s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done)
case err := <-done:
s.handlePtyCommandCompletion(logger, session, err)
s.handlePtyCommandCompletion(logger, session, ptyMgr, err)
}
}
@@ -296,17 +304,20 @@ func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Ses
}
}
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) {
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, err error) {
if err != nil {
logger.Debugf("Pty command execution failed: %v", err)
s.handleSessionExit(session, err, logger)
return
} else {
logger.Debugf("Pty command completed successfully")
if err := session.Exit(0); err != nil {
logSessionExitError(logger, err)
}
}
// Normal completion
logger.Debugf("Pty command completed successfully")
if err := session.Exit(0); err != nil {
logSessionExitError(logger, err)
// Close PTY to unblock io.Copy goroutines
if err := ptyMgr.Close(); err != nil {
logger.Debugf("Pty close after completion: %v", err)
}
}

View File

@@ -20,32 +20,32 @@ import (
// getUserEnvironment retrieves the Windows environment for the target user.
// Follows OpenSSH's resilient approach with graceful degradation on failures.
func (s *Server) getUserEnvironment(username, domain string) ([]string, error) {
userToken, err := s.getUserToken(username, domain)
func (s *Server) getUserEnvironment(logger *log.Entry, username, domain string) ([]string, error) {
userToken, err := s.getUserToken(logger, username, domain)
if err != nil {
return nil, fmt.Errorf("get user token: %w", err)
}
defer func() {
if err := windows.CloseHandle(userToken); err != nil {
log.Debugf("close user token: %v", err)
logger.Debugf("close user token: %v", err)
}
}()
return s.getUserEnvironmentWithToken(userToken, username, domain)
return s.getUserEnvironmentWithToken(logger, userToken, username, domain)
}
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) {
func (s *Server) getUserEnvironmentWithToken(logger *log.Entry, userToken windows.Handle, username, domain string) ([]string, error) {
userProfile, err := s.loadUserProfile(userToken, username, domain)
if err != nil {
log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
logger.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
userProfile = fmt.Sprintf("C:\\Users\\%s", username)
}
envMap := make(map[string]string)
if err := s.loadSystemEnvironment(envMap); err != nil {
log.Debugf("failed to load system environment from registry: %v", err)
logger.Debugf("failed to load system environment from registry: %v", err)
}
s.setUserEnvironmentVariables(envMap, userProfile, username, domain)
@@ -59,8 +59,8 @@ func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username,
}
// getUserToken creates a user token for the specified user.
func (s *Server) getUserToken(username, domain string) (windows.Handle, error) {
privilegeDropper := NewPrivilegeDropper()
func (s *Server) getUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
token, err := privilegeDropper.createToken(username, domain)
if err != nil {
return 0, fmt.Errorf("generate S4U user token: %w", err)
@@ -242,9 +242,9 @@ func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfi
}
// prepareCommandEnv prepares environment variables for command execution on Windows
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, session ssh.Session) []string {
username, domain := s.parseUsername(localUser.Username)
userEnv, err := s.getUserEnvironment(username, domain)
userEnv, err := s.getUserEnvironment(logger, username, domain)
if err != nil {
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
@@ -267,22 +267,16 @@ func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []
return env
}
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool {
if privilegeResult.User == nil {
logger.Errorf("no user in privilege result")
return false
}
cmd := session.Command()
shell := getUserShell(privilegeResult.User.Uid)
logger.Infof("starting interactive shell: %s", shell)
if len(cmd) == 0 {
logger.Infof("starting interactive shell: %s", shell)
} else {
logger.Infof("executing command: %s", safeLogCommand(cmd))
}
s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
s.executeCommandWithPty(logger, session, nil, privilegeResult, ptyReq, nil)
return true
}
@@ -294,11 +288,6 @@ func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
return []string{shell, "-Command", cmdString}
}
func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
logger.Info("starting interactive shell")
s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand())
}
type PtyExecutionRequest struct {
Shell string
Command string
@@ -308,25 +297,25 @@ type PtyExecutionRequest struct {
Domain string
}
func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error {
log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
func executePtyCommandWithUserToken(logger *log.Entry, session ssh.Session, req PtyExecutionRequest) error {
logger.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height)
privilegeDropper := NewPrivilegeDropper()
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
userToken, err := privilegeDropper.createToken(req.Username, req.Domain)
if err != nil {
return fmt.Errorf("create user token: %w", err)
}
defer func() {
if err := windows.CloseHandle(userToken); err != nil {
log.Debugf("close user token: %v", err)
logger.Debugf("close user token: %v", err)
}
}()
server := &Server{}
userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain)
userEnv, err := server.getUserEnvironmentWithToken(logger, userToken, req.Username, req.Domain)
if err != nil {
log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
logger.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
userEnv = os.Environ()
}
@@ -348,8 +337,8 @@ func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, re
Environment: userEnv,
}
log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig)
logger.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
return winpty.ExecutePtyWithUserToken(session, ptyConfig, userConfig)
}
func getUserHomeFromEnv(env []string) string {
@@ -371,10 +360,8 @@ func (s *Server) killProcessGroup(cmd *exec.Cmd) {
return
}
logger := log.WithField("pid", cmd.Process.Pid)
if err := cmd.Process.Kill(); err != nil {
logger.Debugf("kill process failed: %v", err)
log.Debugf("kill process %d failed: %v", cmd.Process.Pid, err)
}
}
@@ -389,21 +376,7 @@ func (s *Server) detectUtilLinuxLogin(context.Context) bool {
}
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
command := session.RawCommand()
if command == "" {
logger.Error("no command specified for PTY execution")
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}
return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command)
}
// executeConPtyCommand executes a command using ConPty (common for interactive and command execution)
func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool {
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _ *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool {
localUser := privilegeResult.User
if localUser == nil {
logger.Errorf("no user in privilege result")
@@ -415,14 +388,14 @@ func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, pr
req := PtyExecutionRequest{
Shell: shell,
Command: command,
Command: session.RawCommand(),
Width: ptyReq.Window.Width,
Height: ptyReq.Window.Height,
Username: username,
Domain: domain,
}
if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil {
if err := executePtyCommandWithUserToken(logger, session, req); err != nil {
logger.Errorf("ConPty execution failed: %v", err)
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)

View File

@@ -4,12 +4,15 @@ import (
"context"
"crypto/ed25519"
"crypto/rand"
"errors"
"fmt"
"io"
"net"
"os"
"os/exec"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
"time"
@@ -23,25 +26,67 @@ import (
"github.com/netbirdio/netbird/client/ssh/testutil"
)
// TestMain handles package-level setup and cleanup
func TestMain(m *testing.M) {
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
// This happens when running tests as non-privileged user with fallback
// On platforms where su doesn't support --pty (macOS, FreeBSD, Windows), the SSH server
// spawns an executor subprocess via os.Executable(). During tests, this invokes the test
// binary with "ssh exec" args. We handle that here to properly execute commands and
// propagate exit codes.
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
// Just exit with error to break the recursion
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
os.Exit(1)
runTestExecutor()
return
}
// Run tests
code := m.Run()
// Cleanup any created test users
testutil.CleanupTestUsers()
os.Exit(code)
}
// runTestExecutor emulates the netbird executor for tests.
// Parses --shell and --cmd args, runs the command, and exits with the correct code.
func runTestExecutor() {
if os.Getenv("_NETBIRD_TEST_EXECUTOR") != "" {
fmt.Fprintf(os.Stderr, "executor recursion detected\n")
os.Exit(1)
}
os.Setenv("_NETBIRD_TEST_EXECUTOR", "1")
shell := "/bin/sh"
var command string
for i := 3; i < len(os.Args); i++ {
switch os.Args[i] {
case "--shell":
if i+1 < len(os.Args) {
shell = os.Args[i+1]
i++
}
case "--cmd":
if i+1 < len(os.Args) {
command = os.Args[i+1]
i++
}
}
}
var cmd *exec.Cmd
if command == "" {
cmd = exec.Command(shell)
} else {
cmd = exec.Command(shell, "-c", command)
}
cmd.Args[0] = "-" + filepath.Base(shell)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
os.Exit(exitErr.ExitCode())
}
os.Exit(1)
}
os.Exit(0)
}
// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client
func TestSSHServerCompatibility(t *testing.T) {
if testing.Short() {
@@ -405,6 +450,171 @@ func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) {
return createTempKeyFileFromBytes(t, privateKey)
}
// TestSSHPtyModes tests different PTY allocation modes (-T, -t, -tt flags)
// This ensures our implementation matches OpenSSH behavior for:
// - ssh host command (no PTY - default when no TTY)
// - ssh -T host command (explicit no PTY)
// - ssh -t host command (force PTY)
// - ssh -T host (no PTY shell - our implementation)
func TestSSHPtyModes(t *testing.T) {
if testing.Short() {
t.Skip("Skipping SSH PTY mode tests in short mode")
}
if !isSSHClientAvailable() {
t.Skip("SSH client not available on this system")
}
if runtime.GOOS == "windows" && testutil.IsCI() {
t.Skip("Skipping Windows SSH PTY tests in CI due to S4U authentication issues")
}
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH)
defer cleanupKey()
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
username := testutil.GetTestUsername(t)
baseArgs := []string{
"-i", clientKeyFile,
"-p", portStr,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
"-o", "BatchMode=yes",
}
t.Run("command_default_no_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), fmt.Sprintf("%s@%s", username, host), "echo", "no_pty_default")
cmd := exec.Command("ssh", args...)
output, err := cmd.CombinedOutput()
require.NoError(t, err, "Command (default no PTY) failed: %s", output)
assert.Contains(t, string(output), "no_pty_default")
})
t.Run("command_explicit_no_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "echo", "explicit_no_pty")
cmd := exec.Command("ssh", args...)
output, err := cmd.CombinedOutput()
require.NoError(t, err, "Command (-T explicit no PTY) failed: %s", output)
assert.Contains(t, string(output), "explicit_no_pty")
})
t.Run("command_force_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "echo", "force_pty")
cmd := exec.Command("ssh", args...)
output, err := cmd.CombinedOutput()
require.NoError(t, err, "Command (-tt force PTY) failed: %s", output)
assert.Contains(t, string(output), "force_pty")
})
t.Run("shell_explicit_no_pty", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host))
cmd := exec.CommandContext(ctx, "ssh", args...)
stdin, err := cmd.StdinPipe()
require.NoError(t, err)
stdout, err := cmd.StdoutPipe()
require.NoError(t, err)
require.NoError(t, cmd.Start(), "Shell (-T no PTY) start failed")
go func() {
defer stdin.Close()
time.Sleep(100 * time.Millisecond)
_, err := stdin.Write([]byte("echo shell_no_pty_test\n"))
assert.NoError(t, err, "write echo command")
time.Sleep(100 * time.Millisecond)
_, err = stdin.Write([]byte("exit 0\n"))
assert.NoError(t, err, "write exit command")
}()
output, _ := io.ReadAll(stdout)
err = cmd.Wait()
require.NoError(t, err, "Shell (-T no PTY) failed: %s", output)
assert.Contains(t, string(output), "shell_no_pty_test")
})
t.Run("exit_code_preserved_no_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "exit", "42")
cmd := exec.Command("ssh", args...)
err := cmd.Run()
require.Error(t, err, "Command should exit with non-zero")
var exitErr *exec.ExitError
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
assert.Equal(t, 42, exitErr.ExitCode(), "Exit code should be preserved with -T")
})
t.Run("exit_code_preserved_with_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "sh -c 'exit 43'")
cmd := exec.Command("ssh", args...)
err := cmd.Run()
require.Error(t, err, "PTY command should exit with non-zero")
var exitErr *exec.ExitError
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
assert.Equal(t, 43, exitErr.ExitCode(), "Exit code should be preserved with -tt")
})
t.Run("stderr_works_no_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host),
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
cmd := exec.Command("ssh", args...)
var stdout, stderr strings.Builder
cmd.Stdout = &stdout
cmd.Stderr = &stderr
require.NoError(t, cmd.Run(), "stderr test failed")
assert.Contains(t, stdout.String(), "stdout_msg", "stdout should have stdout_msg")
assert.Contains(t, stderr.String(), "stderr_msg", "stderr should have stderr_msg")
assert.NotContains(t, stdout.String(), "stderr_msg", "stdout should NOT have stderr_msg")
})
t.Run("stderr_merged_with_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host),
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
cmd := exec.Command("ssh", args...)
output, err := cmd.CombinedOutput()
require.NoError(t, err, "PTY stderr test failed: %s", output)
assert.Contains(t, string(output), "stdout_msg")
assert.Contains(t, string(output), "stderr_msg")
})
}
// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility
func TestSSHServerFeatureCompatibility(t *testing.T) {
if testing.Short() {

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"syscall"
@@ -35,11 +36,35 @@ type ExecutorConfig struct {
}
// PrivilegeDropper handles secure privilege dropping in child processes
type PrivilegeDropper struct{}
type PrivilegeDropper struct {
logger *log.Entry
}
// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper
type PrivilegeDropperOption func(*PrivilegeDropper)
// NewPrivilegeDropper creates a new privilege dropper
func NewPrivilegeDropper() *PrivilegeDropper {
return &PrivilegeDropper{}
func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper {
pd := &PrivilegeDropper{}
for _, opt := range opts {
opt(pd)
}
return pd
}
// WithLogger sets the logger for the PrivilegeDropper
func WithLogger(logger *log.Entry) PrivilegeDropperOption {
return func(pd *PrivilegeDropper) {
pd.logger = logger
}
}
// log returns the logger, falling back to standard logger if none set
func (pd *PrivilegeDropper) log() *log.Entry {
if pd.logger != nil {
return pd.logger
}
return log.NewEntry(log.StandardLogger())
}
// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
@@ -83,7 +108,7 @@ func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config Ex
break
}
}
log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
pd.log().Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
return exec.CommandContext(ctx, netbirdPath, args...), nil
}
@@ -206,17 +231,22 @@ func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config
var execCmd *exec.Cmd
if config.Command == "" {
os.Exit(ExitCodeSuccess)
execCmd = exec.CommandContext(ctx, config.Shell)
} else {
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
}
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
execCmd.Args[0] = "-" + filepath.Base(config.Shell)
execCmd.Stdin = os.Stdin
execCmd.Stdout = os.Stdout
execCmd.Stderr = os.Stderr
cmdParts := strings.Fields(config.Command)
safeCmd := safeLogCommand(cmdParts)
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
if config.Command == "" {
log.Tracef("executing login shell: %s", execCmd.Path)
} else {
cmdParts := strings.Fields(config.Command)
safeCmd := safeLogCommand(cmdParts)
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
}
if err := execCmd.Run(); err != nil {
var exitError *exec.ExitError
if errors.As(err, &exitError) {

View File

@@ -28,22 +28,45 @@ const (
)
type WindowsExecutorConfig struct {
Username string
Domain string
WorkingDir string
Shell string
Command string
Args []string
Interactive bool
Pty bool
PtyWidth int
PtyHeight int
Username string
Domain string
WorkingDir string
Shell string
Command string
Args []string
Pty bool
PtyWidth int
PtyHeight int
}
type PrivilegeDropper struct{}
type PrivilegeDropper struct {
logger *log.Entry
}
func NewPrivilegeDropper() *PrivilegeDropper {
return &PrivilegeDropper{}
// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper
type PrivilegeDropperOption func(*PrivilegeDropper)
func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper {
pd := &PrivilegeDropper{}
for _, opt := range opts {
opt(pd)
}
return pd
}
// WithLogger sets the logger for the PrivilegeDropper
func WithLogger(logger *log.Entry) PrivilegeDropperOption {
return func(pd *PrivilegeDropper) {
pd.logger = logger
}
}
// log returns the logger, falling back to standard logger if none set
func (pd *PrivilegeDropper) log() *log.Entry {
if pd.logger != nil {
return pd.logger
}
return log.NewEntry(log.StandardLogger())
}
var (
@@ -56,7 +79,6 @@ const (
// Common error messages
commandFlag = "-Command"
closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials
convertUsernameError = "convert username to UTF16: %w"
convertDomainError = "convert domain to UTF16: %w"
)
@@ -80,7 +102,7 @@ func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, co
shellArgs = []string{shell}
}
log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
pd.log().Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
cmd, token, err := pd.CreateWindowsProcessAsUser(
ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
@@ -180,10 +202,10 @@ func newLsaString(s string) lsaString {
// generateS4UUserToken creates a Windows token using S4U authentication
// This is the exact approach OpenSSH for Windows uses for public key authentication
func generateS4UUserToken(username, domain string) (windows.Handle, error) {
func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
userCpn := buildUserCpn(username, domain)
pd := NewPrivilegeDropper()
pd := NewPrivilegeDropper(WithLogger(logger))
isDomainUser := !pd.isLocalUser(domain)
lsaHandle, err := initializeLsaConnection()
@@ -197,12 +219,12 @@ func generateS4UUserToken(username, domain string) (windows.Handle, error) {
return 0, err
}
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser)
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(logger, username, domain, isDomainUser)
if err != nil {
return 0, err
}
return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
return performS4ULogon(logger, lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
}
// buildUserCpn constructs the user principal name
@@ -310,21 +332,21 @@ func lookupPrincipalName(username, domain string) (string, error) {
}
// prepareS4ULogonStructure creates the appropriate S4U logon structure
func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
func prepareS4ULogonStructure(logger *log.Entry, username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
if isDomainUser {
return prepareDomainS4ULogon(username, domain)
return prepareDomainS4ULogon(logger, username, domain)
}
return prepareLocalS4ULogon(username)
return prepareLocalS4ULogon(logger, username)
}
// prepareDomainS4ULogon creates S4U logon structure for domain users
func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) {
func prepareDomainS4ULogon(logger *log.Entry, username, domain string) (unsafe.Pointer, uintptr, error) {
upn, err := lookupPrincipalName(username, domain)
if err != nil {
return nil, 0, fmt.Errorf("lookup principal name: %w", err)
}
log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
logger.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
upnUtf16, err := windows.UTF16FromString(upn)
if err != nil {
@@ -357,8 +379,8 @@ func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, er
}
// prepareLocalS4ULogon creates S4U logon structure for local users
func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
log.Debugf("using Msv1_0S4ULogon for local user: %s", username)
func prepareLocalS4ULogon(logger *log.Entry, username string) (unsafe.Pointer, uintptr, error) {
logger.Debugf("using Msv1_0S4ULogon for local user: %s", username)
usernameUtf16, err := windows.UTF16FromString(username)
if err != nil {
@@ -406,11 +428,11 @@ func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
}
// performS4ULogon executes the S4U logon operation
func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
func performS4ULogon(logger *log.Entry, lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
var tokenSource tokenSource
copy(tokenSource.SourceName[:], "netbird")
if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 {
log.Debugf("AllocateLocallyUniqueId failed")
logger.Debugf("AllocateLocallyUniqueId failed")
}
originName := newLsaString("netbird")
@@ -441,7 +463,7 @@ func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo u
if profile != 0 {
if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess {
log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
logger.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
}
}
@@ -449,7 +471,7 @@ func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo u
return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus)
}
log.Debugf("created S4U %s token for user %s",
logger.Debugf("created S4U %s token for user %s",
map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn)
return token, nil
}
@@ -497,8 +519,8 @@ func (pd *PrivilegeDropper) isLocalUser(domain string) bool {
// authenticateLocalUser handles authentication for local users
func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) {
log.Debugf("using S4U authentication for local user %s", fullUsername)
token, err := generateS4UUserToken(username, ".")
pd.log().Debugf("using S4U authentication for local user %s", fullUsername)
token, err := generateS4UUserToken(pd.log(), username, ".")
if err != nil {
return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err)
}
@@ -507,12 +529,12 @@ func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string)
// authenticateDomainUser handles authentication for domain users
func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) {
log.Debugf("using S4U authentication for domain user %s", fullUsername)
token, err := generateS4UUserToken(username, domain)
pd.log().Debugf("using S4U authentication for domain user %s", fullUsername)
token, err := generateS4UUserToken(pd.log(), username, domain)
if err != nil {
return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err)
}
log.Debugf("Successfully created S4U token for domain user %s", fullUsername)
pd.log().Debugf("successfully created S4U token for domain user %s", fullUsername)
return token, nil
}
@@ -526,7 +548,7 @@ func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, exec
defer func() {
if err := windows.CloseHandle(token); err != nil {
log.Debugf("close impersonation token: %v", err)
pd.log().Debugf("close impersonation token: %v", err)
}
}()
@@ -564,7 +586,7 @@ func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceTo
return cmd, primaryToken, nil
}
// createSuCommand creates a command using su -l -c for privilege switching (Windows stub)
func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) {
// createSuCommand creates a command using su - for privilege switching (Windows stub).
func (s *Server) createSuCommand(*log.Entry, ssh.Session, *user.User, bool) (*exec.Cmd, error) {
return nil, fmt.Errorf("su command not available on Windows")
}

View File

@@ -271,13 +271,6 @@ func (s *Server) isRemotePortForwardingAllowed() bool {
return s.allowRemotePortForwarding
}
// isPortForwardingEnabled checks if any port forwarding (local or remote) is enabled
func (s *Server) isPortForwardingEnabled() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.allowLocalPortForwarding || s.allowRemotePortForwarding
}
// parseTcpipForwardRequest parses the SSH request payload
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
var payload tcpipForwardMsg

View File

@@ -335,7 +335,7 @@ func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
sessions = append(sessions, info)
}
// Add authenticated connections without sessions (e.g., -N/-T or port-forwarding only)
// Add authenticated connections without sessions (e.g., -N or port-forwarding only)
for key, connState := range s.connections {
remoteAddr := string(key)
if reportedAddrs[remoteAddr] {

View File

@@ -483,12 +483,11 @@ func TestServer_IsPrivilegedUser(t *testing.T) {
}
}
func TestServer_PortForwardingOnlySession(t *testing.T) {
// Test that sessions without PTY and command are allowed when port forwarding is enabled
func TestServer_NonPtyShellSession(t *testing.T) {
// Test that non-PTY shell sessions (ssh -T) work regardless of port forwarding settings.
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
// Generate host key for server
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
@@ -496,36 +495,26 @@ func TestServer_PortForwardingOnlySession(t *testing.T) {
name string
allowLocalForwarding bool
allowRemoteForwarding bool
expectAllowed bool
description string
}{
{
name: "session_allowed_with_local_forwarding",
name: "shell_with_local_forwarding_enabled",
allowLocalForwarding: true,
allowRemoteForwarding: false,
expectAllowed: true,
description: "Port-forwarding-only session should be allowed when local forwarding is enabled",
},
{
name: "session_allowed_with_remote_forwarding",
name: "shell_with_remote_forwarding_enabled",
allowLocalForwarding: false,
allowRemoteForwarding: true,
expectAllowed: true,
description: "Port-forwarding-only session should be allowed when remote forwarding is enabled",
},
{
name: "session_allowed_with_both",
name: "shell_with_both_forwarding_enabled",
allowLocalForwarding: true,
allowRemoteForwarding: true,
expectAllowed: true,
description: "Port-forwarding-only session should be allowed when both forwarding types enabled",
},
{
name: "session_denied_without_forwarding",
name: "shell_with_forwarding_disabled",
allowLocalForwarding: false,
allowRemoteForwarding: false,
expectAllowed: false,
description: "Port-forwarding-only session should be denied when all forwarding is disabled",
},
}
@@ -545,7 +534,6 @@ func TestServer_PortForwardingOnlySession(t *testing.T) {
_ = server.Stop()
}()
// Connect to the server without requesting PTY or command
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
@@ -557,20 +545,10 @@ func TestServer_PortForwardingOnlySession(t *testing.T) {
_ = client.Close()
}()
// Execute a command without PTY - this simulates ssh -T with no command
// The server should either allow it (port forwarding enabled) or reject it
output, err := client.ExecuteCommand(ctx, "")
if tt.expectAllowed {
// When allowed, the session stays open until cancelled
// ExecuteCommand with empty command should return without error
assert.NoError(t, err, "Session should be allowed when port forwarding is enabled")
assert.NotContains(t, output, "port forwarding is disabled",
"Output should not contain port forwarding disabled message")
} else if err != nil {
// When denied, we expect an error message about port forwarding being disabled
assert.Contains(t, err.Error(), "port forwarding is disabled",
"Should get port forwarding disabled message")
}
// Execute without PTY and no command - simulates ssh -T (shell without PTY)
// Should always succeed regardless of port forwarding settings
_, err = client.ExecuteCommand(ctx, "")
assert.NoError(t, err, "Non-PTY shell session should be allowed")
})
}
}

View File

@@ -405,12 +405,14 @@ func TestSSHServer_WindowsShellHandling(t *testing.T) {
assert.Equal(t, "-Command", args[1])
assert.Equal(t, "echo test", args[2])
} else {
// Test Unix shell behavior
args := server.getShellCommandArgs("/bin/sh", "echo test")
assert.Equal(t, "/bin/sh", args[0])
assert.Equal(t, "-l", args[1])
assert.Equal(t, "-c", args[2])
assert.Equal(t, "echo test", args[3])
assert.Equal(t, "-c", args[1])
assert.Equal(t, "echo test", args[2])
args = server.getShellCommandArgs("/bin/sh", "")
assert.Equal(t, "/bin/sh", args[0])
assert.Len(t, args, 1)
}
}

View File

@@ -62,54 +62,12 @@ func (s *Server) sessionHandler(session ssh.Session) {
ptyReq, winCh, isPty := session.Pty()
hasCommand := len(session.Command()) > 0
switch {
case isPty && hasCommand:
// ssh -t <host> <cmd> - Pty command execution
s.handleCommand(logger, session, privilegeResult, winCh)
case isPty:
// ssh <host> - Pty interactive session (login)
s.handlePty(logger, session, privilegeResult, ptyReq, winCh)
case hasCommand:
// ssh <host> <cmd> - non-Pty command execution
s.handleCommand(logger, session, privilegeResult, nil)
default:
// ssh -T (or ssh -N) - no PTY, no command
s.handleNonInteractiveSession(logger, session)
}
}
// handleNonInteractiveSession handles sessions that have no PTY and no command.
// These are typically used for port forwarding (ssh -L/-R) or tunneling (ssh -N).
func (s *Server) handleNonInteractiveSession(logger *log.Entry, session ssh.Session) {
s.updateSessionType(session, cmdNonInteractive)
if !s.isPortForwardingEnabled() {
if _, err := io.WriteString(session, "port forwarding is disabled on this server\n"); err != nil {
logger.Debugf(errWriteSession, err)
}
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
logger.Infof("rejected non-interactive session: port forwarding disabled")
return
}
<-session.Context().Done()
if err := session.Exit(0); err != nil {
logSessionExitError(logger, err)
}
}
func (s *Server) updateSessionType(session ssh.Session, sessionType string) {
s.mu.Lock()
defer s.mu.Unlock()
for _, state := range s.sessions {
if state.session == session {
state.sessionType = sessionType
return
}
if isPty && !hasCommand {
// ssh <host> - PTY interactive session (login)
s.handlePtyLogin(logger, session, privilegeResult, ptyReq, winCh)
} else {
// ssh <host> <cmd>, ssh -t <host> <cmd>, ssh -T <host> - command or shell execution
s.handleExecution(logger, session, privilegeResult, ptyReq, winCh)
}
}

View File

@@ -9,8 +9,8 @@ import (
log "github.com/sirupsen/logrus"
)
// handlePty is not supported on JS/WASM
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
// handlePtyLogin is not supported on JS/WASM
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
errorMsg := "PTY sessions are not supported on WASM/JS platform\n"
if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil {
logger.Debugf(errWriteSession, err)

View File

@@ -181,8 +181,8 @@ func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) {
// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping.
// Returns the command and a cleanup function (no-op on Unix).
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
logger.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
if err := validateUsername(localUser.Username); err != nil {
return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
@@ -192,7 +192,7 @@ func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User
if err != nil {
return nil, nil, fmt.Errorf("parse user credentials: %w", err)
}
privilegeDropper := NewPrivilegeDropper()
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
config := ExecutorConfig{
UID: uid,
GID: gid,
@@ -233,7 +233,7 @@ func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.Use
shell := getUserShell(localUser.Uid)
args := s.getShellCommandArgs(shell, session.RawCommand())
cmd := exec.CommandContext(session.Context(), args[0], args[1:]...)
cmd := s.createShellCommand(session.Context(), shell, args)
cmd.Dir = localUser.HomeDir
cmd.Env = s.preparePtyEnv(localUser, ptyReq, session)

View File

@@ -88,20 +88,20 @@ func validateUsernameFormat(username string) error {
// createExecutorCommand creates a command using Windows executor for privilege dropping.
// Returns the command and a cleanup function that must be called after starting the process.
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
logger.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
username, _ := s.parseUsername(localUser.Username)
if err := validateUsername(username); err != nil {
return nil, nil, fmt.Errorf("invalid username %q: %w", username, err)
}
return s.createUserSwitchCommand(localUser, session, hasPty)
return s.createUserSwitchCommand(logger, session, localUser)
}
// createUserSwitchCommand creates a command with Windows user switching.
// Returns the command and a cleanup function that must be called after starting the process.
func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, func(), error) {
func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session, localUser *user.User) (*exec.Cmd, func(), error) {
username, domain := s.parseUsername(localUser.Username)
shell := getUserShell(localUser.Uid)
@@ -113,15 +113,14 @@ func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Sessi
}
config := WindowsExecutorConfig{
Username: username,
Domain: domain,
WorkingDir: localUser.HomeDir,
Shell: shell,
Command: command,
Interactive: interactive || (rawCmd == ""),
Username: username,
Domain: domain,
WorkingDir: localUser.HomeDir,
Shell: shell,
Command: command,
}
dropper := NewPrivilegeDropper()
dropper := NewPrivilegeDropper(WithLogger(logger))
cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config)
if err != nil {
return nil, nil, err
@@ -130,7 +129,7 @@ func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Sessi
cleanup := func() {
if token != 0 {
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
log.Debugf("close primary token: %v", err)
logger.Debugf("close primary token: %v", err)
}
}
}

View File

@@ -56,7 +56,7 @@ var (
)
// ExecutePtyWithUserToken executes a command with ConPty using user token.
func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error {
func ExecutePtyWithUserToken(session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error {
args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command)
commandLine := buildCommandLine(args)
@@ -64,7 +64,7 @@ func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig
Pty: ptyConfig,
User: userConfig,
Session: session,
Context: ctx,
Context: session.Context(),
}
return executeConPtyWithConfig(commandLine, config)