[client] Support non-PTY no-command interactive SSH sessions (#5093)

This commit is contained in:
Viktor Liu
2026-01-27 18:05:04 +08:00
committed by GitHub
parent d4f7df271a
commit 06966da012
17 changed files with 461 additions and 270 deletions

View File

@@ -207,8 +207,6 @@ func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
} }
func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) { func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) {
// Create a backend session to mirror the client's session request.
// This keeps the connection alive on the server side while port forwarding channels operate.
serverSession, err := sshClient.NewSession() serverSession, err := sshClient.NewSession()
if err != nil { if err != nil {
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err) _, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
@@ -216,10 +214,28 @@ func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *c
} }
defer func() { _ = serverSession.Close() }() defer func() { _ = serverSession.Close() }()
<-session.Context().Done() serverSession.Stdin = session
serverSession.Stdout = session
serverSession.Stderr = session.Stderr()
if err := session.Exit(0); err != nil { if err := serverSession.Shell(); err != nil {
log.Debugf("session exit: %v", err) log.Debugf("start shell: %v", err)
return
}
done := make(chan error, 1)
go func() {
done <- serverSession.Wait()
}()
select {
case <-session.Context().Done():
return
case err := <-done:
if err != nil {
log.Debugf("shell session: %v", err)
p.handleProxyExitCode(session, err)
}
} }
} }

View File

@@ -12,8 +12,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// handleCommand executes an SSH command with privilege validation // handleExecution executes an SSH command or shell with privilege validation
func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) { func (s *Server) handleExecution(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) {
hasPty := winCh != nil hasPty := winCh != nil
commandType := "command" commandType := "command"
@@ -23,7 +23,7 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command())) logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty) execCmd, cleanup, err := s.createCommand(logger, privilegeResult, session, hasPty)
if err != nil { if err != nil {
logger.Errorf("%s creation failed: %v", commandType, err) logger.Errorf("%s creation failed: %v", commandType, err)
@@ -51,13 +51,12 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
defer cleanup() defer cleanup()
ptyReq, _, _ := session.Pty()
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) { if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
logger.Debugf("%s execution completed", commandType) logger.Debugf("%s execution completed", commandType)
} }
} }
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) { func (s *Server) createCommand(logger *log.Entry, privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
localUser := privilegeResult.User localUser := privilegeResult.User
if localUser == nil { if localUser == nil {
return nil, nil, errors.New("no user in privilege result") return nil, nil, errors.New("no user in privilege result")
@@ -66,28 +65,28 @@ func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh
// If PTY requested but su doesn't support --pty, skip su and use executor // If PTY requested but su doesn't support --pty, skip su and use executor
// This ensures PTY functionality is provided (executor runs within our allocated PTY) // This ensures PTY functionality is provided (executor runs within our allocated PTY)
if hasPty && !s.suSupportsPty { if hasPty && !s.suSupportsPty {
log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality") logger.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty) cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("create command with privileges: %w", err) return nil, nil, fmt.Errorf("create command with privileges: %w", err)
} }
cmd.Env = s.prepareCommandEnv(localUser, session) cmd.Env = s.prepareCommandEnv(logger, localUser, session)
return cmd, cleanup, nil return cmd, cleanup, nil
} }
// Try su first for system integration (PAM/audit) when privileged // Try su first for system integration (PAM/audit) when privileged
cmd, err := s.createSuCommand(session, localUser, hasPty) cmd, err := s.createSuCommand(logger, session, localUser, hasPty)
if err != nil || privilegeResult.UsedFallback { if err != nil || privilegeResult.UsedFallback {
log.Debugf("su command failed, falling back to executor: %v", err) logger.Debugf("su command failed, falling back to executor: %v", err)
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty) cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("create command with privileges: %w", err) return nil, nil, fmt.Errorf("create command with privileges: %w", err)
} }
cmd.Env = s.prepareCommandEnv(localUser, session) cmd.Env = s.prepareCommandEnv(logger, localUser, session)
return cmd, cleanup, nil return cmd, cleanup, nil
} }
cmd.Env = s.prepareCommandEnv(localUser, session) cmd.Env = s.prepareCommandEnv(logger, localUser, session)
return cmd, func() {}, nil return cmd, func() {}, nil
} }

View File

@@ -15,17 +15,17 @@ import (
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform") var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
// createSuCommand is not supported on JS/WASM // createSuCommand is not supported on JS/WASM
func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) { func (s *Server) createSuCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
return nil, errNotSupported return nil, errNotSupported
} }
// createExecutorCommand is not supported on JS/WASM // createExecutorCommand is not supported on JS/WASM
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) { func (s *Server) createExecutorCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
return nil, nil, errNotSupported return nil, nil, errNotSupported
} }
// prepareCommandEnv is not supported on JS/WASM // prepareCommandEnv is not supported on JS/WASM
func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string { func (s *Server) prepareCommandEnv(_ *log.Entry, _ *user.User, _ ssh.Session) []string {
return nil return nil
} }

View File

@@ -10,6 +10,7 @@ import (
"os" "os"
"os/exec" "os/exec"
"os/user" "os/user"
"path/filepath"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
@@ -99,40 +100,52 @@ func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool {
return isUtilLinux return isUtilLinux
} }
// createSuCommand creates a command using su -l -c for privilege switching // createSuCommand creates a command using su - for privilege switching.
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) { func (s *Server) createSuCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
if err := validateUsername(localUser.Username); err != nil {
return nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
}
suPath, err := exec.LookPath("su") suPath, err := exec.LookPath("su")
if err != nil { if err != nil {
return nil, fmt.Errorf("su command not available: %w", err) return nil, fmt.Errorf("su command not available: %w", err)
} }
command := session.RawCommand() args := []string{"-"}
if command == "" {
return nil, fmt.Errorf("no command specified for su execution")
}
args := []string{"-l"}
if hasPty && s.suSupportsPty { if hasPty && s.suSupportsPty {
args = append(args, "--pty") args = append(args, "--pty")
} }
args = append(args, localUser.Username, "-c", command) args = append(args, localUser.Username)
command := session.RawCommand()
if command != "" {
args = append(args, "-c", command)
}
logger.Debugf("creating su command: %s %v", suPath, args)
cmd := exec.CommandContext(session.Context(), suPath, args...) cmd := exec.CommandContext(session.Context(), suPath, args...)
cmd.Dir = localUser.HomeDir cmd.Dir = localUser.HomeDir
return cmd, nil return cmd, nil
} }
// getShellCommandArgs returns the shell command and arguments for executing a command string // getShellCommandArgs returns the shell command and arguments for executing a command string.
func (s *Server) getShellCommandArgs(shell, cmdString string) []string { func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
if cmdString == "" { if cmdString == "" {
return []string{shell, "-l"} return []string{shell}
} }
return []string{shell, "-l", "-c", cmdString} return []string{shell, "-c", cmdString}
}
// createShellCommand creates an exec.Cmd configured as a login shell by setting argv[0] to "-shellname".
func (s *Server) createShellCommand(ctx context.Context, shell string, args []string) *exec.Cmd {
cmd := exec.CommandContext(ctx, shell, args[1:]...)
cmd.Args[0] = "-" + filepath.Base(shell)
return cmd
} }
// prepareCommandEnv prepares environment variables for command execution on Unix // prepareCommandEnv prepares environment variables for command execution on Unix
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string { func (s *Server) prepareCommandEnv(_ *log.Entry, localUser *user.User, session ssh.Session) []string {
env := prepareUserEnv(localUser, getUserShell(localUser.Uid)) env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...) env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() { for _, v := range session.Environ() {
@@ -154,7 +167,7 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, e
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh) return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
} }
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session) execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
if err != nil { if err != nil {
logger.Errorf("Pty command creation failed: %v", err) logger.Errorf("Pty command creation failed: %v", err)
@@ -244,11 +257,6 @@ func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *pty
}() }()
go func() { go func() {
defer func() {
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
logger.Debugf("session close error: %v", err)
}
}()
if _, err := io.Copy(session, ptmx); err != nil { if _, err := io.Copy(session, ptmx); err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) { if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
logger.Warnf("Pty output copy error: %v", err) logger.Warnf("Pty output copy error: %v", err)
@@ -268,7 +276,7 @@ func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, ex
case <-ctx.Done(): case <-ctx.Done():
s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done) s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done)
case err := <-done: case err := <-done:
s.handlePtyCommandCompletion(logger, session, err) s.handlePtyCommandCompletion(logger, session, ptyMgr, err)
} }
} }
@@ -296,17 +304,20 @@ func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Ses
} }
} }
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) { func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, err error) {
if err != nil { if err != nil {
logger.Debugf("Pty command execution failed: %v", err) logger.Debugf("Pty command execution failed: %v", err)
s.handleSessionExit(session, err, logger) s.handleSessionExit(session, err, logger)
return } else {
logger.Debugf("Pty command completed successfully")
if err := session.Exit(0); err != nil {
logSessionExitError(logger, err)
}
} }
// Normal completion // Close PTY to unblock io.Copy goroutines
logger.Debugf("Pty command completed successfully") if err := ptyMgr.Close(); err != nil {
if err := session.Exit(0); err != nil { logger.Debugf("Pty close after completion: %v", err)
logSessionExitError(logger, err)
} }
} }

View File

@@ -20,32 +20,32 @@ import (
// getUserEnvironment retrieves the Windows environment for the target user. // getUserEnvironment retrieves the Windows environment for the target user.
// Follows OpenSSH's resilient approach with graceful degradation on failures. // Follows OpenSSH's resilient approach with graceful degradation on failures.
func (s *Server) getUserEnvironment(username, domain string) ([]string, error) { func (s *Server) getUserEnvironment(logger *log.Entry, username, domain string) ([]string, error) {
userToken, err := s.getUserToken(username, domain) userToken, err := s.getUserToken(logger, username, domain)
if err != nil { if err != nil {
return nil, fmt.Errorf("get user token: %w", err) return nil, fmt.Errorf("get user token: %w", err)
} }
defer func() { defer func() {
if err := windows.CloseHandle(userToken); err != nil { if err := windows.CloseHandle(userToken); err != nil {
log.Debugf("close user token: %v", err) logger.Debugf("close user token: %v", err)
} }
}() }()
return s.getUserEnvironmentWithToken(userToken, username, domain) return s.getUserEnvironmentWithToken(logger, userToken, username, domain)
} }
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token. // getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) { func (s *Server) getUserEnvironmentWithToken(logger *log.Entry, userToken windows.Handle, username, domain string) ([]string, error) {
userProfile, err := s.loadUserProfile(userToken, username, domain) userProfile, err := s.loadUserProfile(userToken, username, domain)
if err != nil { if err != nil {
log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err) logger.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
userProfile = fmt.Sprintf("C:\\Users\\%s", username) userProfile = fmt.Sprintf("C:\\Users\\%s", username)
} }
envMap := make(map[string]string) envMap := make(map[string]string)
if err := s.loadSystemEnvironment(envMap); err != nil { if err := s.loadSystemEnvironment(envMap); err != nil {
log.Debugf("failed to load system environment from registry: %v", err) logger.Debugf("failed to load system environment from registry: %v", err)
} }
s.setUserEnvironmentVariables(envMap, userProfile, username, domain) s.setUserEnvironmentVariables(envMap, userProfile, username, domain)
@@ -59,8 +59,8 @@ func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username,
} }
// getUserToken creates a user token for the specified user. // getUserToken creates a user token for the specified user.
func (s *Server) getUserToken(username, domain string) (windows.Handle, error) { func (s *Server) getUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
privilegeDropper := NewPrivilegeDropper() privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
token, err := privilegeDropper.createToken(username, domain) token, err := privilegeDropper.createToken(username, domain)
if err != nil { if err != nil {
return 0, fmt.Errorf("generate S4U user token: %w", err) return 0, fmt.Errorf("generate S4U user token: %w", err)
@@ -242,9 +242,9 @@ func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfi
} }
// prepareCommandEnv prepares environment variables for command execution on Windows // prepareCommandEnv prepares environment variables for command execution on Windows
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string { func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, session ssh.Session) []string {
username, domain := s.parseUsername(localUser.Username) username, domain := s.parseUsername(localUser.Username)
userEnv, err := s.getUserEnvironment(username, domain) userEnv, err := s.getUserEnvironment(logger, username, domain)
if err != nil { if err != nil {
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err) log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
env := prepareUserEnv(localUser, getUserShell(localUser.Uid)) env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
@@ -267,22 +267,16 @@ func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []
return env return env
} }
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool {
if privilegeResult.User == nil { if privilegeResult.User == nil {
logger.Errorf("no user in privilege result") logger.Errorf("no user in privilege result")
return false return false
} }
cmd := session.Command()
shell := getUserShell(privilegeResult.User.Uid) shell := getUserShell(privilegeResult.User.Uid)
logger.Infof("starting interactive shell: %s", shell)
if len(cmd) == 0 { s.executeCommandWithPty(logger, session, nil, privilegeResult, ptyReq, nil)
logger.Infof("starting interactive shell: %s", shell)
} else {
logger.Infof("executing command: %s", safeLogCommand(cmd))
}
s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
return true return true
} }
@@ -294,11 +288,6 @@ func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
return []string{shell, "-Command", cmdString} return []string{shell, "-Command", cmdString}
} }
func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
logger.Info("starting interactive shell")
s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand())
}
type PtyExecutionRequest struct { type PtyExecutionRequest struct {
Shell string Shell string
Command string Command string
@@ -308,25 +297,25 @@ type PtyExecutionRequest struct {
Domain string Domain string
} }
func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error { func executePtyCommandWithUserToken(logger *log.Entry, session ssh.Session, req PtyExecutionRequest) error {
log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d", logger.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height) req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height)
privilegeDropper := NewPrivilegeDropper() privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
userToken, err := privilegeDropper.createToken(req.Username, req.Domain) userToken, err := privilegeDropper.createToken(req.Username, req.Domain)
if err != nil { if err != nil {
return fmt.Errorf("create user token: %w", err) return fmt.Errorf("create user token: %w", err)
} }
defer func() { defer func() {
if err := windows.CloseHandle(userToken); err != nil { if err := windows.CloseHandle(userToken); err != nil {
log.Debugf("close user token: %v", err) logger.Debugf("close user token: %v", err)
} }
}() }()
server := &Server{} server := &Server{}
userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain) userEnv, err := server.getUserEnvironmentWithToken(logger, userToken, req.Username, req.Domain)
if err != nil { if err != nil {
log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err) logger.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
userEnv = os.Environ() userEnv = os.Environ()
} }
@@ -348,8 +337,8 @@ func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, re
Environment: userEnv, Environment: userEnv,
} }
log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir) logger.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig) return winpty.ExecutePtyWithUserToken(session, ptyConfig, userConfig)
} }
func getUserHomeFromEnv(env []string) string { func getUserHomeFromEnv(env []string) string {
@@ -371,10 +360,8 @@ func (s *Server) killProcessGroup(cmd *exec.Cmd) {
return return
} }
logger := log.WithField("pid", cmd.Process.Pid)
if err := cmd.Process.Kill(); err != nil { if err := cmd.Process.Kill(); err != nil {
logger.Debugf("kill process failed: %v", err) log.Debugf("kill process %d failed: %v", cmd.Process.Pid, err)
} }
} }
@@ -389,21 +376,7 @@ func (s *Server) detectUtilLinuxLogin(context.Context) bool {
} }
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty // executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _ *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool {
command := session.RawCommand()
if command == "" {
logger.Error("no command specified for PTY execution")
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}
return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command)
}
// executeConPtyCommand executes a command using ConPty (common for interactive and command execution)
func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool {
localUser := privilegeResult.User localUser := privilegeResult.User
if localUser == nil { if localUser == nil {
logger.Errorf("no user in privilege result") logger.Errorf("no user in privilege result")
@@ -415,14 +388,14 @@ func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, pr
req := PtyExecutionRequest{ req := PtyExecutionRequest{
Shell: shell, Shell: shell,
Command: command, Command: session.RawCommand(),
Width: ptyReq.Window.Width, Width: ptyReq.Window.Width,
Height: ptyReq.Window.Height, Height: ptyReq.Window.Height,
Username: username, Username: username,
Domain: domain, Domain: domain,
} }
if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil { if err := executePtyCommandWithUserToken(logger, session, req); err != nil {
logger.Errorf("ConPty execution failed: %v", err) logger.Errorf("ConPty execution failed: %v", err)
if err := session.Exit(1); err != nil { if err := session.Exit(1); err != nil {
logSessionExitError(logger, err) logSessionExitError(logger, err)

View File

@@ -4,12 +4,15 @@ import (
"context" "context"
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
"os" "os"
"os/exec" "os/exec"
"path/filepath"
"runtime" "runtime"
"slices"
"strings" "strings"
"testing" "testing"
"time" "time"
@@ -23,25 +26,67 @@ import (
"github.com/netbirdio/netbird/client/ssh/testutil" "github.com/netbirdio/netbird/client/ssh/testutil"
) )
// TestMain handles package-level setup and cleanup
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
// Guard against infinite recursion when test binary is called as "netbird ssh exec" // On platforms where su doesn't support --pty (macOS, FreeBSD, Windows), the SSH server
// This happens when running tests as non-privileged user with fallback // spawns an executor subprocess via os.Executable(). During tests, this invokes the test
// binary with "ssh exec" args. We handle that here to properly execute commands and
// propagate exit codes.
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" { if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
// Just exit with error to break the recursion runTestExecutor()
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n") return
os.Exit(1)
} }
// Run tests
code := m.Run() code := m.Run()
// Cleanup any created test users
testutil.CleanupTestUsers() testutil.CleanupTestUsers()
os.Exit(code) os.Exit(code)
} }
// runTestExecutor emulates the netbird executor for tests.
// Parses --shell and --cmd args, runs the command, and exits with the correct code.
func runTestExecutor() {
if os.Getenv("_NETBIRD_TEST_EXECUTOR") != "" {
fmt.Fprintf(os.Stderr, "executor recursion detected\n")
os.Exit(1)
}
os.Setenv("_NETBIRD_TEST_EXECUTOR", "1")
shell := "/bin/sh"
var command string
for i := 3; i < len(os.Args); i++ {
switch os.Args[i] {
case "--shell":
if i+1 < len(os.Args) {
shell = os.Args[i+1]
i++
}
case "--cmd":
if i+1 < len(os.Args) {
command = os.Args[i+1]
i++
}
}
}
var cmd *exec.Cmd
if command == "" {
cmd = exec.Command(shell)
} else {
cmd = exec.Command(shell, "-c", command)
}
cmd.Args[0] = "-" + filepath.Base(shell)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
os.Exit(exitErr.ExitCode())
}
os.Exit(1)
}
os.Exit(0)
}
// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client // TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client
func TestSSHServerCompatibility(t *testing.T) { func TestSSHServerCompatibility(t *testing.T) {
if testing.Short() { if testing.Short() {
@@ -405,6 +450,171 @@ func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) {
return createTempKeyFileFromBytes(t, privateKey) return createTempKeyFileFromBytes(t, privateKey)
} }
// TestSSHPtyModes tests different PTY allocation modes (-T, -t, -tt flags)
// This ensures our implementation matches OpenSSH behavior for:
// - ssh host command (no PTY - default when no TTY)
// - ssh -T host command (explicit no PTY)
// - ssh -t host command (force PTY)
// - ssh -T host (no PTY shell - our implementation)
func TestSSHPtyModes(t *testing.T) {
if testing.Short() {
t.Skip("Skipping SSH PTY mode tests in short mode")
}
if !isSSHClientAvailable() {
t.Skip("SSH client not available on this system")
}
if runtime.GOOS == "windows" && testutil.IsCI() {
t.Skip("Skipping Windows SSH PTY tests in CI due to S4U authentication issues")
}
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH)
defer cleanupKey()
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
username := testutil.GetTestUsername(t)
baseArgs := []string{
"-i", clientKeyFile,
"-p", portStr,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
"-o", "BatchMode=yes",
}
t.Run("command_default_no_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), fmt.Sprintf("%s@%s", username, host), "echo", "no_pty_default")
cmd := exec.Command("ssh", args...)
output, err := cmd.CombinedOutput()
require.NoError(t, err, "Command (default no PTY) failed: %s", output)
assert.Contains(t, string(output), "no_pty_default")
})
t.Run("command_explicit_no_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "echo", "explicit_no_pty")
cmd := exec.Command("ssh", args...)
output, err := cmd.CombinedOutput()
require.NoError(t, err, "Command (-T explicit no PTY) failed: %s", output)
assert.Contains(t, string(output), "explicit_no_pty")
})
t.Run("command_force_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "echo", "force_pty")
cmd := exec.Command("ssh", args...)
output, err := cmd.CombinedOutput()
require.NoError(t, err, "Command (-tt force PTY) failed: %s", output)
assert.Contains(t, string(output), "force_pty")
})
t.Run("shell_explicit_no_pty", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host))
cmd := exec.CommandContext(ctx, "ssh", args...)
stdin, err := cmd.StdinPipe()
require.NoError(t, err)
stdout, err := cmd.StdoutPipe()
require.NoError(t, err)
require.NoError(t, cmd.Start(), "Shell (-T no PTY) start failed")
go func() {
defer stdin.Close()
time.Sleep(100 * time.Millisecond)
_, err := stdin.Write([]byte("echo shell_no_pty_test\n"))
assert.NoError(t, err, "write echo command")
time.Sleep(100 * time.Millisecond)
_, err = stdin.Write([]byte("exit 0\n"))
assert.NoError(t, err, "write exit command")
}()
output, _ := io.ReadAll(stdout)
err = cmd.Wait()
require.NoError(t, err, "Shell (-T no PTY) failed: %s", output)
assert.Contains(t, string(output), "shell_no_pty_test")
})
t.Run("exit_code_preserved_no_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "exit", "42")
cmd := exec.Command("ssh", args...)
err := cmd.Run()
require.Error(t, err, "Command should exit with non-zero")
var exitErr *exec.ExitError
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
assert.Equal(t, 42, exitErr.ExitCode(), "Exit code should be preserved with -T")
})
t.Run("exit_code_preserved_with_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "sh -c 'exit 43'")
cmd := exec.Command("ssh", args...)
err := cmd.Run()
require.Error(t, err, "PTY command should exit with non-zero")
var exitErr *exec.ExitError
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
assert.Equal(t, 43, exitErr.ExitCode(), "Exit code should be preserved with -tt")
})
t.Run("stderr_works_no_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host),
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
cmd := exec.Command("ssh", args...)
var stdout, stderr strings.Builder
cmd.Stdout = &stdout
cmd.Stderr = &stderr
require.NoError(t, cmd.Run(), "stderr test failed")
assert.Contains(t, stdout.String(), "stdout_msg", "stdout should have stdout_msg")
assert.Contains(t, stderr.String(), "stderr_msg", "stderr should have stderr_msg")
assert.NotContains(t, stdout.String(), "stderr_msg", "stdout should NOT have stderr_msg")
})
t.Run("stderr_merged_with_pty", func(t *testing.T) {
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host),
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
cmd := exec.Command("ssh", args...)
output, err := cmd.CombinedOutput()
require.NoError(t, err, "PTY stderr test failed: %s", output)
assert.Contains(t, string(output), "stdout_msg")
assert.Contains(t, string(output), "stderr_msg")
})
}
// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility // TestSSHServerFeatureCompatibility tests specific SSH features for compatibility
func TestSSHServerFeatureCompatibility(t *testing.T) { func TestSSHServerFeatureCompatibility(t *testing.T) {
if testing.Short() { if testing.Short() {

View File

@@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"os" "os"
"os/exec" "os/exec"
"path/filepath"
"runtime" "runtime"
"strings" "strings"
"syscall" "syscall"
@@ -35,11 +36,35 @@ type ExecutorConfig struct {
} }
// PrivilegeDropper handles secure privilege dropping in child processes // PrivilegeDropper handles secure privilege dropping in child processes
type PrivilegeDropper struct{} type PrivilegeDropper struct {
logger *log.Entry
}
// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper
type PrivilegeDropperOption func(*PrivilegeDropper)
// NewPrivilegeDropper creates a new privilege dropper // NewPrivilegeDropper creates a new privilege dropper
func NewPrivilegeDropper() *PrivilegeDropper { func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper {
return &PrivilegeDropper{} pd := &PrivilegeDropper{}
for _, opt := range opts {
opt(pd)
}
return pd
}
// WithLogger sets the logger for the PrivilegeDropper
func WithLogger(logger *log.Entry) PrivilegeDropperOption {
return func(pd *PrivilegeDropper) {
pd.logger = logger
}
}
// log returns the logger, falling back to standard logger if none set
func (pd *PrivilegeDropper) log() *log.Entry {
if pd.logger != nil {
return pd.logger
}
return log.NewEntry(log.StandardLogger())
} }
// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping // CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
@@ -83,7 +108,7 @@ func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config Ex
break break
} }
} }
log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs) pd.log().Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
return exec.CommandContext(ctx, netbirdPath, args...), nil return exec.CommandContext(ctx, netbirdPath, args...), nil
} }
@@ -206,17 +231,22 @@ func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config
var execCmd *exec.Cmd var execCmd *exec.Cmd
if config.Command == "" { if config.Command == "" {
os.Exit(ExitCodeSuccess) execCmd = exec.CommandContext(ctx, config.Shell)
} else {
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
} }
execCmd.Args[0] = "-" + filepath.Base(config.Shell)
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
execCmd.Stdin = os.Stdin execCmd.Stdin = os.Stdin
execCmd.Stdout = os.Stdout execCmd.Stdout = os.Stdout
execCmd.Stderr = os.Stderr execCmd.Stderr = os.Stderr
cmdParts := strings.Fields(config.Command) if config.Command == "" {
safeCmd := safeLogCommand(cmdParts) log.Tracef("executing login shell: %s", execCmd.Path)
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd) } else {
cmdParts := strings.Fields(config.Command)
safeCmd := safeLogCommand(cmdParts)
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
}
if err := execCmd.Run(); err != nil { if err := execCmd.Run(); err != nil {
var exitError *exec.ExitError var exitError *exec.ExitError
if errors.As(err, &exitError) { if errors.As(err, &exitError) {

View File

@@ -28,22 +28,45 @@ const (
) )
type WindowsExecutorConfig struct { type WindowsExecutorConfig struct {
Username string Username string
Domain string Domain string
WorkingDir string WorkingDir string
Shell string Shell string
Command string Command string
Args []string Args []string
Interactive bool Pty bool
Pty bool PtyWidth int
PtyWidth int PtyHeight int
PtyHeight int
} }
type PrivilegeDropper struct{} type PrivilegeDropper struct {
logger *log.Entry
}
func NewPrivilegeDropper() *PrivilegeDropper { // PrivilegeDropperOption is a functional option for configuring PrivilegeDropper
return &PrivilegeDropper{} type PrivilegeDropperOption func(*PrivilegeDropper)
func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper {
pd := &PrivilegeDropper{}
for _, opt := range opts {
opt(pd)
}
return pd
}
// WithLogger sets the logger for the PrivilegeDropper
func WithLogger(logger *log.Entry) PrivilegeDropperOption {
return func(pd *PrivilegeDropper) {
pd.logger = logger
}
}
// log returns the logger, falling back to standard logger if none set
func (pd *PrivilegeDropper) log() *log.Entry {
if pd.logger != nil {
return pd.logger
}
return log.NewEntry(log.StandardLogger())
} }
var ( var (
@@ -56,7 +79,6 @@ const (
// Common error messages // Common error messages
commandFlag = "-Command" commandFlag = "-Command"
closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials
convertUsernameError = "convert username to UTF16: %w" convertUsernameError = "convert username to UTF16: %w"
convertDomainError = "convert domain to UTF16: %w" convertDomainError = "convert domain to UTF16: %w"
) )
@@ -80,7 +102,7 @@ func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, co
shellArgs = []string{shell} shellArgs = []string{shell}
} }
log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs) pd.log().Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
cmd, token, err := pd.CreateWindowsProcessAsUser( cmd, token, err := pd.CreateWindowsProcessAsUser(
ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir) ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
@@ -180,10 +202,10 @@ func newLsaString(s string) lsaString {
// generateS4UUserToken creates a Windows token using S4U authentication // generateS4UUserToken creates a Windows token using S4U authentication
// This is the exact approach OpenSSH for Windows uses for public key authentication // This is the exact approach OpenSSH for Windows uses for public key authentication
func generateS4UUserToken(username, domain string) (windows.Handle, error) { func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
userCpn := buildUserCpn(username, domain) userCpn := buildUserCpn(username, domain)
pd := NewPrivilegeDropper() pd := NewPrivilegeDropper(WithLogger(logger))
isDomainUser := !pd.isLocalUser(domain) isDomainUser := !pd.isLocalUser(domain)
lsaHandle, err := initializeLsaConnection() lsaHandle, err := initializeLsaConnection()
@@ -197,12 +219,12 @@ func generateS4UUserToken(username, domain string) (windows.Handle, error) {
return 0, err return 0, err
} }
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser) logonInfo, logonInfoSize, err := prepareS4ULogonStructure(logger, username, domain, isDomainUser)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser) return performS4ULogon(logger, lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
} }
// buildUserCpn constructs the user principal name // buildUserCpn constructs the user principal name
@@ -310,21 +332,21 @@ func lookupPrincipalName(username, domain string) (string, error) {
} }
// prepareS4ULogonStructure creates the appropriate S4U logon structure // prepareS4ULogonStructure creates the appropriate S4U logon structure
func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) { func prepareS4ULogonStructure(logger *log.Entry, username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
if isDomainUser { if isDomainUser {
return prepareDomainS4ULogon(username, domain) return prepareDomainS4ULogon(logger, username, domain)
} }
return prepareLocalS4ULogon(username) return prepareLocalS4ULogon(logger, username)
} }
// prepareDomainS4ULogon creates S4U logon structure for domain users // prepareDomainS4ULogon creates S4U logon structure for domain users
func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) { func prepareDomainS4ULogon(logger *log.Entry, username, domain string) (unsafe.Pointer, uintptr, error) {
upn, err := lookupPrincipalName(username, domain) upn, err := lookupPrincipalName(username, domain)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("lookup principal name: %w", err) return nil, 0, fmt.Errorf("lookup principal name: %w", err)
} }
log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn) logger.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
upnUtf16, err := windows.UTF16FromString(upn) upnUtf16, err := windows.UTF16FromString(upn)
if err != nil { if err != nil {
@@ -357,8 +379,8 @@ func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, er
} }
// prepareLocalS4ULogon creates S4U logon structure for local users // prepareLocalS4ULogon creates S4U logon structure for local users
func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) { func prepareLocalS4ULogon(logger *log.Entry, username string) (unsafe.Pointer, uintptr, error) {
log.Debugf("using Msv1_0S4ULogon for local user: %s", username) logger.Debugf("using Msv1_0S4ULogon for local user: %s", username)
usernameUtf16, err := windows.UTF16FromString(username) usernameUtf16, err := windows.UTF16FromString(username)
if err != nil { if err != nil {
@@ -406,11 +428,11 @@ func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
} }
// performS4ULogon executes the S4U logon operation // performS4ULogon executes the S4U logon operation
func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) { func performS4ULogon(logger *log.Entry, lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
var tokenSource tokenSource var tokenSource tokenSource
copy(tokenSource.SourceName[:], "netbird") copy(tokenSource.SourceName[:], "netbird")
if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 { if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 {
log.Debugf("AllocateLocallyUniqueId failed") logger.Debugf("AllocateLocallyUniqueId failed")
} }
originName := newLsaString("netbird") originName := newLsaString("netbird")
@@ -441,7 +463,7 @@ func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo u
if profile != 0 { if profile != 0 {
if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess { if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess {
log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret) logger.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
} }
} }
@@ -449,7 +471,7 @@ func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo u
return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus) return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus)
} }
log.Debugf("created S4U %s token for user %s", logger.Debugf("created S4U %s token for user %s",
map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn) map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn)
return token, nil return token, nil
} }
@@ -497,8 +519,8 @@ func (pd *PrivilegeDropper) isLocalUser(domain string) bool {
// authenticateLocalUser handles authentication for local users // authenticateLocalUser handles authentication for local users
func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) { func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) {
log.Debugf("using S4U authentication for local user %s", fullUsername) pd.log().Debugf("using S4U authentication for local user %s", fullUsername)
token, err := generateS4UUserToken(username, ".") token, err := generateS4UUserToken(pd.log(), username, ".")
if err != nil { if err != nil {
return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err) return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err)
} }
@@ -507,12 +529,12 @@ func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string)
// authenticateDomainUser handles authentication for domain users // authenticateDomainUser handles authentication for domain users
func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) { func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) {
log.Debugf("using S4U authentication for domain user %s", fullUsername) pd.log().Debugf("using S4U authentication for domain user %s", fullUsername)
token, err := generateS4UUserToken(username, domain) token, err := generateS4UUserToken(pd.log(), username, domain)
if err != nil { if err != nil {
return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err) return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err)
} }
log.Debugf("Successfully created S4U token for domain user %s", fullUsername) pd.log().Debugf("successfully created S4U token for domain user %s", fullUsername)
return token, nil return token, nil
} }
@@ -526,7 +548,7 @@ func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, exec
defer func() { defer func() {
if err := windows.CloseHandle(token); err != nil { if err := windows.CloseHandle(token); err != nil {
log.Debugf("close impersonation token: %v", err) pd.log().Debugf("close impersonation token: %v", err)
} }
}() }()
@@ -564,7 +586,7 @@ func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceTo
return cmd, primaryToken, nil return cmd, primaryToken, nil
} }
// createSuCommand creates a command using su -l -c for privilege switching (Windows stub) // createSuCommand creates a command using su - for privilege switching (Windows stub).
func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) { func (s *Server) createSuCommand(*log.Entry, ssh.Session, *user.User, bool) (*exec.Cmd, error) {
return nil, fmt.Errorf("su command not available on Windows") return nil, fmt.Errorf("su command not available on Windows")
} }

View File

@@ -271,13 +271,6 @@ func (s *Server) isRemotePortForwardingAllowed() bool {
return s.allowRemotePortForwarding return s.allowRemotePortForwarding
} }
// isPortForwardingEnabled checks if any port forwarding (local or remote) is enabled
func (s *Server) isPortForwardingEnabled() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.allowLocalPortForwarding || s.allowRemotePortForwarding
}
// parseTcpipForwardRequest parses the SSH request payload // parseTcpipForwardRequest parses the SSH request payload
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) { func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
var payload tcpipForwardMsg var payload tcpipForwardMsg

View File

@@ -335,7 +335,7 @@ func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
sessions = append(sessions, info) sessions = append(sessions, info)
} }
// Add authenticated connections without sessions (e.g., -N/-T or port-forwarding only) // Add authenticated connections without sessions (e.g., -N or port-forwarding only)
for key, connState := range s.connections { for key, connState := range s.connections {
remoteAddr := string(key) remoteAddr := string(key)
if reportedAddrs[remoteAddr] { if reportedAddrs[remoteAddr] {

View File

@@ -483,12 +483,11 @@ func TestServer_IsPrivilegedUser(t *testing.T) {
} }
} }
func TestServer_PortForwardingOnlySession(t *testing.T) { func TestServer_NonPtyShellSession(t *testing.T) {
// Test that sessions without PTY and command are allowed when port forwarding is enabled // Test that non-PTY shell sessions (ssh -T) work regardless of port forwarding settings.
currentUser, err := user.Current() currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user") require.NoError(t, err, "Should be able to get current user")
// Generate host key for server
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err) require.NoError(t, err)
@@ -496,36 +495,26 @@ func TestServer_PortForwardingOnlySession(t *testing.T) {
name string name string
allowLocalForwarding bool allowLocalForwarding bool
allowRemoteForwarding bool allowRemoteForwarding bool
expectAllowed bool
description string
}{ }{
{ {
name: "session_allowed_with_local_forwarding", name: "shell_with_local_forwarding_enabled",
allowLocalForwarding: true, allowLocalForwarding: true,
allowRemoteForwarding: false, allowRemoteForwarding: false,
expectAllowed: true,
description: "Port-forwarding-only session should be allowed when local forwarding is enabled",
}, },
{ {
name: "session_allowed_with_remote_forwarding", name: "shell_with_remote_forwarding_enabled",
allowLocalForwarding: false, allowLocalForwarding: false,
allowRemoteForwarding: true, allowRemoteForwarding: true,
expectAllowed: true,
description: "Port-forwarding-only session should be allowed when remote forwarding is enabled",
}, },
{ {
name: "session_allowed_with_both", name: "shell_with_both_forwarding_enabled",
allowLocalForwarding: true, allowLocalForwarding: true,
allowRemoteForwarding: true, allowRemoteForwarding: true,
expectAllowed: true,
description: "Port-forwarding-only session should be allowed when both forwarding types enabled",
}, },
{ {
name: "session_denied_without_forwarding", name: "shell_with_forwarding_disabled",
allowLocalForwarding: false, allowLocalForwarding: false,
allowRemoteForwarding: false, allowRemoteForwarding: false,
expectAllowed: false,
description: "Port-forwarding-only session should be denied when all forwarding is disabled",
}, },
} }
@@ -545,7 +534,6 @@ func TestServer_PortForwardingOnlySession(t *testing.T) {
_ = server.Stop() _ = server.Stop()
}() }()
// Connect to the server without requesting PTY or command
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
@@ -557,20 +545,10 @@ func TestServer_PortForwardingOnlySession(t *testing.T) {
_ = client.Close() _ = client.Close()
}() }()
// Execute a command without PTY - this simulates ssh -T with no command // Execute without PTY and no command - simulates ssh -T (shell without PTY)
// The server should either allow it (port forwarding enabled) or reject it // Should always succeed regardless of port forwarding settings
output, err := client.ExecuteCommand(ctx, "") _, err = client.ExecuteCommand(ctx, "")
if tt.expectAllowed { assert.NoError(t, err, "Non-PTY shell session should be allowed")
// When allowed, the session stays open until cancelled
// ExecuteCommand with empty command should return without error
assert.NoError(t, err, "Session should be allowed when port forwarding is enabled")
assert.NotContains(t, output, "port forwarding is disabled",
"Output should not contain port forwarding disabled message")
} else if err != nil {
// When denied, we expect an error message about port forwarding being disabled
assert.Contains(t, err.Error(), "port forwarding is disabled",
"Should get port forwarding disabled message")
}
}) })
} }
} }

View File

@@ -405,12 +405,14 @@ func TestSSHServer_WindowsShellHandling(t *testing.T) {
assert.Equal(t, "-Command", args[1]) assert.Equal(t, "-Command", args[1])
assert.Equal(t, "echo test", args[2]) assert.Equal(t, "echo test", args[2])
} else { } else {
// Test Unix shell behavior
args := server.getShellCommandArgs("/bin/sh", "echo test") args := server.getShellCommandArgs("/bin/sh", "echo test")
assert.Equal(t, "/bin/sh", args[0]) assert.Equal(t, "/bin/sh", args[0])
assert.Equal(t, "-l", args[1]) assert.Equal(t, "-c", args[1])
assert.Equal(t, "-c", args[2]) assert.Equal(t, "echo test", args[2])
assert.Equal(t, "echo test", args[3])
args = server.getShellCommandArgs("/bin/sh", "")
assert.Equal(t, "/bin/sh", args[0])
assert.Len(t, args, 1)
} }
} }

View File

@@ -62,54 +62,12 @@ func (s *Server) sessionHandler(session ssh.Session) {
ptyReq, winCh, isPty := session.Pty() ptyReq, winCh, isPty := session.Pty()
hasCommand := len(session.Command()) > 0 hasCommand := len(session.Command()) > 0
switch { if isPty && !hasCommand {
case isPty && hasCommand: // ssh <host> - PTY interactive session (login)
// ssh -t <host> <cmd> - Pty command execution s.handlePtyLogin(logger, session, privilegeResult, ptyReq, winCh)
s.handleCommand(logger, session, privilegeResult, winCh) } else {
case isPty: // ssh <host> <cmd>, ssh -t <host> <cmd>, ssh -T <host> - command or shell execution
// ssh <host> - Pty interactive session (login) s.handleExecution(logger, session, privilegeResult, ptyReq, winCh)
s.handlePty(logger, session, privilegeResult, ptyReq, winCh)
case hasCommand:
// ssh <host> <cmd> - non-Pty command execution
s.handleCommand(logger, session, privilegeResult, nil)
default:
// ssh -T (or ssh -N) - no PTY, no command
s.handleNonInteractiveSession(logger, session)
}
}
// handleNonInteractiveSession handles sessions that have no PTY and no command.
// These are typically used for port forwarding (ssh -L/-R) or tunneling (ssh -N).
func (s *Server) handleNonInteractiveSession(logger *log.Entry, session ssh.Session) {
s.updateSessionType(session, cmdNonInteractive)
if !s.isPortForwardingEnabled() {
if _, err := io.WriteString(session, "port forwarding is disabled on this server\n"); err != nil {
logger.Debugf(errWriteSession, err)
}
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
logger.Infof("rejected non-interactive session: port forwarding disabled")
return
}
<-session.Context().Done()
if err := session.Exit(0); err != nil {
logSessionExitError(logger, err)
}
}
func (s *Server) updateSessionType(session ssh.Session, sessionType string) {
s.mu.Lock()
defer s.mu.Unlock()
for _, state := range s.sessions {
if state.session == session {
state.sessionType = sessionType
return
}
} }
} }

View File

@@ -9,8 +9,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// handlePty is not supported on JS/WASM // handlePtyLogin is not supported on JS/WASM
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool { func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
errorMsg := "PTY sessions are not supported on WASM/JS platform\n" errorMsg := "PTY sessions are not supported on WASM/JS platform\n"
if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil { if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil {
logger.Debugf(errWriteSession, err) logger.Debugf(errWriteSession, err)

View File

@@ -181,8 +181,8 @@ func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) {
// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping. // createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping.
// Returns the command and a cleanup function (no-op on Unix). // Returns the command and a cleanup function (no-op on Unix).
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) { func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty) logger.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
if err := validateUsername(localUser.Username); err != nil { if err := validateUsername(localUser.Username); err != nil {
return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err) return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
@@ -192,7 +192,7 @@ func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("parse user credentials: %w", err) return nil, nil, fmt.Errorf("parse user credentials: %w", err)
} }
privilegeDropper := NewPrivilegeDropper() privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
config := ExecutorConfig{ config := ExecutorConfig{
UID: uid, UID: uid,
GID: gid, GID: gid,
@@ -233,7 +233,7 @@ func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.Use
shell := getUserShell(localUser.Uid) shell := getUserShell(localUser.Uid)
args := s.getShellCommandArgs(shell, session.RawCommand()) args := s.getShellCommandArgs(shell, session.RawCommand())
cmd := exec.CommandContext(session.Context(), args[0], args[1:]...) cmd := s.createShellCommand(session.Context(), shell, args)
cmd.Dir = localUser.HomeDir cmd.Dir = localUser.HomeDir
cmd.Env = s.preparePtyEnv(localUser, ptyReq, session) cmd.Env = s.preparePtyEnv(localUser, ptyReq, session)

View File

@@ -88,20 +88,20 @@ func validateUsernameFormat(username string) error {
// createExecutorCommand creates a command using Windows executor for privilege dropping. // createExecutorCommand creates a command using Windows executor for privilege dropping.
// Returns the command and a cleanup function that must be called after starting the process. // Returns the command and a cleanup function that must be called after starting the process.
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) { func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty) logger.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
username, _ := s.parseUsername(localUser.Username) username, _ := s.parseUsername(localUser.Username)
if err := validateUsername(username); err != nil { if err := validateUsername(username); err != nil {
return nil, nil, fmt.Errorf("invalid username %q: %w", username, err) return nil, nil, fmt.Errorf("invalid username %q: %w", username, err)
} }
return s.createUserSwitchCommand(localUser, session, hasPty) return s.createUserSwitchCommand(logger, session, localUser)
} }
// createUserSwitchCommand creates a command with Windows user switching. // createUserSwitchCommand creates a command with Windows user switching.
// Returns the command and a cleanup function that must be called after starting the process. // Returns the command and a cleanup function that must be called after starting the process.
func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, func(), error) { func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session, localUser *user.User) (*exec.Cmd, func(), error) {
username, domain := s.parseUsername(localUser.Username) username, domain := s.parseUsername(localUser.Username)
shell := getUserShell(localUser.Uid) shell := getUserShell(localUser.Uid)
@@ -113,15 +113,14 @@ func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Sessi
} }
config := WindowsExecutorConfig{ config := WindowsExecutorConfig{
Username: username, Username: username,
Domain: domain, Domain: domain,
WorkingDir: localUser.HomeDir, WorkingDir: localUser.HomeDir,
Shell: shell, Shell: shell,
Command: command, Command: command,
Interactive: interactive || (rawCmd == ""),
} }
dropper := NewPrivilegeDropper() dropper := NewPrivilegeDropper(WithLogger(logger))
cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config) cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@@ -130,7 +129,7 @@ func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Sessi
cleanup := func() { cleanup := func() {
if token != 0 { if token != 0 {
if err := windows.CloseHandle(windows.Handle(token)); err != nil { if err := windows.CloseHandle(windows.Handle(token)); err != nil {
log.Debugf("close primary token: %v", err) logger.Debugf("close primary token: %v", err)
} }
} }
} }

View File

@@ -56,7 +56,7 @@ var (
) )
// ExecutePtyWithUserToken executes a command with ConPty using user token. // ExecutePtyWithUserToken executes a command with ConPty using user token.
func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error { func ExecutePtyWithUserToken(session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error {
args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command) args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command)
commandLine := buildCommandLine(args) commandLine := buildCommandLine(args)
@@ -64,7 +64,7 @@ func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig
Pty: ptyConfig, Pty: ptyConfig,
User: userConfig, User: userConfig,
Session: session, Session: session,
Context: ctx, Context: session.Context(),
} }
return executeConPtyWithConfig(commandLine, config) return executeConPtyWithConfig(commandLine, config)