diff --git a/client/ssh/client/client_test.go b/client/ssh/client/client_test.go index d141ac00a..6383a2e32 100644 --- a/client/ssh/client/client_test.go +++ b/client/ssh/client/client_test.go @@ -171,7 +171,12 @@ func TestSSHClient_ContextCancellation(t *testing.T) { currentUser := getCurrentUsername() _, err = DialInsecure(ctx, serverAddr, currentUser) if err != nil { - assert.Contains(t, err.Error(), "context") + // Check for actual timeout-related errors rather than string matching + assert.True(t, + errors.Is(err, context.DeadlineExceeded) || + errors.Is(err, context.Canceled) || + strings.Contains(err.Error(), "timeout"), + "Expected timeout-related error, got: %v", err) } }) @@ -373,8 +378,16 @@ func TestSSHClient_PortForwardingDataTransfer(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - currentUser := getCurrentUsername() - client, err := DialInsecure(ctx, serverAddr, currentUser) + // Port forwarding requires the actual current user, not test user + realUser, err := getRealCurrentUser() + require.NoError(t, err) + + // Skip if running as system account that can't do port forwarding + if isSystemAccount(realUser) { + t.Skipf("Skipping port forwarding test - running as system account: %s", realUser) + } + + client, err := DialInsecure(ctx, serverAddr, realUser) require.NoError(t, err) defer func() { if err := client.Close(); err != nil { @@ -634,6 +647,25 @@ func isSystemAccount(username string) bool { return false } +// getRealCurrentUser returns the actual current user (not test user) for features like port forwarding +func getRealCurrentUser() (string, error) { + if runtime.GOOS == "windows" { + if currentUser, err := user.Current(); err == nil { + return currentUser.Username, nil + } + } + + if username := os.Getenv("USER"); username != "" { + return username, nil + } + + if currentUser, err := user.Current(); err == nil { + return currentUser.Username, nil + } + + return "", fmt.Errorf("unable to determine current user") +} + // isWindowsPrivilegeError checks if an error is related to Windows privilege restrictions func isWindowsPrivilegeError(err error) bool { if err == nil { diff --git a/client/ssh/server/command_execution_windows.go b/client/ssh/server/command_execution_windows.go index b0b76f22d..25f4a75eb 100644 --- a/client/ssh/server/command_execution_windows.go +++ b/client/ssh/server/command_execution_windows.go @@ -172,7 +172,18 @@ func (s *Server) expandRegistryValue(value string, valueType uint32, name string log.Debugf("failed to expand environment string for %s: %v", name, err) return value } - if expandedLen > 0 { + + // If buffer was too small, retry with larger buffer + if expandedLen > uint32(len(expandedBuffer)) { + expandedBuffer = make([]uint16, expandedLen) + expandedLen, err = windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer))) + if err != nil { + log.Debugf("failed to expand environment string for %s on retry: %v", name, err) + return value + } + } + + if expandedLen > 0 && expandedLen <= uint32(len(expandedBuffer)) { return windows.UTF16ToString(expandedBuffer[:expandedLen-1]) } return value diff --git a/client/ssh/server/compatibility_test.go b/client/ssh/server/compatibility_test.go index 7cc924215..eb5e5c519 100644 --- a/client/ssh/server/compatibility_test.go +++ b/client/ssh/server/compatibility_test.go @@ -873,4 +873,3 @@ func createWindowsTestUser(t *testing.T, username string) bool { t.Logf("Test user creation result: %s", string(output)) return true } -