//go:build unix package server import ( "context" "errors" "fmt" "io" "os" "os/exec" "os/user" "runtime" "strings" "sync" "syscall" "time" "github.com/creack/pty" "github.com/gliderlabs/ssh" log "github.com/sirupsen/logrus" ) // ptyManager manages Pty file operations with thread safety type ptyManager struct { file *os.File mu sync.RWMutex closed bool closeErr error once sync.Once } func newPtyManager(file *os.File) *ptyManager { return &ptyManager{file: file} } func (pm *ptyManager) Close() error { pm.once.Do(func() { pm.mu.Lock() pm.closed = true pm.closeErr = pm.file.Close() pm.mu.Unlock() }) pm.mu.RLock() defer pm.mu.RUnlock() return pm.closeErr } func (pm *ptyManager) Setsize(ws *pty.Winsize) error { pm.mu.RLock() defer pm.mu.RUnlock() if pm.closed { return errors.New("pty is closed") } return pty.Setsize(pm.file, ws) } func (pm *ptyManager) File() *os.File { return pm.file } // detectSuPtySupport checks if su supports the --pty flag func (s *Server) detectSuPtySupport(ctx context.Context) bool { ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) defer cancel() cmd := exec.CommandContext(ctx, "su", "--help") output, err := cmd.CombinedOutput() if err != nil { log.Debugf("su --help failed (may not support --help): %v", err) return false } supported := strings.Contains(string(output), "--pty") log.Debugf("su --pty support detected: %v", supported) return supported } // detectUtilLinuxLogin checks if login is from util-linux (vs shadow-utils). // util-linux login uses vhangup() which requires setsid wrapper to avoid killing parent. // See https://bugs.debian.org/1078023 for details. func (s *Server) detectUtilLinuxLogin(ctx context.Context) bool { if runtime.GOOS != "linux" { return false } ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond) defer cancel() cmd := exec.CommandContext(ctx, "login", "--version") output, err := cmd.CombinedOutput() if err != nil { log.Debugf("login --version failed (likely shadow-utils): %v", err) return false } isUtilLinux := strings.Contains(string(output), "util-linux") log.Debugf("util-linux login detected: %v", isUtilLinux) return isUtilLinux } // createSuCommand creates a command using su -l -c for privilege switching func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) { suPath, err := exec.LookPath("su") if err != nil { return nil, fmt.Errorf("su command not available: %w", err) } command := session.RawCommand() if command == "" { return nil, fmt.Errorf("no command specified for su execution") } args := []string{"-l"} if hasPty && s.suSupportsPty { args = append(args, "--pty") } args = append(args, localUser.Username, "-c", command) cmd := exec.CommandContext(session.Context(), suPath, args...) cmd.Dir = localUser.HomeDir return cmd, nil } // getShellCommandArgs returns the shell command and arguments for executing a command string func (s *Server) getShellCommandArgs(shell, cmdString string) []string { if cmdString == "" { return []string{shell, "-l"} } return []string{shell, "-l", "-c", cmdString} } // prepareCommandEnv prepares environment variables for command execution on Unix func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string { env := prepareUserEnv(localUser, getUserShell(localUser.Uid)) env = append(env, prepareSSHEnv(session)...) for _, v := range session.Environ() { if acceptEnv(v) { env = append(env, v) } } return env } // executeCommandWithPty executes a command with PTY allocation func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { termType := ptyReq.Term if termType == "" { termType = "xterm-256color" } execCmd.Env = append(execCmd.Env, fmt.Sprintf("TERM=%s", termType)) return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh) } func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session) if err != nil { logger.Errorf("Pty command creation failed: %v", err) errorMsg := "User switching failed - login command not available\r\n" if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil { logger.Debugf(errWriteSession, writeErr) } if err := session.Exit(1); err != nil { logSessionExitError(logger, err) } return false } logger.Infof("starting interactive shell: %s", strings.Join(execCmd.Args, " ")) return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh) } // runPtyCommand runs a command with PTY management (common code for interactive and command execution) func (s *Server) runPtyCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { ptmx, err := s.startPtyCommandWithSize(execCmd, ptyReq) if err != nil { logger.Errorf("Pty start failed: %v", err) if err := session.Exit(1); err != nil { logSessionExitError(logger, err) } return false } ptyMgr := newPtyManager(ptmx) defer func() { if err := ptyMgr.Close(); err != nil { logger.Debugf("Pty close error: %v", err) } }() go s.handlePtyWindowResize(logger, session, ptyMgr, winCh) s.handlePtyIO(logger, session, ptyMgr) s.waitForPtyCompletion(logger, session, execCmd, ptyMgr) return true } func (s *Server) startPtyCommandWithSize(execCmd *exec.Cmd, ptyReq ssh.Pty) (*os.File, error) { winSize := &pty.Winsize{ Cols: uint16(ptyReq.Window.Width), Rows: uint16(ptyReq.Window.Height), } if winSize.Cols == 0 { winSize.Cols = 80 } if winSize.Rows == 0 { winSize.Rows = 24 } ptmx, err := pty.StartWithSize(execCmd, winSize) if err != nil { return nil, fmt.Errorf("start Pty: %w", err) } return ptmx, nil } func (s *Server) handlePtyWindowResize(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, winCh <-chan ssh.Window) { for { select { case <-session.Context().Done(): return case win, ok := <-winCh: if !ok { return } if err := ptyMgr.Setsize(&pty.Winsize{Rows: uint16(win.Height), Cols: uint16(win.Width)}); err != nil { logger.Debugf("Pty resize to %dx%d: %v", win.Width, win.Height, err) } } } } func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager) { ptmx := ptyMgr.File() go func() { if _, err := io.Copy(ptmx, session); err != nil { if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) { logger.Warnf("Pty input copy error: %v", err) } } }() go func() { defer func() { if err := session.Close(); err != nil && !errors.Is(err, io.EOF) { logger.Debugf("session close error: %v", err) } }() if _, err := io.Copy(session, ptmx); err != nil { if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) { logger.Warnf("Pty output copy error: %v", err) } } }() } func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager) { ctx := session.Context() done := make(chan error, 1) go func() { done <- execCmd.Wait() }() select { case <-ctx.Done(): s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done) case err := <-done: s.handlePtyCommandCompletion(logger, session, err) } } func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager, done <-chan error) { logger.Debugf("Pty session cancelled, terminating command") if err := ptyMgr.Close(); err != nil { logger.Debugf("Pty close during session cancellation: %v", err) } s.killProcessGroup(execCmd) select { case err := <-done: if err != nil { logger.Debugf("Pty command terminated after session cancellation with error: %v", err) } else { logger.Debugf("Pty command terminated after session cancellation") } case <-time.After(5 * time.Second): logger.Warnf("Pty command did not terminate within 5 seconds after session cancellation") } if err := session.Exit(130); err != nil { logSessionExitError(logger, err) } } func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) { if err != nil { logger.Debugf("Pty command execution failed: %v", err) s.handleSessionExit(session, err, logger) return } // Normal completion logger.Debugf("Pty command completed successfully") if err := session.Exit(0); err != nil { logSessionExitError(logger, err) } } func (s *Server) setupProcessGroup(cmd *exec.Cmd) { cmd.SysProcAttr = &syscall.SysProcAttr{ Setpgid: true, } } func (s *Server) killProcessGroup(cmd *exec.Cmd) { if cmd.Process == nil { return } logger := log.WithField("pid", cmd.Process.Pid) pgid := cmd.Process.Pid if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil { logger.Debugf("kill process group SIGTERM: %v", err) return } const gracePeriod = 500 * time.Millisecond const checkInterval = 50 * time.Millisecond ticker := time.NewTicker(checkInterval) defer ticker.Stop() timeout := time.After(gracePeriod) for { select { case <-timeout: if err := syscall.Kill(-pgid, syscall.SIGKILL); err != nil { logger.Debugf("kill process group SIGKILL: %v", err) } return case <-ticker.C: if err := syscall.Kill(-pgid, 0); err != nil { return } } } }