mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
[client] Support non-PTY no-command interactive SSH sessions (#5093)
This commit is contained in:
@@ -207,8 +207,6 @@ func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) {
|
func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) {
|
||||||
// Create a backend session to mirror the client's session request.
|
|
||||||
// This keeps the connection alive on the server side while port forwarding channels operate.
|
|
||||||
serverSession, err := sshClient.NewSession()
|
serverSession, err := sshClient.NewSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
|
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
|
||||||
@@ -216,10 +214,28 @@ func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *c
|
|||||||
}
|
}
|
||||||
defer func() { _ = serverSession.Close() }()
|
defer func() { _ = serverSession.Close() }()
|
||||||
|
|
||||||
<-session.Context().Done()
|
serverSession.Stdin = session
|
||||||
|
serverSession.Stdout = session
|
||||||
|
serverSession.Stderr = session.Stderr()
|
||||||
|
|
||||||
if err := session.Exit(0); err != nil {
|
if err := serverSession.Shell(); err != nil {
|
||||||
log.Debugf("session exit: %v", err)
|
log.Debugf("start shell: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- serverSession.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-session.Context().Done():
|
||||||
|
return
|
||||||
|
case err := <-done:
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("shell session: %v", err)
|
||||||
|
p.handleProxyExitCode(session, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// handleCommand executes an SSH command with privilege validation
|
// handleExecution executes an SSH command or shell with privilege validation
|
||||||
func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) {
|
func (s *Server) handleExecution(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) {
|
||||||
hasPty := winCh != nil
|
hasPty := winCh != nil
|
||||||
|
|
||||||
commandType := "command"
|
commandType := "command"
|
||||||
@@ -23,7 +23,7 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
|
|||||||
|
|
||||||
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
|
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
|
||||||
|
|
||||||
execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty)
|
execCmd, cleanup, err := s.createCommand(logger, privilegeResult, session, hasPty)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("%s creation failed: %v", commandType, err)
|
logger.Errorf("%s creation failed: %v", commandType, err)
|
||||||
|
|
||||||
@@ -51,13 +51,12 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
|
|||||||
|
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
ptyReq, _, _ := session.Pty()
|
|
||||||
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
|
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
|
||||||
logger.Debugf("%s execution completed", commandType)
|
logger.Debugf("%s execution completed", commandType)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
|
func (s *Server) createCommand(logger *log.Entry, privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
|
||||||
localUser := privilegeResult.User
|
localUser := privilegeResult.User
|
||||||
if localUser == nil {
|
if localUser == nil {
|
||||||
return nil, nil, errors.New("no user in privilege result")
|
return nil, nil, errors.New("no user in privilege result")
|
||||||
@@ -66,28 +65,28 @@ func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh
|
|||||||
// If PTY requested but su doesn't support --pty, skip su and use executor
|
// If PTY requested but su doesn't support --pty, skip su and use executor
|
||||||
// This ensures PTY functionality is provided (executor runs within our allocated PTY)
|
// This ensures PTY functionality is provided (executor runs within our allocated PTY)
|
||||||
if hasPty && !s.suSupportsPty {
|
if hasPty && !s.suSupportsPty {
|
||||||
log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
|
logger.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
|
||||||
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
|
cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
||||||
}
|
}
|
||||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
|
||||||
return cmd, cleanup, nil
|
return cmd, cleanup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try su first for system integration (PAM/audit) when privileged
|
// Try su first for system integration (PAM/audit) when privileged
|
||||||
cmd, err := s.createSuCommand(session, localUser, hasPty)
|
cmd, err := s.createSuCommand(logger, session, localUser, hasPty)
|
||||||
if err != nil || privilegeResult.UsedFallback {
|
if err != nil || privilegeResult.UsedFallback {
|
||||||
log.Debugf("su command failed, falling back to executor: %v", err)
|
logger.Debugf("su command failed, falling back to executor: %v", err)
|
||||||
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
|
cmd, cleanup, err := s.createExecutorCommand(logger, session, localUser, hasPty)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
||||||
}
|
}
|
||||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
|
||||||
return cmd, cleanup, nil
|
return cmd, cleanup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
cmd.Env = s.prepareCommandEnv(logger, localUser, session)
|
||||||
return cmd, func() {}, nil
|
return cmd, func() {}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,17 +15,17 @@ import (
|
|||||||
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
|
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
|
||||||
|
|
||||||
// createSuCommand is not supported on JS/WASM
|
// createSuCommand is not supported on JS/WASM
|
||||||
func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
|
func (s *Server) createSuCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
|
||||||
return nil, errNotSupported
|
return nil, errNotSupported
|
||||||
}
|
}
|
||||||
|
|
||||||
// createExecutorCommand is not supported on JS/WASM
|
// createExecutorCommand is not supported on JS/WASM
|
||||||
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
|
func (s *Server) createExecutorCommand(_ *log.Entry, _ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
|
||||||
return nil, nil, errNotSupported
|
return nil, nil, errNotSupported
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepareCommandEnv is not supported on JS/WASM
|
// prepareCommandEnv is not supported on JS/WASM
|
||||||
func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string {
|
func (s *Server) prepareCommandEnv(_ *log.Entry, _ *user.User, _ ssh.Session) []string {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"os/user"
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -99,40 +100,52 @@ func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool {
|
|||||||
return isUtilLinux
|
return isUtilLinux
|
||||||
}
|
}
|
||||||
|
|
||||||
// createSuCommand creates a command using su -l -c for privilege switching
|
// createSuCommand creates a command using su - for privilege switching.
|
||||||
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
func (s *Server) createSuCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
||||||
|
if err := validateUsername(localUser.Username); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
|
||||||
|
}
|
||||||
|
|
||||||
suPath, err := exec.LookPath("su")
|
suPath, err := exec.LookPath("su")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("su command not available: %w", err)
|
return nil, fmt.Errorf("su command not available: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
command := session.RawCommand()
|
args := []string{"-"}
|
||||||
if command == "" {
|
|
||||||
return nil, fmt.Errorf("no command specified for su execution")
|
|
||||||
}
|
|
||||||
|
|
||||||
args := []string{"-l"}
|
|
||||||
if hasPty && s.suSupportsPty {
|
if hasPty && s.suSupportsPty {
|
||||||
args = append(args, "--pty")
|
args = append(args, "--pty")
|
||||||
}
|
}
|
||||||
args = append(args, localUser.Username, "-c", command)
|
args = append(args, localUser.Username)
|
||||||
|
|
||||||
|
command := session.RawCommand()
|
||||||
|
if command != "" {
|
||||||
|
args = append(args, "-c", command)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debugf("creating su command: %s %v", suPath, args)
|
||||||
cmd := exec.CommandContext(session.Context(), suPath, args...)
|
cmd := exec.CommandContext(session.Context(), suPath, args...)
|
||||||
cmd.Dir = localUser.HomeDir
|
cmd.Dir = localUser.HomeDir
|
||||||
|
|
||||||
return cmd, nil
|
return cmd, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getShellCommandArgs returns the shell command and arguments for executing a command string
|
// getShellCommandArgs returns the shell command and arguments for executing a command string.
|
||||||
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
||||||
if cmdString == "" {
|
if cmdString == "" {
|
||||||
return []string{shell, "-l"}
|
return []string{shell}
|
||||||
}
|
}
|
||||||
return []string{shell, "-l", "-c", cmdString}
|
return []string{shell, "-c", cmdString}
|
||||||
|
}
|
||||||
|
|
||||||
|
// createShellCommand creates an exec.Cmd configured as a login shell by setting argv[0] to "-shellname".
|
||||||
|
func (s *Server) createShellCommand(ctx context.Context, shell string, args []string) *exec.Cmd {
|
||||||
|
cmd := exec.CommandContext(ctx, shell, args[1:]...)
|
||||||
|
cmd.Args[0] = "-" + filepath.Base(shell)
|
||||||
|
return cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepareCommandEnv prepares environment variables for command execution on Unix
|
// prepareCommandEnv prepares environment variables for command execution on Unix
|
||||||
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
func (s *Server) prepareCommandEnv(_ *log.Entry, localUser *user.User, session ssh.Session) []string {
|
||||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||||
env = append(env, prepareSSHEnv(session)...)
|
env = append(env, prepareSSHEnv(session)...)
|
||||||
for _, v := range session.Environ() {
|
for _, v := range session.Environ() {
|
||||||
@@ -154,7 +167,7 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, e
|
|||||||
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||||
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
|
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Pty command creation failed: %v", err)
|
logger.Errorf("Pty command creation failed: %v", err)
|
||||||
@@ -244,11 +257,6 @@ func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *pty
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer func() {
|
|
||||||
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
|
|
||||||
logger.Debugf("session close error: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
if _, err := io.Copy(session, ptmx); err != nil {
|
if _, err := io.Copy(session, ptmx); err != nil {
|
||||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
|
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
|
||||||
logger.Warnf("Pty output copy error: %v", err)
|
logger.Warnf("Pty output copy error: %v", err)
|
||||||
@@ -268,7 +276,7 @@ func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, ex
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done)
|
s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done)
|
||||||
case err := <-done:
|
case err := <-done:
|
||||||
s.handlePtyCommandCompletion(logger, session, err)
|
s.handlePtyCommandCompletion(logger, session, ptyMgr, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -296,17 +304,20 @@ func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Ses
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) {
|
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Debugf("Pty command execution failed: %v", err)
|
logger.Debugf("Pty command execution failed: %v", err)
|
||||||
s.handleSessionExit(session, err, logger)
|
s.handleSessionExit(session, err, logger)
|
||||||
return
|
} else {
|
||||||
|
logger.Debugf("Pty command completed successfully")
|
||||||
|
if err := session.Exit(0); err != nil {
|
||||||
|
logSessionExitError(logger, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normal completion
|
// Close PTY to unblock io.Copy goroutines
|
||||||
logger.Debugf("Pty command completed successfully")
|
if err := ptyMgr.Close(); err != nil {
|
||||||
if err := session.Exit(0); err != nil {
|
logger.Debugf("Pty close after completion: %v", err)
|
||||||
logSessionExitError(logger, err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,32 +20,32 @@ import (
|
|||||||
|
|
||||||
// getUserEnvironment retrieves the Windows environment for the target user.
|
// getUserEnvironment retrieves the Windows environment for the target user.
|
||||||
// Follows OpenSSH's resilient approach with graceful degradation on failures.
|
// Follows OpenSSH's resilient approach with graceful degradation on failures.
|
||||||
func (s *Server) getUserEnvironment(username, domain string) ([]string, error) {
|
func (s *Server) getUserEnvironment(logger *log.Entry, username, domain string) ([]string, error) {
|
||||||
userToken, err := s.getUserToken(username, domain)
|
userToken, err := s.getUserToken(logger, username, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get user token: %w", err)
|
return nil, fmt.Errorf("get user token: %w", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := windows.CloseHandle(userToken); err != nil {
|
if err := windows.CloseHandle(userToken); err != nil {
|
||||||
log.Debugf("close user token: %v", err)
|
logger.Debugf("close user token: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return s.getUserEnvironmentWithToken(userToken, username, domain)
|
return s.getUserEnvironmentWithToken(logger, userToken, username, domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
|
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
|
||||||
func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) {
|
func (s *Server) getUserEnvironmentWithToken(logger *log.Entry, userToken windows.Handle, username, domain string) ([]string, error) {
|
||||||
userProfile, err := s.loadUserProfile(userToken, username, domain)
|
userProfile, err := s.loadUserProfile(userToken, username, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
|
logger.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
|
||||||
userProfile = fmt.Sprintf("C:\\Users\\%s", username)
|
userProfile = fmt.Sprintf("C:\\Users\\%s", username)
|
||||||
}
|
}
|
||||||
|
|
||||||
envMap := make(map[string]string)
|
envMap := make(map[string]string)
|
||||||
|
|
||||||
if err := s.loadSystemEnvironment(envMap); err != nil {
|
if err := s.loadSystemEnvironment(envMap); err != nil {
|
||||||
log.Debugf("failed to load system environment from registry: %v", err)
|
logger.Debugf("failed to load system environment from registry: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.setUserEnvironmentVariables(envMap, userProfile, username, domain)
|
s.setUserEnvironmentVariables(envMap, userProfile, username, domain)
|
||||||
@@ -59,8 +59,8 @@ func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getUserToken creates a user token for the specified user.
|
// getUserToken creates a user token for the specified user.
|
||||||
func (s *Server) getUserToken(username, domain string) (windows.Handle, error) {
|
func (s *Server) getUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
|
||||||
privilegeDropper := NewPrivilegeDropper()
|
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
|
||||||
token, err := privilegeDropper.createToken(username, domain)
|
token, err := privilegeDropper.createToken(username, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("generate S4U user token: %w", err)
|
return 0, fmt.Errorf("generate S4U user token: %w", err)
|
||||||
@@ -242,9 +242,9 @@ func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfi
|
|||||||
}
|
}
|
||||||
|
|
||||||
// prepareCommandEnv prepares environment variables for command execution on Windows
|
// prepareCommandEnv prepares environment variables for command execution on Windows
|
||||||
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, session ssh.Session) []string {
|
||||||
username, domain := s.parseUsername(localUser.Username)
|
username, domain := s.parseUsername(localUser.Username)
|
||||||
userEnv, err := s.getUserEnvironment(username, domain)
|
userEnv, err := s.getUserEnvironment(logger, username, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
|
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
|
||||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||||
@@ -267,22 +267,16 @@ func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []
|
|||||||
return env
|
return env
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool {
|
||||||
if privilegeResult.User == nil {
|
if privilegeResult.User == nil {
|
||||||
logger.Errorf("no user in privilege result")
|
logger.Errorf("no user in privilege result")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := session.Command()
|
|
||||||
shell := getUserShell(privilegeResult.User.Uid)
|
shell := getUserShell(privilegeResult.User.Uid)
|
||||||
|
logger.Infof("starting interactive shell: %s", shell)
|
||||||
|
|
||||||
if len(cmd) == 0 {
|
s.executeCommandWithPty(logger, session, nil, privilegeResult, ptyReq, nil)
|
||||||
logger.Infof("starting interactive shell: %s", shell)
|
|
||||||
} else {
|
|
||||||
logger.Infof("executing command: %s", safeLogCommand(cmd))
|
|
||||||
}
|
|
||||||
|
|
||||||
s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -294,11 +288,6 @@ func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
|||||||
return []string{shell, "-Command", cmdString}
|
return []string{shell, "-Command", cmdString}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
|
|
||||||
logger.Info("starting interactive shell")
|
|
||||||
s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand())
|
|
||||||
}
|
|
||||||
|
|
||||||
type PtyExecutionRequest struct {
|
type PtyExecutionRequest struct {
|
||||||
Shell string
|
Shell string
|
||||||
Command string
|
Command string
|
||||||
@@ -308,25 +297,25 @@ type PtyExecutionRequest struct {
|
|||||||
Domain string
|
Domain string
|
||||||
}
|
}
|
||||||
|
|
||||||
func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error {
|
func executePtyCommandWithUserToken(logger *log.Entry, session ssh.Session, req PtyExecutionRequest) error {
|
||||||
log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
|
logger.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
|
||||||
req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height)
|
req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height)
|
||||||
|
|
||||||
privilegeDropper := NewPrivilegeDropper()
|
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
|
||||||
userToken, err := privilegeDropper.createToken(req.Username, req.Domain)
|
userToken, err := privilegeDropper.createToken(req.Username, req.Domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create user token: %w", err)
|
return fmt.Errorf("create user token: %w", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := windows.CloseHandle(userToken); err != nil {
|
if err := windows.CloseHandle(userToken); err != nil {
|
||||||
log.Debugf("close user token: %v", err)
|
logger.Debugf("close user token: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
server := &Server{}
|
server := &Server{}
|
||||||
userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain)
|
userEnv, err := server.getUserEnvironmentWithToken(logger, userToken, req.Username, req.Domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
|
logger.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
|
||||||
userEnv = os.Environ()
|
userEnv = os.Environ()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -348,8 +337,8 @@ func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, re
|
|||||||
Environment: userEnv,
|
Environment: userEnv,
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
|
logger.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
|
||||||
return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig)
|
return winpty.ExecutePtyWithUserToken(session, ptyConfig, userConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getUserHomeFromEnv(env []string) string {
|
func getUserHomeFromEnv(env []string) string {
|
||||||
@@ -371,10 +360,8 @@ func (s *Server) killProcessGroup(cmd *exec.Cmd) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logger := log.WithField("pid", cmd.Process.Pid)
|
|
||||||
|
|
||||||
if err := cmd.Process.Kill(); err != nil {
|
if err := cmd.Process.Kill(); err != nil {
|
||||||
logger.Debugf("kill process failed: %v", err)
|
log.Debugf("kill process %d failed: %v", cmd.Process.Pid, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -389,21 +376,7 @@ func (s *Server) detectUtilLinuxLogin(context.Context) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
|
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
|
||||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _ *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window) bool {
|
||||||
command := session.RawCommand()
|
|
||||||
if command == "" {
|
|
||||||
logger.Error("no command specified for PTY execution")
|
|
||||||
if err := session.Exit(1); err != nil {
|
|
||||||
logSessionExitError(logger, err)
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command)
|
|
||||||
}
|
|
||||||
|
|
||||||
// executeConPtyCommand executes a command using ConPty (common for interactive and command execution)
|
|
||||||
func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool {
|
|
||||||
localUser := privilegeResult.User
|
localUser := privilegeResult.User
|
||||||
if localUser == nil {
|
if localUser == nil {
|
||||||
logger.Errorf("no user in privilege result")
|
logger.Errorf("no user in privilege result")
|
||||||
@@ -415,14 +388,14 @@ func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, pr
|
|||||||
|
|
||||||
req := PtyExecutionRequest{
|
req := PtyExecutionRequest{
|
||||||
Shell: shell,
|
Shell: shell,
|
||||||
Command: command,
|
Command: session.RawCommand(),
|
||||||
Width: ptyReq.Window.Width,
|
Width: ptyReq.Window.Width,
|
||||||
Height: ptyReq.Window.Height,
|
Height: ptyReq.Window.Height,
|
||||||
Username: username,
|
Username: username,
|
||||||
Domain: domain,
|
Domain: domain,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil {
|
if err := executePtyCommandWithUserToken(logger, session, req); err != nil {
|
||||||
logger.Errorf("ConPty execution failed: %v", err)
|
logger.Errorf("ConPty execution failed: %v", err)
|
||||||
if err := session.Exit(1); err != nil {
|
if err := session.Exit(1); err != nil {
|
||||||
logSessionExitError(logger, err)
|
logSessionExitError(logger, err)
|
||||||
|
|||||||
@@ -4,12 +4,15 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -23,25 +26,67 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestMain handles package-level setup and cleanup
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
|
// On platforms where su doesn't support --pty (macOS, FreeBSD, Windows), the SSH server
|
||||||
// This happens when running tests as non-privileged user with fallback
|
// spawns an executor subprocess via os.Executable(). During tests, this invokes the test
|
||||||
|
// binary with "ssh exec" args. We handle that here to properly execute commands and
|
||||||
|
// propagate exit codes.
|
||||||
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
||||||
// Just exit with error to break the recursion
|
runTestExecutor()
|
||||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
|
return
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run tests
|
|
||||||
code := m.Run()
|
code := m.Run()
|
||||||
|
|
||||||
// Cleanup any created test users
|
|
||||||
testutil.CleanupTestUsers()
|
testutil.CleanupTestUsers()
|
||||||
|
|
||||||
os.Exit(code)
|
os.Exit(code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// runTestExecutor emulates the netbird executor for tests.
|
||||||
|
// Parses --shell and --cmd args, runs the command, and exits with the correct code.
|
||||||
|
func runTestExecutor() {
|
||||||
|
if os.Getenv("_NETBIRD_TEST_EXECUTOR") != "" {
|
||||||
|
fmt.Fprintf(os.Stderr, "executor recursion detected\n")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
os.Setenv("_NETBIRD_TEST_EXECUTOR", "1")
|
||||||
|
|
||||||
|
shell := "/bin/sh"
|
||||||
|
var command string
|
||||||
|
for i := 3; i < len(os.Args); i++ {
|
||||||
|
switch os.Args[i] {
|
||||||
|
case "--shell":
|
||||||
|
if i+1 < len(os.Args) {
|
||||||
|
shell = os.Args[i+1]
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
case "--cmd":
|
||||||
|
if i+1 < len(os.Args) {
|
||||||
|
command = os.Args[i+1]
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
if command == "" {
|
||||||
|
cmd = exec.Command(shell)
|
||||||
|
} else {
|
||||||
|
cmd = exec.Command(shell, "-c", command)
|
||||||
|
}
|
||||||
|
cmd.Args[0] = "-" + filepath.Base(shell)
|
||||||
|
cmd.Stdin = os.Stdin
|
||||||
|
cmd.Stdout = os.Stdout
|
||||||
|
cmd.Stderr = os.Stderr
|
||||||
|
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||||
|
os.Exit(exitErr.ExitCode())
|
||||||
|
}
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client
|
// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client
|
||||||
func TestSSHServerCompatibility(t *testing.T) {
|
func TestSSHServerCompatibility(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
@@ -405,6 +450,171 @@ func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) {
|
|||||||
return createTempKeyFileFromBytes(t, privateKey)
|
return createTempKeyFileFromBytes(t, privateKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestSSHPtyModes tests different PTY allocation modes (-T, -t, -tt flags)
|
||||||
|
// This ensures our implementation matches OpenSSH behavior for:
|
||||||
|
// - ssh host command (no PTY - default when no TTY)
|
||||||
|
// - ssh -T host command (explicit no PTY)
|
||||||
|
// - ssh -t host command (force PTY)
|
||||||
|
// - ssh -T host (no PTY shell - our implementation)
|
||||||
|
func TestSSHPtyModes(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping SSH PTY mode tests in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isSSHClientAvailable() {
|
||||||
|
t.Skip("SSH client not available on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||||
|
t.Skip("Skipping Windows SSH PTY tests in CI due to S4U authentication issues")
|
||||||
|
}
|
||||||
|
|
||||||
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
serverConfig := &Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: nil,
|
||||||
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
|
server.SetAllowRootLogin(true)
|
||||||
|
|
||||||
|
serverAddr := StartTestServer(t, server)
|
||||||
|
defer func() {
|
||||||
|
err := server.Stop()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH)
|
||||||
|
defer cleanupKey()
|
||||||
|
|
||||||
|
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
username := testutil.GetTestUsername(t)
|
||||||
|
|
||||||
|
baseArgs := []string{
|
||||||
|
"-i", clientKeyFile,
|
||||||
|
"-p", portStr,
|
||||||
|
"-o", "StrictHostKeyChecking=no",
|
||||||
|
"-o", "UserKnownHostsFile=/dev/null",
|
||||||
|
"-o", "ConnectTimeout=5",
|
||||||
|
"-o", "BatchMode=yes",
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("command_default_no_pty", func(t *testing.T) {
|
||||||
|
args := append(slices.Clone(baseArgs), fmt.Sprintf("%s@%s", username, host), "echo", "no_pty_default")
|
||||||
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
require.NoError(t, err, "Command (default no PTY) failed: %s", output)
|
||||||
|
assert.Contains(t, string(output), "no_pty_default")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("command_explicit_no_pty", func(t *testing.T) {
|
||||||
|
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "echo", "explicit_no_pty")
|
||||||
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
require.NoError(t, err, "Command (-T explicit no PTY) failed: %s", output)
|
||||||
|
assert.Contains(t, string(output), "explicit_no_pty")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("command_force_pty", func(t *testing.T) {
|
||||||
|
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "echo", "force_pty")
|
||||||
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
require.NoError(t, err, "Command (-tt force PTY) failed: %s", output)
|
||||||
|
assert.Contains(t, string(output), "force_pty")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("shell_explicit_no_pty", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host))
|
||||||
|
cmd := exec.CommandContext(ctx, "ssh", args...)
|
||||||
|
|
||||||
|
stdin, err := cmd.StdinPipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
stdout, err := cmd.StdoutPipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, cmd.Start(), "Shell (-T no PTY) start failed")
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer stdin.Close()
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
_, err := stdin.Write([]byte("echo shell_no_pty_test\n"))
|
||||||
|
assert.NoError(t, err, "write echo command")
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
_, err = stdin.Write([]byte("exit 0\n"))
|
||||||
|
assert.NoError(t, err, "write exit command")
|
||||||
|
}()
|
||||||
|
|
||||||
|
output, _ := io.ReadAll(stdout)
|
||||||
|
err = cmd.Wait()
|
||||||
|
|
||||||
|
require.NoError(t, err, "Shell (-T no PTY) failed: %s", output)
|
||||||
|
assert.Contains(t, string(output), "shell_no_pty_test")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("exit_code_preserved_no_pty", func(t *testing.T) {
|
||||||
|
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "exit", "42")
|
||||||
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
err := cmd.Run()
|
||||||
|
require.Error(t, err, "Command should exit with non-zero")
|
||||||
|
|
||||||
|
var exitErr *exec.ExitError
|
||||||
|
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
|
||||||
|
assert.Equal(t, 42, exitErr.ExitCode(), "Exit code should be preserved with -T")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("exit_code_preserved_with_pty", func(t *testing.T) {
|
||||||
|
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "sh -c 'exit 43'")
|
||||||
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
err := cmd.Run()
|
||||||
|
require.Error(t, err, "PTY command should exit with non-zero")
|
||||||
|
|
||||||
|
var exitErr *exec.ExitError
|
||||||
|
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
|
||||||
|
assert.Equal(t, 43, exitErr.ExitCode(), "Exit code should be preserved with -tt")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("stderr_works_no_pty", func(t *testing.T) {
|
||||||
|
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host),
|
||||||
|
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
|
||||||
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
var stdout, stderr strings.Builder
|
||||||
|
cmd.Stdout = &stdout
|
||||||
|
cmd.Stderr = &stderr
|
||||||
|
|
||||||
|
require.NoError(t, cmd.Run(), "stderr test failed")
|
||||||
|
assert.Contains(t, stdout.String(), "stdout_msg", "stdout should have stdout_msg")
|
||||||
|
assert.Contains(t, stderr.String(), "stderr_msg", "stderr should have stderr_msg")
|
||||||
|
assert.NotContains(t, stdout.String(), "stderr_msg", "stdout should NOT have stderr_msg")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("stderr_merged_with_pty", func(t *testing.T) {
|
||||||
|
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host),
|
||||||
|
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
|
||||||
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
require.NoError(t, err, "PTY stderr test failed: %s", output)
|
||||||
|
assert.Contains(t, string(output), "stdout_msg")
|
||||||
|
assert.Contains(t, string(output), "stderr_msg")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility
|
// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility
|
||||||
func TestSSHServerFeatureCompatibility(t *testing.T) {
|
func TestSSHServerFeatureCompatibility(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
@@ -35,11 +36,35 @@ type ExecutorConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// PrivilegeDropper handles secure privilege dropping in child processes
|
// PrivilegeDropper handles secure privilege dropping in child processes
|
||||||
type PrivilegeDropper struct{}
|
type PrivilegeDropper struct {
|
||||||
|
logger *log.Entry
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper
|
||||||
|
type PrivilegeDropperOption func(*PrivilegeDropper)
|
||||||
|
|
||||||
// NewPrivilegeDropper creates a new privilege dropper
|
// NewPrivilegeDropper creates a new privilege dropper
|
||||||
func NewPrivilegeDropper() *PrivilegeDropper {
|
func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper {
|
||||||
return &PrivilegeDropper{}
|
pd := &PrivilegeDropper{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(pd)
|
||||||
|
}
|
||||||
|
return pd
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithLogger sets the logger for the PrivilegeDropper
|
||||||
|
func WithLogger(logger *log.Entry) PrivilegeDropperOption {
|
||||||
|
return func(pd *PrivilegeDropper) {
|
||||||
|
pd.logger = logger
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// log returns the logger, falling back to standard logger if none set
|
||||||
|
func (pd *PrivilegeDropper) log() *log.Entry {
|
||||||
|
if pd.logger != nil {
|
||||||
|
return pd.logger
|
||||||
|
}
|
||||||
|
return log.NewEntry(log.StandardLogger())
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
|
// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
|
||||||
@@ -83,7 +108,7 @@ func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config Ex
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
|
pd.log().Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
|
||||||
return exec.CommandContext(ctx, netbirdPath, args...), nil
|
return exec.CommandContext(ctx, netbirdPath, args...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -206,17 +231,22 @@ func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config
|
|||||||
|
|
||||||
var execCmd *exec.Cmd
|
var execCmd *exec.Cmd
|
||||||
if config.Command == "" {
|
if config.Command == "" {
|
||||||
os.Exit(ExitCodeSuccess)
|
execCmd = exec.CommandContext(ctx, config.Shell)
|
||||||
|
} else {
|
||||||
|
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
|
||||||
}
|
}
|
||||||
|
execCmd.Args[0] = "-" + filepath.Base(config.Shell)
|
||||||
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
|
|
||||||
execCmd.Stdin = os.Stdin
|
execCmd.Stdin = os.Stdin
|
||||||
execCmd.Stdout = os.Stdout
|
execCmd.Stdout = os.Stdout
|
||||||
execCmd.Stderr = os.Stderr
|
execCmd.Stderr = os.Stderr
|
||||||
|
|
||||||
cmdParts := strings.Fields(config.Command)
|
if config.Command == "" {
|
||||||
safeCmd := safeLogCommand(cmdParts)
|
log.Tracef("executing login shell: %s", execCmd.Path)
|
||||||
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
|
} else {
|
||||||
|
cmdParts := strings.Fields(config.Command)
|
||||||
|
safeCmd := safeLogCommand(cmdParts)
|
||||||
|
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
|
||||||
|
}
|
||||||
if err := execCmd.Run(); err != nil {
|
if err := execCmd.Run(); err != nil {
|
||||||
var exitError *exec.ExitError
|
var exitError *exec.ExitError
|
||||||
if errors.As(err, &exitError) {
|
if errors.As(err, &exitError) {
|
||||||
|
|||||||
@@ -28,22 +28,45 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type WindowsExecutorConfig struct {
|
type WindowsExecutorConfig struct {
|
||||||
Username string
|
Username string
|
||||||
Domain string
|
Domain string
|
||||||
WorkingDir string
|
WorkingDir string
|
||||||
Shell string
|
Shell string
|
||||||
Command string
|
Command string
|
||||||
Args []string
|
Args []string
|
||||||
Interactive bool
|
Pty bool
|
||||||
Pty bool
|
PtyWidth int
|
||||||
PtyWidth int
|
PtyHeight int
|
||||||
PtyHeight int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type PrivilegeDropper struct{}
|
type PrivilegeDropper struct {
|
||||||
|
logger *log.Entry
|
||||||
|
}
|
||||||
|
|
||||||
func NewPrivilegeDropper() *PrivilegeDropper {
|
// PrivilegeDropperOption is a functional option for configuring PrivilegeDropper
|
||||||
return &PrivilegeDropper{}
|
type PrivilegeDropperOption func(*PrivilegeDropper)
|
||||||
|
|
||||||
|
func NewPrivilegeDropper(opts ...PrivilegeDropperOption) *PrivilegeDropper {
|
||||||
|
pd := &PrivilegeDropper{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(pd)
|
||||||
|
}
|
||||||
|
return pd
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithLogger sets the logger for the PrivilegeDropper
|
||||||
|
func WithLogger(logger *log.Entry) PrivilegeDropperOption {
|
||||||
|
return func(pd *PrivilegeDropper) {
|
||||||
|
pd.logger = logger
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// log returns the logger, falling back to standard logger if none set
|
||||||
|
func (pd *PrivilegeDropper) log() *log.Entry {
|
||||||
|
if pd.logger != nil {
|
||||||
|
return pd.logger
|
||||||
|
}
|
||||||
|
return log.NewEntry(log.StandardLogger())
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -56,7 +79,6 @@ const (
|
|||||||
|
|
||||||
// Common error messages
|
// Common error messages
|
||||||
commandFlag = "-Command"
|
commandFlag = "-Command"
|
||||||
closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials
|
|
||||||
convertUsernameError = "convert username to UTF16: %w"
|
convertUsernameError = "convert username to UTF16: %w"
|
||||||
convertDomainError = "convert domain to UTF16: %w"
|
convertDomainError = "convert domain to UTF16: %w"
|
||||||
)
|
)
|
||||||
@@ -80,7 +102,7 @@ func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, co
|
|||||||
shellArgs = []string{shell}
|
shellArgs = []string{shell}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
|
pd.log().Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
|
||||||
|
|
||||||
cmd, token, err := pd.CreateWindowsProcessAsUser(
|
cmd, token, err := pd.CreateWindowsProcessAsUser(
|
||||||
ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
|
ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
|
||||||
@@ -180,10 +202,10 @@ func newLsaString(s string) lsaString {
|
|||||||
|
|
||||||
// generateS4UUserToken creates a Windows token using S4U authentication
|
// generateS4UUserToken creates a Windows token using S4U authentication
|
||||||
// This is the exact approach OpenSSH for Windows uses for public key authentication
|
// This is the exact approach OpenSSH for Windows uses for public key authentication
|
||||||
func generateS4UUserToken(username, domain string) (windows.Handle, error) {
|
func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
|
||||||
userCpn := buildUserCpn(username, domain)
|
userCpn := buildUserCpn(username, domain)
|
||||||
|
|
||||||
pd := NewPrivilegeDropper()
|
pd := NewPrivilegeDropper(WithLogger(logger))
|
||||||
isDomainUser := !pd.isLocalUser(domain)
|
isDomainUser := !pd.isLocalUser(domain)
|
||||||
|
|
||||||
lsaHandle, err := initializeLsaConnection()
|
lsaHandle, err := initializeLsaConnection()
|
||||||
@@ -197,12 +219,12 @@ func generateS4UUserToken(username, domain string) (windows.Handle, error) {
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser)
|
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(logger, username, domain, isDomainUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
|
return performS4ULogon(logger, lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildUserCpn constructs the user principal name
|
// buildUserCpn constructs the user principal name
|
||||||
@@ -310,21 +332,21 @@ func lookupPrincipalName(username, domain string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// prepareS4ULogonStructure creates the appropriate S4U logon structure
|
// prepareS4ULogonStructure creates the appropriate S4U logon structure
|
||||||
func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
|
func prepareS4ULogonStructure(logger *log.Entry, username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
|
||||||
if isDomainUser {
|
if isDomainUser {
|
||||||
return prepareDomainS4ULogon(username, domain)
|
return prepareDomainS4ULogon(logger, username, domain)
|
||||||
}
|
}
|
||||||
return prepareLocalS4ULogon(username)
|
return prepareLocalS4ULogon(logger, username)
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepareDomainS4ULogon creates S4U logon structure for domain users
|
// prepareDomainS4ULogon creates S4U logon structure for domain users
|
||||||
func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) {
|
func prepareDomainS4ULogon(logger *log.Entry, username, domain string) (unsafe.Pointer, uintptr, error) {
|
||||||
upn, err := lookupPrincipalName(username, domain)
|
upn, err := lookupPrincipalName(username, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("lookup principal name: %w", err)
|
return nil, 0, fmt.Errorf("lookup principal name: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
|
logger.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
|
||||||
|
|
||||||
upnUtf16, err := windows.UTF16FromString(upn)
|
upnUtf16, err := windows.UTF16FromString(upn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -357,8 +379,8 @@ func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// prepareLocalS4ULogon creates S4U logon structure for local users
|
// prepareLocalS4ULogon creates S4U logon structure for local users
|
||||||
func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
|
func prepareLocalS4ULogon(logger *log.Entry, username string) (unsafe.Pointer, uintptr, error) {
|
||||||
log.Debugf("using Msv1_0S4ULogon for local user: %s", username)
|
logger.Debugf("using Msv1_0S4ULogon for local user: %s", username)
|
||||||
|
|
||||||
usernameUtf16, err := windows.UTF16FromString(username)
|
usernameUtf16, err := windows.UTF16FromString(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -406,11 +428,11 @@ func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// performS4ULogon executes the S4U logon operation
|
// performS4ULogon executes the S4U logon operation
|
||||||
func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
|
func performS4ULogon(logger *log.Entry, lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
|
||||||
var tokenSource tokenSource
|
var tokenSource tokenSource
|
||||||
copy(tokenSource.SourceName[:], "netbird")
|
copy(tokenSource.SourceName[:], "netbird")
|
||||||
if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 {
|
if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 {
|
||||||
log.Debugf("AllocateLocallyUniqueId failed")
|
logger.Debugf("AllocateLocallyUniqueId failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
originName := newLsaString("netbird")
|
originName := newLsaString("netbird")
|
||||||
@@ -441,7 +463,7 @@ func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo u
|
|||||||
|
|
||||||
if profile != 0 {
|
if profile != 0 {
|
||||||
if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess {
|
if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess {
|
||||||
log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
|
logger.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -449,7 +471,7 @@ func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo u
|
|||||||
return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus)
|
return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("created S4U %s token for user %s",
|
logger.Debugf("created S4U %s token for user %s",
|
||||||
map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn)
|
map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn)
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
@@ -497,8 +519,8 @@ func (pd *PrivilegeDropper) isLocalUser(domain string) bool {
|
|||||||
|
|
||||||
// authenticateLocalUser handles authentication for local users
|
// authenticateLocalUser handles authentication for local users
|
||||||
func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) {
|
func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) {
|
||||||
log.Debugf("using S4U authentication for local user %s", fullUsername)
|
pd.log().Debugf("using S4U authentication for local user %s", fullUsername)
|
||||||
token, err := generateS4UUserToken(username, ".")
|
token, err := generateS4UUserToken(pd.log(), username, ".")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err)
|
return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err)
|
||||||
}
|
}
|
||||||
@@ -507,12 +529,12 @@ func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string)
|
|||||||
|
|
||||||
// authenticateDomainUser handles authentication for domain users
|
// authenticateDomainUser handles authentication for domain users
|
||||||
func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) {
|
func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) {
|
||||||
log.Debugf("using S4U authentication for domain user %s", fullUsername)
|
pd.log().Debugf("using S4U authentication for domain user %s", fullUsername)
|
||||||
token, err := generateS4UUserToken(username, domain)
|
token, err := generateS4UUserToken(pd.log(), username, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err)
|
return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err)
|
||||||
}
|
}
|
||||||
log.Debugf("Successfully created S4U token for domain user %s", fullUsername)
|
pd.log().Debugf("successfully created S4U token for domain user %s", fullUsername)
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -526,7 +548,7 @@ func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, exec
|
|||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := windows.CloseHandle(token); err != nil {
|
if err := windows.CloseHandle(token); err != nil {
|
||||||
log.Debugf("close impersonation token: %v", err)
|
pd.log().Debugf("close impersonation token: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -564,7 +586,7 @@ func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceTo
|
|||||||
return cmd, primaryToken, nil
|
return cmd, primaryToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createSuCommand creates a command using su -l -c for privilege switching (Windows stub)
|
// createSuCommand creates a command using su - for privilege switching (Windows stub).
|
||||||
func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) {
|
func (s *Server) createSuCommand(*log.Entry, ssh.Session, *user.User, bool) (*exec.Cmd, error) {
|
||||||
return nil, fmt.Errorf("su command not available on Windows")
|
return nil, fmt.Errorf("su command not available on Windows")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -271,13 +271,6 @@ func (s *Server) isRemotePortForwardingAllowed() bool {
|
|||||||
return s.allowRemotePortForwarding
|
return s.allowRemotePortForwarding
|
||||||
}
|
}
|
||||||
|
|
||||||
// isPortForwardingEnabled checks if any port forwarding (local or remote) is enabled
|
|
||||||
func (s *Server) isPortForwardingEnabled() bool {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
return s.allowLocalPortForwarding || s.allowRemotePortForwarding
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseTcpipForwardRequest parses the SSH request payload
|
// parseTcpipForwardRequest parses the SSH request payload
|
||||||
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
|
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
|
||||||
var payload tcpipForwardMsg
|
var payload tcpipForwardMsg
|
||||||
|
|||||||
@@ -335,7 +335,7 @@ func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
|
|||||||
sessions = append(sessions, info)
|
sessions = append(sessions, info)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add authenticated connections without sessions (e.g., -N/-T or port-forwarding only)
|
// Add authenticated connections without sessions (e.g., -N or port-forwarding only)
|
||||||
for key, connState := range s.connections {
|
for key, connState := range s.connections {
|
||||||
remoteAddr := string(key)
|
remoteAddr := string(key)
|
||||||
if reportedAddrs[remoteAddr] {
|
if reportedAddrs[remoteAddr] {
|
||||||
|
|||||||
@@ -483,12 +483,11 @@ func TestServer_IsPrivilegedUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_PortForwardingOnlySession(t *testing.T) {
|
func TestServer_NonPtyShellSession(t *testing.T) {
|
||||||
// Test that sessions without PTY and command are allowed when port forwarding is enabled
|
// Test that non-PTY shell sessions (ssh -T) work regardless of port forwarding settings.
|
||||||
currentUser, err := user.Current()
|
currentUser, err := user.Current()
|
||||||
require.NoError(t, err, "Should be able to get current user")
|
require.NoError(t, err, "Should be able to get current user")
|
||||||
|
|
||||||
// Generate host key for server
|
|
||||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -496,36 +495,26 @@ func TestServer_PortForwardingOnlySession(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
allowLocalForwarding bool
|
allowLocalForwarding bool
|
||||||
allowRemoteForwarding bool
|
allowRemoteForwarding bool
|
||||||
expectAllowed bool
|
|
||||||
description string
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "session_allowed_with_local_forwarding",
|
name: "shell_with_local_forwarding_enabled",
|
||||||
allowLocalForwarding: true,
|
allowLocalForwarding: true,
|
||||||
allowRemoteForwarding: false,
|
allowRemoteForwarding: false,
|
||||||
expectAllowed: true,
|
|
||||||
description: "Port-forwarding-only session should be allowed when local forwarding is enabled",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "session_allowed_with_remote_forwarding",
|
name: "shell_with_remote_forwarding_enabled",
|
||||||
allowLocalForwarding: false,
|
allowLocalForwarding: false,
|
||||||
allowRemoteForwarding: true,
|
allowRemoteForwarding: true,
|
||||||
expectAllowed: true,
|
|
||||||
description: "Port-forwarding-only session should be allowed when remote forwarding is enabled",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "session_allowed_with_both",
|
name: "shell_with_both_forwarding_enabled",
|
||||||
allowLocalForwarding: true,
|
allowLocalForwarding: true,
|
||||||
allowRemoteForwarding: true,
|
allowRemoteForwarding: true,
|
||||||
expectAllowed: true,
|
|
||||||
description: "Port-forwarding-only session should be allowed when both forwarding types enabled",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "session_denied_without_forwarding",
|
name: "shell_with_forwarding_disabled",
|
||||||
allowLocalForwarding: false,
|
allowLocalForwarding: false,
|
||||||
allowRemoteForwarding: false,
|
allowRemoteForwarding: false,
|
||||||
expectAllowed: false,
|
|
||||||
description: "Port-forwarding-only session should be denied when all forwarding is disabled",
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -545,7 +534,6 @@ func TestServer_PortForwardingOnlySession(t *testing.T) {
|
|||||||
_ = server.Stop()
|
_ = server.Stop()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Connect to the server without requesting PTY or command
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -557,20 +545,10 @@ func TestServer_PortForwardingOnlySession(t *testing.T) {
|
|||||||
_ = client.Close()
|
_ = client.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Execute a command without PTY - this simulates ssh -T with no command
|
// Execute without PTY and no command - simulates ssh -T (shell without PTY)
|
||||||
// The server should either allow it (port forwarding enabled) or reject it
|
// Should always succeed regardless of port forwarding settings
|
||||||
output, err := client.ExecuteCommand(ctx, "")
|
_, err = client.ExecuteCommand(ctx, "")
|
||||||
if tt.expectAllowed {
|
assert.NoError(t, err, "Non-PTY shell session should be allowed")
|
||||||
// When allowed, the session stays open until cancelled
|
|
||||||
// ExecuteCommand with empty command should return without error
|
|
||||||
assert.NoError(t, err, "Session should be allowed when port forwarding is enabled")
|
|
||||||
assert.NotContains(t, output, "port forwarding is disabled",
|
|
||||||
"Output should not contain port forwarding disabled message")
|
|
||||||
} else if err != nil {
|
|
||||||
// When denied, we expect an error message about port forwarding being disabled
|
|
||||||
assert.Contains(t, err.Error(), "port forwarding is disabled",
|
|
||||||
"Should get port forwarding disabled message")
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -405,12 +405,14 @@ func TestSSHServer_WindowsShellHandling(t *testing.T) {
|
|||||||
assert.Equal(t, "-Command", args[1])
|
assert.Equal(t, "-Command", args[1])
|
||||||
assert.Equal(t, "echo test", args[2])
|
assert.Equal(t, "echo test", args[2])
|
||||||
} else {
|
} else {
|
||||||
// Test Unix shell behavior
|
|
||||||
args := server.getShellCommandArgs("/bin/sh", "echo test")
|
args := server.getShellCommandArgs("/bin/sh", "echo test")
|
||||||
assert.Equal(t, "/bin/sh", args[0])
|
assert.Equal(t, "/bin/sh", args[0])
|
||||||
assert.Equal(t, "-l", args[1])
|
assert.Equal(t, "-c", args[1])
|
||||||
assert.Equal(t, "-c", args[2])
|
assert.Equal(t, "echo test", args[2])
|
||||||
assert.Equal(t, "echo test", args[3])
|
|
||||||
|
args = server.getShellCommandArgs("/bin/sh", "")
|
||||||
|
assert.Equal(t, "/bin/sh", args[0])
|
||||||
|
assert.Len(t, args, 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -62,54 +62,12 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
|||||||
ptyReq, winCh, isPty := session.Pty()
|
ptyReq, winCh, isPty := session.Pty()
|
||||||
hasCommand := len(session.Command()) > 0
|
hasCommand := len(session.Command()) > 0
|
||||||
|
|
||||||
switch {
|
if isPty && !hasCommand {
|
||||||
case isPty && hasCommand:
|
// ssh <host> - PTY interactive session (login)
|
||||||
// ssh -t <host> <cmd> - Pty command execution
|
s.handlePtyLogin(logger, session, privilegeResult, ptyReq, winCh)
|
||||||
s.handleCommand(logger, session, privilegeResult, winCh)
|
} else {
|
||||||
case isPty:
|
// ssh <host> <cmd>, ssh -t <host> <cmd>, ssh -T <host> - command or shell execution
|
||||||
// ssh <host> - Pty interactive session (login)
|
s.handleExecution(logger, session, privilegeResult, ptyReq, winCh)
|
||||||
s.handlePty(logger, session, privilegeResult, ptyReq, winCh)
|
|
||||||
case hasCommand:
|
|
||||||
// ssh <host> <cmd> - non-Pty command execution
|
|
||||||
s.handleCommand(logger, session, privilegeResult, nil)
|
|
||||||
default:
|
|
||||||
// ssh -T (or ssh -N) - no PTY, no command
|
|
||||||
s.handleNonInteractiveSession(logger, session)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleNonInteractiveSession handles sessions that have no PTY and no command.
|
|
||||||
// These are typically used for port forwarding (ssh -L/-R) or tunneling (ssh -N).
|
|
||||||
func (s *Server) handleNonInteractiveSession(logger *log.Entry, session ssh.Session) {
|
|
||||||
s.updateSessionType(session, cmdNonInteractive)
|
|
||||||
|
|
||||||
if !s.isPortForwardingEnabled() {
|
|
||||||
if _, err := io.WriteString(session, "port forwarding is disabled on this server\n"); err != nil {
|
|
||||||
logger.Debugf(errWriteSession, err)
|
|
||||||
}
|
|
||||||
if err := session.Exit(1); err != nil {
|
|
||||||
logSessionExitError(logger, err)
|
|
||||||
}
|
|
||||||
logger.Infof("rejected non-interactive session: port forwarding disabled")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
<-session.Context().Done()
|
|
||||||
|
|
||||||
if err := session.Exit(0); err != nil {
|
|
||||||
logSessionExitError(logger, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) updateSessionType(session ssh.Session, sessionType string) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
for _, state := range s.sessions {
|
|
||||||
if state.session == session {
|
|
||||||
state.sessionType = sessionType
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// handlePty is not supported on JS/WASM
|
// handlePtyLogin is not supported on JS/WASM
|
||||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
|
func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
|
||||||
errorMsg := "PTY sessions are not supported on WASM/JS platform\n"
|
errorMsg := "PTY sessions are not supported on WASM/JS platform\n"
|
||||||
if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil {
|
if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil {
|
||||||
logger.Debugf(errWriteSession, err)
|
logger.Debugf(errWriteSession, err)
|
||||||
|
|||||||
@@ -181,8 +181,8 @@ func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) {
|
|||||||
|
|
||||||
// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping.
|
// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping.
|
||||||
// Returns the command and a cleanup function (no-op on Unix).
|
// Returns the command and a cleanup function (no-op on Unix).
|
||||||
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
||||||
log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
logger.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||||
|
|
||||||
if err := validateUsername(localUser.Username); err != nil {
|
if err := validateUsername(localUser.Username); err != nil {
|
||||||
return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
|
return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
|
||||||
@@ -192,7 +192,7 @@ func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("parse user credentials: %w", err)
|
return nil, nil, fmt.Errorf("parse user credentials: %w", err)
|
||||||
}
|
}
|
||||||
privilegeDropper := NewPrivilegeDropper()
|
privilegeDropper := NewPrivilegeDropper(WithLogger(logger))
|
||||||
config := ExecutorConfig{
|
config := ExecutorConfig{
|
||||||
UID: uid,
|
UID: uid,
|
||||||
GID: gid,
|
GID: gid,
|
||||||
@@ -233,7 +233,7 @@ func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.Use
|
|||||||
shell := getUserShell(localUser.Uid)
|
shell := getUserShell(localUser.Uid)
|
||||||
args := s.getShellCommandArgs(shell, session.RawCommand())
|
args := s.getShellCommandArgs(shell, session.RawCommand())
|
||||||
|
|
||||||
cmd := exec.CommandContext(session.Context(), args[0], args[1:]...)
|
cmd := s.createShellCommand(session.Context(), shell, args)
|
||||||
cmd.Dir = localUser.HomeDir
|
cmd.Dir = localUser.HomeDir
|
||||||
cmd.Env = s.preparePtyEnv(localUser, ptyReq, session)
|
cmd.Env = s.preparePtyEnv(localUser, ptyReq, session)
|
||||||
|
|
||||||
|
|||||||
@@ -88,20 +88,20 @@ func validateUsernameFormat(username string) error {
|
|||||||
|
|
||||||
// createExecutorCommand creates a command using Windows executor for privilege dropping.
|
// createExecutorCommand creates a command using Windows executor for privilege dropping.
|
||||||
// Returns the command and a cleanup function that must be called after starting the process.
|
// Returns the command and a cleanup function that must be called after starting the process.
|
||||||
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
||||||
log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
logger.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||||
|
|
||||||
username, _ := s.parseUsername(localUser.Username)
|
username, _ := s.parseUsername(localUser.Username)
|
||||||
if err := validateUsername(username); err != nil {
|
if err := validateUsername(username); err != nil {
|
||||||
return nil, nil, fmt.Errorf("invalid username %q: %w", username, err)
|
return nil, nil, fmt.Errorf("invalid username %q: %w", username, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.createUserSwitchCommand(localUser, session, hasPty)
|
return s.createUserSwitchCommand(logger, session, localUser)
|
||||||
}
|
}
|
||||||
|
|
||||||
// createUserSwitchCommand creates a command with Windows user switching.
|
// createUserSwitchCommand creates a command with Windows user switching.
|
||||||
// Returns the command and a cleanup function that must be called after starting the process.
|
// Returns the command and a cleanup function that must be called after starting the process.
|
||||||
func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, func(), error) {
|
func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session, localUser *user.User) (*exec.Cmd, func(), error) {
|
||||||
username, domain := s.parseUsername(localUser.Username)
|
username, domain := s.parseUsername(localUser.Username)
|
||||||
|
|
||||||
shell := getUserShell(localUser.Uid)
|
shell := getUserShell(localUser.Uid)
|
||||||
@@ -113,15 +113,14 @@ func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Sessi
|
|||||||
}
|
}
|
||||||
|
|
||||||
config := WindowsExecutorConfig{
|
config := WindowsExecutorConfig{
|
||||||
Username: username,
|
Username: username,
|
||||||
Domain: domain,
|
Domain: domain,
|
||||||
WorkingDir: localUser.HomeDir,
|
WorkingDir: localUser.HomeDir,
|
||||||
Shell: shell,
|
Shell: shell,
|
||||||
Command: command,
|
Command: command,
|
||||||
Interactive: interactive || (rawCmd == ""),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dropper := NewPrivilegeDropper()
|
dropper := NewPrivilegeDropper(WithLogger(logger))
|
||||||
cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config)
|
cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
@@ -130,7 +129,7 @@ func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Sessi
|
|||||||
cleanup := func() {
|
cleanup := func() {
|
||||||
if token != 0 {
|
if token != 0 {
|
||||||
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
|
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
|
||||||
log.Debugf("close primary token: %v", err)
|
logger.Debugf("close primary token: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// ExecutePtyWithUserToken executes a command with ConPty using user token.
|
// ExecutePtyWithUserToken executes a command with ConPty using user token.
|
||||||
func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error {
|
func ExecutePtyWithUserToken(session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error {
|
||||||
args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command)
|
args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command)
|
||||||
commandLine := buildCommandLine(args)
|
commandLine := buildCommandLine(args)
|
||||||
|
|
||||||
@@ -64,7 +64,7 @@ func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig
|
|||||||
Pty: ptyConfig,
|
Pty: ptyConfig,
|
||||||
User: userConfig,
|
User: userConfig,
|
||||||
Session: session,
|
Session: session,
|
||||||
Context: ctx,
|
Context: session.Context(),
|
||||||
}
|
}
|
||||||
|
|
||||||
return executeConPtyWithConfig(commandLine, config)
|
return executeConPtyWithConfig(commandLine, config)
|
||||||
|
|||||||
Reference in New Issue
Block a user