Complete overhaul

This commit is contained in:
Viktor Liu
2025-06-24 12:19:53 +02:00
parent f56075ca15
commit 9d1554f9f7
74 changed files with 16626 additions and 4524 deletions

View File

@@ -0,0 +1,466 @@
//go:build windows
package winpty
import (
"context"
"fmt"
"io"
"strings"
"sync"
"syscall"
"unsafe"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
extendedStartupInfoPresent = 0x00080000
createUnicodeEnvironment = 0x00000400
procThreadAttributePseudoConsole = 0x00020016
PowerShellCommandFlag = "-Command"
errCloseInputRead = "close input read handle: %v"
errCloseConPtyCleanup = "close ConPty handle during cleanup"
)
// PtyConfig holds configuration for Pty execution.
type PtyConfig struct {
Shell string
Command string
Width int
Height int
WorkingDir string
}
// UserConfig holds user execution configuration.
type UserConfig struct {
Token windows.Handle
Environment []string
}
var (
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole")
procInitializeProcThreadAttributeList = kernel32.NewProc("InitializeProcThreadAttributeList")
procUpdateProcThreadAttribute = kernel32.NewProc("UpdateProcThreadAttribute")
procDeleteProcThreadAttributeList = kernel32.NewProc("DeleteProcThreadAttributeList")
)
// ExecutePtyWithUserToken executes a command with ConPty using user token.
func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error {
args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command)
commandLine := buildCommandLine(args)
config := ExecutionConfig{
Pty: ptyConfig,
User: userConfig,
Session: session,
Context: ctx,
}
return executeConPtyWithConfig(commandLine, config)
}
// ExecutionConfig holds all execution configuration.
type ExecutionConfig struct {
Pty PtyConfig
User UserConfig
Session ssh.Session
Context context.Context
}
// executeConPtyWithConfig creates ConPty and executes process with configuration.
func executeConPtyWithConfig(commandLine string, config ExecutionConfig) error {
ctx := config.Context
session := config.Session
width := config.Pty.Width
height := config.Pty.Height
userToken := config.User.Token
userEnv := config.User.Environment
workingDir := config.Pty.WorkingDir
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
if err != nil {
return fmt.Errorf("create ConPty pipes: %w", err)
}
hPty, err := createConPty(width, height, inputRead, outputWrite)
if err != nil {
return fmt.Errorf("create ConPty: %w", err)
}
primaryToken, err := duplicateToPrimaryToken(userToken)
if err != nil {
if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
log.Debugf(errCloseConPtyCleanup)
}
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
return fmt.Errorf("duplicate to primary token: %w", err)
}
defer func() {
if err := windows.CloseHandle(primaryToken); err != nil {
log.Debugf("close primary token: %v", err)
}
}()
siEx, err := setupConPtyStartupInfo(hPty)
if err != nil {
if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
log.Debugf(errCloseConPtyCleanup)
}
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
return fmt.Errorf("setup startup info: %w", err)
}
defer func() {
_, _, _ = procDeleteProcThreadAttributeList.Call(uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)))
}()
pi, err := createConPtyProcess(commandLine, primaryToken, userEnv, workingDir, siEx)
if err != nil {
if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
log.Debugf(errCloseConPtyCleanup)
}
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
return fmt.Errorf("create process as user with ConPty: %w", err)
}
defer closeProcessInfo(pi)
if err := windows.CloseHandle(inputRead); err != nil {
log.Debugf(errCloseInputRead, err)
}
if err := windows.CloseHandle(outputWrite); err != nil {
log.Debugf("close output write handle: %v", err)
}
return bridgeConPtyIO(ctx, hPty, inputWrite, outputRead, session, session, pi.Process)
}
// createConPtyPipes creates input/output pipes for ConPty.
func createConPtyPipes() (inputRead, inputWrite, outputRead, outputWrite windows.Handle, err error) {
if err := windows.CreatePipe(&inputRead, &inputWrite, nil, 0); err != nil {
return 0, 0, 0, 0, fmt.Errorf("create input pipe: %w", err)
}
if err := windows.CreatePipe(&outputRead, &outputWrite, nil, 0); err != nil {
if closeErr := windows.CloseHandle(inputRead); closeErr != nil {
log.Debugf(errCloseInputRead, closeErr)
}
if closeErr := windows.CloseHandle(inputWrite); closeErr != nil {
log.Debugf("close input write handle: %v", closeErr)
}
return 0, 0, 0, 0, fmt.Errorf("create output pipe: %w", err)
}
return inputRead, inputWrite, outputRead, outputWrite, nil
}
// createConPty creates a Windows ConPty with the specified size and pipe handles.
func createConPty(width, height int, inputRead, outputWrite windows.Handle) (windows.Handle, error) {
size := windows.Coord{X: int16(width), Y: int16(height)}
var hPty windows.Handle
if err := windows.CreatePseudoConsole(size, inputRead, outputWrite, 0, &hPty); err != nil {
return 0, fmt.Errorf("CreatePseudoConsole: %w", err)
}
return hPty, nil
}
// setupConPtyStartupInfo prepares the STARTUPINFOEX with ConPty attributes.
func setupConPtyStartupInfo(hPty windows.Handle) (*windows.StartupInfoEx, error) {
var siEx windows.StartupInfoEx
siEx.StartupInfo.Cb = uint32(unsafe.Sizeof(siEx))
var attrListSize uintptr
ret, _, _ := procInitializeProcThreadAttributeList.Call(0, 1, 0, uintptr(unsafe.Pointer(&attrListSize)))
if ret == 0 && attrListSize == 0 {
return nil, fmt.Errorf("get attribute list size")
}
attrListBytes := make([]byte, attrListSize)
siEx.ProcThreadAttributeList = (*windows.ProcThreadAttributeList)(unsafe.Pointer(&attrListBytes[0]))
ret, _, err := procInitializeProcThreadAttributeList.Call(
uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)),
1,
0,
uintptr(unsafe.Pointer(&attrListSize)),
)
if ret == 0 {
return nil, fmt.Errorf("initialize attribute list: %w", err)
}
ret, _, err = procUpdateProcThreadAttribute.Call(
uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)),
0,
procThreadAttributePseudoConsole,
uintptr(hPty),
unsafe.Sizeof(hPty),
0,
0,
)
if ret == 0 {
return nil, fmt.Errorf("update thread attribute: %w", err)
}
return &siEx, nil
}
// createConPtyProcess creates the actual process with ConPty.
func createConPtyProcess(commandLine string, userToken windows.Handle, userEnv []string, workingDir string, siEx *windows.StartupInfoEx) (*windows.ProcessInformation, error) {
var pi windows.ProcessInformation
creationFlags := uint32(extendedStartupInfoPresent | createUnicodeEnvironment)
commandLinePtr, err := windows.UTF16PtrFromString(commandLine)
if err != nil {
return nil, fmt.Errorf("convert command line to UTF16: %w", err)
}
envPtr, err := convertEnvironmentToUTF16(userEnv)
if err != nil {
return nil, err
}
var workingDirPtr *uint16
if workingDir != "" {
workingDirPtr, err = windows.UTF16PtrFromString(workingDir)
if err != nil {
return nil, fmt.Errorf("convert working directory to UTF16: %w", err)
}
}
siEx.StartupInfo.Flags |= windows.STARTF_USESTDHANDLES
siEx.StartupInfo.StdInput = windows.Handle(0)
siEx.StartupInfo.StdOutput = windows.Handle(0)
siEx.StartupInfo.StdErr = siEx.StartupInfo.StdOutput
if userToken != windows.InvalidHandle {
err = windows.CreateProcessAsUser(
windows.Token(userToken),
nil,
commandLinePtr,
nil,
nil,
true,
creationFlags,
envPtr,
workingDirPtr,
&siEx.StartupInfo,
&pi,
)
} else {
err = windows.CreateProcess(
nil,
commandLinePtr,
nil,
nil,
true,
creationFlags,
envPtr,
workingDirPtr,
&siEx.StartupInfo,
&pi,
)
}
if err != nil {
return nil, fmt.Errorf("create process: %w", err)
}
return &pi, nil
}
// convertEnvironmentToUTF16 converts environment variables to Windows UTF16 format.
func convertEnvironmentToUTF16(userEnv []string) (*uint16, error) {
if len(userEnv) == 0 {
return nil, nil
}
var envUTF16 []uint16
for _, envVar := range userEnv {
if envVar != "" {
utf16Str, err := windows.UTF16FromString(envVar)
if err != nil {
log.Debugf("skipping invalid environment variable: %s (error: %v)", envVar, err)
continue
}
envUTF16 = append(envUTF16, utf16Str[:len(utf16Str)-1]...)
envUTF16 = append(envUTF16, 0)
}
}
envUTF16 = append(envUTF16, 0)
if len(envUTF16) > 0 {
return &envUTF16[0], nil
}
return nil, nil
}
// duplicateToPrimaryToken converts an impersonation token to a primary token.
func duplicateToPrimaryToken(token windows.Handle) (windows.Handle, error) {
var primaryToken windows.Handle
if err := windows.DuplicateTokenEx(
windows.Token(token),
windows.TOKEN_ALL_ACCESS,
nil,
windows.SecurityImpersonation,
windows.TokenPrimary,
(*windows.Token)(&primaryToken),
); err != nil {
return 0, fmt.Errorf("duplicate token: %w", err)
}
return primaryToken, nil
}
// bridgeConPtyIO handles I/O bridging between ConPty and readers/writers.
func bridgeConPtyIO(ctx context.Context, hPty, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer, process windows.Handle) error {
if err := ctx.Err(); err != nil {
return err
}
var wg sync.WaitGroup
startIOBridging(ctx, &wg, inputWrite, outputRead, reader, writer)
processErr := waitForProcess(ctx, process)
if processErr != nil {
return processErr
}
// Clean up in the original order after process completes
if err := reader.Close(); err != nil {
log.Debugf("close reader: %v", err)
}
ret, _, err := procClosePseudoConsole.Call(uintptr(hPty))
if ret == 0 {
log.Debugf("close ConPty handle: %v", err)
}
wg.Wait()
if err := windows.CloseHandle(outputRead); err != nil {
log.Debugf("close output read handle: %v", err)
}
return nil
}
// startIOBridging starts the I/O bridging goroutines.
func startIOBridging(ctx context.Context, wg *sync.WaitGroup, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer) {
wg.Add(2)
// Input: reader (SSH session) -> inputWrite (ConPty)
go func() {
defer wg.Done()
defer func() {
if err := windows.CloseHandle(inputWrite); err != nil {
log.Debugf("close input write handle in goroutine: %v", err)
}
}()
if _, err := io.Copy(&windowsHandleWriter{handle: inputWrite}, reader); err != nil {
log.Debugf("input copy ended with error: %v", err)
}
}()
// Output: outputRead (ConPty) -> writer (SSH session)
go func() {
defer wg.Done()
if _, err := io.Copy(writer, &windowsHandleReader{handle: outputRead}); err != nil {
log.Debugf("output copy ended with error: %v", err)
}
}()
}
// waitForProcess waits for process completion with context cancellation.
func waitForProcess(ctx context.Context, process windows.Handle) error {
if _, err := windows.WaitForSingleObject(process, windows.INFINITE); err != nil {
return fmt.Errorf("wait for process %d: %w", process, err)
}
return nil
}
// buildShellArgs builds shell arguments for ConPty execution.
func buildShellArgs(shell, command string) []string {
if command != "" {
return []string{shell, PowerShellCommandFlag, command}
}
return []string{shell}
}
// buildCommandLine builds a Windows command line from arguments using proper escaping.
func buildCommandLine(args []string) string {
if len(args) == 0 {
return ""
}
var result strings.Builder
for i, arg := range args {
if i > 0 {
result.WriteString(" ")
}
result.WriteString(syscall.EscapeArg(arg))
}
return result.String()
}
// closeHandles closes multiple Windows handles.
func closeHandles(handles ...windows.Handle) {
for _, handle := range handles {
if handle != windows.InvalidHandle {
if err := windows.CloseHandle(handle); err != nil {
log.Debugf("close handle: %v", err)
}
}
}
}
// closeProcessInfo closes process and thread handles.
func closeProcessInfo(pi *windows.ProcessInformation) {
if pi != nil {
if err := windows.CloseHandle(pi.Process); err != nil {
log.Debugf("close process handle: %v", err)
}
if err := windows.CloseHandle(pi.Thread); err != nil {
log.Debugf("close thread handle: %v", err)
}
}
}
// windowsHandleReader wraps a Windows handle for reading.
type windowsHandleReader struct {
handle windows.Handle
}
func (r *windowsHandleReader) Read(p []byte) (n int, err error) {
var bytesRead uint32
if err := windows.ReadFile(r.handle, p, &bytesRead, nil); err != nil {
return 0, err
}
return int(bytesRead), nil
}
func (r *windowsHandleReader) Close() error {
return windows.CloseHandle(r.handle)
}
// windowsHandleWriter wraps a Windows handle for writing.
type windowsHandleWriter struct {
handle windows.Handle
}
func (w *windowsHandleWriter) Write(p []byte) (n int, err error) {
var bytesWritten uint32
if err := windows.WriteFile(w.handle, p, &bytesWritten, nil); err != nil {
return 0, err
}
return int(bytesWritten), nil
}
func (w *windowsHandleWriter) Close() error {
return windows.CloseHandle(w.handle)
}

View File

@@ -0,0 +1,286 @@
//go:build windows
package winpty
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/windows"
)
func TestBuildShellArgs(t *testing.T) {
tests := []struct {
name string
shell string
command string
expected []string
}{
{
name: "Shell with command",
shell: "powershell.exe",
command: "Get-Process",
expected: []string{"powershell.exe", "-Command", "Get-Process"},
},
{
name: "CMD with command",
shell: "cmd.exe",
command: "dir",
expected: []string{"cmd.exe", "-Command", "dir"},
},
{
name: "Shell interactive",
shell: "powershell.exe",
command: "",
expected: []string{"powershell.exe"},
},
{
name: "CMD interactive",
shell: "cmd.exe",
command: "",
expected: []string{"cmd.exe"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildShellArgs(tt.shell, tt.command)
assert.Equal(t, tt.expected, result)
})
}
}
func TestBuildCommandLine(t *testing.T) {
tests := []struct {
name string
args []string
expected string
}{
{
name: "Simple args",
args: []string{"cmd.exe", "/c", "echo"},
expected: "cmd.exe /c echo",
},
{
name: "Args with spaces",
args: []string{"Program Files\\app.exe", "arg with spaces"},
expected: `"Program Files\app.exe" "arg with spaces"`,
},
{
name: "Args with quotes",
args: []string{"cmd.exe", "/c", `echo "hello world"`},
expected: `cmd.exe /c "echo \"hello world\""`,
},
{
name: "PowerShell calling PowerShell",
args: []string{"powershell.exe", "-Command", `powershell.exe -Command "Get-Process | Where-Object {$_.Name -eq 'notepad'}"`},
expected: `powershell.exe -Command "powershell.exe -Command \"Get-Process | Where-Object {$_.Name -eq 'notepad'}\""`,
},
{
name: "Complex nested quotes",
args: []string{"cmd.exe", "/c", `echo "He said \"Hello\" to me"`},
expected: `cmd.exe /c "echo \"He said \\\"Hello\\\" to me\""`,
},
{
name: "Path with spaces and args",
args: []string{`C:\Program Files\MyApp\app.exe`, "--config", `C:\My Config\settings.json`},
expected: `"C:\Program Files\MyApp\app.exe" --config "C:\My Config\settings.json"`,
},
{
name: "Empty argument",
args: []string{"cmd.exe", "/c", "echo", ""},
expected: `cmd.exe /c echo ""`,
},
{
name: "Argument with backslashes",
args: []string{"robocopy", `C:\Source\`, `C:\Dest\`, "/E"},
expected: `robocopy C:\Source\ C:\Dest\ /E`,
},
{
name: "Empty args",
args: []string{},
expected: "",
},
{
name: "Single arg with space",
args: []string{"path with spaces"},
expected: `"path with spaces"`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildCommandLine(tt.args)
assert.Equal(t, tt.expected, result)
})
}
}
func TestCreateConPtyPipes(t *testing.T) {
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
require.NoError(t, err, "Should create ConPty pipes successfully")
// Verify all handles are valid
assert.NotEqual(t, windows.InvalidHandle, inputRead, "Input read handle should be valid")
assert.NotEqual(t, windows.InvalidHandle, inputWrite, "Input write handle should be valid")
assert.NotEqual(t, windows.InvalidHandle, outputRead, "Output read handle should be valid")
assert.NotEqual(t, windows.InvalidHandle, outputWrite, "Output write handle should be valid")
// Clean up handles
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
}
func TestCreateConPty(t *testing.T) {
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
require.NoError(t, err, "Should create ConPty pipes successfully")
defer closeHandles(inputRead, inputWrite, outputRead, outputWrite)
hPty, err := createConPty(80, 24, inputRead, outputWrite)
require.NoError(t, err, "Should create ConPty successfully")
assert.NotEqual(t, windows.InvalidHandle, hPty, "ConPty handle should be valid")
// Clean up ConPty
ret, _, _ := procClosePseudoConsole.Call(uintptr(hPty))
assert.NotEqual(t, uintptr(0), ret, "Should close ConPty successfully")
}
func TestConvertEnvironmentToUTF16(t *testing.T) {
tests := []struct {
name string
userEnv []string
hasError bool
}{
{
name: "Valid environment variables",
userEnv: []string{"PATH=C:\\Windows", "USER=testuser", "HOME=C:\\Users\\testuser"},
hasError: false,
},
{
name: "Empty environment",
userEnv: []string{},
hasError: false,
},
{
name: "Environment with empty strings",
userEnv: []string{"PATH=C:\\Windows", "", "USER=testuser"},
hasError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := convertEnvironmentToUTF16(tt.userEnv)
if tt.hasError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if len(tt.userEnv) == 0 {
assert.Nil(t, result, "Empty environment should return nil")
} else {
assert.NotNil(t, result, "Non-empty environment should return valid pointer")
}
}
})
}
}
func TestDuplicateToPrimaryToken(t *testing.T) {
if testing.Short() {
t.Skip("Skipping token tests in short mode")
}
// Get current process token for testing
var token windows.Token
err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_ALL_ACCESS, &token)
require.NoError(t, err, "Should open current process token")
defer func() {
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
t.Logf("Failed to close token: %v", err)
}
}()
primaryToken, err := duplicateToPrimaryToken(windows.Handle(token))
require.NoError(t, err, "Should duplicate token to primary")
assert.NotEqual(t, windows.InvalidHandle, primaryToken, "Primary token should be valid")
// Clean up
err = windows.CloseHandle(primaryToken)
assert.NoError(t, err, "Should close primary token")
}
func TestWindowsHandleReader(t *testing.T) {
// Create a pipe for testing
var readHandle, writeHandle windows.Handle
err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0)
require.NoError(t, err, "Should create pipe for testing")
defer closeHandles(readHandle, writeHandle)
// Write test data
testData := []byte("Hello, Windows Handle Reader!")
var bytesWritten uint32
err = windows.WriteFile(writeHandle, testData, &bytesWritten, nil)
require.NoError(t, err, "Should write test data")
require.Equal(t, uint32(len(testData)), bytesWritten, "Should write all test data")
// Close write handle to signal EOF
if err := windows.CloseHandle(writeHandle); err != nil {
t.Fatalf("Should close write handle: %v", err)
}
// Test reading
reader := &windowsHandleReader{handle: readHandle}
buffer := make([]byte, len(testData))
n, err := reader.Read(buffer)
require.NoError(t, err, "Should read from handle")
assert.Equal(t, len(testData), n, "Should read expected number of bytes")
assert.Equal(t, testData, buffer, "Should read expected data")
}
func TestWindowsHandleWriter(t *testing.T) {
// Create a pipe for testing
var readHandle, writeHandle windows.Handle
err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0)
require.NoError(t, err, "Should create pipe for testing")
defer closeHandles(readHandle, writeHandle)
// Test writing
testData := []byte("Hello, Windows Handle Writer!")
writer := &windowsHandleWriter{handle: writeHandle}
n, err := writer.Write(testData)
require.NoError(t, err, "Should write to handle")
assert.Equal(t, len(testData), n, "Should write expected number of bytes")
// Close write handle
if err := windows.CloseHandle(writeHandle); err != nil {
t.Fatalf("Should close write handle: %v", err)
}
// Verify data was written by reading it back
buffer := make([]byte, len(testData))
var bytesRead uint32
err = windows.ReadFile(readHandle, buffer, &bytesRead, nil)
require.NoError(t, err, "Should read back written data")
assert.Equal(t, uint32(len(testData)), bytesRead, "Should read back expected number of bytes")
assert.Equal(t, testData, buffer, "Should read back expected data")
}
// BenchmarkConPtyCreation benchmarks ConPty creation performance
func BenchmarkConPtyCreation(b *testing.B) {
for i := 0; i < b.N; i++ {
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
if err != nil {
b.Fatal(err)
}
hPty, err := createConPty(80, 24, inputRead, outputWrite)
if err != nil {
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
b.Fatal(err)
}
// Clean up
procClosePseudoConsole.Call(uintptr(hPty))
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
}
}