mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 17:26:40 +00:00
Complete overhaul
This commit is contained in:
466
client/ssh/server/winpty/conpty.go
Normal file
466
client/ssh/server/winpty/conpty.go
Normal file
@@ -0,0 +1,466 @@
|
||||
//go:build windows
|
||||
|
||||
package winpty
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
extendedStartupInfoPresent = 0x00080000
|
||||
createUnicodeEnvironment = 0x00000400
|
||||
procThreadAttributePseudoConsole = 0x00020016
|
||||
|
||||
PowerShellCommandFlag = "-Command"
|
||||
|
||||
errCloseInputRead = "close input read handle: %v"
|
||||
errCloseConPtyCleanup = "close ConPty handle during cleanup"
|
||||
)
|
||||
|
||||
// PtyConfig holds configuration for Pty execution.
|
||||
type PtyConfig struct {
|
||||
Shell string
|
||||
Command string
|
||||
Width int
|
||||
Height int
|
||||
WorkingDir string
|
||||
}
|
||||
|
||||
// UserConfig holds user execution configuration.
|
||||
type UserConfig struct {
|
||||
Token windows.Handle
|
||||
Environment []string
|
||||
}
|
||||
|
||||
var (
|
||||
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole")
|
||||
procInitializeProcThreadAttributeList = kernel32.NewProc("InitializeProcThreadAttributeList")
|
||||
procUpdateProcThreadAttribute = kernel32.NewProc("UpdateProcThreadAttribute")
|
||||
procDeleteProcThreadAttributeList = kernel32.NewProc("DeleteProcThreadAttributeList")
|
||||
)
|
||||
|
||||
// ExecutePtyWithUserToken executes a command with ConPty using user token.
|
||||
func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error {
|
||||
args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command)
|
||||
commandLine := buildCommandLine(args)
|
||||
|
||||
config := ExecutionConfig{
|
||||
Pty: ptyConfig,
|
||||
User: userConfig,
|
||||
Session: session,
|
||||
Context: ctx,
|
||||
}
|
||||
|
||||
return executeConPtyWithConfig(commandLine, config)
|
||||
}
|
||||
|
||||
// ExecutionConfig holds all execution configuration.
|
||||
type ExecutionConfig struct {
|
||||
Pty PtyConfig
|
||||
User UserConfig
|
||||
Session ssh.Session
|
||||
Context context.Context
|
||||
}
|
||||
|
||||
// executeConPtyWithConfig creates ConPty and executes process with configuration.
|
||||
func executeConPtyWithConfig(commandLine string, config ExecutionConfig) error {
|
||||
ctx := config.Context
|
||||
session := config.Session
|
||||
width := config.Pty.Width
|
||||
height := config.Pty.Height
|
||||
userToken := config.User.Token
|
||||
userEnv := config.User.Environment
|
||||
workingDir := config.Pty.WorkingDir
|
||||
|
||||
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create ConPty pipes: %w", err)
|
||||
}
|
||||
|
||||
hPty, err := createConPty(width, height, inputRead, outputWrite)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create ConPty: %w", err)
|
||||
}
|
||||
|
||||
primaryToken, err := duplicateToPrimaryToken(userToken)
|
||||
if err != nil {
|
||||
if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
|
||||
log.Debugf(errCloseConPtyCleanup)
|
||||
}
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
return fmt.Errorf("duplicate to primary token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(primaryToken); err != nil {
|
||||
log.Debugf("close primary token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
siEx, err := setupConPtyStartupInfo(hPty)
|
||||
if err != nil {
|
||||
if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
|
||||
log.Debugf(errCloseConPtyCleanup)
|
||||
}
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
return fmt.Errorf("setup startup info: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_, _, _ = procDeleteProcThreadAttributeList.Call(uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)))
|
||||
}()
|
||||
|
||||
pi, err := createConPtyProcess(commandLine, primaryToken, userEnv, workingDir, siEx)
|
||||
if err != nil {
|
||||
if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
|
||||
log.Debugf(errCloseConPtyCleanup)
|
||||
}
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
return fmt.Errorf("create process as user with ConPty: %w", err)
|
||||
}
|
||||
defer closeProcessInfo(pi)
|
||||
|
||||
if err := windows.CloseHandle(inputRead); err != nil {
|
||||
log.Debugf(errCloseInputRead, err)
|
||||
}
|
||||
if err := windows.CloseHandle(outputWrite); err != nil {
|
||||
log.Debugf("close output write handle: %v", err)
|
||||
}
|
||||
|
||||
return bridgeConPtyIO(ctx, hPty, inputWrite, outputRead, session, session, pi.Process)
|
||||
}
|
||||
|
||||
// createConPtyPipes creates input/output pipes for ConPty.
|
||||
func createConPtyPipes() (inputRead, inputWrite, outputRead, outputWrite windows.Handle, err error) {
|
||||
if err := windows.CreatePipe(&inputRead, &inputWrite, nil, 0); err != nil {
|
||||
return 0, 0, 0, 0, fmt.Errorf("create input pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := windows.CreatePipe(&outputRead, &outputWrite, nil, 0); err != nil {
|
||||
if closeErr := windows.CloseHandle(inputRead); closeErr != nil {
|
||||
log.Debugf(errCloseInputRead, closeErr)
|
||||
}
|
||||
if closeErr := windows.CloseHandle(inputWrite); closeErr != nil {
|
||||
log.Debugf("close input write handle: %v", closeErr)
|
||||
}
|
||||
return 0, 0, 0, 0, fmt.Errorf("create output pipe: %w", err)
|
||||
}
|
||||
|
||||
return inputRead, inputWrite, outputRead, outputWrite, nil
|
||||
}
|
||||
|
||||
// createConPty creates a Windows ConPty with the specified size and pipe handles.
|
||||
func createConPty(width, height int, inputRead, outputWrite windows.Handle) (windows.Handle, error) {
|
||||
size := windows.Coord{X: int16(width), Y: int16(height)}
|
||||
|
||||
var hPty windows.Handle
|
||||
if err := windows.CreatePseudoConsole(size, inputRead, outputWrite, 0, &hPty); err != nil {
|
||||
return 0, fmt.Errorf("CreatePseudoConsole: %w", err)
|
||||
}
|
||||
|
||||
return hPty, nil
|
||||
}
|
||||
|
||||
// setupConPtyStartupInfo prepares the STARTUPINFOEX with ConPty attributes.
|
||||
func setupConPtyStartupInfo(hPty windows.Handle) (*windows.StartupInfoEx, error) {
|
||||
var siEx windows.StartupInfoEx
|
||||
siEx.StartupInfo.Cb = uint32(unsafe.Sizeof(siEx))
|
||||
|
||||
var attrListSize uintptr
|
||||
ret, _, _ := procInitializeProcThreadAttributeList.Call(0, 1, 0, uintptr(unsafe.Pointer(&attrListSize)))
|
||||
if ret == 0 && attrListSize == 0 {
|
||||
return nil, fmt.Errorf("get attribute list size")
|
||||
}
|
||||
|
||||
attrListBytes := make([]byte, attrListSize)
|
||||
siEx.ProcThreadAttributeList = (*windows.ProcThreadAttributeList)(unsafe.Pointer(&attrListBytes[0]))
|
||||
|
||||
ret, _, err := procInitializeProcThreadAttributeList.Call(
|
||||
uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)),
|
||||
1,
|
||||
0,
|
||||
uintptr(unsafe.Pointer(&attrListSize)),
|
||||
)
|
||||
if ret == 0 {
|
||||
return nil, fmt.Errorf("initialize attribute list: %w", err)
|
||||
}
|
||||
|
||||
ret, _, err = procUpdateProcThreadAttribute.Call(
|
||||
uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)),
|
||||
0,
|
||||
procThreadAttributePseudoConsole,
|
||||
uintptr(hPty),
|
||||
unsafe.Sizeof(hPty),
|
||||
0,
|
||||
0,
|
||||
)
|
||||
if ret == 0 {
|
||||
return nil, fmt.Errorf("update thread attribute: %w", err)
|
||||
}
|
||||
|
||||
return &siEx, nil
|
||||
}
|
||||
|
||||
// createConPtyProcess creates the actual process with ConPty.
|
||||
func createConPtyProcess(commandLine string, userToken windows.Handle, userEnv []string, workingDir string, siEx *windows.StartupInfoEx) (*windows.ProcessInformation, error) {
|
||||
var pi windows.ProcessInformation
|
||||
creationFlags := uint32(extendedStartupInfoPresent | createUnicodeEnvironment)
|
||||
|
||||
commandLinePtr, err := windows.UTF16PtrFromString(commandLine)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert command line to UTF16: %w", err)
|
||||
}
|
||||
|
||||
envPtr, err := convertEnvironmentToUTF16(userEnv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var workingDirPtr *uint16
|
||||
if workingDir != "" {
|
||||
workingDirPtr, err = windows.UTF16PtrFromString(workingDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert working directory to UTF16: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
siEx.StartupInfo.Flags |= windows.STARTF_USESTDHANDLES
|
||||
siEx.StartupInfo.StdInput = windows.Handle(0)
|
||||
siEx.StartupInfo.StdOutput = windows.Handle(0)
|
||||
siEx.StartupInfo.StdErr = siEx.StartupInfo.StdOutput
|
||||
|
||||
if userToken != windows.InvalidHandle {
|
||||
err = windows.CreateProcessAsUser(
|
||||
windows.Token(userToken),
|
||||
nil,
|
||||
commandLinePtr,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
creationFlags,
|
||||
envPtr,
|
||||
workingDirPtr,
|
||||
&siEx.StartupInfo,
|
||||
&pi,
|
||||
)
|
||||
} else {
|
||||
err = windows.CreateProcess(
|
||||
nil,
|
||||
commandLinePtr,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
creationFlags,
|
||||
envPtr,
|
||||
workingDirPtr,
|
||||
&siEx.StartupInfo,
|
||||
&pi,
|
||||
)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create process: %w", err)
|
||||
}
|
||||
|
||||
return &pi, nil
|
||||
}
|
||||
|
||||
// convertEnvironmentToUTF16 converts environment variables to Windows UTF16 format.
|
||||
func convertEnvironmentToUTF16(userEnv []string) (*uint16, error) {
|
||||
if len(userEnv) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var envUTF16 []uint16
|
||||
for _, envVar := range userEnv {
|
||||
if envVar != "" {
|
||||
utf16Str, err := windows.UTF16FromString(envVar)
|
||||
if err != nil {
|
||||
log.Debugf("skipping invalid environment variable: %s (error: %v)", envVar, err)
|
||||
continue
|
||||
}
|
||||
envUTF16 = append(envUTF16, utf16Str[:len(utf16Str)-1]...)
|
||||
envUTF16 = append(envUTF16, 0)
|
||||
}
|
||||
}
|
||||
envUTF16 = append(envUTF16, 0)
|
||||
|
||||
if len(envUTF16) > 0 {
|
||||
return &envUTF16[0], nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// duplicateToPrimaryToken converts an impersonation token to a primary token.
|
||||
func duplicateToPrimaryToken(token windows.Handle) (windows.Handle, error) {
|
||||
var primaryToken windows.Handle
|
||||
if err := windows.DuplicateTokenEx(
|
||||
windows.Token(token),
|
||||
windows.TOKEN_ALL_ACCESS,
|
||||
nil,
|
||||
windows.SecurityImpersonation,
|
||||
windows.TokenPrimary,
|
||||
(*windows.Token)(&primaryToken),
|
||||
); err != nil {
|
||||
return 0, fmt.Errorf("duplicate token: %w", err)
|
||||
}
|
||||
return primaryToken, nil
|
||||
}
|
||||
|
||||
// bridgeConPtyIO handles I/O bridging between ConPty and readers/writers.
|
||||
func bridgeConPtyIO(ctx context.Context, hPty, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer, process windows.Handle) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
startIOBridging(ctx, &wg, inputWrite, outputRead, reader, writer)
|
||||
|
||||
processErr := waitForProcess(ctx, process)
|
||||
if processErr != nil {
|
||||
return processErr
|
||||
}
|
||||
|
||||
// Clean up in the original order after process completes
|
||||
if err := reader.Close(); err != nil {
|
||||
log.Debugf("close reader: %v", err)
|
||||
}
|
||||
|
||||
ret, _, err := procClosePseudoConsole.Call(uintptr(hPty))
|
||||
if ret == 0 {
|
||||
log.Debugf("close ConPty handle: %v", err)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if err := windows.CloseHandle(outputRead); err != nil {
|
||||
log.Debugf("close output read handle: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// startIOBridging starts the I/O bridging goroutines.
|
||||
func startIOBridging(ctx context.Context, wg *sync.WaitGroup, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer) {
|
||||
wg.Add(2)
|
||||
|
||||
// Input: reader (SSH session) -> inputWrite (ConPty)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(inputWrite); err != nil {
|
||||
log.Debugf("close input write handle in goroutine: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := io.Copy(&windowsHandleWriter{handle: inputWrite}, reader); err != nil {
|
||||
log.Debugf("input copy ended with error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Output: outputRead (ConPty) -> writer (SSH session)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if _, err := io.Copy(writer, &windowsHandleReader{handle: outputRead}); err != nil {
|
||||
log.Debugf("output copy ended with error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// waitForProcess waits for process completion with context cancellation.
|
||||
func waitForProcess(ctx context.Context, process windows.Handle) error {
|
||||
if _, err := windows.WaitForSingleObject(process, windows.INFINITE); err != nil {
|
||||
return fmt.Errorf("wait for process %d: %w", process, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildShellArgs builds shell arguments for ConPty execution.
|
||||
func buildShellArgs(shell, command string) []string {
|
||||
if command != "" {
|
||||
return []string{shell, PowerShellCommandFlag, command}
|
||||
}
|
||||
return []string{shell}
|
||||
}
|
||||
|
||||
// buildCommandLine builds a Windows command line from arguments using proper escaping.
|
||||
func buildCommandLine(args []string) string {
|
||||
if len(args) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
for i, arg := range args {
|
||||
if i > 0 {
|
||||
result.WriteString(" ")
|
||||
}
|
||||
result.WriteString(syscall.EscapeArg(arg))
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// closeHandles closes multiple Windows handles.
|
||||
func closeHandles(handles ...windows.Handle) {
|
||||
for _, handle := range handles {
|
||||
if handle != windows.InvalidHandle {
|
||||
if err := windows.CloseHandle(handle); err != nil {
|
||||
log.Debugf("close handle: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// closeProcessInfo closes process and thread handles.
|
||||
func closeProcessInfo(pi *windows.ProcessInformation) {
|
||||
if pi != nil {
|
||||
if err := windows.CloseHandle(pi.Process); err != nil {
|
||||
log.Debugf("close process handle: %v", err)
|
||||
}
|
||||
if err := windows.CloseHandle(pi.Thread); err != nil {
|
||||
log.Debugf("close thread handle: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// windowsHandleReader wraps a Windows handle for reading.
|
||||
type windowsHandleReader struct {
|
||||
handle windows.Handle
|
||||
}
|
||||
|
||||
func (r *windowsHandleReader) Read(p []byte) (n int, err error) {
|
||||
var bytesRead uint32
|
||||
if err := windows.ReadFile(r.handle, p, &bytesRead, nil); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(bytesRead), nil
|
||||
}
|
||||
|
||||
func (r *windowsHandleReader) Close() error {
|
||||
return windows.CloseHandle(r.handle)
|
||||
}
|
||||
|
||||
// windowsHandleWriter wraps a Windows handle for writing.
|
||||
type windowsHandleWriter struct {
|
||||
handle windows.Handle
|
||||
}
|
||||
|
||||
func (w *windowsHandleWriter) Write(p []byte) (n int, err error) {
|
||||
var bytesWritten uint32
|
||||
if err := windows.WriteFile(w.handle, p, &bytesWritten, nil); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(bytesWritten), nil
|
||||
}
|
||||
|
||||
func (w *windowsHandleWriter) Close() error {
|
||||
return windows.CloseHandle(w.handle)
|
||||
}
|
||||
286
client/ssh/server/winpty/conpty_test.go
Normal file
286
client/ssh/server/winpty/conpty_test.go
Normal file
@@ -0,0 +1,286 @@
|
||||
//go:build windows
|
||||
|
||||
package winpty
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func TestBuildShellArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
shell string
|
||||
command string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Shell with command",
|
||||
shell: "powershell.exe",
|
||||
command: "Get-Process",
|
||||
expected: []string{"powershell.exe", "-Command", "Get-Process"},
|
||||
},
|
||||
{
|
||||
name: "CMD with command",
|
||||
shell: "cmd.exe",
|
||||
command: "dir",
|
||||
expected: []string{"cmd.exe", "-Command", "dir"},
|
||||
},
|
||||
{
|
||||
name: "Shell interactive",
|
||||
shell: "powershell.exe",
|
||||
command: "",
|
||||
expected: []string{"powershell.exe"},
|
||||
},
|
||||
{
|
||||
name: "CMD interactive",
|
||||
shell: "cmd.exe",
|
||||
command: "",
|
||||
expected: []string{"cmd.exe"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := buildShellArgs(tt.shell, tt.command)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandLine(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Simple args",
|
||||
args: []string{"cmd.exe", "/c", "echo"},
|
||||
expected: "cmd.exe /c echo",
|
||||
},
|
||||
{
|
||||
name: "Args with spaces",
|
||||
args: []string{"Program Files\\app.exe", "arg with spaces"},
|
||||
expected: `"Program Files\app.exe" "arg with spaces"`,
|
||||
},
|
||||
{
|
||||
name: "Args with quotes",
|
||||
args: []string{"cmd.exe", "/c", `echo "hello world"`},
|
||||
expected: `cmd.exe /c "echo \"hello world\""`,
|
||||
},
|
||||
{
|
||||
name: "PowerShell calling PowerShell",
|
||||
args: []string{"powershell.exe", "-Command", `powershell.exe -Command "Get-Process | Where-Object {$_.Name -eq 'notepad'}"`},
|
||||
expected: `powershell.exe -Command "powershell.exe -Command \"Get-Process | Where-Object {$_.Name -eq 'notepad'}\""`,
|
||||
},
|
||||
{
|
||||
name: "Complex nested quotes",
|
||||
args: []string{"cmd.exe", "/c", `echo "He said \"Hello\" to me"`},
|
||||
expected: `cmd.exe /c "echo \"He said \\\"Hello\\\" to me\""`,
|
||||
},
|
||||
{
|
||||
name: "Path with spaces and args",
|
||||
args: []string{`C:\Program Files\MyApp\app.exe`, "--config", `C:\My Config\settings.json`},
|
||||
expected: `"C:\Program Files\MyApp\app.exe" --config "C:\My Config\settings.json"`,
|
||||
},
|
||||
{
|
||||
name: "Empty argument",
|
||||
args: []string{"cmd.exe", "/c", "echo", ""},
|
||||
expected: `cmd.exe /c echo ""`,
|
||||
},
|
||||
{
|
||||
name: "Argument with backslashes",
|
||||
args: []string{"robocopy", `C:\Source\`, `C:\Dest\`, "/E"},
|
||||
expected: `robocopy C:\Source\ C:\Dest\ /E`,
|
||||
},
|
||||
{
|
||||
name: "Empty args",
|
||||
args: []string{},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Single arg with space",
|
||||
args: []string{"path with spaces"},
|
||||
expected: `"path with spaces"`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := buildCommandLine(tt.args)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateConPtyPipes(t *testing.T) {
|
||||
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
|
||||
require.NoError(t, err, "Should create ConPty pipes successfully")
|
||||
|
||||
// Verify all handles are valid
|
||||
assert.NotEqual(t, windows.InvalidHandle, inputRead, "Input read handle should be valid")
|
||||
assert.NotEqual(t, windows.InvalidHandle, inputWrite, "Input write handle should be valid")
|
||||
assert.NotEqual(t, windows.InvalidHandle, outputRead, "Output read handle should be valid")
|
||||
assert.NotEqual(t, windows.InvalidHandle, outputWrite, "Output write handle should be valid")
|
||||
|
||||
// Clean up handles
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
}
|
||||
|
||||
func TestCreateConPty(t *testing.T) {
|
||||
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
|
||||
require.NoError(t, err, "Should create ConPty pipes successfully")
|
||||
defer closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
|
||||
hPty, err := createConPty(80, 24, inputRead, outputWrite)
|
||||
require.NoError(t, err, "Should create ConPty successfully")
|
||||
assert.NotEqual(t, windows.InvalidHandle, hPty, "ConPty handle should be valid")
|
||||
|
||||
// Clean up ConPty
|
||||
ret, _, _ := procClosePseudoConsole.Call(uintptr(hPty))
|
||||
assert.NotEqual(t, uintptr(0), ret, "Should close ConPty successfully")
|
||||
}
|
||||
|
||||
func TestConvertEnvironmentToUTF16(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userEnv []string
|
||||
hasError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid environment variables",
|
||||
userEnv: []string{"PATH=C:\\Windows", "USER=testuser", "HOME=C:\\Users\\testuser"},
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "Empty environment",
|
||||
userEnv: []string{},
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "Environment with empty strings",
|
||||
userEnv: []string{"PATH=C:\\Windows", "", "USER=testuser"},
|
||||
hasError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := convertEnvironmentToUTF16(tt.userEnv)
|
||||
if tt.hasError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if len(tt.userEnv) == 0 {
|
||||
assert.Nil(t, result, "Empty environment should return nil")
|
||||
} else {
|
||||
assert.NotNil(t, result, "Non-empty environment should return valid pointer")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuplicateToPrimaryToken(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping token tests in short mode")
|
||||
}
|
||||
|
||||
// Get current process token for testing
|
||||
var token windows.Token
|
||||
err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_ALL_ACCESS, &token)
|
||||
require.NoError(t, err, "Should open current process token")
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
|
||||
t.Logf("Failed to close token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
primaryToken, err := duplicateToPrimaryToken(windows.Handle(token))
|
||||
require.NoError(t, err, "Should duplicate token to primary")
|
||||
assert.NotEqual(t, windows.InvalidHandle, primaryToken, "Primary token should be valid")
|
||||
|
||||
// Clean up
|
||||
err = windows.CloseHandle(primaryToken)
|
||||
assert.NoError(t, err, "Should close primary token")
|
||||
}
|
||||
|
||||
func TestWindowsHandleReader(t *testing.T) {
|
||||
// Create a pipe for testing
|
||||
var readHandle, writeHandle windows.Handle
|
||||
err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0)
|
||||
require.NoError(t, err, "Should create pipe for testing")
|
||||
defer closeHandles(readHandle, writeHandle)
|
||||
|
||||
// Write test data
|
||||
testData := []byte("Hello, Windows Handle Reader!")
|
||||
var bytesWritten uint32
|
||||
err = windows.WriteFile(writeHandle, testData, &bytesWritten, nil)
|
||||
require.NoError(t, err, "Should write test data")
|
||||
require.Equal(t, uint32(len(testData)), bytesWritten, "Should write all test data")
|
||||
|
||||
// Close write handle to signal EOF
|
||||
if err := windows.CloseHandle(writeHandle); err != nil {
|
||||
t.Fatalf("Should close write handle: %v", err)
|
||||
}
|
||||
|
||||
// Test reading
|
||||
reader := &windowsHandleReader{handle: readHandle}
|
||||
buffer := make([]byte, len(testData))
|
||||
n, err := reader.Read(buffer)
|
||||
require.NoError(t, err, "Should read from handle")
|
||||
assert.Equal(t, len(testData), n, "Should read expected number of bytes")
|
||||
assert.Equal(t, testData, buffer, "Should read expected data")
|
||||
}
|
||||
|
||||
func TestWindowsHandleWriter(t *testing.T) {
|
||||
// Create a pipe for testing
|
||||
var readHandle, writeHandle windows.Handle
|
||||
err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0)
|
||||
require.NoError(t, err, "Should create pipe for testing")
|
||||
defer closeHandles(readHandle, writeHandle)
|
||||
|
||||
// Test writing
|
||||
testData := []byte("Hello, Windows Handle Writer!")
|
||||
writer := &windowsHandleWriter{handle: writeHandle}
|
||||
n, err := writer.Write(testData)
|
||||
require.NoError(t, err, "Should write to handle")
|
||||
assert.Equal(t, len(testData), n, "Should write expected number of bytes")
|
||||
|
||||
// Close write handle
|
||||
if err := windows.CloseHandle(writeHandle); err != nil {
|
||||
t.Fatalf("Should close write handle: %v", err)
|
||||
}
|
||||
|
||||
// Verify data was written by reading it back
|
||||
buffer := make([]byte, len(testData))
|
||||
var bytesRead uint32
|
||||
err = windows.ReadFile(readHandle, buffer, &bytesRead, nil)
|
||||
require.NoError(t, err, "Should read back written data")
|
||||
assert.Equal(t, uint32(len(testData)), bytesRead, "Should read back expected number of bytes")
|
||||
assert.Equal(t, testData, buffer, "Should read back expected data")
|
||||
}
|
||||
|
||||
// BenchmarkConPtyCreation benchmarks ConPty creation performance
|
||||
func BenchmarkConPtyCreation(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
hPty, err := createConPty(80, 24, inputRead, outputWrite)
|
||||
if err != nil {
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Clean up
|
||||
procClosePseudoConsole.Call(uintptr(hPty))
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user