mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
436 lines
13 KiB
Go
436 lines
13 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"os/exec"
|
|
"os/user"
|
|
"path/filepath"
|
|
"strings"
|
|
"unsafe"
|
|
|
|
"github.com/gliderlabs/ssh"
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.org/x/sys/windows"
|
|
"golang.org/x/sys/windows/registry"
|
|
|
|
"github.com/netbirdio/netbird/client/ssh/server/winpty"
|
|
)
|
|
|
|
// getUserEnvironment retrieves the Windows environment for the target user.
|
|
// Follows OpenSSH's resilient approach with graceful degradation on failures.
|
|
func (s *Server) getUserEnvironment(username, domain string) ([]string, error) {
|
|
userToken, err := s.getUserToken(username, domain)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get user token: %w", err)
|
|
}
|
|
defer func() {
|
|
if err := windows.CloseHandle(userToken); err != nil {
|
|
log.Debugf("close user token: %v", err)
|
|
}
|
|
}()
|
|
|
|
return s.getUserEnvironmentWithToken(userToken, username, domain)
|
|
}
|
|
|
|
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
|
|
func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) {
|
|
userProfile, err := s.loadUserProfile(userToken, username, domain)
|
|
if err != nil {
|
|
log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
|
|
userProfile = fmt.Sprintf("C:\\Users\\%s", username)
|
|
}
|
|
|
|
envMap := make(map[string]string)
|
|
|
|
if err := s.loadSystemEnvironment(envMap); err != nil {
|
|
log.Debugf("failed to load system environment from registry: %v", err)
|
|
}
|
|
|
|
s.setUserEnvironmentVariables(envMap, userProfile, username, domain)
|
|
|
|
var env []string
|
|
for key, value := range envMap {
|
|
env = append(env, key+"="+value)
|
|
}
|
|
|
|
return env, nil
|
|
}
|
|
|
|
// getUserToken creates a user token for the specified user.
|
|
func (s *Server) getUserToken(username, domain string) (windows.Handle, error) {
|
|
privilegeDropper := NewPrivilegeDropper()
|
|
token, err := privilegeDropper.createToken(username, domain)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("generate S4U user token: %w", err)
|
|
}
|
|
return token, nil
|
|
}
|
|
|
|
// loadUserProfile loads the Windows user profile and returns the profile path.
|
|
func (s *Server) loadUserProfile(userToken windows.Handle, username, domain string) (string, error) {
|
|
usernamePtr, err := windows.UTF16PtrFromString(username)
|
|
if err != nil {
|
|
return "", fmt.Errorf("convert username to UTF-16: %w", err)
|
|
}
|
|
|
|
var domainUTF16 *uint16
|
|
if domain != "" && domain != "." {
|
|
domainUTF16, err = windows.UTF16PtrFromString(domain)
|
|
if err != nil {
|
|
return "", fmt.Errorf("convert domain to UTF-16: %w", err)
|
|
}
|
|
}
|
|
|
|
type profileInfo struct {
|
|
dwSize uint32
|
|
dwFlags uint32
|
|
lpUserName *uint16
|
|
lpProfilePath *uint16
|
|
lpDefaultPath *uint16
|
|
lpServerName *uint16
|
|
lpPolicyPath *uint16
|
|
hProfile windows.Handle
|
|
}
|
|
|
|
const PI_NOUI = 0x00000001
|
|
|
|
profile := profileInfo{
|
|
dwSize: uint32(unsafe.Sizeof(profileInfo{})),
|
|
dwFlags: PI_NOUI,
|
|
lpUserName: usernamePtr,
|
|
lpServerName: domainUTF16,
|
|
}
|
|
|
|
userenv := windows.NewLazySystemDLL("userenv.dll")
|
|
loadUserProfileW := userenv.NewProc("LoadUserProfileW")
|
|
|
|
ret, _, err := loadUserProfileW.Call(
|
|
uintptr(userToken),
|
|
uintptr(unsafe.Pointer(&profile)),
|
|
)
|
|
|
|
if ret == 0 {
|
|
return "", fmt.Errorf("LoadUserProfileW: %w", err)
|
|
}
|
|
|
|
if profile.lpProfilePath == nil {
|
|
return "", fmt.Errorf("LoadUserProfileW returned null profile path")
|
|
}
|
|
|
|
profilePath := windows.UTF16PtrToString(profile.lpProfilePath)
|
|
return profilePath, nil
|
|
}
|
|
|
|
// loadSystemEnvironment loads system-wide environment variables from registry.
|
|
func (s *Server) loadSystemEnvironment(envMap map[string]string) error {
|
|
key, err := registry.OpenKey(registry.LOCAL_MACHINE,
|
|
`SYSTEM\CurrentControlSet\Control\Session Manager\Environment`,
|
|
registry.QUERY_VALUE)
|
|
if err != nil {
|
|
return fmt.Errorf("open system environment registry key: %w", err)
|
|
}
|
|
defer func() {
|
|
if err := key.Close(); err != nil {
|
|
log.Debugf("close registry key: %v", err)
|
|
}
|
|
}()
|
|
|
|
return s.readRegistryEnvironment(key, envMap)
|
|
}
|
|
|
|
// readRegistryEnvironment reads environment variables from a registry key.
|
|
func (s *Server) readRegistryEnvironment(key registry.Key, envMap map[string]string) error {
|
|
names, err := key.ReadValueNames(0)
|
|
if err != nil {
|
|
return fmt.Errorf("read registry value names: %w", err)
|
|
}
|
|
|
|
for _, name := range names {
|
|
value, valueType, err := key.GetStringValue(name)
|
|
if err != nil {
|
|
log.Debugf("failed to read registry value %s: %v", name, err)
|
|
continue
|
|
}
|
|
|
|
finalValue := s.expandRegistryValue(value, valueType, name)
|
|
s.setEnvironmentVariable(envMap, name, finalValue)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// expandRegistryValue expands registry values if they contain environment variables.
|
|
func (s *Server) expandRegistryValue(value string, valueType uint32, name string) string {
|
|
if valueType != registry.EXPAND_SZ {
|
|
return value
|
|
}
|
|
|
|
sourcePtr := windows.StringToUTF16Ptr(value)
|
|
expandedBuffer := make([]uint16, 1024)
|
|
expandedLen, err := windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer)))
|
|
if err != nil {
|
|
log.Debugf("failed to expand environment string for %s: %v", name, err)
|
|
return value
|
|
}
|
|
|
|
// If buffer was too small, retry with larger buffer
|
|
if expandedLen > uint32(len(expandedBuffer)) {
|
|
expandedBuffer = make([]uint16, expandedLen)
|
|
expandedLen, err = windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer)))
|
|
if err != nil {
|
|
log.Debugf("failed to expand environment string for %s on retry: %v", name, err)
|
|
return value
|
|
}
|
|
}
|
|
|
|
if expandedLen > 0 && expandedLen <= uint32(len(expandedBuffer)) {
|
|
return windows.UTF16ToString(expandedBuffer[:expandedLen-1])
|
|
}
|
|
return value
|
|
}
|
|
|
|
// setEnvironmentVariable sets an environment variable with special handling for PATH.
|
|
func (s *Server) setEnvironmentVariable(envMap map[string]string, name, value string) {
|
|
upperName := strings.ToUpper(name)
|
|
|
|
if upperName == "PATH" {
|
|
if existing, exists := envMap["PATH"]; exists && existing != value {
|
|
envMap["PATH"] = existing + ";" + value
|
|
} else {
|
|
envMap["PATH"] = value
|
|
}
|
|
} else {
|
|
envMap[upperName] = value
|
|
}
|
|
}
|
|
|
|
// setUserEnvironmentVariables sets critical user-specific environment variables.
|
|
func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfile, username, domain string) {
|
|
envMap["USERPROFILE"] = userProfile
|
|
|
|
if len(userProfile) >= 2 && userProfile[1] == ':' {
|
|
envMap["HOMEDRIVE"] = userProfile[:2]
|
|
envMap["HOMEPATH"] = userProfile[2:]
|
|
}
|
|
|
|
envMap["APPDATA"] = filepath.Join(userProfile, "AppData", "Roaming")
|
|
envMap["LOCALAPPDATA"] = filepath.Join(userProfile, "AppData", "Local")
|
|
|
|
tempDir := filepath.Join(userProfile, "AppData", "Local", "Temp")
|
|
envMap["TEMP"] = tempDir
|
|
envMap["TMP"] = tempDir
|
|
|
|
envMap["USERNAME"] = username
|
|
if domain != "" && domain != "." {
|
|
envMap["USERDOMAIN"] = domain
|
|
envMap["USERDNSDOMAIN"] = domain
|
|
}
|
|
|
|
systemVars := []string{
|
|
"PROCESSOR_ARCHITECTURE", "PROCESSOR_IDENTIFIER", "PROCESSOR_LEVEL", "PROCESSOR_REVISION",
|
|
"SYSTEMDRIVE", "SYSTEMROOT", "WINDIR", "COMPUTERNAME", "OS", "PATHEXT",
|
|
"PROGRAMFILES", "PROGRAMDATA", "ALLUSERSPROFILE", "COMSPEC",
|
|
}
|
|
|
|
for _, sysVar := range systemVars {
|
|
if sysValue := os.Getenv(sysVar); sysValue != "" {
|
|
envMap[sysVar] = sysValue
|
|
}
|
|
}
|
|
}
|
|
|
|
// prepareCommandEnv prepares environment variables for command execution on Windows
|
|
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
|
username, domain := s.parseUsername(localUser.Username)
|
|
userEnv, err := s.getUserEnvironment(username, domain)
|
|
if err != nil {
|
|
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
|
|
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
|
env = append(env, prepareSSHEnv(session)...)
|
|
for _, v := range session.Environ() {
|
|
if acceptEnv(v) {
|
|
env = append(env, v)
|
|
}
|
|
}
|
|
return env
|
|
}
|
|
|
|
env := userEnv
|
|
env = append(env, prepareSSHEnv(session)...)
|
|
for _, v := range session.Environ() {
|
|
if acceptEnv(v) {
|
|
env = append(env, v)
|
|
}
|
|
}
|
|
return env
|
|
}
|
|
|
|
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
|
if privilegeResult.User == nil {
|
|
logger.Errorf("no user in privilege result")
|
|
return false
|
|
}
|
|
|
|
cmd := session.Command()
|
|
shell := getUserShell(privilegeResult.User.Uid)
|
|
|
|
if len(cmd) == 0 {
|
|
logger.Infof("starting interactive shell: %s", shell)
|
|
} else {
|
|
logger.Infof("executing command: %s", safeLogCommand(cmd))
|
|
}
|
|
|
|
s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
|
|
return true
|
|
}
|
|
|
|
// getShellCommandArgs returns the shell command and arguments for executing a command string
|
|
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
|
if cmdString == "" {
|
|
return []string{shell, "-NoLogo"}
|
|
}
|
|
return []string{shell, "-Command", cmdString}
|
|
}
|
|
|
|
func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
|
|
logger.Info("starting interactive shell")
|
|
s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand())
|
|
}
|
|
|
|
type PtyExecutionRequest struct {
|
|
Shell string
|
|
Command string
|
|
Width int
|
|
Height int
|
|
Username string
|
|
Domain string
|
|
}
|
|
|
|
func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error {
|
|
log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
|
|
req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height)
|
|
|
|
privilegeDropper := NewPrivilegeDropper()
|
|
userToken, err := privilegeDropper.createToken(req.Username, req.Domain)
|
|
if err != nil {
|
|
return fmt.Errorf("create user token: %w", err)
|
|
}
|
|
defer func() {
|
|
if err := windows.CloseHandle(userToken); err != nil {
|
|
log.Debugf("close user token: %v", err)
|
|
}
|
|
}()
|
|
|
|
server := &Server{}
|
|
userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain)
|
|
if err != nil {
|
|
log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
|
|
userEnv = os.Environ()
|
|
}
|
|
|
|
workingDir := getUserHomeFromEnv(userEnv)
|
|
if workingDir == "" {
|
|
workingDir = fmt.Sprintf(`C:\Users\%s`, req.Username)
|
|
}
|
|
|
|
ptyConfig := winpty.PtyConfig{
|
|
Shell: req.Shell,
|
|
Command: req.Command,
|
|
Width: req.Width,
|
|
Height: req.Height,
|
|
WorkingDir: workingDir,
|
|
}
|
|
|
|
userConfig := winpty.UserConfig{
|
|
Token: userToken,
|
|
Environment: userEnv,
|
|
}
|
|
|
|
log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
|
|
return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig)
|
|
}
|
|
|
|
func getUserHomeFromEnv(env []string) string {
|
|
for _, envVar := range env {
|
|
if len(envVar) > 12 && envVar[:12] == "USERPROFILE=" {
|
|
return envVar[12:]
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (s *Server) setupProcessGroup(_ *exec.Cmd) {
|
|
// Windows doesn't support process groups in the same way as Unix
|
|
// Process creation groups are handled differently
|
|
}
|
|
|
|
func (s *Server) killProcessGroup(cmd *exec.Cmd) {
|
|
if cmd.Process == nil {
|
|
return
|
|
}
|
|
|
|
logger := log.WithField("pid", cmd.Process.Pid)
|
|
|
|
if err := cmd.Process.Kill(); err != nil {
|
|
logger.Debugf("kill process failed: %v", err)
|
|
}
|
|
}
|
|
|
|
// detectSuPtySupport always returns false on Windows as su is not available
|
|
func (s *Server) detectSuPtySupport(context.Context) bool {
|
|
return false
|
|
}
|
|
|
|
// detectUtilLinuxLogin always returns false on Windows
|
|
func (s *Server) detectUtilLinuxLogin(context.Context) bool {
|
|
return false
|
|
}
|
|
|
|
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
|
|
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
|
command := session.RawCommand()
|
|
if command == "" {
|
|
logger.Error("no command specified for PTY execution")
|
|
if err := session.Exit(1); err != nil {
|
|
logSessionExitError(logger, err)
|
|
}
|
|
return false
|
|
}
|
|
|
|
return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command)
|
|
}
|
|
|
|
// executeConPtyCommand executes a command using ConPty (common for interactive and command execution)
|
|
func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool {
|
|
localUser := privilegeResult.User
|
|
if localUser == nil {
|
|
logger.Errorf("no user in privilege result")
|
|
return false
|
|
}
|
|
|
|
username, domain := s.parseUsername(localUser.Username)
|
|
shell := getUserShell(localUser.Uid)
|
|
|
|
req := PtyExecutionRequest{
|
|
Shell: shell,
|
|
Command: command,
|
|
Width: ptyReq.Window.Width,
|
|
Height: ptyReq.Window.Height,
|
|
Username: username,
|
|
Domain: domain,
|
|
}
|
|
|
|
if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil {
|
|
logger.Errorf("ConPty execution failed: %v", err)
|
|
if err := session.Exit(1); err != nil {
|
|
logSessionExitError(logger, err)
|
|
}
|
|
return false
|
|
}
|
|
|
|
logger.Debug("ConPty execution completed")
|
|
return true
|
|
}
|