mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
[client,management] Rewrite the SSH feature (#4015)
This commit is contained in:
430
client/ssh/server/command_execution_windows.go
Normal file
430
client/ssh/server/command_execution_windows.go
Normal file
@@ -0,0 +1,430 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/server/winpty"
|
||||
)
|
||||
|
||||
// getUserEnvironment retrieves the Windows environment for the target user.
|
||||
// Follows OpenSSH's resilient approach with graceful degradation on failures.
|
||||
func (s *Server) getUserEnvironment(username, domain string) ([]string, error) {
|
||||
userToken, err := s.getUserToken(username, domain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(userToken); err != nil {
|
||||
log.Debugf("close user token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s.getUserEnvironmentWithToken(userToken, username, domain)
|
||||
}
|
||||
|
||||
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
|
||||
func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) {
|
||||
userProfile, err := s.loadUserProfile(userToken, username, domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
|
||||
userProfile = fmt.Sprintf("C:\\Users\\%s", username)
|
||||
}
|
||||
|
||||
envMap := make(map[string]string)
|
||||
|
||||
if err := s.loadSystemEnvironment(envMap); err != nil {
|
||||
log.Debugf("failed to load system environment from registry: %v", err)
|
||||
}
|
||||
|
||||
s.setUserEnvironmentVariables(envMap, userProfile, username, domain)
|
||||
|
||||
var env []string
|
||||
for key, value := range envMap {
|
||||
env = append(env, key+"="+value)
|
||||
}
|
||||
|
||||
return env, nil
|
||||
}
|
||||
|
||||
// getUserToken creates a user token for the specified user.
|
||||
func (s *Server) getUserToken(username, domain string) (windows.Handle, error) {
|
||||
privilegeDropper := NewPrivilegeDropper()
|
||||
token, err := privilegeDropper.createToken(username, domain)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("generate S4U user token: %w", err)
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// loadUserProfile loads the Windows user profile and returns the profile path.
|
||||
func (s *Server) loadUserProfile(userToken windows.Handle, username, domain string) (string, error) {
|
||||
usernamePtr, err := windows.UTF16PtrFromString(username)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("convert username to UTF-16: %w", err)
|
||||
}
|
||||
|
||||
var domainUTF16 *uint16
|
||||
if domain != "" && domain != "." {
|
||||
domainUTF16, err = windows.UTF16PtrFromString(domain)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("convert domain to UTF-16: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
type profileInfo struct {
|
||||
dwSize uint32
|
||||
dwFlags uint32
|
||||
lpUserName *uint16
|
||||
lpProfilePath *uint16
|
||||
lpDefaultPath *uint16
|
||||
lpServerName *uint16
|
||||
lpPolicyPath *uint16
|
||||
hProfile windows.Handle
|
||||
}
|
||||
|
||||
const PI_NOUI = 0x00000001
|
||||
|
||||
profile := profileInfo{
|
||||
dwSize: uint32(unsafe.Sizeof(profileInfo{})),
|
||||
dwFlags: PI_NOUI,
|
||||
lpUserName: usernamePtr,
|
||||
lpServerName: domainUTF16,
|
||||
}
|
||||
|
||||
userenv := windows.NewLazySystemDLL("userenv.dll")
|
||||
loadUserProfileW := userenv.NewProc("LoadUserProfileW")
|
||||
|
||||
ret, _, err := loadUserProfileW.Call(
|
||||
uintptr(userToken),
|
||||
uintptr(unsafe.Pointer(&profile)),
|
||||
)
|
||||
|
||||
if ret == 0 {
|
||||
return "", fmt.Errorf("LoadUserProfileW: %w", err)
|
||||
}
|
||||
|
||||
if profile.lpProfilePath == nil {
|
||||
return "", fmt.Errorf("LoadUserProfileW returned null profile path")
|
||||
}
|
||||
|
||||
profilePath := windows.UTF16PtrToString(profile.lpProfilePath)
|
||||
return profilePath, nil
|
||||
}
|
||||
|
||||
// loadSystemEnvironment loads system-wide environment variables from registry.
|
||||
func (s *Server) loadSystemEnvironment(envMap map[string]string) error {
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE,
|
||||
`SYSTEM\CurrentControlSet\Control\Session Manager\Environment`,
|
||||
registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open system environment registry key: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := key.Close(); err != nil {
|
||||
log.Debugf("close registry key: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s.readRegistryEnvironment(key, envMap)
|
||||
}
|
||||
|
||||
// readRegistryEnvironment reads environment variables from a registry key.
|
||||
func (s *Server) readRegistryEnvironment(key registry.Key, envMap map[string]string) error {
|
||||
names, err := key.ReadValueNames(0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read registry value names: %w", err)
|
||||
}
|
||||
|
||||
for _, name := range names {
|
||||
value, valueType, err := key.GetStringValue(name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to read registry value %s: %v", name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
finalValue := s.expandRegistryValue(value, valueType, name)
|
||||
s.setEnvironmentVariable(envMap, name, finalValue)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// expandRegistryValue expands registry values if they contain environment variables.
|
||||
func (s *Server) expandRegistryValue(value string, valueType uint32, name string) string {
|
||||
if valueType != registry.EXPAND_SZ {
|
||||
return value
|
||||
}
|
||||
|
||||
sourcePtr := windows.StringToUTF16Ptr(value)
|
||||
expandedBuffer := make([]uint16, 1024)
|
||||
expandedLen, err := windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer)))
|
||||
if err != nil {
|
||||
log.Debugf("failed to expand environment string for %s: %v", name, err)
|
||||
return value
|
||||
}
|
||||
|
||||
// If buffer was too small, retry with larger buffer
|
||||
if expandedLen > uint32(len(expandedBuffer)) {
|
||||
expandedBuffer = make([]uint16, expandedLen)
|
||||
expandedLen, err = windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer)))
|
||||
if err != nil {
|
||||
log.Debugf("failed to expand environment string for %s on retry: %v", name, err)
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
if expandedLen > 0 && expandedLen <= uint32(len(expandedBuffer)) {
|
||||
return windows.UTF16ToString(expandedBuffer[:expandedLen-1])
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// setEnvironmentVariable sets an environment variable with special handling for PATH.
|
||||
func (s *Server) setEnvironmentVariable(envMap map[string]string, name, value string) {
|
||||
upperName := strings.ToUpper(name)
|
||||
|
||||
if upperName == "PATH" {
|
||||
if existing, exists := envMap["PATH"]; exists && existing != value {
|
||||
envMap["PATH"] = existing + ";" + value
|
||||
} else {
|
||||
envMap["PATH"] = value
|
||||
}
|
||||
} else {
|
||||
envMap[upperName] = value
|
||||
}
|
||||
}
|
||||
|
||||
// setUserEnvironmentVariables sets critical user-specific environment variables.
|
||||
func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfile, username, domain string) {
|
||||
envMap["USERPROFILE"] = userProfile
|
||||
|
||||
if len(userProfile) >= 2 && userProfile[1] == ':' {
|
||||
envMap["HOMEDRIVE"] = userProfile[:2]
|
||||
envMap["HOMEPATH"] = userProfile[2:]
|
||||
}
|
||||
|
||||
envMap["APPDATA"] = filepath.Join(userProfile, "AppData", "Roaming")
|
||||
envMap["LOCALAPPDATA"] = filepath.Join(userProfile, "AppData", "Local")
|
||||
|
||||
tempDir := filepath.Join(userProfile, "AppData", "Local", "Temp")
|
||||
envMap["TEMP"] = tempDir
|
||||
envMap["TMP"] = tempDir
|
||||
|
||||
envMap["USERNAME"] = username
|
||||
if domain != "" && domain != "." {
|
||||
envMap["USERDOMAIN"] = domain
|
||||
envMap["USERDNSDOMAIN"] = domain
|
||||
}
|
||||
|
||||
systemVars := []string{
|
||||
"PROCESSOR_ARCHITECTURE", "PROCESSOR_IDENTIFIER", "PROCESSOR_LEVEL", "PROCESSOR_REVISION",
|
||||
"SYSTEMDRIVE", "SYSTEMROOT", "WINDIR", "COMPUTERNAME", "OS", "PATHEXT",
|
||||
"PROGRAMFILES", "PROGRAMDATA", "ALLUSERSPROFILE", "COMSPEC",
|
||||
}
|
||||
|
||||
for _, sysVar := range systemVars {
|
||||
if sysValue := os.Getenv(sysVar); sysValue != "" {
|
||||
envMap[sysVar] = sysValue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// prepareCommandEnv prepares environment variables for command execution on Windows
|
||||
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
userEnv, err := s.getUserEnvironment(username, domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
env := userEnv
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
if privilegeResult.User == nil {
|
||||
logger.Errorf("no user in privilege result")
|
||||
return false
|
||||
}
|
||||
|
||||
cmd := session.Command()
|
||||
shell := getUserShell(privilegeResult.User.Uid)
|
||||
|
||||
if len(cmd) == 0 {
|
||||
logger.Infof("starting interactive shell: %s", shell)
|
||||
} else {
|
||||
logger.Infof("executing command: %s", safeLogCommand(cmd))
|
||||
}
|
||||
|
||||
s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
|
||||
return true
|
||||
}
|
||||
|
||||
// getShellCommandArgs returns the shell command and arguments for executing a command string
|
||||
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
||||
if cmdString == "" {
|
||||
return []string{shell, "-NoLogo"}
|
||||
}
|
||||
return []string{shell, "-Command", cmdString}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
|
||||
logger.Info("starting interactive shell")
|
||||
s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand())
|
||||
}
|
||||
|
||||
type PtyExecutionRequest struct {
|
||||
Shell string
|
||||
Command string
|
||||
Width int
|
||||
Height int
|
||||
Username string
|
||||
Domain string
|
||||
}
|
||||
|
||||
func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error {
|
||||
log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
|
||||
req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height)
|
||||
|
||||
privilegeDropper := NewPrivilegeDropper()
|
||||
userToken, err := privilegeDropper.createToken(req.Username, req.Domain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create user token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(userToken); err != nil {
|
||||
log.Debugf("close user token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
server := &Server{}
|
||||
userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
|
||||
userEnv = os.Environ()
|
||||
}
|
||||
|
||||
workingDir := getUserHomeFromEnv(userEnv)
|
||||
if workingDir == "" {
|
||||
workingDir = fmt.Sprintf(`C:\Users\%s`, req.Username)
|
||||
}
|
||||
|
||||
ptyConfig := winpty.PtyConfig{
|
||||
Shell: req.Shell,
|
||||
Command: req.Command,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
WorkingDir: workingDir,
|
||||
}
|
||||
|
||||
userConfig := winpty.UserConfig{
|
||||
Token: userToken,
|
||||
Environment: userEnv,
|
||||
}
|
||||
|
||||
log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
|
||||
return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig)
|
||||
}
|
||||
|
||||
func getUserHomeFromEnv(env []string) string {
|
||||
for _, envVar := range env {
|
||||
if len(envVar) > 12 && envVar[:12] == "USERPROFILE=" {
|
||||
return envVar[12:]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *Server) setupProcessGroup(_ *exec.Cmd) {
|
||||
// Windows doesn't support process groups in the same way as Unix
|
||||
// Process creation groups are handled differently
|
||||
}
|
||||
|
||||
func (s *Server) killProcessGroup(cmd *exec.Cmd) {
|
||||
if cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
|
||||
logger := log.WithField("pid", cmd.Process.Pid)
|
||||
|
||||
if err := cmd.Process.Kill(); err != nil {
|
||||
logger.Debugf("kill process failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// detectSuPtySupport always returns false on Windows as su is not available
|
||||
func (s *Server) detectSuPtySupport(context.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
|
||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
command := session.RawCommand()
|
||||
if command == "" {
|
||||
logger.Error("no command specified for PTY execution")
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command)
|
||||
}
|
||||
|
||||
// executeConPtyCommand executes a command using ConPty (common for interactive and command execution)
|
||||
func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool {
|
||||
localUser := privilegeResult.User
|
||||
if localUser == nil {
|
||||
logger.Errorf("no user in privilege result")
|
||||
return false
|
||||
}
|
||||
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
shell := getUserShell(localUser.Uid)
|
||||
|
||||
req := PtyExecutionRequest{
|
||||
Shell: shell,
|
||||
Command: command,
|
||||
Width: ptyReq.Window.Width,
|
||||
Height: ptyReq.Window.Height,
|
||||
Username: username,
|
||||
Domain: domain,
|
||||
}
|
||||
|
||||
if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil {
|
||||
logger.Errorf("ConPty execution failed: %v", err)
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
logger.Debug("ConPty execution completed")
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user