[client,management] Rewrite the SSH feature (#4015)

This commit is contained in:
Viktor Liu
2025-11-17 17:10:41 +01:00
committed by GitHub
parent 0d79301141
commit d71a82769c
170 changed files with 18744 additions and 2853 deletions

View File

@@ -0,0 +1,206 @@
package server
import (
"errors"
"fmt"
"io"
"os"
"os/exec"
"time"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
// handleCommand executes an SSH command with privilege validation
func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) {
hasPty := winCh != nil
commandType := "command"
if hasPty {
commandType = "Pty command"
}
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty)
if err != nil {
logger.Errorf("%s creation failed: %v", commandType, err)
errorMsg := fmt.Sprintf("Cannot create %s - platform may not support user switching", commandType)
if hasPty {
errorMsg += " with Pty"
}
errorMsg += "\n"
if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil {
logger.Debugf(errWriteSession, writeErr)
}
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return
}
if !hasPty {
if s.executeCommand(logger, session, execCmd, cleanup) {
logger.Debugf("%s execution completed", commandType)
}
return
}
defer cleanup()
ptyReq, _, _ := session.Pty()
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
logger.Debugf("%s execution completed", commandType)
}
}
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
localUser := privilegeResult.User
if localUser == nil {
return nil, nil, errors.New("no user in privilege result")
}
// If PTY requested but su doesn't support --pty, skip su and use executor
// This ensures PTY functionality is provided (executor runs within our allocated PTY)
if hasPty && !s.suSupportsPty {
log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
if err != nil {
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
}
cmd.Env = s.prepareCommandEnv(localUser, session)
return cmd, cleanup, nil
}
// Try su first for system integration (PAM/audit) when privileged
cmd, err := s.createSuCommand(session, localUser, hasPty)
if err != nil || privilegeResult.UsedFallback {
log.Debugf("su command failed, falling back to executor: %v", err)
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
if err != nil {
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
}
cmd.Env = s.prepareCommandEnv(localUser, session)
return cmd, cleanup, nil
}
cmd.Env = s.prepareCommandEnv(localUser, session)
return cmd, func() {}, nil
}
// executeCommand executes the command and handles I/O and exit codes
func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, cleanup func()) bool {
defer cleanup()
s.setupProcessGroup(execCmd)
stdinPipe, err := execCmd.StdinPipe()
if err != nil {
logger.Errorf("create stdin pipe: %v", err)
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}
execCmd.Stdout = session
execCmd.Stderr = session.Stderr()
if execCmd.Dir != "" {
if _, err := os.Stat(execCmd.Dir); err != nil {
logger.Warnf("working directory does not exist: %s (%v)", execCmd.Dir, err)
execCmd.Dir = "/"
}
}
if err := execCmd.Start(); err != nil {
logger.Errorf("command start failed: %v", err)
// no user message for exec failure, just exit
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}
go s.handleCommandIO(logger, stdinPipe, session)
return s.waitForCommandCleanup(logger, session, execCmd)
}
// handleCommandIO manages stdin/stdout copying in a goroutine
func (s *Server) handleCommandIO(logger *log.Entry, stdinPipe io.WriteCloser, session ssh.Session) {
defer func() {
if err := stdinPipe.Close(); err != nil {
logger.Debugf("stdin pipe close error: %v", err)
}
}()
if _, err := io.Copy(stdinPipe, session); err != nil {
logger.Debugf("stdin copy error: %v", err)
}
}
// waitForCommandCleanup waits for command completion with session disconnect handling
func (s *Server) waitForCommandCleanup(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd) bool {
ctx := session.Context()
done := make(chan error, 1)
go func() {
done <- execCmd.Wait()
}()
select {
case <-ctx.Done():
logger.Debugf("session cancelled, terminating command")
s.killProcessGroup(execCmd)
select {
case err := <-done:
logger.Tracef("command terminated after session cancellation: %v", err)
case <-time.After(5 * time.Second):
logger.Warnf("command did not terminate within 5 seconds after session cancellation")
}
if err := session.Exit(130); err != nil {
logSessionExitError(logger, err)
}
return false
case err := <-done:
return s.handleCommandCompletion(logger, session, err)
}
}
// handleCommandCompletion handles command completion
func (s *Server) handleCommandCompletion(logger *log.Entry, session ssh.Session, err error) bool {
if err != nil {
logger.Debugf("command execution failed: %v", err)
s.handleSessionExit(session, err, logger)
return false
}
s.handleSessionExit(session, nil, logger)
return true
}
// handleSessionExit handles command errors and sets appropriate exit codes
func (s *Server) handleSessionExit(session ssh.Session, err error, logger *log.Entry) {
if err == nil {
if err := session.Exit(0); err != nil {
logSessionExitError(logger, err)
}
return
}
var exitError *exec.ExitError
if errors.As(err, &exitError) {
if err := session.Exit(exitError.ExitCode()); err != nil {
logSessionExitError(logger, err)
}
} else {
logger.Debugf("non-exit error in command execution: %v", err)
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
}
}

View File

@@ -0,0 +1,52 @@
//go:build js
package server
import (
"context"
"errors"
"os/exec"
"os/user"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
// createSuCommand is not supported on JS/WASM
func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
return nil, errNotSupported
}
// createExecutorCommand is not supported on JS/WASM
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
return nil, nil, errNotSupported
}
// prepareCommandEnv is not supported on JS/WASM
func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string {
return nil
}
// setupProcessGroup is not supported on JS/WASM
func (s *Server) setupProcessGroup(_ *exec.Cmd) {
}
// killProcessGroup is not supported on JS/WASM
func (s *Server) killProcessGroup(*exec.Cmd) {
}
// detectSuPtySupport always returns false on JS/WASM
func (s *Server) detectSuPtySupport(context.Context) bool {
return false
}
// executeCommandWithPty is not supported on JS/WASM
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
logger.Errorf("PTY command execution not supported on JS/WASM")
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}

View File

@@ -0,0 +1,329 @@
//go:build unix
package server
import (
"context"
"errors"
"fmt"
"io"
"os"
"os/exec"
"os/user"
"strings"
"sync"
"syscall"
"time"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
// ptyManager manages Pty file operations with thread safety
type ptyManager struct {
file *os.File
mu sync.RWMutex
closed bool
closeErr error
once sync.Once
}
func newPtyManager(file *os.File) *ptyManager {
return &ptyManager{file: file}
}
func (pm *ptyManager) Close() error {
pm.once.Do(func() {
pm.mu.Lock()
pm.closed = true
pm.closeErr = pm.file.Close()
pm.mu.Unlock()
})
pm.mu.RLock()
defer pm.mu.RUnlock()
return pm.closeErr
}
func (pm *ptyManager) Setsize(ws *pty.Winsize) error {
pm.mu.RLock()
defer pm.mu.RUnlock()
if pm.closed {
return errors.New("pty is closed")
}
return pty.Setsize(pm.file, ws)
}
func (pm *ptyManager) File() *os.File {
return pm.file
}
// detectSuPtySupport checks if su supports the --pty flag
func (s *Server) detectSuPtySupport(ctx context.Context) bool {
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer cancel()
cmd := exec.CommandContext(ctx, "su", "--help")
output, err := cmd.CombinedOutput()
if err != nil {
log.Debugf("su --help failed (may not support --help): %v", err)
return false
}
supported := strings.Contains(string(output), "--pty")
log.Debugf("su --pty support detected: %v", supported)
return supported
}
// createSuCommand creates a command using su -l -c for privilege switching
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
suPath, err := exec.LookPath("su")
if err != nil {
return nil, fmt.Errorf("su command not available: %w", err)
}
command := session.RawCommand()
if command == "" {
return nil, fmt.Errorf("no command specified for su execution")
}
args := []string{"-l"}
if hasPty && s.suSupportsPty {
args = append(args, "--pty")
}
args = append(args, localUser.Username, "-c", command)
cmd := exec.CommandContext(session.Context(), suPath, args...)
cmd.Dir = localUser.HomeDir
return cmd, nil
}
// getShellCommandArgs returns the shell command and arguments for executing a command string
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
if cmdString == "" {
return []string{shell, "-l"}
}
return []string{shell, "-l", "-c", cmdString}
}
// prepareCommandEnv prepares environment variables for command execution on Unix
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
if acceptEnv(v) {
env = append(env, v)
}
}
return env
}
// executeCommandWithPty executes a command with PTY allocation
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
termType := ptyReq.Term
if termType == "" {
termType = "xterm-256color"
}
execCmd.Env = append(execCmd.Env, fmt.Sprintf("TERM=%s", termType))
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
}
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
if err != nil {
logger.Errorf("Pty command creation failed: %v", err)
errorMsg := "User switching failed - login command not available\r\n"
if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil {
logger.Debugf(errWriteSession, writeErr)
}
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}
logger.Infof("starting interactive shell: %s", execCmd.Path)
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
}
// runPtyCommand runs a command with PTY management (common code for interactive and command execution)
func (s *Server) runPtyCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
ptmx, err := s.startPtyCommandWithSize(execCmd, ptyReq)
if err != nil {
logger.Errorf("Pty start failed: %v", err)
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}
ptyMgr := newPtyManager(ptmx)
defer func() {
if err := ptyMgr.Close(); err != nil {
logger.Debugf("Pty close error: %v", err)
}
}()
go s.handlePtyWindowResize(logger, session, ptyMgr, winCh)
s.handlePtyIO(logger, session, ptyMgr)
s.waitForPtyCompletion(logger, session, execCmd, ptyMgr)
return true
}
func (s *Server) startPtyCommandWithSize(execCmd *exec.Cmd, ptyReq ssh.Pty) (*os.File, error) {
winSize := &pty.Winsize{
Cols: uint16(ptyReq.Window.Width),
Rows: uint16(ptyReq.Window.Height),
}
if winSize.Cols == 0 {
winSize.Cols = 80
}
if winSize.Rows == 0 {
winSize.Rows = 24
}
ptmx, err := pty.StartWithSize(execCmd, winSize)
if err != nil {
return nil, fmt.Errorf("start Pty: %w", err)
}
return ptmx, nil
}
func (s *Server) handlePtyWindowResize(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, winCh <-chan ssh.Window) {
for {
select {
case <-session.Context().Done():
return
case win, ok := <-winCh:
if !ok {
return
}
if err := ptyMgr.Setsize(&pty.Winsize{Rows: uint16(win.Height), Cols: uint16(win.Width)}); err != nil {
logger.Debugf("Pty resize to %dx%d: %v", win.Width, win.Height, err)
}
}
}
}
func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager) {
ptmx := ptyMgr.File()
go func() {
if _, err := io.Copy(ptmx, session); err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
logger.Warnf("Pty input copy error: %v", err)
}
}
}()
go func() {
defer func() {
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
logger.Debugf("session close error: %v", err)
}
}()
if _, err := io.Copy(session, ptmx); err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
logger.Warnf("Pty output copy error: %v", err)
}
}
}()
}
func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager) {
ctx := session.Context()
done := make(chan error, 1)
go func() {
done <- execCmd.Wait()
}()
select {
case <-ctx.Done():
s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done)
case err := <-done:
s.handlePtyCommandCompletion(logger, session, err)
}
}
func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager, done <-chan error) {
logger.Debugf("Pty session cancelled, terminating command")
if err := ptyMgr.Close(); err != nil {
logger.Debugf("Pty close during session cancellation: %v", err)
}
s.killProcessGroup(execCmd)
select {
case err := <-done:
if err != nil {
logger.Debugf("Pty command terminated after session cancellation with error: %v", err)
} else {
logger.Debugf("Pty command terminated after session cancellation")
}
case <-time.After(5 * time.Second):
logger.Warnf("Pty command did not terminate within 5 seconds after session cancellation")
}
if err := session.Exit(130); err != nil {
logSessionExitError(logger, err)
}
}
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) {
if err != nil {
logger.Debugf("Pty command execution failed: %v", err)
s.handleSessionExit(session, err, logger)
return
}
// Normal completion
logger.Debugf("Pty command completed successfully")
if err := session.Exit(0); err != nil {
logSessionExitError(logger, err)
}
}
func (s *Server) setupProcessGroup(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{
Setpgid: true,
}
}
func (s *Server) killProcessGroup(cmd *exec.Cmd) {
if cmd.Process == nil {
return
}
logger := log.WithField("pid", cmd.Process.Pid)
pgid := cmd.Process.Pid
if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil {
logger.Debugf("kill process group SIGTERM: %v", err)
return
}
const gracePeriod = 500 * time.Millisecond
const checkInterval = 50 * time.Millisecond
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
timeout := time.After(gracePeriod)
for {
select {
case <-timeout:
if err := syscall.Kill(-pgid, syscall.SIGKILL); err != nil {
logger.Debugf("kill process group SIGKILL: %v", err)
}
return
case <-ticker.C:
if err := syscall.Kill(-pgid, 0); err != nil {
return
}
}
}
}

View File

@@ -0,0 +1,430 @@
package server
import (
"context"
"fmt"
"os"
"os/exec"
"os/user"
"path/filepath"
"strings"
"unsafe"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/client/ssh/server/winpty"
)
// getUserEnvironment retrieves the Windows environment for the target user.
// Follows OpenSSH's resilient approach with graceful degradation on failures.
func (s *Server) getUserEnvironment(username, domain string) ([]string, error) {
userToken, err := s.getUserToken(username, domain)
if err != nil {
return nil, fmt.Errorf("get user token: %w", err)
}
defer func() {
if err := windows.CloseHandle(userToken); err != nil {
log.Debugf("close user token: %v", err)
}
}()
return s.getUserEnvironmentWithToken(userToken, username, domain)
}
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) {
userProfile, err := s.loadUserProfile(userToken, username, domain)
if err != nil {
log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
userProfile = fmt.Sprintf("C:\\Users\\%s", username)
}
envMap := make(map[string]string)
if err := s.loadSystemEnvironment(envMap); err != nil {
log.Debugf("failed to load system environment from registry: %v", err)
}
s.setUserEnvironmentVariables(envMap, userProfile, username, domain)
var env []string
for key, value := range envMap {
env = append(env, key+"="+value)
}
return env, nil
}
// getUserToken creates a user token for the specified user.
func (s *Server) getUserToken(username, domain string) (windows.Handle, error) {
privilegeDropper := NewPrivilegeDropper()
token, err := privilegeDropper.createToken(username, domain)
if err != nil {
return 0, fmt.Errorf("generate S4U user token: %w", err)
}
return token, nil
}
// loadUserProfile loads the Windows user profile and returns the profile path.
func (s *Server) loadUserProfile(userToken windows.Handle, username, domain string) (string, error) {
usernamePtr, err := windows.UTF16PtrFromString(username)
if err != nil {
return "", fmt.Errorf("convert username to UTF-16: %w", err)
}
var domainUTF16 *uint16
if domain != "" && domain != "." {
domainUTF16, err = windows.UTF16PtrFromString(domain)
if err != nil {
return "", fmt.Errorf("convert domain to UTF-16: %w", err)
}
}
type profileInfo struct {
dwSize uint32
dwFlags uint32
lpUserName *uint16
lpProfilePath *uint16
lpDefaultPath *uint16
lpServerName *uint16
lpPolicyPath *uint16
hProfile windows.Handle
}
const PI_NOUI = 0x00000001
profile := profileInfo{
dwSize: uint32(unsafe.Sizeof(profileInfo{})),
dwFlags: PI_NOUI,
lpUserName: usernamePtr,
lpServerName: domainUTF16,
}
userenv := windows.NewLazySystemDLL("userenv.dll")
loadUserProfileW := userenv.NewProc("LoadUserProfileW")
ret, _, err := loadUserProfileW.Call(
uintptr(userToken),
uintptr(unsafe.Pointer(&profile)),
)
if ret == 0 {
return "", fmt.Errorf("LoadUserProfileW: %w", err)
}
if profile.lpProfilePath == nil {
return "", fmt.Errorf("LoadUserProfileW returned null profile path")
}
profilePath := windows.UTF16PtrToString(profile.lpProfilePath)
return profilePath, nil
}
// loadSystemEnvironment loads system-wide environment variables from registry.
func (s *Server) loadSystemEnvironment(envMap map[string]string) error {
key, err := registry.OpenKey(registry.LOCAL_MACHINE,
`SYSTEM\CurrentControlSet\Control\Session Manager\Environment`,
registry.QUERY_VALUE)
if err != nil {
return fmt.Errorf("open system environment registry key: %w", err)
}
defer func() {
if err := key.Close(); err != nil {
log.Debugf("close registry key: %v", err)
}
}()
return s.readRegistryEnvironment(key, envMap)
}
// readRegistryEnvironment reads environment variables from a registry key.
func (s *Server) readRegistryEnvironment(key registry.Key, envMap map[string]string) error {
names, err := key.ReadValueNames(0)
if err != nil {
return fmt.Errorf("read registry value names: %w", err)
}
for _, name := range names {
value, valueType, err := key.GetStringValue(name)
if err != nil {
log.Debugf("failed to read registry value %s: %v", name, err)
continue
}
finalValue := s.expandRegistryValue(value, valueType, name)
s.setEnvironmentVariable(envMap, name, finalValue)
}
return nil
}
// expandRegistryValue expands registry values if they contain environment variables.
func (s *Server) expandRegistryValue(value string, valueType uint32, name string) string {
if valueType != registry.EXPAND_SZ {
return value
}
sourcePtr := windows.StringToUTF16Ptr(value)
expandedBuffer := make([]uint16, 1024)
expandedLen, err := windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer)))
if err != nil {
log.Debugf("failed to expand environment string for %s: %v", name, err)
return value
}
// If buffer was too small, retry with larger buffer
if expandedLen > uint32(len(expandedBuffer)) {
expandedBuffer = make([]uint16, expandedLen)
expandedLen, err = windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer)))
if err != nil {
log.Debugf("failed to expand environment string for %s on retry: %v", name, err)
return value
}
}
if expandedLen > 0 && expandedLen <= uint32(len(expandedBuffer)) {
return windows.UTF16ToString(expandedBuffer[:expandedLen-1])
}
return value
}
// setEnvironmentVariable sets an environment variable with special handling for PATH.
func (s *Server) setEnvironmentVariable(envMap map[string]string, name, value string) {
upperName := strings.ToUpper(name)
if upperName == "PATH" {
if existing, exists := envMap["PATH"]; exists && existing != value {
envMap["PATH"] = existing + ";" + value
} else {
envMap["PATH"] = value
}
} else {
envMap[upperName] = value
}
}
// setUserEnvironmentVariables sets critical user-specific environment variables.
func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfile, username, domain string) {
envMap["USERPROFILE"] = userProfile
if len(userProfile) >= 2 && userProfile[1] == ':' {
envMap["HOMEDRIVE"] = userProfile[:2]
envMap["HOMEPATH"] = userProfile[2:]
}
envMap["APPDATA"] = filepath.Join(userProfile, "AppData", "Roaming")
envMap["LOCALAPPDATA"] = filepath.Join(userProfile, "AppData", "Local")
tempDir := filepath.Join(userProfile, "AppData", "Local", "Temp")
envMap["TEMP"] = tempDir
envMap["TMP"] = tempDir
envMap["USERNAME"] = username
if domain != "" && domain != "." {
envMap["USERDOMAIN"] = domain
envMap["USERDNSDOMAIN"] = domain
}
systemVars := []string{
"PROCESSOR_ARCHITECTURE", "PROCESSOR_IDENTIFIER", "PROCESSOR_LEVEL", "PROCESSOR_REVISION",
"SYSTEMDRIVE", "SYSTEMROOT", "WINDIR", "COMPUTERNAME", "OS", "PATHEXT",
"PROGRAMFILES", "PROGRAMDATA", "ALLUSERSPROFILE", "COMSPEC",
}
for _, sysVar := range systemVars {
if sysValue := os.Getenv(sysVar); sysValue != "" {
envMap[sysVar] = sysValue
}
}
}
// prepareCommandEnv prepares environment variables for command execution on Windows
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
username, domain := s.parseUsername(localUser.Username)
userEnv, err := s.getUserEnvironment(username, domain)
if err != nil {
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
if acceptEnv(v) {
env = append(env, v)
}
}
return env
}
env := userEnv
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
if acceptEnv(v) {
env = append(env, v)
}
}
return env
}
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
if privilegeResult.User == nil {
logger.Errorf("no user in privilege result")
return false
}
cmd := session.Command()
shell := getUserShell(privilegeResult.User.Uid)
if len(cmd) == 0 {
logger.Infof("starting interactive shell: %s", shell)
} else {
logger.Infof("executing command: %s", safeLogCommand(cmd))
}
s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
return true
}
// getShellCommandArgs returns the shell command and arguments for executing a command string
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
if cmdString == "" {
return []string{shell, "-NoLogo"}
}
return []string{shell, "-Command", cmdString}
}
func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
logger.Info("starting interactive shell")
s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand())
}
type PtyExecutionRequest struct {
Shell string
Command string
Width int
Height int
Username string
Domain string
}
func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error {
log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height)
privilegeDropper := NewPrivilegeDropper()
userToken, err := privilegeDropper.createToken(req.Username, req.Domain)
if err != nil {
return fmt.Errorf("create user token: %w", err)
}
defer func() {
if err := windows.CloseHandle(userToken); err != nil {
log.Debugf("close user token: %v", err)
}
}()
server := &Server{}
userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain)
if err != nil {
log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
userEnv = os.Environ()
}
workingDir := getUserHomeFromEnv(userEnv)
if workingDir == "" {
workingDir = fmt.Sprintf(`C:\Users\%s`, req.Username)
}
ptyConfig := winpty.PtyConfig{
Shell: req.Shell,
Command: req.Command,
Width: req.Width,
Height: req.Height,
WorkingDir: workingDir,
}
userConfig := winpty.UserConfig{
Token: userToken,
Environment: userEnv,
}
log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig)
}
func getUserHomeFromEnv(env []string) string {
for _, envVar := range env {
if len(envVar) > 12 && envVar[:12] == "USERPROFILE=" {
return envVar[12:]
}
}
return ""
}
func (s *Server) setupProcessGroup(_ *exec.Cmd) {
// Windows doesn't support process groups in the same way as Unix
// Process creation groups are handled differently
}
func (s *Server) killProcessGroup(cmd *exec.Cmd) {
if cmd.Process == nil {
return
}
logger := log.WithField("pid", cmd.Process.Pid)
if err := cmd.Process.Kill(); err != nil {
logger.Debugf("kill process failed: %v", err)
}
}
// detectSuPtySupport always returns false on Windows as su is not available
func (s *Server) detectSuPtySupport(context.Context) bool {
return false
}
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
command := session.RawCommand()
if command == "" {
logger.Error("no command specified for PTY execution")
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}
return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command)
}
// executeConPtyCommand executes a command using ConPty (common for interactive and command execution)
func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool {
localUser := privilegeResult.User
if localUser == nil {
logger.Errorf("no user in privilege result")
return false
}
username, domain := s.parseUsername(localUser.Username)
shell := getUserShell(localUser.Uid)
req := PtyExecutionRequest{
Shell: shell,
Command: command,
Width: ptyReq.Window.Width,
Height: ptyReq.Window.Height,
Username: username,
Domain: domain,
}
if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil {
logger.Errorf("ConPty execution failed: %v", err)
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}
logger.Debug("ConPty execution completed")
return true
}

View File

@@ -0,0 +1,722 @@
package server
import (
"context"
"crypto/ed25519"
"crypto/rand"
"fmt"
"io"
"net"
"os"
"os/exec"
"runtime"
"strings"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/testutil"
)
// TestMain handles package-level setup and cleanup
func TestMain(m *testing.M) {
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
// This happens when running tests as non-privileged user with fallback
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
// Just exit with error to break the recursion
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
os.Exit(1)
}
// Run tests
code := m.Run()
// Cleanup any created test users
testutil.CleanupTestUsers()
os.Exit(code)
}
// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client
func TestSSHServerCompatibility(t *testing.T) {
if testing.Short() {
t.Skip("Skipping SSH compatibility tests in short mode")
}
// Check if ssh binary is available
if !isSSHClientAvailable() {
t.Skip("SSH client not available on this system")
}
// Set up SSH server - use our existing key generation for server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Generate OpenSSH-compatible keys for client
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Create temporary key files for SSH client
clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH)
defer cleanupKey()
// Extract host and port from server address
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
// Get appropriate user for SSH connection (handle system accounts)
username := testutil.GetTestUsername(t)
t.Run("basic command execution", func(t *testing.T) {
testSSHCommandExecutionWithUser(t, host, portStr, clientKeyFile, username)
})
t.Run("interactive command", func(t *testing.T) {
testSSHInteractiveCommand(t, host, portStr, clientKeyFile)
})
t.Run("port forwarding", func(t *testing.T) {
testSSHPortForwarding(t, host, portStr, clientKeyFile)
})
}
// testSSHCommandExecutionWithUser tests basic command execution with system SSH client using specified user.
func testSSHCommandExecutionWithUser(t *testing.T, host, port, keyFile, username string) {
cmd := exec.Command("ssh",
"-i", keyFile,
"-p", port,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
fmt.Sprintf("%s@%s", username, host),
"echo", "hello_world")
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("SSH command failed: %v", err)
t.Logf("Output: %s", string(output))
return
}
assert.Contains(t, string(output), "hello_world", "SSH command should execute successfully")
}
// testSSHInteractiveCommand tests interactive shell session.
func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := testutil.GetTestUsername(t)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "ssh",
"-i", keyFile,
"-p", port,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
fmt.Sprintf("%s@%s", username, host))
stdin, err := cmd.StdinPipe()
if err != nil {
t.Skipf("Cannot create stdin pipe: %v", err)
return
}
stdout, err := cmd.StdoutPipe()
if err != nil {
t.Skipf("Cannot create stdout pipe: %v", err)
return
}
err = cmd.Start()
if err != nil {
t.Logf("Cannot start SSH session: %v", err)
return
}
go func() {
defer func() {
if err := stdin.Close(); err != nil {
t.Logf("stdin close error: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
if _, err := stdin.Write([]byte("echo interactive_test\n")); err != nil {
t.Logf("stdin write error: %v", err)
}
time.Sleep(100 * time.Millisecond)
if _, err := stdin.Write([]byte("exit\n")); err != nil {
t.Logf("stdin write error: %v", err)
}
}()
output, err := io.ReadAll(stdout)
if err != nil {
t.Logf("Cannot read SSH output: %v", err)
}
err = cmd.Wait()
if err != nil {
t.Logf("SSH interactive session error: %v", err)
t.Logf("Output: %s", string(output))
return
}
assert.Contains(t, string(output), "interactive_test", "Interactive SSH session should work")
}
// testSSHPortForwarding tests port forwarding compatibility.
func testSSHPortForwarding(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := testutil.GetTestUsername(t)
testServer, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer testServer.Close()
testServerAddr := testServer.Addr().String()
expectedResponse := "HTTP/1.1 200 OK\r\nContent-Length: 21\r\n\r\nCompatibility Test OK"
go func() {
for {
conn, err := testServer.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer func() {
if err := c.Close(); err != nil {
t.Logf("test server connection close error: %v", err)
}
}()
buf := make([]byte, 1024)
if _, err := c.Read(buf); err != nil {
t.Logf("Test server read error: %v", err)
}
if _, err := c.Write([]byte(expectedResponse)); err != nil {
t.Logf("Test server write error: %v", err)
}
}(conn)
}
}()
localListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
localAddr := localListener.Addr().String()
localListener.Close()
_, localPort, err := net.SplitHostPort(localAddr)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
forwardSpec := fmt.Sprintf("%s:%s", localPort, testServerAddr)
cmd := exec.CommandContext(ctx, "ssh",
"-i", keyFile,
"-p", port,
"-L", forwardSpec,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
"-N",
fmt.Sprintf("%s@%s", username, host))
err = cmd.Start()
if err != nil {
t.Logf("Cannot start SSH port forwarding: %v", err)
return
}
defer func() {
if cmd.Process != nil {
if err := cmd.Process.Kill(); err != nil {
t.Logf("process kill error: %v", err)
}
}
if err := cmd.Wait(); err != nil {
t.Logf("process wait after kill: %v", err)
}
}()
time.Sleep(500 * time.Millisecond)
conn, err := net.DialTimeout("tcp", localAddr, 3*time.Second)
if err != nil {
t.Logf("Cannot connect to forwarded port: %v", err)
return
}
defer func() {
if err := conn.Close(); err != nil {
t.Logf("forwarded connection close error: %v", err)
}
}()
request := "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
_, err = conn.Write([]byte(request))
require.NoError(t, err)
if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil {
log.Debugf("failed to set read deadline: %v", err)
}
response := make([]byte, len(expectedResponse))
n, err := io.ReadFull(conn, response)
if err != nil {
t.Logf("Cannot read forwarded response: %v", err)
return
}
assert.Equal(t, len(expectedResponse), n, "Should read expected number of bytes")
assert.Equal(t, expectedResponse, string(response), "Should get correct HTTP response through SSH port forwarding")
}
// isSSHClientAvailable checks if the ssh binary is available
func isSSHClientAvailable() bool {
_, err := exec.LookPath("ssh")
return err == nil
}
// generateOpenSSHKey generates an ED25519 key in OpenSSH format that the system SSH client can use.
func generateOpenSSHKey(t *testing.T) ([]byte, []byte, error) {
// Check if ssh-keygen is available
if _, err := exec.LookPath("ssh-keygen"); err != nil {
// Fall back to our existing key generation and try to convert
return generateOpenSSHKeyFallback()
}
// Create temporary file for ssh-keygen
tempFile, err := os.CreateTemp("", "ssh_keygen_*")
if err != nil {
return nil, nil, fmt.Errorf("create temp file: %w", err)
}
keyPath := tempFile.Name()
tempFile.Close()
// Remove the temp file so ssh-keygen can create it
if err := os.Remove(keyPath); err != nil {
t.Logf("failed to remove key file: %v", err)
}
// Clean up temp files
defer func() {
if err := os.Remove(keyPath); err != nil {
t.Logf("failed to cleanup key file: %v", err)
}
if err := os.Remove(keyPath + ".pub"); err != nil {
t.Logf("failed to cleanup public key file: %v", err)
}
}()
// Generate key using ssh-keygen
cmd := exec.Command("ssh-keygen", "-t", "ed25519", "-f", keyPath, "-N", "", "-q")
output, err := cmd.CombinedOutput()
if err != nil {
return nil, nil, fmt.Errorf("ssh-keygen failed: %w, output: %s", err, string(output))
}
// Read private key
privKeyBytes, err := os.ReadFile(keyPath)
if err != nil {
return nil, nil, fmt.Errorf("read private key: %w", err)
}
// Read public key
pubKeyBytes, err := os.ReadFile(keyPath + ".pub")
if err != nil {
return nil, nil, fmt.Errorf("read public key: %w", err)
}
return privKeyBytes, pubKeyBytes, nil
}
// generateOpenSSHKeyFallback falls back to generating keys using our existing method
func generateOpenSSHKeyFallback() ([]byte, []byte, error) {
// Generate shared.ED25519 key pair using our existing method
_, privKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, fmt.Errorf("generate key: %w", err)
}
// Convert to SSH format
sshPrivKey, err := ssh.NewSignerFromKey(privKey)
if err != nil {
return nil, nil, fmt.Errorf("create signer: %w", err)
}
// For the fallback, just use our PKCS#8 format and hope it works
// This won't be in OpenSSH format but might still work with some SSH clients
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
return nil, nil, fmt.Errorf("generate fallback key: %w", err)
}
// Get public key in SSH format
sshPubKey := ssh.MarshalAuthorizedKey(sshPrivKey.PublicKey())
return hostKey, sshPubKey, nil
}
// createTempKeyFileFromBytes creates a temporary SSH private key file from raw bytes
func createTempKeyFileFromBytes(t *testing.T, keyBytes []byte) (string, func()) {
t.Helper()
tempFile, err := os.CreateTemp("", "ssh_test_key_*")
require.NoError(t, err)
_, err = tempFile.Write(keyBytes)
require.NoError(t, err)
err = tempFile.Close()
require.NoError(t, err)
// Set proper permissions for SSH key (readable by owner only)
err = os.Chmod(tempFile.Name(), 0600)
require.NoError(t, err)
cleanup := func() {
_ = os.Remove(tempFile.Name())
}
return tempFile.Name(), cleanup
}
// createTempKeyFile creates a temporary SSH private key file (for backward compatibility)
func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) {
return createTempKeyFileFromBytes(t, privateKey)
}
// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility
func TestSSHServerFeatureCompatibility(t *testing.T) {
if testing.Short() {
t.Skip("Skipping SSH feature compatibility tests in short mode")
}
if runtime.GOOS == "windows" && testutil.IsCI() {
t.Skip("Skipping Windows SSH compatibility tests in CI due to S4U authentication issues")
}
if !isSSHClientAvailable() {
t.Skip("SSH client not available on this system")
}
// Test various SSH features
testCases := []struct {
name string
testFunc func(t *testing.T, host, port, keyFile string)
description string
}{
{
name: "command_with_flags",
testFunc: testCommandWithFlags,
description: "Commands with flags should work like standard SSH",
},
{
name: "environment_variables",
testFunc: testEnvironmentVariables,
description: "Environment variables should be available",
},
{
name: "exit_codes",
testFunc: testExitCodes,
description: "Exit codes should be properly handled",
},
}
// Set up SSH server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey)
defer cleanupKey()
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tc.testFunc(t, host, portStr, clientKeyFile)
})
}
}
// testCommandWithFlags tests that commands with flags work properly
func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := testutil.GetTestUsername(t)
// Test ls with flags
cmd := exec.Command("ssh",
"-i", keyFile,
"-p", port,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
fmt.Sprintf("%s@%s", username, host),
"ls", "-la", "/tmp")
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Command with flags failed: %v", err)
t.Logf("Output: %s", string(output))
return
}
// Should not be empty and should not contain error messages
assert.NotEmpty(t, string(output), "ls -la should produce output")
assert.NotContains(t, strings.ToLower(string(output)), "command not found", "Command should be executed")
}
// testEnvironmentVariables tests that environment is properly set up
func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := testutil.GetTestUsername(t)
cmd := exec.Command("ssh",
"-i", keyFile,
"-p", port,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
fmt.Sprintf("%s@%s", username, host),
"echo", "$HOME")
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Environment test failed: %v", err)
t.Logf("Output: %s", string(output))
return
}
// HOME environment variable should be available
homeOutput := strings.TrimSpace(string(output))
assert.NotEmpty(t, homeOutput, "HOME environment variable should be set")
assert.NotEqual(t, "$HOME", homeOutput, "Environment variable should be expanded")
}
// testExitCodes tests that exit codes are properly handled
func testExitCodes(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := testutil.GetTestUsername(t)
// Test successful command (exit code 0)
cmd := exec.Command("ssh",
"-i", keyFile,
"-p", port,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
fmt.Sprintf("%s@%s", username, host),
"true") // always succeeds
err := cmd.Run()
assert.NoError(t, err, "Command with exit code 0 should succeed")
// Test failing command (exit code 1)
cmd = exec.Command("ssh",
"-i", keyFile,
"-p", port,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
fmt.Sprintf("%s@%s", username, host),
"false") // always fails
err = cmd.Run()
assert.Error(t, err, "Command with exit code 1 should fail")
// Check if it's the right kind of error
if exitError, ok := err.(*exec.ExitError); ok {
assert.Equal(t, 1, exitError.ExitCode(), "Exit code should be preserved")
}
}
// TestSSHServerSecurityFeatures tests security-related SSH features
func TestSSHServerSecurityFeatures(t *testing.T) {
if testing.Short() {
t.Skip("Skipping SSH security tests in short mode")
}
if !isSSHClientAvailable() {
t.Skip("SSH client not available on this system")
}
// Get appropriate user for SSH connection
username := testutil.GetTestUsername(t)
// Set up SSH server with specific security settings
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey)
defer cleanupKey()
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
t.Run("key_authentication", func(t *testing.T) {
// Test that key authentication works
cmd := exec.Command("ssh",
"-i", clientKeyFile,
"-p", portStr,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
"-o", "PasswordAuthentication=no",
fmt.Sprintf("%s@%s", username, host),
"echo", "auth_success")
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Key authentication failed: %v", err)
t.Logf("Output: %s", string(output))
return
}
assert.Contains(t, string(output), "auth_success", "Key authentication should work")
})
t.Run("any_key_accepted_in_no_auth_mode", func(t *testing.T) {
// Create a different key that shouldn't be accepted
wrongKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
wrongKeyFile, cleanupWrongKey := createTempKeyFile(t, wrongKey)
defer cleanupWrongKey()
// Test that wrong key is rejected
cmd := exec.Command("ssh",
"-i", wrongKeyFile,
"-p", portStr,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
"-o", "PasswordAuthentication=no",
fmt.Sprintf("%s@%s", username, host),
"echo", "should_not_work")
err = cmd.Run()
assert.NoError(t, err, "Any key should work in no-auth mode")
})
}
// TestCrossPlatformCompatibility tests cross-platform behavior
func TestCrossPlatformCompatibility(t *testing.T) {
if testing.Short() {
t.Skip("Skipping cross-platform compatibility tests in short mode")
}
if !isSSHClientAvailable() {
t.Skip("SSH client not available on this system")
}
// Get appropriate user for SSH connection
username := testutil.GetTestUsername(t)
// Set up SSH server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey)
defer cleanupKey()
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
// Test platform-specific commands
var testCommand string
switch runtime.GOOS {
case "windows":
testCommand = "echo %OS%"
default:
testCommand = "uname"
}
cmd := exec.Command("ssh",
"-i", clientKeyFile,
"-p", portStr,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
fmt.Sprintf("%s@%s", username, host),
testCommand)
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Platform-specific command failed: %v", err)
t.Logf("Output: %s", string(output))
return
}
outputStr := strings.TrimSpace(string(output))
t.Logf("Platform command output: %s", outputStr)
assert.NotEmpty(t, outputStr, "Platform-specific command should produce output")
}

View File

@@ -0,0 +1,253 @@
//go:build unix
package server
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"runtime"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
)
// Exit codes for executor process communication
const (
ExitCodeSuccess = 0
ExitCodePrivilegeDropFail = 10
ExitCodeShellExecFail = 11
ExitCodeValidationFail = 12
)
// ExecutorConfig holds configuration for the executor process
type ExecutorConfig struct {
UID uint32
GID uint32
Groups []uint32
WorkingDir string
Shell string
Command string
PTY bool
}
// PrivilegeDropper handles secure privilege dropping in child processes
type PrivilegeDropper struct{}
// NewPrivilegeDropper creates a new privilege dropper
func NewPrivilegeDropper() *PrivilegeDropper {
return &PrivilegeDropper{}
}
// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config ExecutorConfig) (*exec.Cmd, error) {
netbirdPath, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("get netbird executable path: %w", err)
}
if err := pd.validatePrivileges(config.UID, config.GID); err != nil {
return nil, fmt.Errorf("invalid privileges: %w", err)
}
args := []string{
"ssh", "exec",
"--uid", fmt.Sprintf("%d", config.UID),
"--gid", fmt.Sprintf("%d", config.GID),
"--working-dir", config.WorkingDir,
"--shell", config.Shell,
}
for _, group := range config.Groups {
args = append(args, "--groups", fmt.Sprintf("%d", group))
}
if config.PTY {
args = append(args, "--pty")
}
if config.Command != "" {
args = append(args, "--cmd", config.Command)
}
// Log executor args safely - show all args except hide the command value
safeArgs := make([]string, len(args))
copy(safeArgs, args)
for i := 0; i < len(safeArgs)-1; i++ {
if safeArgs[i] == "--cmd" {
cmdParts := strings.Fields(safeArgs[i+1])
safeArgs[i+1] = safeLogCommand(cmdParts)
break
}
}
log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
return exec.CommandContext(ctx, netbirdPath, args...), nil
}
// DropPrivileges performs privilege dropping with thread locking for security
func (pd *PrivilegeDropper) DropPrivileges(targetUID, targetGID uint32, supplementaryGroups []uint32) error {
if err := pd.validatePrivileges(targetUID, targetGID); err != nil {
return fmt.Errorf("invalid privileges: %w", err)
}
runtime.LockOSThread()
defer runtime.UnlockOSThread()
originalUID := os.Geteuid()
originalGID := os.Getegid()
if originalUID != int(targetUID) || originalGID != int(targetGID) {
if err := pd.setGroupsAndIDs(targetUID, targetGID, supplementaryGroups); err != nil {
return fmt.Errorf("set groups and IDs: %w", err)
}
}
if err := pd.validatePrivilegeDropSuccess(targetUID, targetGID, originalUID, originalGID); err != nil {
return err
}
log.Tracef("successfully dropped privileges to UID=%d, GID=%d", targetUID, targetGID)
return nil
}
// setGroupsAndIDs sets the supplementary groups, GID, and UID
func (pd *PrivilegeDropper) setGroupsAndIDs(targetUID, targetGID uint32, supplementaryGroups []uint32) error {
groups := make([]int, len(supplementaryGroups))
for i, g := range supplementaryGroups {
groups[i] = int(g)
}
if runtime.GOOS == "darwin" || runtime.GOOS == "freebsd" {
if len(groups) == 0 || groups[0] != int(targetGID) {
groups = append([]int{int(targetGID)}, groups...)
}
}
if err := syscall.Setgroups(groups); err != nil {
return fmt.Errorf("setgroups to %v: %w", groups, err)
}
if err := syscall.Setgid(int(targetGID)); err != nil {
return fmt.Errorf("setgid to %d: %w", targetGID, err)
}
if err := syscall.Setuid(int(targetUID)); err != nil {
return fmt.Errorf("setuid to %d: %w", targetUID, err)
}
return nil
}
// validatePrivilegeDropSuccess validates that privilege dropping was successful
func (pd *PrivilegeDropper) validatePrivilegeDropSuccess(targetUID, targetGID uint32, originalUID, originalGID int) error {
if err := pd.validatePrivilegeDropReversibility(targetUID, targetGID, originalUID, originalGID); err != nil {
return err
}
if err := pd.validateCurrentPrivileges(targetUID, targetGID); err != nil {
return err
}
return nil
}
// validatePrivilegeDropReversibility ensures privileges cannot be restored
func (pd *PrivilegeDropper) validatePrivilegeDropReversibility(targetUID, targetGID uint32, originalUID, originalGID int) error {
if originalGID != int(targetGID) {
if err := syscall.Setegid(originalGID); err == nil {
return fmt.Errorf("privilege drop validation failed: able to restore original GID %d", originalGID)
}
}
if originalUID != int(targetUID) {
if err := syscall.Seteuid(originalUID); err == nil {
return fmt.Errorf("privilege drop validation failed: able to restore original UID %d", originalUID)
}
}
return nil
}
// validateCurrentPrivileges validates the current UID and GID match the target
func (pd *PrivilegeDropper) validateCurrentPrivileges(targetUID, targetGID uint32) error {
currentUID := os.Geteuid()
if currentUID != int(targetUID) {
return fmt.Errorf("privilege drop validation failed: current UID %d, expected %d", currentUID, targetUID)
}
currentGID := os.Getegid()
if currentGID != int(targetGID) {
return fmt.Errorf("privilege drop validation failed: current GID %d, expected %d", currentGID, targetGID)
}
return nil
}
// ExecuteWithPrivilegeDrop executes a command with privilege dropping, using exit codes to signal specific failures
func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config ExecutorConfig) {
log.Tracef("dropping privileges to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups)
// TODO: Implement Pty support for executor path
if config.PTY {
config.PTY = false
}
if err := pd.DropPrivileges(config.UID, config.GID, config.Groups); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "privilege drop failed: %v\n", err)
os.Exit(ExitCodePrivilegeDropFail)
}
if config.WorkingDir != "" {
if err := os.Chdir(config.WorkingDir); err != nil {
log.Debugf("failed to change to working directory %s, continuing with current directory: %v", config.WorkingDir, err)
}
}
var execCmd *exec.Cmd
if config.Command == "" {
os.Exit(ExitCodeSuccess)
}
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
execCmd.Stdin = os.Stdin
execCmd.Stdout = os.Stdout
execCmd.Stderr = os.Stderr
cmdParts := strings.Fields(config.Command)
safeCmd := safeLogCommand(cmdParts)
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
if err := execCmd.Run(); err != nil {
var exitError *exec.ExitError
if errors.As(err, &exitError) {
// Normal command exit with non-zero code - not an SSH execution error
log.Tracef("command exited with code %d", exitError.ExitCode())
os.Exit(exitError.ExitCode())
}
// Actual execution failure (command not found, permission denied, etc.)
log.Debugf("command execution failed: %v", err)
os.Exit(ExitCodeShellExecFail)
}
os.Exit(ExitCodeSuccess)
}
// validatePrivileges validates that privilege dropping to the target UID/GID is allowed
func (pd *PrivilegeDropper) validatePrivileges(uid, gid uint32) error {
currentUID := uint32(os.Geteuid())
currentGID := uint32(os.Getegid())
// Allow same-user operations (no privilege dropping needed)
if uid == currentUID && gid == currentGID {
return nil
}
// Only root can drop privileges to other users
if currentUID != 0 {
return fmt.Errorf("cannot drop privileges from non-root user (UID %d) to UID %d", currentUID, uid)
}
// Root can drop to any user (including root itself)
return nil
}

View File

@@ -0,0 +1,262 @@
//go:build unix
package server
import (
"context"
"fmt"
"os"
"os/exec"
"os/user"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) {
pd := NewPrivilegeDropper()
currentUID := uint32(os.Geteuid())
currentGID := uint32(os.Getegid())
tests := []struct {
name string
uid uint32
gid uint32
wantErr bool
}{
{
name: "same user - no privilege drop needed",
uid: currentUID,
gid: currentGID,
wantErr: false,
},
{
name: "non-root to different user should fail",
uid: currentUID + 1, // Use a different UID to ensure it's actually different
gid: currentGID + 1, // Use a different GID to ensure it's actually different
wantErr: currentUID != 0, // Only fail if current user is not root
},
{
name: "root can drop to any user",
uid: 1000,
gid: 1000,
wantErr: false,
},
{
name: "root can stay as root",
uid: 0,
gid: 0,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Skip non-root tests when running as root, and root tests when not root
if tt.name == "non-root to different user should fail" && currentUID == 0 {
t.Skip("Skipping non-root test when running as root")
}
if (tt.name == "root can drop to any user" || tt.name == "root can stay as root") && currentUID != 0 {
t.Skip("Skipping root test when not running as root")
}
err := pd.validatePrivileges(tt.uid, tt.gid)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) {
pd := NewPrivilegeDropper()
config := ExecutorConfig{
UID: 1000,
GID: 1000,
Groups: []uint32{1000, 1001},
WorkingDir: "/home/testuser",
Shell: "/bin/bash",
Command: "ls -la",
}
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, cmd)
// Verify the command is calling netbird ssh exec
assert.Contains(t, cmd.Args, "ssh")
assert.Contains(t, cmd.Args, "exec")
assert.Contains(t, cmd.Args, "--uid")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "--gid")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "--groups")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "1001")
assert.Contains(t, cmd.Args, "--working-dir")
assert.Contains(t, cmd.Args, "/home/testuser")
assert.Contains(t, cmd.Args, "--shell")
assert.Contains(t, cmd.Args, "/bin/bash")
assert.Contains(t, cmd.Args, "--cmd")
assert.Contains(t, cmd.Args, "ls -la")
}
func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) {
pd := NewPrivilegeDropper()
config := ExecutorConfig{
UID: 1000,
GID: 1000,
Groups: []uint32{1000},
WorkingDir: "/home/testuser",
Shell: "/bin/bash",
Command: "",
}
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, cmd)
// Verify no command mode (command is empty so no --cmd flag)
assert.NotContains(t, cmd.Args, "--cmd")
assert.NotContains(t, cmd.Args, "--interactive")
}
// TestPrivilegeDropper_ActualPrivilegeDrop tests actual privilege dropping
// This test requires root privileges and will be skipped if not running as root
func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) {
if os.Geteuid() != 0 {
t.Skip("This test requires root privileges")
}
// Find a non-root user to test with
testUser, err := findNonRootUser()
if err != nil {
t.Skip("No suitable non-root user found for testing")
}
// Verify the user actually exists by looking it up again
_, err = user.LookupId(testUser.Uid)
if err != nil {
t.Skipf("Test user %s (UID %s) does not exist on this system: %v", testUser.Username, testUser.Uid, err)
}
uid64, err := strconv.ParseUint(testUser.Uid, 10, 32)
require.NoError(t, err)
targetUID := uint32(uid64)
gid64, err := strconv.ParseUint(testUser.Gid, 10, 32)
require.NoError(t, err)
targetGID := uint32(gid64)
// Test in a child process to avoid affecting the test runner
if os.Getenv("TEST_PRIVILEGE_DROP") == "1" {
pd := NewPrivilegeDropper()
// This should succeed
err := pd.DropPrivileges(targetUID, targetGID, []uint32{targetGID})
require.NoError(t, err)
// Verify we are now running as the target user
currentUID := uint32(os.Geteuid())
currentGID := uint32(os.Getegid())
assert.Equal(t, targetUID, currentUID, "UID should match target")
assert.Equal(t, targetGID, currentGID, "GID should match target")
assert.NotEqual(t, uint32(0), currentUID, "Should not be running as root")
assert.NotEqual(t, uint32(0), currentGID, "Should not be running as root group")
return
}
// Fork a child process to test privilege dropping
cmd := os.Args[0]
args := []string{"-test.run=TestPrivilegeDropper_ActualPrivilegeDrop"}
env := append(os.Environ(), "TEST_PRIVILEGE_DROP=1")
execCmd := exec.Command(cmd, args...)
execCmd.Env = env
err = execCmd.Run()
require.NoError(t, err, "Child process should succeed")
}
// findNonRootUser finds any non-root user on the system for testing
func findNonRootUser() (*user.User, error) {
// Try common non-root users, but avoid "nobody" on macOS due to negative UID issues
commonUsers := []string{"daemon", "bin", "sys", "sync", "games", "man", "lp", "mail", "news", "uucp", "proxy", "www-data", "backup", "list", "irc"}
for _, username := range commonUsers {
if u, err := user.Lookup(username); err == nil {
// Parse as signed integer first to handle negative UIDs
uid64, err := strconv.ParseInt(u.Uid, 10, 32)
if err != nil {
continue
}
// Skip negative UIDs (like nobody=-2 on macOS) and root
if uid64 > 0 && uid64 != 0 {
return u, nil
}
}
}
// If no common users found, try to find any regular user with UID > 100
// This helps on macOS where regular users start at UID 501
allUsers := []string{"vma", "user", "test", "admin"}
for _, username := range allUsers {
if u, err := user.Lookup(username); err == nil {
uid64, err := strconv.ParseInt(u.Uid, 10, 32)
if err != nil {
continue
}
if uid64 > 100 { // Regular user
return u, nil
}
}
}
// If no common users found, return an error
return nil, fmt.Errorf("no suitable non-root user found on this system")
}
func TestPrivilegeDropper_ExecuteWithPrivilegeDrop_Validation(t *testing.T) {
pd := NewPrivilegeDropper()
currentUID := uint32(os.Geteuid())
if currentUID == 0 {
// When running as root, test that root can create commands for any user
config := ExecutorConfig{
UID: 1000, // Target non-root user
GID: 1000,
Groups: []uint32{1000},
WorkingDir: "/tmp",
Shell: "/bin/sh",
Command: "echo test",
}
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
assert.NoError(t, err, "Root should be able to create commands for any user")
assert.NotNil(t, cmd)
} else {
// When running as non-root, test that we can't drop to a different user
config := ExecutorConfig{
UID: 0, // Try to target root
GID: 0,
Groups: []uint32{0},
WorkingDir: "/tmp",
Shell: "/bin/sh",
Command: "echo test",
}
_, err := pd.CreateExecutorCommand(context.Background(), config)
assert.Error(t, err)
assert.Contains(t, err.Error(), "cannot drop privileges")
}
}

View File

@@ -0,0 +1,570 @@
//go:build windows
package server
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"os/user"
"strings"
"syscall"
"unsafe"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
ExitCodeSuccess = 0
ExitCodeLogonFail = 10
ExitCodeCreateProcessFail = 11
ExitCodeWorkingDirFail = 12
ExitCodeShellExecFail = 13
ExitCodeValidationFail = 14
)
type WindowsExecutorConfig struct {
Username string
Domain string
WorkingDir string
Shell string
Command string
Args []string
Interactive bool
Pty bool
PtyWidth int
PtyHeight int
}
type PrivilegeDropper struct{}
func NewPrivilegeDropper() *PrivilegeDropper {
return &PrivilegeDropper{}
}
var (
advapi32 = windows.NewLazyDLL("advapi32.dll")
procAllocateLocallyUniqueId = advapi32.NewProc("AllocateLocallyUniqueId")
)
const (
logon32LogonNetwork = 3 // Network logon - no password required for authenticated users
// Common error messages
commandFlag = "-Command"
closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials
convertUsernameError = "convert username to UTF16: %w"
convertDomainError = "convert domain to UTF16: %w"
)
// CreateWindowsExecutorCommand creates a Windows command with privilege dropping.
// The caller must close the returned token handle after starting the process.
func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, config WindowsExecutorConfig) (*exec.Cmd, windows.Token, error) {
if config.Username == "" {
return nil, 0, errors.New("username cannot be empty")
}
if config.Shell == "" {
return nil, 0, errors.New("shell cannot be empty")
}
shell := config.Shell
var shellArgs []string
if config.Command != "" {
shellArgs = []string{shell, commandFlag, config.Command}
} else {
shellArgs = []string{shell}
}
log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
cmd, token, err := pd.CreateWindowsProcessAsUser(
ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
if err != nil {
return nil, 0, fmt.Errorf("create Windows process as user: %w", err)
}
return cmd, token, nil
}
const (
// StatusSuccess represents successful LSA operation
StatusSuccess = 0
// KerbS4ULogonType message type for domain users with Kerberos
KerbS4ULogonType = 12
// Msv10s4ulogontype message type for local users with MSV1_0
Msv10s4ulogontype = 12
// MicrosoftKerberosNameA is the authentication package name for Kerberos
MicrosoftKerberosNameA = "Kerberos"
// Msv10packagename is the authentication package name for MSV1_0
Msv10packagename = "MICROSOFT_AUTHENTICATION_PACKAGE_V1_0"
NameSamCompatible = 2
NameUserPrincipal = 8
NameCanonical = 7
maxUPNLen = 1024
)
// kerbS4ULogon structure for S4U authentication (domain users)
type kerbS4ULogon struct {
MessageType uint32
Flags uint32
ClientUpn unicodeString
ClientRealm unicodeString
}
// msv10s4ulogon structure for S4U authentication (local users)
type msv10s4ulogon struct {
MessageType uint32
Flags uint32
UserPrincipalName unicodeString
DomainName unicodeString
}
// unicodeString structure
type unicodeString struct {
Length uint16
MaximumLength uint16
Buffer *uint16
}
// lsaString structure
type lsaString struct {
Length uint16
MaximumLength uint16
Buffer *byte
}
// tokenSource structure
type tokenSource struct {
SourceName [8]byte
SourceIdentifier windows.LUID
}
// quotaLimits structure
type quotaLimits struct {
PagedPoolLimit uint32
NonPagedPoolLimit uint32
MinimumWorkingSetSize uint32
MaximumWorkingSetSize uint32
PagefileLimit uint32
TimeLimit int64
}
var (
secur32 = windows.NewLazyDLL("secur32.dll")
procLsaRegisterLogonProcess = secur32.NewProc("LsaRegisterLogonProcess")
procLsaLookupAuthenticationPackage = secur32.NewProc("LsaLookupAuthenticationPackage")
procLsaLogonUser = secur32.NewProc("LsaLogonUser")
procLsaFreeReturnBuffer = secur32.NewProc("LsaFreeReturnBuffer")
procLsaDeregisterLogonProcess = secur32.NewProc("LsaDeregisterLogonProcess")
procTranslateNameW = secur32.NewProc("TranslateNameW")
)
// newLsaString creates an LsaString from a Go string
func newLsaString(s string) lsaString {
b := append([]byte(s), 0)
return lsaString{
Length: uint16(len(s)),
MaximumLength: uint16(len(b)),
Buffer: &b[0],
}
}
// generateS4UUserToken creates a Windows token using S4U authentication
// This is the exact approach OpenSSH for Windows uses for public key authentication
func generateS4UUserToken(username, domain string) (windows.Handle, error) {
userCpn := buildUserCpn(username, domain)
pd := NewPrivilegeDropper()
isDomainUser := !pd.isLocalUser(domain)
lsaHandle, err := initializeLsaConnection()
if err != nil {
return 0, err
}
defer cleanupLsaConnection(lsaHandle)
authPackageId, err := lookupAuthenticationPackage(lsaHandle, isDomainUser)
if err != nil {
return 0, err
}
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser)
if err != nil {
return 0, err
}
return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
}
// buildUserCpn constructs the user principal name
func buildUserCpn(username, domain string) string {
if domain != "" && domain != "." {
return fmt.Sprintf(`%s\%s`, domain, username)
}
return username
}
// initializeLsaConnection establishes connection to LSA
func initializeLsaConnection() (windows.Handle, error) {
processName := newLsaString("NetBird")
var mode uint32
var lsaHandle windows.Handle
ret, _, _ := procLsaRegisterLogonProcess.Call(
uintptr(unsafe.Pointer(&processName)),
uintptr(unsafe.Pointer(&lsaHandle)),
uintptr(unsafe.Pointer(&mode)),
)
if ret != StatusSuccess {
return 0, fmt.Errorf("LsaRegisterLogonProcess: 0x%x", ret)
}
return lsaHandle, nil
}
// cleanupLsaConnection closes the LSA connection
func cleanupLsaConnection(lsaHandle windows.Handle) {
if ret, _, _ := procLsaDeregisterLogonProcess.Call(uintptr(lsaHandle)); ret != StatusSuccess {
log.Debugf("LsaDeregisterLogonProcess failed: 0x%x", ret)
}
}
// lookupAuthenticationPackage finds the correct authentication package
func lookupAuthenticationPackage(lsaHandle windows.Handle, isDomainUser bool) (uint32, error) {
var authPackageName lsaString
if isDomainUser {
authPackageName = newLsaString(MicrosoftKerberosNameA)
} else {
authPackageName = newLsaString(Msv10packagename)
}
var authPackageId uint32
ret, _, _ := procLsaLookupAuthenticationPackage.Call(
uintptr(lsaHandle),
uintptr(unsafe.Pointer(&authPackageName)),
uintptr(unsafe.Pointer(&authPackageId)),
)
if ret != StatusSuccess {
return 0, fmt.Errorf("LsaLookupAuthenticationPackage: 0x%x", ret)
}
return authPackageId, nil
}
// lookupPrincipalName converts DOMAIN\username to username@domain.fqdn (UPN format)
func lookupPrincipalName(username, domain string) (string, error) {
samAccountName := fmt.Sprintf(`%s\%s`, domain, username)
samAccountNameUtf16, err := windows.UTF16PtrFromString(samAccountName)
if err != nil {
return "", fmt.Errorf("convert SAM account name to UTF-16: %w", err)
}
upnBuf := make([]uint16, maxUPNLen+1)
upnSize := uint32(len(upnBuf))
ret, _, _ := procTranslateNameW.Call(
uintptr(unsafe.Pointer(samAccountNameUtf16)),
uintptr(NameSamCompatible),
uintptr(NameUserPrincipal),
uintptr(unsafe.Pointer(&upnBuf[0])),
uintptr(unsafe.Pointer(&upnSize)),
)
if ret != 0 {
upn := windows.UTF16ToString(upnBuf[:upnSize])
log.Debugf("Translated %s to explicit UPN: %s", samAccountName, upn)
return upn, nil
}
upnSize = uint32(len(upnBuf))
ret, _, _ = procTranslateNameW.Call(
uintptr(unsafe.Pointer(samAccountNameUtf16)),
uintptr(NameSamCompatible),
uintptr(NameCanonical),
uintptr(unsafe.Pointer(&upnBuf[0])),
uintptr(unsafe.Pointer(&upnSize)),
)
if ret != 0 {
canonical := windows.UTF16ToString(upnBuf[:upnSize])
slashIdx := strings.IndexByte(canonical, '/')
if slashIdx > 0 {
fqdn := canonical[:slashIdx]
upn := fmt.Sprintf("%s@%s", username, fqdn)
log.Debugf("Translated %s to implicit UPN: %s (from canonical: %s)", samAccountName, upn, canonical)
return upn, nil
}
}
log.Debugf("Could not translate %s to UPN, using SAM format", samAccountName)
return samAccountName, nil
}
// prepareS4ULogonStructure creates the appropriate S4U logon structure
func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
if isDomainUser {
return prepareDomainS4ULogon(username, domain)
}
return prepareLocalS4ULogon(username)
}
// prepareDomainS4ULogon creates S4U logon structure for domain users
func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) {
upn, err := lookupPrincipalName(username, domain)
if err != nil {
return nil, 0, fmt.Errorf("lookup principal name: %w", err)
}
log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
upnUtf16, err := windows.UTF16FromString(upn)
if err != nil {
return nil, 0, fmt.Errorf(convertUsernameError, err)
}
structSize := unsafe.Sizeof(kerbS4ULogon{})
upnByteSize := len(upnUtf16) * 2
logonInfoSize := structSize + uintptr(upnByteSize)
buffer := make([]byte, logonInfoSize)
logonInfo := unsafe.Pointer(&buffer[0])
s4uLogon := (*kerbS4ULogon)(logonInfo)
s4uLogon.MessageType = KerbS4ULogonType
s4uLogon.Flags = 0
upnOffset := structSize
upnBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + upnOffset))
copy((*[1025]uint16)(unsafe.Pointer(upnBuffer))[:len(upnUtf16)], upnUtf16)
s4uLogon.ClientUpn = unicodeString{
Length: uint16((len(upnUtf16) - 1) * 2),
MaximumLength: uint16(len(upnUtf16) * 2),
Buffer: upnBuffer,
}
s4uLogon.ClientRealm = unicodeString{}
return logonInfo, logonInfoSize, nil
}
// prepareLocalS4ULogon creates S4U logon structure for local users
func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
log.Debugf("using Msv1_0S4ULogon for local user: %s", username)
usernameUtf16, err := windows.UTF16FromString(username)
if err != nil {
return nil, 0, fmt.Errorf(convertUsernameError, err)
}
domainUtf16, err := windows.UTF16FromString(".")
if err != nil {
return nil, 0, fmt.Errorf(convertDomainError, err)
}
structSize := unsafe.Sizeof(msv10s4ulogon{})
usernameByteSize := len(usernameUtf16) * 2
domainByteSize := len(domainUtf16) * 2
logonInfoSize := structSize + uintptr(usernameByteSize) + uintptr(domainByteSize)
buffer := make([]byte, logonInfoSize)
logonInfo := unsafe.Pointer(&buffer[0])
s4uLogon := (*msv10s4ulogon)(logonInfo)
s4uLogon.MessageType = Msv10s4ulogontype
s4uLogon.Flags = 0x0
usernameOffset := structSize
usernameBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + usernameOffset))
copy((*[256]uint16)(unsafe.Pointer(usernameBuffer))[:len(usernameUtf16)], usernameUtf16)
s4uLogon.UserPrincipalName = unicodeString{
Length: uint16((len(usernameUtf16) - 1) * 2),
MaximumLength: uint16(len(usernameUtf16) * 2),
Buffer: usernameBuffer,
}
domainOffset := usernameOffset + uintptr(usernameByteSize)
domainBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + domainOffset))
copy((*[16]uint16)(unsafe.Pointer(domainBuffer))[:len(domainUtf16)], domainUtf16)
s4uLogon.DomainName = unicodeString{
Length: uint16((len(domainUtf16) - 1) * 2),
MaximumLength: uint16(len(domainUtf16) * 2),
Buffer: domainBuffer,
}
return logonInfo, logonInfoSize, nil
}
// performS4ULogon executes the S4U logon operation
func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
var tokenSource tokenSource
copy(tokenSource.SourceName[:], "netbird")
if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 {
log.Debugf("AllocateLocallyUniqueId failed")
}
originName := newLsaString("netbird")
var profile uintptr
var profileSize uint32
var logonId windows.LUID
var token windows.Handle
var quotas quotaLimits
var subStatus int32
ret, _, _ := procLsaLogonUser.Call(
uintptr(lsaHandle),
uintptr(unsafe.Pointer(&originName)),
logon32LogonNetwork,
uintptr(authPackageId),
uintptr(logonInfo),
logonInfoSize,
0,
uintptr(unsafe.Pointer(&tokenSource)),
uintptr(unsafe.Pointer(&profile)),
uintptr(unsafe.Pointer(&profileSize)),
uintptr(unsafe.Pointer(&logonId)),
uintptr(unsafe.Pointer(&token)),
uintptr(unsafe.Pointer(&quotas)),
uintptr(unsafe.Pointer(&subStatus)),
)
if profile != 0 {
if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess {
log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
}
}
if ret != StatusSuccess {
return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus)
}
log.Debugf("created S4U %s token for user %s",
map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn)
return token, nil
}
// createToken implements NetBird trust-based authentication using S4U
func (pd *PrivilegeDropper) createToken(username, domain string) (windows.Handle, error) {
fullUsername := buildUserCpn(username, domain)
if err := userExists(fullUsername, username, domain); err != nil {
return 0, err
}
isLocalUser := pd.isLocalUser(domain)
if isLocalUser {
return pd.authenticateLocalUser(username, fullUsername)
}
return pd.authenticateDomainUser(username, domain, fullUsername)
}
// userExists checks if the target useVerifier exists on the system
func userExists(fullUsername, username, domain string) error {
if _, err := lookupUser(fullUsername); err != nil {
log.Debugf("User %s not found: %v", fullUsername, err)
if domain != "" && domain != "." {
_, err = lookupUser(username)
}
if err != nil {
return fmt.Errorf("target user %s not found: %w", fullUsername, err)
}
}
return nil
}
// isLocalUser determines if this is a local user vs domain user
func (pd *PrivilegeDropper) isLocalUser(domain string) bool {
hostname, err := os.Hostname()
if err != nil {
hostname = "localhost"
}
return domain == "" || domain == "." ||
strings.EqualFold(domain, hostname)
}
// authenticateLocalUser handles authentication for local users
func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) {
log.Debugf("using S4U authentication for local user %s", fullUsername)
token, err := generateS4UUserToken(username, ".")
if err != nil {
return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err)
}
return token, nil
}
// authenticateDomainUser handles authentication for domain users
func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) {
log.Debugf("using S4U authentication for domain user %s", fullUsername)
token, err := generateS4UUserToken(username, domain)
if err != nil {
return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err)
}
log.Debugf("Successfully created S4U token for domain user %s", fullUsername)
return token, nil
}
// CreateWindowsProcessAsUser creates a process as user with safe argument passing (for SFTP and executables).
// The caller must close the returned token handle after starting the process.
func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, executablePath string, args []string, username, domain, workingDir string) (*exec.Cmd, windows.Token, error) {
token, err := pd.createToken(username, domain)
if err != nil {
return nil, 0, fmt.Errorf("user authentication: %w", err)
}
defer func() {
if err := windows.CloseHandle(token); err != nil {
log.Debugf("close impersonation token: %v", err)
}
}()
cmd, primaryToken, err := pd.createProcessWithToken(ctx, windows.Token(token), executablePath, args, workingDir)
if err != nil {
return nil, 0, err
}
return cmd, primaryToken, nil
}
// createProcessWithToken creates process with the specified token and executable path.
// The caller must close the returned token handle after starting the process.
func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceToken windows.Token, executablePath string, args []string, workingDir string) (*exec.Cmd, windows.Token, error) {
cmd := exec.CommandContext(ctx, executablePath, args[1:]...)
cmd.Dir = workingDir
var primaryToken windows.Token
err := windows.DuplicateTokenEx(
sourceToken,
windows.TOKEN_ALL_ACCESS,
nil,
windows.SecurityIdentification,
windows.TokenPrimary,
&primaryToken,
)
if err != nil {
return nil, 0, fmt.Errorf("duplicate token to primary token: %w", err)
}
cmd.SysProcAttr = &syscall.SysProcAttr{
Token: syscall.Token(primaryToken),
}
return cmd, primaryToken, nil
}
// createSuCommand creates a command using su -l -c for privilege switching (Windows stub)
func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) {
return nil, fmt.Errorf("su command not available on Windows")
}

View File

@@ -0,0 +1,629 @@
package server
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"runtime"
"strconv"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
cryptossh "golang.org/x/crypto/ssh"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
func TestJWTEnforcement(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT enforcement tests in short mode")
}
// Set up SSH server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
t.Run("blocks_without_jwt", func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
KeysLocation: "test-keys",
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
if err != nil {
t.Logf("Detection failed: %v", err)
}
t.Logf("Detected server type: %s", serverType)
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
_, err = cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
assert.Error(t, err, "SSH connection should fail when JWT is required but not provided")
})
t.Run("allows_when_disabled", func(t *testing.T) {
serverConfigNoJWT := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
serverNoJWT := New(serverConfigNoJWT)
require.False(t, serverNoJWT.jwtEnabled, "JWT should be disabled without config")
serverNoJWT.SetAllowRootLogin(true)
serverAddrNoJWT := StartTestServer(t, serverNoJWT)
defer require.NoError(t, serverNoJWT.Stop())
hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
require.NoError(t, err)
portNoJWT, err := strconv.Atoi(portStrNoJWT)
require.NoError(t, err)
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT)
require.NoError(t, err)
assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType)
assert.False(t, serverType.RequiresJWT())
client, err := connectWithNetBirdClient(t, hostNoJWT, portNoJWT)
require.NoError(t, err)
defer client.Close()
})
}
// setupJWKSServer creates a test HTTP server serving JWKS and returns the server, private key, and URL
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
privateKey, jwksJSON := generateTestJWKS(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jwksJSON); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}))
return server, privateKey, server.URL
}
// generateTestJWKS creates a test RSA key pair and returns private key and JWKS JSON
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
publicKey := &privateKey.PublicKey
n := publicKey.N.Bytes()
e := publicKey.E
jwk := nbjwt.JSONWebKey{
Kty: "RSA",
Kid: "test-key-id",
Use: "sig",
N: base64RawURLEncode(n),
E: base64RawURLEncode(big.NewInt(int64(e)).Bytes()),
}
jwks := nbjwt.Jwks{
Keys: []nbjwt.JSONWebKey{jwk},
}
jwksJSON, err := json.Marshal(jwks)
require.NoError(t, err)
return privateKey, jwksJSON
}
func base64RawURLEncode(data []byte) string {
return base64.RawURLEncoding.EncodeToString(data)
}
// generateValidJWT creates a valid JWT token for testing
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
return tokenString
}
// connectWithNetBirdClient connects to SSH server using NetBird's SSH client
func connectWithNetBirdClient(t *testing.T, host string, port int) (*client.Client, error) {
t.Helper()
addr := net.JoinHostPort(host, strconv.Itoa(port))
ctx := context.Background()
return client.Dial(ctx, addr, testutil.GetTestUsername(t), client.DialOptions{
InsecureSkipVerify: true,
})
}
// TestJWTDetection tests that server detection correctly identifies JWT-enabled servers
func TestJWTDetection(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT detection test in short mode")
}
jwksServer, _, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
require.NoError(t, err)
assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType)
assert.True(t, serverType.RequiresJWT())
}
func TestJWTFailClose(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT fail-close tests in short mode")
}
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
testCases := []struct {
name string
tokenClaims jwt.MapClaims
}{
{
name: "blocks_token_missing_iat",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
},
},
{
name: "blocks_token_missing_sub",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": audience,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_missing_iss",
tokenClaims: jwt.MapClaims{
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_missing_aud",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_wrong_issuer",
tokenClaims: jwt.MapClaims{
"iss": "wrong-issuer",
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_wrong_audience",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": "wrong-audience",
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_expired_token",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(-time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(),
},
},
{
name: "blocks_token_exceeding_max_age",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(),
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
MaxTokenAge: 3600,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
token := jwt.NewWithClaims(jwt.SigningMethodRS256, tc.tokenClaims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{
cryptossh.Password(tokenString),
},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
if conn != nil {
defer func() {
if err := conn.Close(); err != nil {
t.Logf("close connection: %v", err)
}
}()
}
assert.Error(t, err, "Authentication should fail (fail-close)")
})
}
}
// TestJWTAuthentication tests JWT authentication with valid/invalid tokens and enforcement for various connection types
func TestJWTAuthentication(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT authentication tests in short mode")
}
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
testCases := []struct {
name string
token string
wantAuthOK bool
setupServer func(*Server)
testOperation func(*testing.T, *cryptossh.Client, string) error
wantOpSuccess bool
}{
{
name: "allows_shell_with_jwt",
token: "valid",
wantAuthOK: true,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
return session.Shell()
},
wantOpSuccess: true,
},
{
name: "rejects_invalid_token",
token: "invalid",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
output, err := session.CombinedOutput("echo test")
if err != nil {
t.Logf("Command output: %s", string(output))
return err
}
return nil
},
wantOpSuccess: false,
},
{
name: "blocks_shell_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
output, err := session.CombinedOutput("echo test")
if err != nil {
t.Logf("Command output: %s", string(output))
return err
}
return nil
},
wantOpSuccess: false,
},
{
name: "blocks_command_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
output, err := session.CombinedOutput("ls")
if err != nil {
t.Logf("Command output: %s", string(output))
return err
}
return nil
},
wantOpSuccess: false,
},
{
name: "allows_sftp_with_jwt",
token: "valid",
wantAuthOK: true,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowSFTP(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
session.Stdout = io.Discard
session.Stderr = io.Discard
return session.RequestSubsystem("sftp")
},
wantOpSuccess: true,
},
{
name: "blocks_sftp_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowSFTP(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
session.Stdout = io.Discard
session.Stderr = io.Discard
err = session.RequestSubsystem("sftp")
if err == nil {
err = session.Wait()
}
return err
},
wantOpSuccess: false,
},
{
name: "allows_port_forward_with_jwt",
token: "valid",
wantAuthOK: true,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowRemotePortForwarding(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
ln, err := conn.Listen("tcp", "127.0.0.1:0")
if ln != nil {
defer ln.Close()
}
return err
},
wantOpSuccess: true,
},
{
name: "blocks_port_forward_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowLocalPortForwarding(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
ln, err := conn.Listen("tcp", "127.0.0.1:0")
if ln != nil {
defer ln.Close()
}
return err
},
wantOpSuccess: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// TODO: Skip port forwarding tests on Windows - user switching not supported
// These features are tested on Linux/Unix platforms
if runtime.GOOS == "windows" &&
(tc.name == "allows_port_forward_with_jwt" ||
tc.name == "blocks_port_forward_without_jwt") {
t.Skip("Skipping port forwarding test on Windows - covered by Linux tests")
}
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
if tc.setupServer != nil {
tc.setupServer(server)
}
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
var authMethods []cryptossh.AuthMethod
if tc.token == "valid" {
token := generateValidJWT(t, privateKey, issuer, audience)
authMethods = []cryptossh.AuthMethod{
cryptossh.Password(token),
}
} else if tc.token == "invalid" {
invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid"
authMethods = []cryptossh.AuthMethod{
cryptossh.Password(invalidToken),
}
}
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: authMethods,
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
if tc.wantAuthOK {
require.NoError(t, err, "JWT authentication should succeed")
} else if err != nil {
t.Logf("Connection failed as expected: %v", err)
return
}
if conn != nil {
defer func() {
if err := conn.Close(); err != nil {
t.Logf("close connection: %v", err)
}
}()
}
err = tc.testOperation(t, conn, serverAddr)
if tc.wantOpSuccess {
require.NoError(t, err, "Operation should succeed")
} else {
assert.Error(t, err, "Operation should fail")
}
})
}
}

View File

@@ -0,0 +1,386 @@
package server
import (
"encoding/binary"
"fmt"
"io"
"net"
"strconv"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
)
// SessionKey uniquely identifies an SSH session
type SessionKey string
// ConnectionKey uniquely identifies a port forwarding connection within a session
type ConnectionKey string
// ForwardKey uniquely identifies a port forwarding listener
type ForwardKey string
// tcpipForwardMsg represents the structure for tcpip-forward SSH requests
type tcpipForwardMsg struct {
Host string
Port uint32
}
// SetAllowLocalPortForwarding configures local port forwarding
func (s *Server) SetAllowLocalPortForwarding(allow bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.allowLocalPortForwarding = allow
}
// SetAllowRemotePortForwarding configures remote port forwarding
func (s *Server) SetAllowRemotePortForwarding(allow bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.allowRemotePortForwarding = allow
}
// configurePortForwarding sets up port forwarding callbacks
func (s *Server) configurePortForwarding(server *ssh.Server) {
allowLocal := s.allowLocalPortForwarding
allowRemote := s.allowRemotePortForwarding
server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool {
if !allowLocal {
log.Warnf("local port forwarding denied for %s from %s: disabled by configuration",
net.JoinHostPort(dstHost, fmt.Sprintf("%d", dstPort)), ctx.RemoteAddr())
return false
}
if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil {
log.Warnf("local port forwarding denied for %s:%d from %s: %v", dstHost, dstPort, ctx.RemoteAddr(), err)
return false
}
log.Debugf("local port forwarding allowed: %s:%d", dstHost, dstPort)
return true
}
server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
if !allowRemote {
log.Warnf("remote port forwarding denied for %s from %s: disabled by configuration",
net.JoinHostPort(bindHost, fmt.Sprintf("%d", bindPort)), ctx.RemoteAddr())
return false
}
if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil {
log.Warnf("remote port forwarding denied for %s:%d from %s: %v", bindHost, bindPort, ctx.RemoteAddr(), err)
return false
}
log.Debugf("remote port forwarding allowed: %s:%d", bindHost, bindPort)
return true
}
log.Debugf("SSH server configured with local_forwarding=%v, remote_forwarding=%v", allowLocal, allowRemote)
}
// checkPortForwardingPrivileges validates privilege requirements for port forwarding operations.
// Returns nil if allowed, error if denied.
func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType string, port uint32) error {
if ctx == nil {
return fmt.Errorf("%s port forwarding denied: no context", forwardType)
}
username := ctx.User()
remoteAddr := "unknown"
if ctx.RemoteAddr() != nil {
remoteAddr = ctx.RemoteAddr().String()
}
logger := log.WithFields(log.Fields{"user": username, "remote": remoteAddr, "port": port})
result := s.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: username,
FeatureSupportsUserSwitch: false,
FeatureName: forwardType + " port forwarding",
})
if !result.Allowed {
return result.Error
}
logger.Debugf("%s port forwarding allowed: user %s validated (port %d)",
forwardType, result.User.Username, port)
return nil
}
// tcpipForwardHandler handles tcpip-forward requests for remote port forwarding.
func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
logger := s.getRequestLogger(ctx)
if !s.isRemotePortForwardingAllowed() {
logger.Warnf("tcpip-forward request denied: remote port forwarding disabled")
return false, nil
}
payload, err := s.parseTcpipForwardRequest(req)
if err != nil {
logger.Errorf("tcpip-forward unmarshal error: %v", err)
return false, nil
}
if err := s.checkPortForwardingPrivileges(ctx, "tcpip-forward", payload.Port); err != nil {
logger.Warnf("tcpip-forward denied: %v", err)
return false, nil
}
logger.Debugf("tcpip-forward request: %s:%d", payload.Host, payload.Port)
sshConn, err := s.getSSHConnection(ctx)
if err != nil {
logger.Warnf("tcpip-forward request denied: %v", err)
return false, nil
}
return s.setupDirectForward(ctx, logger, sshConn, payload)
}
// cancelTcpipForwardHandler handles cancel-tcpip-forward requests.
func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
logger := s.getRequestLogger(ctx)
var payload tcpipForwardMsg
if err := cryptossh.Unmarshal(req.Payload, &payload); err != nil {
logger.Errorf("cancel-tcpip-forward unmarshal error: %v", err)
return false, nil
}
key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
if s.removeRemoteForwardListener(key) {
logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port)
return true, nil
}
logger.Warnf("cancel-tcpip-forward failed: no listener found for %s:%d", payload.Host, payload.Port)
return false, nil
}
// handleRemoteForwardListener handles incoming connections for remote port forwarding.
func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, host string, port uint32) {
log.Debugf("starting remote forward listener handler for %s:%d", host, port)
defer func() {
log.Debugf("cleaning up remote forward listener for %s:%d", host, port)
if err := ln.Close(); err != nil {
log.Debugf("remote forward listener close error: %v", err)
} else {
log.Debugf("remote forward listener closed successfully for %s:%d", host, port)
}
}()
acceptChan := make(chan acceptResult, 1)
go func() {
for {
conn, err := ln.Accept()
select {
case acceptChan <- acceptResult{conn: conn, err: err}:
if err != nil {
return
}
case <-ctx.Done():
return
}
}
}()
for {
select {
case result := <-acceptChan:
if result.err != nil {
log.Debugf("remote forward accept error: %v", result.err)
return
}
go s.handleRemoteForwardConnection(ctx, result.conn, host, port)
case <-ctx.Done():
log.Debugf("remote forward listener shutting down due to context cancellation for %s:%d", host, port)
return
}
}
}
// getRequestLogger creates a logger with user and remote address context
func (s *Server) getRequestLogger(ctx ssh.Context) *log.Entry {
remoteAddr := "unknown"
username := "unknown"
if ctx != nil {
if ctx.RemoteAddr() != nil {
remoteAddr = ctx.RemoteAddr().String()
}
username = ctx.User()
}
return log.WithFields(log.Fields{"user": username, "remote": remoteAddr})
}
// isRemotePortForwardingAllowed checks if remote port forwarding is enabled
func (s *Server) isRemotePortForwardingAllowed() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.allowRemotePortForwarding
}
// parseTcpipForwardRequest parses the SSH request payload
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
var payload tcpipForwardMsg
err := cryptossh.Unmarshal(req.Payload, &payload)
return &payload, err
}
// getSSHConnection extracts SSH connection from context
func (s *Server) getSSHConnection(ctx ssh.Context) (*cryptossh.ServerConn, error) {
if ctx == nil {
return nil, fmt.Errorf("no context")
}
sshConnValue := ctx.Value(ssh.ContextKeyConn)
if sshConnValue == nil {
return nil, fmt.Errorf("no SSH connection in context")
}
sshConn, ok := sshConnValue.(*cryptossh.ServerConn)
if !ok || sshConn == nil {
return nil, fmt.Errorf("invalid SSH connection in context")
}
return sshConn, nil
}
// setupDirectForward sets up a direct port forward
func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn *cryptossh.ServerConn, payload *tcpipForwardMsg) (bool, []byte) {
bindAddr := net.JoinHostPort(payload.Host, strconv.FormatUint(uint64(payload.Port), 10))
ln, err := net.Listen("tcp", bindAddr)
if err != nil {
logger.Errorf("tcpip-forward listen failed on %s: %v", bindAddr, err)
return false, nil
}
actualPort := payload.Port
if payload.Port == 0 {
tcpAddr := ln.Addr().(*net.TCPAddr)
actualPort = uint32(tcpAddr.Port)
logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host)
}
key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
s.storeRemoteForwardListener(key, ln)
s.markConnectionActivePortForward(sshConn, ctx.User(), ctx.RemoteAddr().String())
go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort)
response := make([]byte, 4)
binary.BigEndian.PutUint32(response, actualPort)
logger.Infof("remote port forwarding established: %s:%d", payload.Host, actualPort)
return true, response
}
// acceptResult holds the result of a listener Accept() call
type acceptResult struct {
conn net.Conn
err error
}
// handleRemoteForwardConnection handles a single remote port forwarding connection
func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) {
sessionKey := s.findSessionKeyByContext(ctx)
connID := fmt.Sprintf("pf-%s->%s:%d", conn.RemoteAddr(), host, port)
logger := log.WithFields(log.Fields{
"session": sessionKey,
"conn": connID,
})
defer func() {
if err := conn.Close(); err != nil {
logger.Debugf("connection close error: %v", err)
}
}()
sshConn := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
if sshConn == nil {
logger.Debugf("remote forward: no SSH connection in context")
return
}
remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
if !ok {
logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr())
return
}
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr, logger)
if err != nil {
logger.Debugf("open forward channel: %v", err)
return
}
s.proxyForwardConnection(ctx, logger, conn, channel)
}
// openForwardChannel creates an SSH forwarded-tcpip channel
func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr, logger *log.Entry) (cryptossh.Channel, error) {
logger.Tracef("opening forwarded-tcpip channel for %s:%d", host, port)
payload := struct {
ConnectedAddress string
ConnectedPort uint32
OriginatorAddress string
OriginatorPort uint32
}{
ConnectedAddress: host,
ConnectedPort: port,
OriginatorAddress: remoteAddr.IP.String(),
OriginatorPort: uint32(remoteAddr.Port),
}
channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", cryptossh.Marshal(&payload))
if err != nil {
return nil, fmt.Errorf("open SSH channel: %w", err)
}
go cryptossh.DiscardRequests(reqs)
return channel, nil
}
// proxyForwardConnection handles bidirectional data transfer between connection and SSH channel
func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn net.Conn, channel cryptossh.Channel) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(channel, conn); err != nil {
logger.Debugf("copy error (conn->channel): %v", err)
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(conn, channel); err != nil {
logger.Debugf("copy error (channel->conn): %v", err)
}
done <- struct{}{}
}()
select {
case <-ctx.Done():
logger.Debugf("session ended, closing connections")
case <-done:
// First copy finished, wait for second copy or context cancellation
select {
case <-ctx.Done():
logger.Debugf("session ended, closing connections")
case <-done:
}
}
if err := channel.Close(); err != nil {
logger.Debugf("channel close error: %v", err)
}
if err := conn.Close(); err != nil {
logger.Debugf("connection close error: %v", err)
}
}

712
client/ssh/server/server.go Normal file
View File

@@ -0,0 +1,712 @@
package server
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/gliderlabs/ssh"
gojwt "github.com/golang-jwt/jwt/v5"
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/auth/jwt"
"github.com/netbirdio/netbird/version"
)
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
const DefaultSSHPort = 22
// InternalSSHPort is the port SSH server listens on and is redirected to
const InternalSSHPort = 22022
const (
errWriteSession = "write session error: %v"
errExitSession = "exit session error: %v"
msgPrivilegedUserDisabled = "privileged user login is disabled"
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
DefaultJWTMaxTokenAge = 5 * 60
)
var (
ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled)
ErrUserNotFound = errors.New("user not found")
)
// PrivilegedUserError represents an error when privileged user login is disabled
type PrivilegedUserError struct {
Username string
}
func (e *PrivilegedUserError) Error() string {
return fmt.Sprintf("%s for user: %s", msgPrivilegedUserDisabled, e.Username)
}
func (e *PrivilegedUserError) Is(target error) bool {
return target == ErrPrivilegedUserDisabled
}
// UserNotFoundError represents an error when a user cannot be found
type UserNotFoundError struct {
Username string
Cause error
}
func (e *UserNotFoundError) Error() string {
if e.Cause != nil {
return fmt.Sprintf("user %s not found: %v", e.Username, e.Cause)
}
return fmt.Sprintf("user %s not found", e.Username)
}
func (e *UserNotFoundError) Is(target error) bool {
return target == ErrUserNotFound
}
func (e *UserNotFoundError) Unwrap() error {
return e.Cause
}
// logSessionExitError logs session exit errors, ignoring EOF (normal close) errors
func logSessionExitError(logger *log.Entry, err error) {
if err != nil && !errors.Is(err, io.EOF) {
logger.Warnf(errExitSession, err)
}
}
// safeLogCommand returns a safe representation of the command for logging
func safeLogCommand(cmd []string) string {
if len(cmd) == 0 {
return "<interactive shell>"
}
if len(cmd) == 1 {
return cmd[0]
}
return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1)
}
type sshConnectionState struct {
hasActivePortForward bool
username string
remoteAddr string
}
type authKey string
func newAuthKey(username string, remoteAddr net.Addr) authKey {
return authKey(fmt.Sprintf("%s@%s", username, remoteAddr.String()))
}
type Server struct {
sshServer *ssh.Server
mu sync.RWMutex
hostKeyPEM []byte
sessions map[SessionKey]ssh.Session
sessionCancels map[ConnectionKey]context.CancelFunc
sessionJWTUsers map[SessionKey]string
pendingAuthJWT map[authKey]string
allowLocalPortForwarding bool
allowRemotePortForwarding bool
allowRootLogin bool
allowSFTP bool
jwtEnabled bool
netstackNet *netstack.Net
wgAddress wgaddr.Address
remoteForwardListeners map[ForwardKey]net.Listener
sshConnections map[*cryptossh.ServerConn]*sshConnectionState
jwtValidator *jwt.Validator
jwtExtractor *jwt.ClaimsExtractor
jwtConfig *JWTConfig
suSupportsPty bool
}
type JWTConfig struct {
Issuer string
Audience string
KeysLocation string
MaxTokenAge int64
}
// Config contains all SSH server configuration options
type Config struct {
// JWT authentication configuration. If nil, JWT authentication is disabled
JWT *JWTConfig
// HostKey is the SSH server host key in PEM format
HostKeyPEM []byte
}
// SessionInfo contains information about an active SSH session
type SessionInfo struct {
Username string
RemoteAddress string
Command string
JWTUsername string
}
// New creates an SSH server instance with the provided host key and optional JWT configuration
// If jwtConfig is nil, JWT authentication is disabled
func New(config *Config) *Server {
s := &Server{
mu: sync.RWMutex{},
hostKeyPEM: config.HostKeyPEM,
sessions: make(map[SessionKey]ssh.Session),
sessionJWTUsers: make(map[SessionKey]string),
pendingAuthJWT: make(map[authKey]string),
remoteForwardListeners: make(map[ForwardKey]net.Listener),
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
jwtEnabled: config.JWT != nil,
jwtConfig: config.JWT,
}
return s
}
// Start runs the SSH server
func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.sshServer != nil {
return errors.New("SSH server is already running")
}
s.suSupportsPty = s.detectSuPtySupport(ctx)
ln, addrDesc, err := s.createListener(ctx, addr)
if err != nil {
return fmt.Errorf("create listener: %w", err)
}
sshServer, err := s.createSSHServer(ln.Addr())
if err != nil {
s.closeListener(ln)
return fmt.Errorf("create SSH server: %w", err)
}
s.sshServer = sshServer
log.Infof("SSH server started on %s", addrDesc)
go func() {
if err := sshServer.Serve(ln); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
log.Errorf("SSH server error: %v", err)
}
}()
return nil
}
func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) {
if s.netstackNet != nil {
ln, err := s.netstackNet.ListenTCPAddrPort(addr)
if err != nil {
return nil, "", fmt.Errorf("listen on netstack: %w", err)
}
return ln, fmt.Sprintf("netstack %s", addr), nil
}
tcpAddr := net.TCPAddrFromAddrPort(addr)
lc := net.ListenConfig{}
ln, err := lc.Listen(ctx, "tcp", tcpAddr.String())
if err != nil {
return nil, "", fmt.Errorf("listen: %w", err)
}
return ln, addr.String(), nil
}
func (s *Server) closeListener(ln net.Listener) {
if ln == nil {
return
}
if err := ln.Close(); err != nil {
log.Debugf("listener close error: %v", err)
}
}
// Stop closes the SSH server
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.sshServer == nil {
return nil
}
if err := s.sshServer.Close(); err != nil {
log.Debugf("close SSH server: %v", err)
}
s.sshServer = nil
maps.Clear(s.sessions)
maps.Clear(s.sessionJWTUsers)
maps.Clear(s.pendingAuthJWT)
maps.Clear(s.sshConnections)
for _, cancelFunc := range s.sessionCancels {
cancelFunc()
}
maps.Clear(s.sessionCancels)
for _, listener := range s.remoteForwardListeners {
if err := listener.Close(); err != nil {
log.Debugf("close remote forward listener: %v", err)
}
}
maps.Clear(s.remoteForwardListeners)
return nil
}
// GetStatus returns the current status of the SSH server and active sessions
func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
s.mu.RLock()
defer s.mu.RUnlock()
enabled = s.sshServer != nil
for sessionKey, session := range s.sessions {
cmd := "<interactive shell>"
if len(session.Command()) > 0 {
cmd = safeLogCommand(session.Command())
}
jwtUsername := s.sessionJWTUsers[sessionKey]
sessions = append(sessions, SessionInfo{
Username: session.User(),
RemoteAddress: session.RemoteAddr().String(),
Command: cmd,
JWTUsername: jwtUsername,
})
}
return enabled, sessions
}
// SetNetstackNet sets the netstack network for userspace networking
func (s *Server) SetNetstackNet(net *netstack.Net) {
s.mu.Lock()
defer s.mu.Unlock()
s.netstackNet = net
}
// SetNetworkValidation configures network-based connection filtering
func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
s.mu.Lock()
defer s.mu.Unlock()
s.wgAddress = addr
}
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
func (s *Server) ensureJWTValidator() error {
s.mu.RLock()
if s.jwtValidator != nil && s.jwtExtractor != nil {
s.mu.RUnlock()
return nil
}
config := s.jwtConfig
s.mu.RUnlock()
if config == nil {
return fmt.Errorf("JWT config not set")
}
log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience)
validator := jwt.NewValidator(
config.Issuer,
[]string{config.Audience},
config.KeysLocation,
true,
)
extractor := jwt.NewClaimsExtractor(
jwt.WithAudience(config.Audience),
)
s.mu.Lock()
defer s.mu.Unlock()
if s.jwtValidator != nil && s.jwtExtractor != nil {
return nil
}
s.jwtValidator = validator
s.jwtExtractor = extractor
log.Infof("JWT validator initialized successfully")
return nil
}
func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
s.mu.RLock()
jwtValidator := s.jwtValidator
jwtConfig := s.jwtConfig
s.mu.RUnlock()
if jwtValidator == nil {
return nil, fmt.Errorf("JWT validator not initialized")
}
token, err := jwtValidator.ValidateAndParse(context.Background(), tokenString)
if err != nil {
if jwtConfig != nil {
if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil {
return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w",
jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err)
}
}
return nil, fmt.Errorf("validate token: %w", err)
}
if err := s.checkTokenAge(token, jwtConfig); err != nil {
return nil, err
}
return token, nil
}
func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
if jwtConfig == nil {
return nil
}
maxTokenAge := jwtConfig.MaxTokenAge
if maxTokenAge <= 0 {
maxTokenAge = DefaultJWTMaxTokenAge
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
}
iat, ok := claims["iat"].(float64)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token missing iat claim (user=%s)", userID)
}
issuedAt := time.Unix(int64(iat), 0)
tokenAge := time.Since(issuedAt)
maxAge := time.Duration(maxTokenAge) * time.Second
if tokenAge > maxAge {
userID := getUserIDFromClaims(claims)
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
}
return nil
}
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) {
s.mu.RLock()
jwtExtractor := s.jwtExtractor
s.mu.RUnlock()
if jwtExtractor == nil {
userID := extractUserID(token)
return nil, fmt.Errorf("JWT extractor not initialized (user=%s)", userID)
}
userAuth, err := jwtExtractor.ToUserAuth(token)
if err != nil {
userID := extractUserID(token)
return nil, fmt.Errorf("extract user from token (user=%s): %w", userID, err)
}
if !s.hasSSHAccess(&userAuth) {
return nil, fmt.Errorf("user %s does not have SSH access permissions", userAuth.UserId)
}
return &userAuth, nil
}
func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool {
return userAuth.UserId != ""
}
func extractUserID(token *gojwt.Token) string {
if token == nil {
return "unknown"
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
return "unknown"
}
return getUserIDFromClaims(claims)
}
func getUserIDFromClaims(claims gojwt.MapClaims) string {
if sub, ok := claims["sub"].(string); ok && sub != "" {
return sub
}
if userID, ok := claims["user_id"].(string); ok && userID != "" {
return userID
}
if email, ok := claims["email"].(string); ok && email != "" {
return email
}
return "unknown"
}
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid token format")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("decode payload: %w", err)
}
var claims map[string]interface{}
if err := json.Unmarshal(payload, &claims); err != nil {
return nil, fmt.Errorf("parse claims: %w", err)
}
return claims, nil
}
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
if err := s.ensureJWTValidator(); err != nil {
log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
return false
}
token, err := s.validateJWTToken(password)
if err != nil {
log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
return false
}
userAuth, err := s.extractAndValidateUser(token)
if err != nil {
log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
return false
}
key := newAuthKey(ctx.User(), ctx.RemoteAddr())
s.mu.Lock()
s.pendingAuthJWT[key] = userAuth.UserId
s.mu.Unlock()
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
return true
}
func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) {
s.mu.Lock()
defer s.mu.Unlock()
if state, exists := s.sshConnections[sshConn]; exists {
state.hasActivePortForward = true
} else {
s.sshConnections[sshConn] = &sshConnectionState{
hasActivePortForward: true,
username: username,
remoteAddr: remoteAddr,
}
}
}
func (s *Server) connectionCloseHandler(conn net.Conn, err error) {
// We can't extract the SSH connection from net.Conn directly
// Connection cleanup will happen during session cleanup or via timeout
log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err)
}
func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
if ctx == nil {
return "unknown"
}
// Try to match by SSH connection
sshConn := ctx.Value(ssh.ContextKeyConn)
if sshConn == nil {
return "unknown"
}
s.mu.RLock()
defer s.mu.RUnlock()
// Look through sessions to find one with matching connection
for sessionKey, session := range s.sessions {
if session.Context().Value(ssh.ContextKeyConn) == sshConn {
return sessionKey
}
}
// If no session found, this might be during early connection setup
// Return a temporary key that we'll fix up later
if ctx.User() != "" && ctx.RemoteAddr() != nil {
tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey)
return tempKey
}
return "unknown"
}
func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
s.mu.RLock()
netbirdNetwork := s.wgAddress.Network
localIP := s.wgAddress.IP
s.mu.RUnlock()
if !netbirdNetwork.IsValid() || !localIP.IsValid() {
return conn
}
remoteAddr := conn.RemoteAddr()
tcpAddr, ok := remoteAddr.(*net.TCPAddr)
if !ok {
log.Warnf("SSH connection rejected: non-TCP address %s", remoteAddr)
return nil
}
remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
if !ok {
log.Warnf("SSH connection rejected: invalid remote IP %s", tcpAddr.IP)
return nil
}
// Block connections from our own IP (prevent local apps from connecting to ourselves)
if remoteIP == localIP {
log.Warnf("SSH connection rejected from own IP %s", remoteIP)
return nil
}
if !netbirdNetwork.Contains(remoteIP) {
log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP)
return nil
}
log.Infof("SSH connection from NetBird peer %s allowed", tcpAddr)
return conn
}
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
if err := enableUserSwitching(); err != nil {
log.Warnf("failed to enable user switching: %v", err)
}
serverVersion := fmt.Sprintf("%s-%s", detection.ServerIdentifier, version.NetbirdVersion())
if s.jwtEnabled {
serverVersion += " " + detection.JWTRequiredMarker
}
server := &ssh.Server{
Addr: addr.String(),
Handler: s.sessionHandler,
SubsystemHandlers: map[string]ssh.SubsystemHandler{
"sftp": s.sftpSubsystemHandler,
},
HostSigners: []ssh.Signer{},
ChannelHandlers: map[string]ssh.ChannelHandler{
"session": ssh.DefaultSessionHandler,
"direct-tcpip": s.directTCPIPHandler,
},
RequestHandlers: map[string]ssh.RequestHandler{
"tcpip-forward": s.tcpipForwardHandler,
"cancel-tcpip-forward": s.cancelTcpipForwardHandler,
},
ConnCallback: s.connectionValidator,
ConnectionFailedCallback: s.connectionCloseHandler,
Version: serverVersion,
}
if s.jwtEnabled {
server.PasswordHandler = s.passwordHandler
}
hostKeyPEM := ssh.HostKeyPEM(s.hostKeyPEM)
if err := server.SetOption(hostKeyPEM); err != nil {
return nil, fmt.Errorf("set host key: %w", err)
}
s.configurePortForwarding(server)
return server, nil
}
func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) {
s.mu.Lock()
defer s.mu.Unlock()
s.remoteForwardListeners[key] = ln
}
func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
s.mu.Lock()
defer s.mu.Unlock()
ln, exists := s.remoteForwardListeners[key]
if !exists {
return false
}
delete(s.remoteForwardListeners, key)
if err := ln.Close(); err != nil {
log.Debugf("remote forward listener close error: %v", err)
}
return true
}
func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) {
var payload struct {
Host string
Port uint32
OriginatorAddr string
OriginatorPort uint32
}
if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
if err := newChan.Reject(cryptossh.ConnectionFailed, "parse payload"); err != nil {
log.Debugf("channel reject error: %v", err)
}
return
}
s.mu.RLock()
allowLocal := s.allowLocalPortForwarding
s.mu.RUnlock()
if !allowLocal {
log.Warnf("local port forwarding denied for %s:%d: disabled by configuration", payload.Host, payload.Port)
_ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled")
return
}
// Check privilege requirements for the destination port
if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil {
log.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
_ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges")
return
}
log.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
}

View File

@@ -0,0 +1,394 @@
package server
import (
"context"
"fmt"
"net"
"os/user"
"runtime"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/ssh"
sshclient "github.com/netbirdio/netbird/client/ssh/client"
)
func TestServer_RootLoginRestriction(t *testing.T) {
// Generate host key for server
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
tests := []struct {
name string
allowRoot bool
username string
expectError bool
description string
}{
{
name: "root login allowed",
allowRoot: true,
username: "root",
expectError: false,
description: "Root login should succeed when allowed",
},
{
name: "root login denied",
allowRoot: false,
username: "root",
expectError: true,
description: "Root login should fail when disabled",
},
{
name: "regular user login always allowed",
allowRoot: false,
username: "testuser",
expectError: false,
description: "Regular user login should work regardless of root setting",
},
}
// Add Windows Administrator tests if on Windows
if runtime.GOOS == "windows" {
tests = append(tests, []struct {
name string
allowRoot bool
username string
expectError bool
description string
}{
{
name: "Administrator login allowed",
allowRoot: true,
username: "Administrator",
expectError: false,
description: "Administrator login should succeed when allowed",
},
{
name: "Administrator login denied",
allowRoot: false,
username: "Administrator",
expectError: true,
description: "Administrator login should fail when disabled",
},
{
name: "administrator login denied (lowercase)",
allowRoot: false,
username: "administrator",
expectError: true,
description: "administrator login should fail when disabled (case insensitive)",
},
}...)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Mock privileged environment to test root access controls
// Set up mock users based on platform
mockUsers := map[string]*user.User{
"root": createTestUser("root", "0", "0", "/root"),
"testuser": createTestUser("testuser", "1000", "1000", "/home/testuser"),
}
// Add Windows-specific users for Administrator tests
if runtime.GOOS == "windows" {
mockUsers["Administrator"] = createTestUser("Administrator", "500", "544", "C:\\Users\\Administrator")
mockUsers["administrator"] = createTestUser("administrator", "500", "544", "C:\\Users\\administrator")
}
cleanup := setupTestDependencies(
createTestUser("root", "0", "0", "/root"), // Running as root
nil,
runtime.GOOS,
0, // euid 0 (root)
mockUsers,
nil,
)
defer cleanup()
// Create server with specific configuration
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(tt.allowRoot)
// Test the userNameLookup method directly
user, err := server.userNameLookup(tt.username)
if tt.expectError {
assert.Error(t, err, tt.description)
if tt.username == "root" || strings.ToLower(tt.username) == "administrator" {
// Check for appropriate error message based on platform capabilities
errorMsg := err.Error()
// Either privileged user restriction OR user switching limitation
hasPrivilegedError := strings.Contains(errorMsg, "privileged user")
hasSwitchingError := strings.Contains(errorMsg, "cannot switch") || strings.Contains(errorMsg, "user switching not supported")
assert.True(t, hasPrivilegedError || hasSwitchingError,
"Expected privileged user or user switching error, got: %s", errorMsg)
}
} else {
if tt.username == "root" || strings.ToLower(tt.username) == "administrator" {
// For privileged users, we expect either success or a different error
// (like user not found), but not the "login disabled" error
if err != nil {
assert.NotContains(t, err.Error(), "privileged user login is disabled")
}
} else {
// For regular users, lookup should generally succeed or fall back gracefully
// Note: may return current user as fallback
assert.NotNil(t, user)
}
}
})
}
}
func TestServer_PortForwardingRestriction(t *testing.T) {
// Test that the port forwarding callbacks properly respect configuration flags
// This is a unit test of the callback logic, not a full integration test
// Generate host key for server
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
tests := []struct {
name string
allowLocalForwarding bool
allowRemoteForwarding bool
description string
}{
{
name: "all forwarding allowed",
allowLocalForwarding: true,
allowRemoteForwarding: true,
description: "Both local and remote forwarding should be allowed",
},
{
name: "local forwarding disabled",
allowLocalForwarding: false,
allowRemoteForwarding: true,
description: "Local forwarding should be denied when disabled",
},
{
name: "remote forwarding disabled",
allowLocalForwarding: true,
allowRemoteForwarding: false,
description: "Remote forwarding should be denied when disabled",
},
{
name: "all forwarding disabled",
allowLocalForwarding: false,
allowRemoteForwarding: false,
description: "Both forwarding types should be denied when disabled",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create server with specific configuration
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowLocalPortForwarding(tt.allowLocalForwarding)
server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding)
// We need to access the internal configuration to simulate the callback tests
// Since the callbacks are created inside the Start method, we'll test the logic directly
// Test the configuration values are set correctly
server.mu.RLock()
allowLocal := server.allowLocalPortForwarding
allowRemote := server.allowRemotePortForwarding
server.mu.RUnlock()
assert.Equal(t, tt.allowLocalForwarding, allowLocal, "Local forwarding configuration should be set correctly")
assert.Equal(t, tt.allowRemoteForwarding, allowRemote, "Remote forwarding configuration should be set correctly")
// Simulate the callback logic
localResult := allowLocal // This would be the callback return value
remoteResult := allowRemote // This would be the callback return value
assert.Equal(t, tt.allowLocalForwarding, localResult,
"Local port forwarding callback should return correct value")
assert.Equal(t, tt.allowRemoteForwarding, remoteResult,
"Remote port forwarding callback should return correct value")
})
}
}
func TestServer_PortConflictHandling(t *testing.T) {
// Test that multiple sessions requesting the same local port are handled naturally by the OS
// Get current user for SSH connection
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
// Generate host key for server
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Create server
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Get a free port for testing
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
testPort := ln.Addr().(*net.TCPAddr).Port
err = ln.Close()
require.NoError(t, err)
// Connect first client
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel1()
client1, err := sshclient.Dial(ctx1, serverAddr, currentUser.Username, sshclient.DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
err := client1.Close()
assert.NoError(t, err)
}()
// Connect second client
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel2()
client2, err := sshclient.Dial(ctx2, serverAddr, currentUser.Username, sshclient.DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
err := client2.Close()
assert.NoError(t, err)
}()
// First client binds to the test port
localAddr1 := fmt.Sprintf("127.0.0.1:%d", testPort)
remoteAddr := "127.0.0.1:80"
// Start first client's port forwarding
done1 := make(chan error, 1)
go func() {
// This should succeed and hold the port
err := client1.LocalPortForward(ctx1, localAddr1, remoteAddr)
done1 <- err
}()
// Give first client time to bind
time.Sleep(200 * time.Millisecond)
// Second client tries to bind to same port
localAddr2 := fmt.Sprintf("127.0.0.1:%d", testPort)
shortCtx, shortCancel := context.WithTimeout(context.Background(), 1*time.Second)
defer shortCancel()
err = client2.LocalPortForward(shortCtx, localAddr2, remoteAddr)
// Second client should fail due to "address already in use"
assert.Error(t, err, "Second client should fail to bind to same port")
if err != nil {
// The error should indicate the address is already in use
errMsg := strings.ToLower(err.Error())
if runtime.GOOS == "windows" {
assert.Contains(t, errMsg, "only one usage of each socket address",
"Error should indicate port conflict")
} else {
assert.Contains(t, errMsg, "address already in use",
"Error should indicate port conflict")
}
}
// Cancel first client's context and wait for it to finish
cancel1()
select {
case err1 := <-done1:
// Should get context cancelled or deadline exceeded
assert.Error(t, err1, "First client should exit when context cancelled")
case <-time.After(2 * time.Second):
t.Error("First client did not exit within timeout")
}
}
func TestServer_IsPrivilegedUser(t *testing.T) {
tests := []struct {
username string
expected bool
description string
}{
{
username: "root",
expected: true,
description: "root should be considered privileged",
},
{
username: "regular",
expected: false,
description: "regular user should not be privileged",
},
{
username: "",
expected: false,
description: "empty username should not be privileged",
},
}
// Add Windows-specific tests
if runtime.GOOS == "windows" {
tests = append(tests, []struct {
username string
expected bool
description string
}{
{
username: "Administrator",
expected: true,
description: "Administrator should be considered privileged on Windows",
},
{
username: "administrator",
expected: true,
description: "administrator should be considered privileged on Windows (case insensitive)",
},
}...)
} else {
// On non-Windows systems, Administrator should not be privileged
tests = append(tests, []struct {
username string
expected bool
description string
}{
{
username: "Administrator",
expected: false,
description: "Administrator should not be privileged on non-Windows systems",
},
}...)
}
for _, tt := range tests {
t.Run(tt.description, func(t *testing.T) {
result := isPrivilegedUsername(tt.username)
assert.Equal(t, tt.expected, result, tt.description)
})
}
}

View File

@@ -0,0 +1,441 @@
package server
import (
"context"
"fmt"
"net"
"net/netip"
"os/user"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
)
func TestServer_StartStop(t *testing.T) {
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: key,
JWT: nil,
}
server := New(serverConfig)
err = server.Stop()
assert.NoError(t, err)
}
func TestSSHServerIntegration(t *testing.T) {
// Generate host key for server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Create server with random port
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
// Start server in background
serverAddr := "127.0.0.1:0"
started := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
// Get a free port
ln, err := net.Listen("tcp", serverAddr)
if err != nil {
errChan <- err
return
}
actualAddr := ln.Addr().String()
if err := ln.Close(); err != nil {
errChan <- fmt.Errorf("close temp listener: %w", err)
return
}
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {
case actualAddr := <-started:
serverAddr = actualAddr
case err := <-errChan:
t.Fatalf("Server failed to start: %v", err)
case <-time.After(5 * time.Second):
t.Fatal("Server start timeout")
}
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Parse client private key
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
require.NoError(t, err)
// Parse server host key for verification
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
require.NoError(t, err)
hostPubKey := hostPrivParsed.PublicKey()
// Get current user for SSH connection
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user for test")
// Create SSH client config
config := &cryptossh.ClientConfig{
User: currentUser.Username,
Auth: []cryptossh.AuthMethod{
cryptossh.PublicKeys(signer),
},
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
Timeout: 3 * time.Second,
}
// Connect to SSH server
client, err := cryptossh.Dial("tcp", serverAddr, config)
require.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
t.Logf("close client: %v", err)
}
}()
// Test creating a session
session, err := client.NewSession()
require.NoError(t, err)
defer func() {
if err := session.Close(); err != nil {
t.Logf("close session: %v", err)
}
}()
// Note: Since we don't have a real shell environment in tests,
// we can't test actual command execution, but we can verify
// the connection and authentication work
t.Log("SSH connection and authentication successful")
}
func TestSSHServerMultipleConnections(t *testing.T) {
// Generate host key for server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Create server
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
// Start server
serverAddr := "127.0.0.1:0"
started := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
ln, err := net.Listen("tcp", serverAddr)
if err != nil {
errChan <- err
return
}
actualAddr := ln.Addr().String()
if err := ln.Close(); err != nil {
errChan <- fmt.Errorf("close temp listener: %w", err)
return
}
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {
case actualAddr := <-started:
serverAddr = actualAddr
case err := <-errChan:
t.Fatalf("Server failed to start: %v", err)
case <-time.After(5 * time.Second):
t.Fatal("Server start timeout")
}
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Parse client private key
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
require.NoError(t, err)
// Parse server host key
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
require.NoError(t, err)
hostPubKey := hostPrivParsed.PublicKey()
// Get current user for SSH connection
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user for test")
config := &cryptossh.ClientConfig{
User: currentUser.Username,
Auth: []cryptossh.AuthMethod{
cryptossh.PublicKeys(signer),
},
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
Timeout: 3 * time.Second,
}
// Test multiple concurrent connections
const numConnections = 5
results := make(chan error, numConnections)
for i := 0; i < numConnections; i++ {
go func(id int) {
client, err := cryptossh.Dial("tcp", serverAddr, config)
if err != nil {
results <- fmt.Errorf("connection %d failed: %w", id, err)
return
}
defer func() {
_ = client.Close() // Ignore error in test goroutine
}()
session, err := client.NewSession()
if err != nil {
results <- fmt.Errorf("session %d failed: %w", id, err)
return
}
defer func() {
_ = session.Close() // Ignore error in test goroutine
}()
results <- nil
}(i)
}
// Wait for all connections to complete
for i := 0; i < numConnections; i++ {
select {
case err := <-results:
assert.NoError(t, err)
case <-time.After(10 * time.Second):
t.Fatalf("Connection %d timed out", i)
}
}
}
func TestSSHServerNoAuthMode(t *testing.T) {
// Generate host key for server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Create server
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
// Start server
serverAddr := "127.0.0.1:0"
started := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
ln, err := net.Listen("tcp", serverAddr)
if err != nil {
errChan <- err
return
}
actualAddr := ln.Addr().String()
if err := ln.Close(); err != nil {
errChan <- fmt.Errorf("close temp listener: %w", err)
return
}
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {
case actualAddr := <-started:
serverAddr = actualAddr
case err := <-errChan:
t.Fatalf("Server failed to start: %v", err)
case <-time.After(5 * time.Second):
t.Fatal("Server start timeout")
}
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Generate a client private key for SSH protocol (server doesn't check it)
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientSigner, err := cryptossh.ParsePrivateKey(clientPrivKey)
require.NoError(t, err)
// Parse server host key
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
require.NoError(t, err)
hostPubKey := hostPrivParsed.PublicKey()
// Get current user for SSH connection
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user for test")
// Try to connect with client key
config := &cryptossh.ClientConfig{
User: currentUser.Username,
Auth: []cryptossh.AuthMethod{
cryptossh.PublicKeys(clientSigner),
},
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
Timeout: 3 * time.Second,
}
// This should succeed in no-auth mode (server doesn't verify keys)
conn, err := cryptossh.Dial("tcp", serverAddr, config)
assert.NoError(t, err, "Connection should succeed in no-auth mode")
if conn != nil {
assert.NoError(t, conn.Close())
}
}
func TestSSHServerStartStopCycle(t *testing.T) {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
serverAddr := "127.0.0.1:0"
// Test multiple start/stop cycles
for i := 0; i < 3; i++ {
t.Logf("Start/stop cycle %d", i+1)
started := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
ln, err := net.Listen("tcp", serverAddr)
if err != nil {
errChan <- err
return
}
actualAddr := ln.Addr().String()
if err := ln.Close(); err != nil {
errChan <- fmt.Errorf("close temp listener: %w", err)
return
}
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {
case <-started:
case err := <-errChan:
t.Fatalf("Cycle %d: Server failed to start: %v", i+1, err)
case <-time.After(5 * time.Second):
t.Fatalf("Cycle %d: Server start timeout", i+1)
}
err = server.Stop()
require.NoError(t, err, "Cycle %d: Stop should succeed", i+1)
}
}
func TestSSHServer_WindowsShellHandling(t *testing.T) {
if testing.Short() {
t.Skip("Skipping Windows shell test in short mode")
}
server := &Server{}
if runtime.GOOS == "windows" {
// Test Windows cmd.exe shell behavior
args := server.getShellCommandArgs("cmd.exe", "echo test")
assert.Equal(t, "cmd.exe", args[0])
assert.Equal(t, "-Command", args[1])
assert.Equal(t, "echo test", args[2])
// Test PowerShell behavior
args = server.getShellCommandArgs("powershell.exe", "echo test")
assert.Equal(t, "powershell.exe", args[0])
assert.Equal(t, "-Command", args[1])
assert.Equal(t, "echo test", args[2])
} else {
// Test Unix shell behavior
args := server.getShellCommandArgs("/bin/sh", "echo test")
assert.Equal(t, "/bin/sh", args[0])
assert.Equal(t, "-l", args[1])
assert.Equal(t, "-c", args[2])
assert.Equal(t, "echo test", args[3])
}
}
func TestSSHServer_PortForwardingConfiguration(t *testing.T) {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
serverConfig1 := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server1 := New(serverConfig1)
serverConfig2 := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server2 := New(serverConfig2)
assert.False(t, server1.allowLocalPortForwarding, "Local port forwarding should be disabled by default for security")
assert.False(t, server1.allowRemotePortForwarding, "Remote port forwarding should be disabled by default for security")
server2.SetAllowLocalPortForwarding(true)
server2.SetAllowRemotePortForwarding(true)
assert.True(t, server2.allowLocalPortForwarding, "Local port forwarding should be enabled when explicitly set")
assert.True(t, server2.allowRemotePortForwarding, "Remote port forwarding should be enabled when explicitly set")
}

View File

@@ -0,0 +1,168 @@
package server
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"strings"
"time"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
)
// sessionHandler handles SSH sessions
func (s *Server) sessionHandler(session ssh.Session) {
sessionKey := s.registerSession(session)
key := newAuthKey(session.User(), session.RemoteAddr())
s.mu.Lock()
jwtUsername := s.pendingAuthJWT[key]
if jwtUsername != "" {
s.sessionJWTUsers[sessionKey] = jwtUsername
delete(s.pendingAuthJWT, key)
}
s.mu.Unlock()
logger := log.WithField("session", sessionKey)
if jwtUsername != "" {
logger = logger.WithField("jwt_user", jwtUsername)
logger.Infof("SSH session started (JWT user: %s)", jwtUsername)
} else {
logger.Infof("SSH session started")
}
sessionStart := time.Now()
defer s.unregisterSession(sessionKey, session)
defer func() {
duration := time.Since(sessionStart).Round(time.Millisecond)
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
logger.Warnf("close session after %v: %v", duration, err)
}
logger.Infof("SSH session closed after %v", duration)
}()
privilegeResult, err := s.userPrivilegeCheck(session.User())
if err != nil {
s.handlePrivError(logger, session, err)
return
}
ptyReq, winCh, isPty := session.Pty()
hasCommand := len(session.Command()) > 0
switch {
case isPty && hasCommand:
// ssh -t <host> <cmd> - Pty command execution
s.handleCommand(logger, session, privilegeResult, winCh)
case isPty:
// ssh <host> - Pty interactive session (login)
s.handlePty(logger, session, privilegeResult, ptyReq, winCh)
case hasCommand:
// ssh <host> <cmd> - non-Pty command execution
s.handleCommand(logger, session, privilegeResult, nil)
default:
s.rejectInvalidSession(logger, session)
}
}
func (s *Server) rejectInvalidSession(logger *log.Entry, session ssh.Session) {
if _, err := io.WriteString(session, "no command specified and Pty not requested\n"); err != nil {
logger.Debugf(errWriteSession, err)
}
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
logger.Infof("rejected non-Pty session without command from %s", session.RemoteAddr())
}
func (s *Server) registerSession(session ssh.Session) SessionKey {
sessionID := session.Context().Value(ssh.ContextKeySessionID)
if sessionID == nil {
sessionID = fmt.Sprintf("%p", session)
}
// Create a short 4-byte identifier from the full session ID
hasher := sha256.New()
hasher.Write([]byte(fmt.Sprintf("%v", sessionID)))
hash := hasher.Sum(nil)
shortID := hex.EncodeToString(hash[:4])
remoteAddr := session.RemoteAddr().String()
username := session.User()
sessionKey := SessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID))
s.mu.Lock()
s.sessions[sessionKey] = session
s.mu.Unlock()
return sessionKey
}
func (s *Server) unregisterSession(sessionKey SessionKey, session ssh.Session) {
s.mu.Lock()
delete(s.sessions, sessionKey)
delete(s.sessionJWTUsers, sessionKey)
// Cancel all port forwarding connections for this session
var connectionsToCancel []ConnectionKey
for key := range s.sessionCancels {
if strings.HasPrefix(string(key), string(sessionKey)+"-") {
connectionsToCancel = append(connectionsToCancel, key)
}
}
for _, key := range connectionsToCancel {
if cancelFunc, exists := s.sessionCancels[key]; exists {
log.WithField("session", sessionKey).Debugf("cancelling port forwarding context: %s", key)
cancelFunc()
delete(s.sessionCancels, key)
}
}
if sshConnValue := session.Context().Value(ssh.ContextKeyConn); sshConnValue != nil {
if sshConn, ok := sshConnValue.(*cryptossh.ServerConn); ok {
delete(s.sshConnections, sshConn)
}
}
s.mu.Unlock()
}
func (s *Server) handlePrivError(logger *log.Entry, session ssh.Session, err error) {
logger.Warnf("user privilege check failed: %v", err)
errorMsg := s.buildUserLookupErrorMessage(err)
if _, writeErr := fmt.Fprint(session, errorMsg); writeErr != nil {
logger.Debugf(errWriteSession, writeErr)
}
if exitErr := session.Exit(1); exitErr != nil {
logSessionExitError(logger, exitErr)
}
}
// buildUserLookupErrorMessage creates appropriate user-facing error messages based on error type
func (s *Server) buildUserLookupErrorMessage(err error) string {
var privilegedErr *PrivilegedUserError
switch {
case errors.As(err, &privilegedErr):
if privilegedErr.Username == "root" {
return "root login is disabled on this SSH server\n"
}
return "privileged user access is disabled on this SSH server\n"
case errors.Is(err, ErrPrivilegeRequired):
return "Windows user switching failed - NetBird must run with elevated privileges for user switching\n"
case errors.Is(err, ErrPrivilegedUserSwitch):
return "Cannot switch to privileged user - current user lacks required privileges\n"
default:
return "User authentication failed\n"
}
}

View File

@@ -0,0 +1,22 @@
//go:build js
package server
import (
"fmt"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
// handlePty is not supported on JS/WASM
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
errorMsg := "PTY sessions are not supported on WASM/JS platform\n"
if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil {
logger.Debugf(errWriteSession, err)
}
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}

81
client/ssh/server/sftp.go Normal file
View File

@@ -0,0 +1,81 @@
package server
import (
"fmt"
"io"
"github.com/gliderlabs/ssh"
"github.com/pkg/sftp"
log "github.com/sirupsen/logrus"
)
// SetAllowSFTP enables or disables SFTP support
func (s *Server) SetAllowSFTP(allow bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.allowSFTP = allow
}
// sftpSubsystemHandler handles SFTP subsystem requests
func (s *Server) sftpSubsystemHandler(sess ssh.Session) {
s.mu.RLock()
allowSFTP := s.allowSFTP
s.mu.RUnlock()
if !allowSFTP {
log.Debugf("SFTP subsystem request denied: SFTP disabled")
if err := sess.Exit(1); err != nil {
log.Debugf("SFTP session exit failed: %v", err)
}
return
}
result := s.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: sess.User(),
FeatureSupportsUserSwitch: true,
FeatureName: FeatureSFTP,
})
if !result.Allowed {
log.Warnf("SFTP access denied for user %s from %s: %v", sess.User(), sess.RemoteAddr(), result.Error)
if err := sess.Exit(1); err != nil {
log.Debugf("exit SFTP session: %v", err)
}
return
}
log.Debugf("SFTP subsystem request from user %s (effective user %s)", sess.User(), result.User.Username)
if !result.RequiresUserSwitching {
if err := s.executeSftpDirect(sess); err != nil {
log.Errorf("SFTP direct execution: %v", err)
}
return
}
if err := s.executeSftpWithPrivilegeDrop(sess, result.User); err != nil {
log.Errorf("SFTP privilege drop execution: %v", err)
}
}
// executeSftpDirect executes SFTP directly without privilege dropping
func (s *Server) executeSftpDirect(sess ssh.Session) error {
log.Debugf("starting SFTP session for user %s (no privilege dropping)", sess.User())
sftpServer, err := sftp.NewServer(sess)
if err != nil {
return fmt.Errorf("SFTP server creation: %w", err)
}
defer func() {
if err := sftpServer.Close(); err != nil {
log.Debugf("failed to close sftp server: %v", err)
}
}()
if err := sftpServer.Serve(); err != nil && err != io.EOF {
return fmt.Errorf("serve: %w", err)
}
return nil
}

View File

@@ -0,0 +1,12 @@
//go:build js
package server
import (
"os/user"
)
// parseUserCredentials is not supported on JS/WASM
func (s *Server) parseUserCredentials(_ *user.User) (uint32, uint32, []uint32, error) {
return 0, 0, nil, errNotSupported
}

View File

@@ -0,0 +1,228 @@
package server
import (
"context"
"fmt"
"net"
"net/netip"
"os"
"os/user"
"testing"
"time"
"github.com/pkg/sftp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/ssh"
)
func TestSSHServer_SFTPSubsystem(t *testing.T) {
// Skip SFTP test when running as root due to protocol issues in some environments
if os.Geteuid() == 0 {
t.Skip("Skipping SFTP test when running as root - may have protocol compatibility issues")
}
// Get current user for SSH connection
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
// Generate host key for server
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Create server with SFTP enabled
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowSFTP(true)
server.SetAllowRootLogin(true)
// Start server
serverAddr := "127.0.0.1:0"
started := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
ln, err := net.Listen("tcp", serverAddr)
if err != nil {
errChan <- err
return
}
actualAddr := ln.Addr().String()
if err := ln.Close(); err != nil {
errChan <- fmt.Errorf("close temp listener: %w", err)
return
}
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {
case actualAddr := <-started:
serverAddr = actualAddr
case err := <-errChan:
t.Fatalf("Server failed to start: %v", err)
case <-time.After(5 * time.Second):
t.Fatal("Server start timeout")
}
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Parse client private key
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
require.NoError(t, err)
// Parse server host key
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
require.NoError(t, err)
hostPubKey := hostPrivParsed.PublicKey()
// (currentUser already obtained at function start)
// Create SSH client connection
clientConfig := &cryptossh.ClientConfig{
User: currentUser.Username,
Auth: []cryptossh.AuthMethod{
cryptossh.PublicKeys(signer),
},
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
Timeout: 5 * time.Second,
}
conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig)
require.NoError(t, err, "SSH connection should succeed")
defer func() {
if err := conn.Close(); err != nil {
t.Logf("connection close error: %v", err)
}
}()
// Create SFTP client
sftpClient, err := sftp.NewClient(conn)
require.NoError(t, err, "SFTP client creation should succeed")
defer func() {
if err := sftpClient.Close(); err != nil {
t.Logf("SFTP client close error: %v", err)
}
}()
// Test basic SFTP operations
workingDir, err := sftpClient.Getwd()
assert.NoError(t, err, "Should be able to get working directory")
assert.NotEmpty(t, workingDir, "Working directory should not be empty")
// Test directory listing
files, err := sftpClient.ReadDir(".")
assert.NoError(t, err, "Should be able to list current directory")
assert.NotNil(t, files, "File list should not be nil")
}
func TestSSHServer_SFTPDisabled(t *testing.T) {
// Get current user for SSH connection
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
// Generate host key for server
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Create server with SFTP disabled
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowSFTP(false)
// Start server
serverAddr := "127.0.0.1:0"
started := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
ln, err := net.Listen("tcp", serverAddr)
if err != nil {
errChan <- err
return
}
actualAddr := ln.Addr().String()
if err := ln.Close(); err != nil {
errChan <- fmt.Errorf("close temp listener: %w", err)
return
}
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {
case actualAddr := <-started:
serverAddr = actualAddr
case err := <-errChan:
t.Fatalf("Server failed to start: %v", err)
case <-time.After(5 * time.Second):
t.Fatal("Server start timeout")
}
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Parse client private key
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
require.NoError(t, err)
// Parse server host key
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
require.NoError(t, err)
hostPubKey := hostPrivParsed.PublicKey()
// (currentUser already obtained at function start)
// Create SSH client connection
clientConfig := &cryptossh.ClientConfig{
User: currentUser.Username,
Auth: []cryptossh.AuthMethod{
cryptossh.PublicKeys(signer),
},
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
Timeout: 5 * time.Second,
}
conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig)
require.NoError(t, err, "SSH connection should succeed")
defer func() {
if err := conn.Close(); err != nil {
t.Logf("connection close error: %v", err)
}
}()
// Try to create SFTP client - should fail when SFTP is disabled
_, err = sftp.NewClient(conn)
assert.Error(t, err, "SFTP client creation should fail when SFTP is disabled")
}

View File

@@ -0,0 +1,71 @@
//go:build !windows
package server
import (
"errors"
"fmt"
"os"
"os/exec"
"os/user"
"strconv"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
// executeSftpWithPrivilegeDrop executes SFTP using Unix privilege dropping
func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error {
uid, gid, groups, err := s.parseUserCredentials(targetUser)
if err != nil {
return fmt.Errorf("parse user credentials: %w", err)
}
sftpCmd, err := s.createSftpExecutorCommand(sess, uid, gid, groups, targetUser.HomeDir)
if err != nil {
return fmt.Errorf("create executor: %w", err)
}
sftpCmd.Stdin = sess
sftpCmd.Stdout = sess
sftpCmd.Stderr = sess.Stderr()
log.Tracef("starting SFTP with privilege dropping to user %s (UID=%d, GID=%d)", targetUser.Username, uid, gid)
if err := sftpCmd.Start(); err != nil {
return fmt.Errorf("starting SFTP executor: %w", err)
}
if err := sftpCmd.Wait(); err != nil {
var exitError *exec.ExitError
if errors.As(err, &exitError) {
log.Tracef("SFTP process exited with code %d", exitError.ExitCode())
return nil
}
return fmt.Errorf("exec: %w", err)
}
return nil
}
// createSftpExecutorCommand creates a command that spawns netbird ssh sftp for privilege dropping
func (s *Server) createSftpExecutorCommand(sess ssh.Session, uid, gid uint32, groups []uint32, workingDir string) (*exec.Cmd, error) {
netbirdPath, err := os.Executable()
if err != nil {
return nil, err
}
args := []string{
"ssh", "sftp",
"--uid", strconv.FormatUint(uint64(uid), 10),
"--gid", strconv.FormatUint(uint64(gid), 10),
"--working-dir", workingDir,
}
for _, group := range groups {
args = append(args, "--groups", strconv.FormatUint(uint64(group), 10))
}
log.Tracef("creating SFTP executor command: %s %v", netbirdPath, args)
return exec.CommandContext(sess.Context(), netbirdPath, args...), nil
}

View File

@@ -0,0 +1,91 @@
//go:build windows
package server
import (
"errors"
"fmt"
"os"
"os/exec"
"os/user"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
// createSftpCommand creates a Windows SFTP command with user switching.
// The caller must close the returned token handle after starting the process.
func (s *Server) createSftpCommand(targetUser *user.User, sess ssh.Session) (*exec.Cmd, windows.Token, error) {
username, domain := s.parseUsername(targetUser.Username)
netbirdPath, err := os.Executable()
if err != nil {
return nil, 0, fmt.Errorf("get netbird executable path: %w", err)
}
args := []string{
"ssh", "sftp",
"--working-dir", targetUser.HomeDir,
"--windows-username", username,
"--windows-domain", domain,
}
pd := NewPrivilegeDropper()
token, err := pd.createToken(username, domain)
if err != nil {
return nil, 0, fmt.Errorf("create token: %w", err)
}
defer func() {
if err := windows.CloseHandle(token); err != nil {
log.Warnf("failed to close impersonation token: %v", err)
}
}()
cmd, primaryToken, err := pd.createProcessWithToken(sess.Context(), windows.Token(token), netbirdPath, append([]string{netbirdPath}, args...), targetUser.HomeDir)
if err != nil {
return nil, 0, fmt.Errorf("create SFTP command: %w", err)
}
log.Debugf("Created Windows SFTP command with user switching for %s", targetUser.Username)
return cmd, primaryToken, nil
}
// executeSftpCommand executes a Windows SFTP command with proper I/O handling
func (s *Server) executeSftpCommand(sess ssh.Session, sftpCmd *exec.Cmd, token windows.Token) error {
defer func() {
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
log.Debugf("close primary token: %v", err)
}
}()
sftpCmd.Stdin = sess
sftpCmd.Stdout = sess
sftpCmd.Stderr = sess.Stderr()
if err := sftpCmd.Start(); err != nil {
return fmt.Errorf("starting sftp executor: %w", err)
}
if err := sftpCmd.Wait(); err != nil {
var exitError *exec.ExitError
if errors.As(err, &exitError) {
log.Tracef("sftp process exited with code %d", exitError.ExitCode())
return nil
}
return fmt.Errorf("exec sftp: %w", err)
}
return nil
}
// executeSftpWithPrivilegeDrop executes SFTP using Windows privilege dropping
func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error {
sftpCmd, token, err := s.createSftpCommand(targetUser, sess)
if err != nil {
return fmt.Errorf("create sftp: %w", err)
}
return s.executeSftpCommand(sess, sftpCmd, token)
}

180
client/ssh/server/shell.go Normal file
View File

@@ -0,0 +1,180 @@
package server
import (
"bufio"
"fmt"
"net"
"os"
"os/exec"
"os/user"
"runtime"
"strconv"
"strings"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
const (
defaultUnixShell = "/bin/sh"
pwshExe = "pwsh.exe" // #nosec G101 - This is not a credential, just executable name
powershellExe = "powershell.exe"
)
// getUserShell returns the appropriate shell for the given user ID
// Handles all platform-specific logic and fallbacks consistently
func getUserShell(userID string) string {
switch runtime.GOOS {
case "windows":
return getWindowsUserShell()
default:
return getUnixUserShell(userID)
}
}
// getWindowsUserShell returns the best shell for Windows users.
// We intentionally do not support cmd.exe or COMSPEC fallbacks to avoid command injection
// vulnerabilities that arise from cmd.exe's complex command line parsing and special characters.
// PowerShell provides safer argument handling and is available on all modern Windows systems.
// Order: pwsh.exe -> powershell.exe
func getWindowsUserShell() string {
if path, err := exec.LookPath(pwshExe); err == nil {
return path
}
if path, err := exec.LookPath(powershellExe); err == nil {
return path
}
return `C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe`
}
// getUnixUserShell returns the shell for Unix-like systems
func getUnixUserShell(userID string) string {
shell := getShellFromPasswd(userID)
if shell != "" {
return shell
}
if shell := os.Getenv("SHELL"); shell != "" {
return shell
}
return defaultUnixShell
}
// getShellFromPasswd reads the shell from /etc/passwd for the given user ID
func getShellFromPasswd(userID string) string {
file, err := os.Open("/etc/passwd")
if err != nil {
return ""
}
defer func() {
if err := file.Close(); err != nil {
log.Warnf("close /etc/passwd file: %v", err)
}
}()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
fields := strings.Split(line, ":")
if len(fields) < 7 {
continue
}
// field 2 is UID
if fields[2] == userID {
shell := strings.TrimSpace(fields[6])
return shell
}
}
if err := scanner.Err(); err != nil {
log.Warnf("error reading /etc/passwd: %v", err)
}
return ""
}
// prepareUserEnv prepares environment variables for user execution
func prepareUserEnv(user *user.User, shell string) []string {
pathValue := "/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games"
if runtime.GOOS == "windows" {
pathValue = `C:\Windows\System32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0`
}
return []string{
fmt.Sprint("SHELL=" + shell),
fmt.Sprint("USER=" + user.Username),
fmt.Sprint("LOGNAME=" + user.Username),
fmt.Sprint("HOME=" + user.HomeDir),
"PATH=" + pathValue,
}
}
// acceptEnv checks if environment variable from SSH client should be accepted
// This is a whitelist of variables that SSH clients can send to the server
func acceptEnv(envVar string) bool {
varName := envVar
if idx := strings.Index(envVar, "="); idx != -1 {
varName = envVar[:idx]
}
exactMatches := []string{
"LANG",
"LANGUAGE",
"TERM",
"COLORTERM",
"EDITOR",
"VISUAL",
"PAGER",
"LESS",
"LESSCHARSET",
"TZ",
}
prefixMatches := []string{
"LC_",
}
for _, exact := range exactMatches {
if varName == exact {
return true
}
}
for _, prefix := range prefixMatches {
if strings.HasPrefix(varName, prefix) {
return true
}
}
return false
}
// prepareSSHEnv prepares SSH protocol-specific environment variables
// These variables provide information about the SSH connection itself
func prepareSSHEnv(session ssh.Session) []string {
remoteAddr := session.RemoteAddr()
localAddr := session.LocalAddr()
remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String())
if err != nil {
remoteHost = remoteAddr.String()
remotePort = "0"
}
localHost, localPort, err := net.SplitHostPort(localAddr.String())
if err != nil {
localHost = localAddr.String()
localPort = strconv.Itoa(InternalSSHPort)
}
return []string{
// SSH_CLIENT format: "client_ip client_port server_port"
fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
// SSH_CONNECTION format: "client_ip client_port server_ip server_port"
fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort),
}
}

45
client/ssh/server/test.go Normal file
View File

@@ -0,0 +1,45 @@
package server
import (
"context"
"fmt"
"net"
"net/netip"
"testing"
"time"
)
func StartTestServer(t *testing.T, server *Server) string {
started := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
errChan <- err
return
}
actualAddr := ln.Addr().String()
if err := ln.Close(); err != nil {
errChan <- fmt.Errorf("close temp listener: %w", err)
return
}
addrPort := netip.MustParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {
case actualAddr := <-started:
return actualAddr
case err := <-errChan:
t.Fatalf("Server failed to start: %v", err)
case <-time.After(5 * time.Second):
t.Fatal("Server start timeout")
}
return ""
}

View File

@@ -0,0 +1,411 @@
package server
import (
"errors"
"fmt"
"os"
"os/user"
"runtime"
"strings"
log "github.com/sirupsen/logrus"
)
var (
ErrPrivilegeRequired = errors.New("SeAssignPrimaryTokenPrivilege required for user switching - NetBird must run with elevated privileges")
ErrPrivilegedUserSwitch = errors.New("cannot switch to privileged user - current user lacks required privileges")
)
// isPlatformUnix returns true for Unix-like platforms (Linux, macOS, etc.)
func isPlatformUnix() bool {
return getCurrentOS() != "windows"
}
// Dependency injection variables for testing - allows mocking dynamic runtime checks
var (
getCurrentUser = user.Current
lookupUser = user.Lookup
getCurrentOS = func() string { return runtime.GOOS }
getIsProcessPrivileged = isCurrentProcessPrivileged
getEuid = os.Geteuid
)
const (
// FeatureSSHLogin represents SSH login operations for privilege checking
FeatureSSHLogin = "SSH login"
// FeatureSFTP represents SFTP operations for privilege checking
FeatureSFTP = "SFTP"
)
// PrivilegeCheckRequest represents a privilege check request
type PrivilegeCheckRequest struct {
// Username being requested (empty = current user)
RequestedUsername string
FeatureSupportsUserSwitch bool // Does this feature/operation support user switching?
FeatureName string
}
// PrivilegeCheckResult represents the result of a privilege check
type PrivilegeCheckResult struct {
// Allowed indicates whether the privilege check passed
Allowed bool
// User is the effective user to use for the operation (nil if not allowed)
User *user.User
// Error contains the reason for denial (nil if allowed)
Error error
// UsedFallback indicates we fell back to current user instead of requested user.
// This happens on Unix when running as an unprivileged user (e.g., in containers)
// where there's no point in user switching since we lack privileges anyway.
// When true, all privilege checks have already been performed and no additional
// privilege dropping or root checks are needed - the current user is the target.
UsedFallback bool
// RequiresUserSwitching indicates whether user switching will actually occur
// (false for fallback cases where no actual switching happens)
RequiresUserSwitching bool
}
// CheckPrivileges performs comprehensive privilege checking for all SSH features.
// This is the single source of truth for privilege decisions across the SSH server.
func (s *Server) CheckPrivileges(req PrivilegeCheckRequest) PrivilegeCheckResult {
context, err := s.buildPrivilegeCheckContext(req.FeatureName)
if err != nil {
return PrivilegeCheckResult{Allowed: false, Error: err}
}
// Handle empty username case - but still check root access controls
if req.RequestedUsername == "" {
if isPrivilegedUsername(context.currentUser.Username) && !context.allowRoot {
return PrivilegeCheckResult{
Allowed: false,
Error: &PrivilegedUserError{Username: context.currentUser.Username},
}
}
return PrivilegeCheckResult{
Allowed: true,
User: context.currentUser,
RequiresUserSwitching: false,
}
}
return s.checkUserRequest(context, req)
}
// buildPrivilegeCheckContext gathers all the context needed for privilege checking
func (s *Server) buildPrivilegeCheckContext(featureName string) (*privilegeCheckContext, error) {
currentUser, err := getCurrentUser()
if err != nil {
return nil, fmt.Errorf("get current user for %s: %w", featureName, err)
}
s.mu.RLock()
allowRoot := s.allowRootLogin
s.mu.RUnlock()
return &privilegeCheckContext{
currentUser: currentUser,
currentUserPrivileged: getIsProcessPrivileged(),
allowRoot: allowRoot,
}, nil
}
// checkUserRequest handles normal privilege checking flow for specific usernames
func (s *Server) checkUserRequest(ctx *privilegeCheckContext, req PrivilegeCheckRequest) PrivilegeCheckResult {
if !ctx.currentUserPrivileged && isPlatformUnix() {
log.Debugf("Unix non-privileged shortcut: falling back to current user %s for %s (requested: %s)",
ctx.currentUser.Username, req.FeatureName, req.RequestedUsername)
return PrivilegeCheckResult{
Allowed: true,
User: ctx.currentUser,
UsedFallback: true,
RequiresUserSwitching: false,
}
}
resolvedUser, err := s.resolveRequestedUser(req.RequestedUsername)
if err != nil {
// Calculate if user switching would be required even if lookup failed
needsUserSwitching := !isSameUser(req.RequestedUsername, ctx.currentUser.Username)
return PrivilegeCheckResult{
Allowed: false,
Error: err,
RequiresUserSwitching: needsUserSwitching,
}
}
needsUserSwitching := !isSameResolvedUser(resolvedUser, ctx.currentUser)
if isPrivilegedUsername(resolvedUser.Username) && !ctx.allowRoot {
return PrivilegeCheckResult{
Allowed: false,
Error: &PrivilegedUserError{Username: resolvedUser.Username},
RequiresUserSwitching: needsUserSwitching,
}
}
if needsUserSwitching && !req.FeatureSupportsUserSwitch {
return PrivilegeCheckResult{
Allowed: false,
Error: fmt.Errorf("%s: user switching not supported by this feature", req.FeatureName),
RequiresUserSwitching: needsUserSwitching,
}
}
return PrivilegeCheckResult{
Allowed: true,
User: resolvedUser,
RequiresUserSwitching: needsUserSwitching,
}
}
// resolveRequestedUser resolves a username to its canonical user identity
func (s *Server) resolveRequestedUser(requestedUsername string) (*user.User, error) {
if requestedUsername == "" {
return getCurrentUser()
}
if err := validateUsername(requestedUsername); err != nil {
return nil, fmt.Errorf("invalid username %q: %w", requestedUsername, err)
}
u, err := lookupUser(requestedUsername)
if err != nil {
return nil, &UserNotFoundError{Username: requestedUsername, Cause: err}
}
return u, nil
}
// isSameResolvedUser compares two resolved user identities
func isSameResolvedUser(user1, user2 *user.User) bool {
if user1 == nil || user2 == nil {
return user1 == user2
}
return user1.Uid == user2.Uid
}
// privilegeCheckContext holds all context needed for privilege checking
type privilegeCheckContext struct {
currentUser *user.User
currentUserPrivileged bool
allowRoot bool
}
// isSameUser checks if two usernames refer to the same user
// SECURITY: This function must be conservative - it should only return true
// when we're certain both usernames refer to the exact same user identity
func isSameUser(requestedUsername, currentUsername string) bool {
// Empty requested username means current user
if requestedUsername == "" {
return true
}
// Exact match (most common case)
if getCurrentOS() == "windows" {
if strings.EqualFold(requestedUsername, currentUsername) {
return true
}
} else {
if requestedUsername == currentUsername {
return true
}
}
// Windows domain resolution: only allow domain stripping when comparing
// a bare username against the current user's domain-qualified name
if getCurrentOS() == "windows" {
return isWindowsSameUser(requestedUsername, currentUsername)
}
return false
}
// isWindowsSameUser handles Windows-specific user comparison with domain logic
func isWindowsSameUser(requestedUsername, currentUsername string) bool {
// Extract domain and username parts
extractParts := func(name string) (domain, user string) {
// Handle DOMAIN\username format
if idx := strings.LastIndex(name, `\`); idx != -1 {
return name[:idx], name[idx+1:]
}
// Handle user@domain.com format
if idx := strings.Index(name, "@"); idx != -1 {
return name[idx+1:], name[:idx]
}
// No domain specified - local machine
return "", name
}
reqDomain, reqUser := extractParts(requestedUsername)
curDomain, curUser := extractParts(currentUsername)
// Case-insensitive username comparison
if !strings.EqualFold(reqUser, curUser) {
return false
}
// If requested username has no domain, it refers to local machine user
// Allow this to match the current user regardless of current user's domain
if reqDomain == "" {
return true
}
// If both have domains, they must match exactly (case-insensitive)
return strings.EqualFold(reqDomain, curDomain)
}
// SetAllowRootLogin configures root login access
func (s *Server) SetAllowRootLogin(allow bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.allowRootLogin = allow
}
// userNameLookup performs user lookup with root login permission check
func (s *Server) userNameLookup(username string) (*user.User, error) {
result := s.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: username,
FeatureSupportsUserSwitch: true,
FeatureName: FeatureSSHLogin,
})
if !result.Allowed {
return nil, result.Error
}
return result.User, nil
}
// userPrivilegeCheck performs user lookup with full privilege check result
func (s *Server) userPrivilegeCheck(username string) (PrivilegeCheckResult, error) {
result := s.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: username,
FeatureSupportsUserSwitch: true,
FeatureName: FeatureSSHLogin,
})
if !result.Allowed {
return result, result.Error
}
return result, nil
}
// isPrivilegedUsername checks if the given username represents a privileged user across platforms.
// On Unix: root
// On Windows: Administrator, SYSTEM (case-insensitive)
// Handles domain-qualified usernames like "DOMAIN\Administrator" or "user@domain.com"
func isPrivilegedUsername(username string) bool {
if getCurrentOS() != "windows" {
return username == "root"
}
bareUsername := username
// Handle Windows domain format: DOMAIN\username
if idx := strings.LastIndex(username, `\`); idx != -1 {
bareUsername = username[idx+1:]
}
// Handle email-style format: username@domain.com
if idx := strings.Index(bareUsername, "@"); idx != -1 {
bareUsername = bareUsername[:idx]
}
return isWindowsPrivilegedUser(bareUsername)
}
// isWindowsPrivilegedUser checks if a bare username (domain already stripped) represents a Windows privileged account
func isWindowsPrivilegedUser(bareUsername string) bool {
// common privileged usernames (case insensitive)
privilegedNames := []string{
"administrator",
"admin",
"root",
"system",
"localsystem",
"networkservice",
"localservice",
}
usernameLower := strings.ToLower(bareUsername)
for _, privilegedName := range privilegedNames {
if usernameLower == privilegedName {
return true
}
}
// computer accounts (ending with $) are not privileged by themselves
// They only gain privileges through group membership or specific SIDs
if targetUser, err := lookupUser(bareUsername); err == nil {
return isWindowsPrivilegedSID(targetUser.Uid)
}
return false
}
// isWindowsPrivilegedSID checks if a Windows SID represents a privileged account
func isWindowsPrivilegedSID(sid string) bool {
privilegedSIDs := []string{
"S-1-5-18", // Local System (SYSTEM)
"S-1-5-19", // Local Service (NT AUTHORITY\LOCAL SERVICE)
"S-1-5-20", // Network Service (NT AUTHORITY\NETWORK SERVICE)
"S-1-5-32-544", // Administrators group (BUILTIN\Administrators)
"S-1-5-500", // Built-in Administrator account (local machine RID 500)
}
for _, privilegedSID := range privilegedSIDs {
if sid == privilegedSID {
return true
}
}
// Check for domain administrator accounts (RID 500 in any domain)
// Format: S-1-5-21-domain-domain-domain-500
// This is reliable as RID 500 is reserved for the domain Administrator account
if strings.HasPrefix(sid, "S-1-5-21-") && strings.HasSuffix(sid, "-500") {
return true
}
// Check for other well-known privileged RIDs in domain contexts
// RID 512 = Domain Admins group, RID 516 = Domain Controllers group
if strings.HasPrefix(sid, "S-1-5-21-") {
if strings.HasSuffix(sid, "-512") || // Domain Admins group
strings.HasSuffix(sid, "-516") || // Domain Controllers group
strings.HasSuffix(sid, "-519") { // Enterprise Admins group
return true
}
}
return false
}
// isCurrentProcessPrivileged checks if the current process is running with elevated privileges.
// On Unix systems, this means running as root (UID 0).
// On Windows, this means running as Administrator or SYSTEM.
func isCurrentProcessPrivileged() bool {
if getCurrentOS() == "windows" {
return isWindowsElevated()
}
return getEuid() == 0
}
// isWindowsElevated checks if the current process is running with elevated privileges on Windows
func isWindowsElevated() bool {
currentUser, err := getCurrentUser()
if err != nil {
log.Errorf("failed to get current user for privilege check, assuming non-privileged: %v", err)
return false
}
if isWindowsPrivilegedSID(currentUser.Uid) {
log.Debugf("Windows user switching supported: running as privileged SID %s", currentUser.Uid)
return true
}
if isPrivilegedUsername(currentUser.Username) {
log.Debugf("Windows user switching supported: running as privileged username %s", currentUser.Username)
return true
}
log.Debugf("Windows user switching not supported: not running as privileged user (current: %s)", currentUser.Uid)
return false
}

View File

@@ -0,0 +1,8 @@
//go:build js
package server
// validateUsername is not supported on JS/WASM
func validateUsername(_ string) error {
return errNotSupported
}

View File

@@ -0,0 +1,908 @@
package server
import (
"errors"
"os/user"
"runtime"
"testing"
"github.com/stretchr/testify/assert"
)
// Test helper functions
func createTestUser(username, uid, gid, homeDir string) *user.User {
return &user.User{
Uid: uid,
Gid: gid,
Username: username,
Name: username,
HomeDir: homeDir,
}
}
// Test dependency injection setup - injects platform dependencies to test real logic
func setupTestDependencies(currentUser *user.User, currentUserErr error, os string, euid int, lookupUsers map[string]*user.User, lookupErrors map[string]error) func() {
// Store originals
originalGetCurrentUser := getCurrentUser
originalLookupUser := lookupUser
originalGetCurrentOS := getCurrentOS
originalGetEuid := getEuid
// Reset caches to ensure clean test state
// Set test values - inject platform dependencies
getCurrentUser = func() (*user.User, error) {
return currentUser, currentUserErr
}
lookupUser = func(username string) (*user.User, error) {
if err, exists := lookupErrors[username]; exists {
return nil, err
}
if userObj, exists := lookupUsers[username]; exists {
return userObj, nil
}
return nil, errors.New("user: unknown user " + username)
}
getCurrentOS = func() string {
return os
}
getEuid = func() int {
return euid
}
// Mock privilege detection based on the test user
getIsProcessPrivileged = func() bool {
if currentUser == nil {
return false
}
// Check both username and SID for Windows systems
if os == "windows" && isWindowsPrivilegedSID(currentUser.Uid) {
return true
}
return isPrivilegedUsername(currentUser.Username)
}
// Return cleanup function
return func() {
getCurrentUser = originalGetCurrentUser
lookupUser = originalLookupUser
getCurrentOS = originalGetCurrentOS
getEuid = originalGetEuid
getIsProcessPrivileged = isCurrentProcessPrivileged
// Reset caches after test
}
}
func TestCheckPrivileges_ComprehensiveMatrix(t *testing.T) {
tests := []struct {
name string
os string
euid int
currentUser *user.User
requestedUsername string
featureSupportsUserSwitch bool
allowRoot bool
lookupUsers map[string]*user.User
expectedAllowed bool
expectedRequiresSwitch bool
}{
{
name: "linux_root_can_switch_to_alice",
os: "linux",
euid: 0, // Root process
currentUser: createTestUser("root", "0", "0", "/root"),
requestedUsername: "alice",
featureSupportsUserSwitch: true,
allowRoot: true,
lookupUsers: map[string]*user.User{
"alice": createTestUser("alice", "1000", "1000", "/home/alice"),
},
expectedAllowed: true,
expectedRequiresSwitch: true,
},
{
name: "linux_non_root_fallback_to_current_user",
os: "linux",
euid: 1000, // Non-root process
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
requestedUsername: "bob",
featureSupportsUserSwitch: true,
allowRoot: true,
expectedAllowed: true, // Should fallback to current user (alice)
expectedRequiresSwitch: false, // Fallback means no actual switching
},
{
name: "windows_admin_can_switch_to_alice",
os: "windows",
euid: 1000, // Irrelevant on Windows
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
requestedUsername: "alice",
featureSupportsUserSwitch: true,
allowRoot: true,
lookupUsers: map[string]*user.User{
"alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
},
expectedAllowed: true,
expectedRequiresSwitch: true,
},
{
name: "windows_non_admin_no_fallback_hard_failure",
os: "windows",
euid: 1000, // Irrelevant on Windows
currentUser: createTestUser("alice", "1001", "1001", "C:\\Users\\alice"),
requestedUsername: "bob",
featureSupportsUserSwitch: true,
allowRoot: true,
lookupUsers: map[string]*user.User{
"bob": createTestUser("bob", "S-1-5-21-123456789-123456789-123456789-1002", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\bob"),
},
expectedAllowed: true, // Let OS decide - deferred security check
expectedRequiresSwitch: true, // Different user was requested
},
// Comprehensive test matrix: non-root linux with different allowRoot settings
{
name: "linux_non_root_request_root_allowRoot_false",
os: "linux",
euid: 1000,
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
requestedUsername: "root",
featureSupportsUserSwitch: true,
allowRoot: false,
expectedAllowed: true, // Fallback allows access regardless of root setting
expectedRequiresSwitch: false, // Fallback case, no switching
},
{
name: "linux_non_root_request_root_allowRoot_true",
os: "linux",
euid: 1000,
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
requestedUsername: "root",
featureSupportsUserSwitch: true,
allowRoot: true,
expectedAllowed: true, // Should fallback to alice (non-privileged process)
expectedRequiresSwitch: false, // Fallback means no actual switching
},
// Windows admin test matrix
{
name: "windows_admin_request_root_allowRoot_false",
os: "windows",
euid: 1000,
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
requestedUsername: "root",
featureSupportsUserSwitch: true,
allowRoot: false,
expectedAllowed: false, // Root not allowed
expectedRequiresSwitch: true,
},
{
name: "windows_admin_request_root_allowRoot_true",
os: "windows",
euid: 1000,
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
requestedUsername: "root",
featureSupportsUserSwitch: true,
allowRoot: true,
lookupUsers: map[string]*user.User{
"root": createTestUser("root", "0", "0", "/root"),
},
expectedAllowed: true, // Windows user switching should work like Unix
expectedRequiresSwitch: true,
},
// Windows non-admin test matrix
{
name: "windows_non_admin_request_root_allowRoot_false",
os: "windows",
euid: 1000,
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
requestedUsername: "root",
featureSupportsUserSwitch: true,
allowRoot: false,
expectedAllowed: false, // Root not allowed (allowRoot=false takes precedence)
expectedRequiresSwitch: true,
},
{
name: "windows_system_account_allowRoot_false",
os: "windows",
euid: 1000,
currentUser: createTestUser("NETBIRD\\WIN2K19-C2$", "S-1-5-18", "S-1-5-18", "C:\\Windows\\System32"),
requestedUsername: "root",
featureSupportsUserSwitch: true,
allowRoot: false,
expectedAllowed: false, // Root not allowed
expectedRequiresSwitch: true,
},
{
name: "windows_system_account_allowRoot_true",
os: "windows",
euid: 1000,
currentUser: createTestUser("NETBIRD\\WIN2K19-C2$", "S-1-5-18", "S-1-5-18", "C:\\Windows\\System32"),
requestedUsername: "root",
featureSupportsUserSwitch: true,
allowRoot: true,
lookupUsers: map[string]*user.User{
"root": createTestUser("root", "0", "0", "/root"),
},
expectedAllowed: true, // SYSTEM can switch to root
expectedRequiresSwitch: true,
},
{
name: "windows_non_admin_request_root_allowRoot_true",
os: "windows",
euid: 1000,
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
requestedUsername: "root",
featureSupportsUserSwitch: true,
allowRoot: true,
lookupUsers: map[string]*user.User{
"root": createTestUser("root", "0", "0", "/root"),
},
expectedAllowed: true, // Let OS decide - deferred security check
expectedRequiresSwitch: true,
},
// Feature doesn't support user switching scenarios
{
name: "linux_root_feature_no_user_switching_same_user",
os: "linux",
euid: 0,
currentUser: createTestUser("root", "0", "0", "/root"),
requestedUsername: "root", // Same user
featureSupportsUserSwitch: false,
allowRoot: true,
lookupUsers: map[string]*user.User{
"root": createTestUser("root", "0", "0", "/root"),
},
expectedAllowed: true, // Same user should work regardless of feature support
expectedRequiresSwitch: false,
},
{
name: "linux_root_feature_no_user_switching_different_user",
os: "linux",
euid: 0,
currentUser: createTestUser("root", "0", "0", "/root"),
requestedUsername: "alice",
featureSupportsUserSwitch: false, // Feature doesn't support switching
allowRoot: true,
lookupUsers: map[string]*user.User{
"alice": createTestUser("alice", "1000", "1000", "/home/alice"),
},
expectedAllowed: false, // Should deny because feature doesn't support switching
expectedRequiresSwitch: true,
},
// Empty username (current user) scenarios
{
name: "linux_non_root_current_user_empty_username",
os: "linux",
euid: 1000,
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
requestedUsername: "", // Empty = current user
featureSupportsUserSwitch: true,
allowRoot: false,
expectedAllowed: true, // Current user should always work
expectedRequiresSwitch: false,
},
{
name: "linux_root_current_user_empty_username_root_not_allowed",
os: "linux",
euid: 0,
currentUser: createTestUser("root", "0", "0", "/root"),
requestedUsername: "", // Empty = current user (root)
featureSupportsUserSwitch: true,
allowRoot: false, // Root not allowed
expectedAllowed: false, // Should deny root even when it's current user
expectedRequiresSwitch: false,
},
// User not found scenarios
{
name: "linux_root_user_not_found",
os: "linux",
euid: 0,
currentUser: createTestUser("root", "0", "0", "/root"),
requestedUsername: "nonexistent",
featureSupportsUserSwitch: true,
allowRoot: true,
lookupUsers: map[string]*user.User{}, // No users defined = user not found
expectedAllowed: false, // Should fail due to user not found
expectedRequiresSwitch: true,
},
// Windows feature doesn't support user switching
{
name: "windows_admin_feature_no_user_switching_different_user",
os: "windows",
euid: 1000,
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
requestedUsername: "alice",
featureSupportsUserSwitch: false, // Feature doesn't support switching
allowRoot: true,
lookupUsers: map[string]*user.User{
"alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
},
expectedAllowed: false, // Should deny because feature doesn't support switching
expectedRequiresSwitch: true,
},
// Windows regular user scenarios (non-admin)
{
name: "windows_regular_user_same_user",
os: "windows",
euid: 1000,
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
requestedUsername: "alice", // Same user
featureSupportsUserSwitch: true,
allowRoot: false,
lookupUsers: map[string]*user.User{
"alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
},
expectedAllowed: true, // Regular user accessing themselves should work
expectedRequiresSwitch: false, // No switching for same user
},
{
name: "windows_regular_user_empty_username",
os: "windows",
euid: 1000,
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
requestedUsername: "", // Empty = current user
featureSupportsUserSwitch: true,
allowRoot: false,
expectedAllowed: true, // Current user should always work
expectedRequiresSwitch: false, // No switching for current user
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Inject platform dependencies to test real logic
cleanup := setupTestDependencies(tt.currentUser, nil, tt.os, tt.euid, tt.lookupUsers, nil)
defer cleanup()
server := &Server{allowRootLogin: tt.allowRoot}
result := server.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: tt.requestedUsername,
FeatureSupportsUserSwitch: tt.featureSupportsUserSwitch,
FeatureName: "SSH login",
})
assert.Equal(t, tt.expectedAllowed, result.Allowed)
assert.Equal(t, tt.expectedRequiresSwitch, result.RequiresUserSwitching)
})
}
}
func TestUsedFallback_MeansNoPrivilegeDropping(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Fallback mechanism is Unix-specific")
}
// Create test scenario where fallback should occur
server := &Server{allowRootLogin: true}
// Mock dependencies to simulate non-privileged user
originalGetCurrentUser := getCurrentUser
originalGetIsProcessPrivileged := getIsProcessPrivileged
defer func() {
getCurrentUser = originalGetCurrentUser
getIsProcessPrivileged = originalGetIsProcessPrivileged
}()
// Set up mocks for fallback scenario
getCurrentUser = func() (*user.User, error) {
return createTestUser("netbird", "1000", "1000", "/var/lib/netbird"), nil
}
getIsProcessPrivileged = func() bool { return false } // Non-privileged
// Request different user - should fallback
result := server.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: "alice",
FeatureSupportsUserSwitch: true,
FeatureName: "SSH login",
})
// Verify fallback occurred
assert.True(t, result.Allowed, "Should allow with fallback")
assert.True(t, result.UsedFallback, "Should indicate fallback was used")
assert.Equal(t, "netbird", result.User.Username, "Should return current user")
assert.False(t, result.RequiresUserSwitching, "Should not require switching when fallback is used")
// Key assertion: When UsedFallback is true, no privilege dropping should be needed
// because all privilege checks have already been performed and we're using current user
t.Logf("UsedFallback=true means: current user (%s) is the target, no privilege dropping needed",
result.User.Username)
}
func TestPrivilegedUsernameDetection(t *testing.T) {
tests := []struct {
name string
username string
platform string
privileged bool
}{
// Unix/Linux tests
{"unix_root", "root", "linux", true},
{"unix_regular_user", "alice", "linux", false},
{"unix_root_capital", "Root", "linux", false}, // Case-sensitive
// Windows tests
{"windows_administrator", "Administrator", "windows", true},
{"windows_system", "SYSTEM", "windows", true},
{"windows_admin", "admin", "windows", true},
{"windows_admin_lowercase", "administrator", "windows", true}, // Case-insensitive
{"windows_domain_admin", "DOMAIN\\Administrator", "windows", true},
{"windows_email_admin", "admin@domain.com", "windows", true},
{"windows_regular_user", "alice", "windows", false},
{"windows_domain_user", "DOMAIN\\alice", "windows", false},
{"windows_localsystem", "localsystem", "windows", true},
{"windows_networkservice", "networkservice", "windows", true},
{"windows_localservice", "localservice", "windows", true},
// Computer accounts (these depend on current user context in real implementation)
{"windows_computer_account", "WIN2K19-C2$", "windows", false}, // Computer account by itself not privileged
{"windows_domain_computer", "DOMAIN\\COMPUTER$", "windows", false}, // Domain computer account
// Cross-platform
{"root_on_windows", "root", "windows", true}, // Root should be privileged everywhere
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Mock the platform for this test
cleanup := setupTestDependencies(nil, nil, tt.platform, 1000, nil, nil)
defer cleanup()
result := isPrivilegedUsername(tt.username)
assert.Equal(t, tt.privileged, result)
})
}
}
func TestWindowsPrivilegedSIDDetection(t *testing.T) {
tests := []struct {
name string
sid string
privileged bool
description string
}{
// Well-known system accounts
{"system_account", "S-1-5-18", true, "Local System (SYSTEM)"},
{"local_service", "S-1-5-19", true, "Local Service"},
{"network_service", "S-1-5-20", true, "Network Service"},
{"administrators_group", "S-1-5-32-544", true, "Administrators group"},
{"builtin_administrator", "S-1-5-500", true, "Built-in Administrator"},
// Domain accounts
{"domain_administrator", "S-1-5-21-1234567890-1234567890-1234567890-500", true, "Domain Administrator (RID 500)"},
{"domain_admins_group", "S-1-5-21-1234567890-1234567890-1234567890-512", true, "Domain Admins group"},
{"domain_controllers_group", "S-1-5-21-1234567890-1234567890-1234567890-516", true, "Domain Controllers group"},
{"enterprise_admins_group", "S-1-5-21-1234567890-1234567890-1234567890-519", true, "Enterprise Admins group"},
// Regular users
{"regular_user", "S-1-5-21-1234567890-1234567890-1234567890-1001", false, "Regular domain user"},
{"another_regular_user", "S-1-5-21-1234567890-1234567890-1234567890-1234", false, "Another regular user"},
{"local_user", "S-1-5-21-1234567890-1234567890-1234567890-1000", false, "Local regular user"},
// Groups that are not privileged
{"domain_users", "S-1-5-21-1234567890-1234567890-1234567890-513", false, "Domain Users group"},
{"power_users", "S-1-5-32-547", false, "Power Users group"},
// Invalid SIDs
{"malformed_sid", "S-1-5-invalid", false, "Malformed SID"},
{"empty_sid", "", false, "Empty SID"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isWindowsPrivilegedSID(tt.sid)
assert.Equal(t, tt.privileged, result, "Failed for %s: %s", tt.description, tt.sid)
})
}
}
func TestIsSameUser(t *testing.T) {
tests := []struct {
name string
user1 string
user2 string
os string
expected bool
}{
// Basic cases
{"same_username", "alice", "alice", "linux", true},
{"different_username", "alice", "bob", "linux", false},
// Linux (no domain processing)
{"linux_domain_vs_bare", "DOMAIN\\alice", "alice", "linux", false},
{"linux_email_vs_bare", "alice@domain.com", "alice", "linux", false},
{"linux_same_literal", "DOMAIN\\alice", "DOMAIN\\alice", "linux", true},
// Windows (with domain processing) - Note: parameter order is (requested, current, os, expected)
{"windows_domain_vs_bare", "alice", "DOMAIN\\alice", "windows", true}, // bare username matches domain current user
{"windows_email_vs_bare", "alice", "alice@domain.com", "windows", true}, // bare username matches email current user
{"windows_different_domains_same_user", "DOMAIN1\\alice", "DOMAIN2\\alice", "windows", false}, // SECURITY: different domains = different users
{"windows_case_insensitive", "Alice", "alice", "windows", true},
{"windows_different_users", "DOMAIN\\alice", "DOMAIN\\bob", "windows", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set up OS mock
cleanup := setupTestDependencies(nil, nil, tt.os, 1000, nil, nil)
defer cleanup()
result := isSameUser(tt.user1, tt.user2)
assert.Equal(t, tt.expected, result)
})
}
}
func TestUsernameValidation_Unix(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Unix-specific username validation tests")
}
tests := []struct {
name string
username string
wantErr bool
errMsg string
}{
// Valid usernames (Unix/POSIX)
{"valid_alphanumeric", "user123", false, ""},
{"valid_with_dots", "user.name", false, ""},
{"valid_with_hyphens", "user-name", false, ""},
{"valid_with_underscores", "user_name", false, ""},
{"valid_uppercase", "UserName", false, ""},
{"valid_starting_with_digit", "123user", false, ""},
{"valid_starting_with_dot", ".hidden", false, ""},
// Invalid usernames (Unix/POSIX)
{"empty_username", "", true, "username cannot be empty"},
{"username_too_long", "thisusernameiswaytoolongandexceedsthe32characterlimit", true, "username too long"},
{"username_starting_with_hyphen", "-user", true, "invalid characters"}, // POSIX restriction
{"username_with_spaces", "user name", true, "invalid characters"},
{"username_with_shell_metacharacters", "user;rm", true, "invalid characters"},
{"username_with_command_injection", "user`rm -rf /`", true, "invalid characters"},
{"username_with_pipe", "user|rm", true, "invalid characters"},
{"username_with_ampersand", "user&rm", true, "invalid characters"},
{"username_with_quotes", "user\"name", true, "invalid characters"},
{"username_with_newline", "user\nname", true, "invalid characters"},
{"reserved_dot", ".", true, "cannot be '.' or '..'"},
{"reserved_dotdot", "..", true, "cannot be '.' or '..'"},
{"username_with_at_symbol", "user@domain", true, "invalid characters"}, // Not allowed in bare Unix usernames
{"username_with_backslash", "user\\name", true, "invalid characters"}, // Not allowed in Unix usernames
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateUsername(tt.username)
if tt.wantErr {
assert.Error(t, err, "Should reject invalid username")
if tt.errMsg != "" {
assert.Contains(t, err.Error(), tt.errMsg, "Error message should contain expected text")
}
} else {
assert.NoError(t, err, "Should accept valid username")
}
})
}
}
func TestUsernameValidation_Windows(t *testing.T) {
if runtime.GOOS != "windows" {
t.Skip("Windows-specific username validation tests")
}
tests := []struct {
name string
username string
wantErr bool
errMsg string
}{
// Valid usernames (Windows)
{"valid_alphanumeric", "user123", false, ""},
{"valid_with_dots", "user.name", false, ""},
{"valid_with_hyphens", "user-name", false, ""},
{"valid_with_underscores", "user_name", false, ""},
{"valid_uppercase", "UserName", false, ""},
{"valid_starting_with_digit", "123user", false, ""},
{"valid_starting_with_dot", ".hidden", false, ""},
{"valid_starting_with_hyphen", "-user", false, ""}, // Windows allows this
{"valid_domain_username", "DOMAIN\\user", false, ""}, // Windows domain format
{"valid_email_username", "user@domain.com", false, ""}, // Windows email format
{"valid_machine_username", "MACHINE\\user", false, ""}, // Windows machine format
// Invalid usernames (Windows)
{"empty_username", "", true, "username cannot be empty"},
{"username_too_long", "thisusernameiswaytoolongandexceedsthe32characterlimit", true, "username too long"},
{"username_with_spaces", "user name", true, "invalid characters"},
{"username_with_shell_metacharacters", "user;rm", true, "invalid characters"},
{"username_with_command_injection", "user`rm -rf /`", true, "invalid characters"},
{"username_with_pipe", "user|rm", true, "invalid characters"},
{"username_with_ampersand", "user&rm", true, "invalid characters"},
{"username_with_quotes", "user\"name", true, "invalid characters"},
{"username_with_newline", "user\nname", true, "invalid characters"},
{"username_with_brackets", "user[name]", true, "invalid characters"},
{"username_with_colon", "user:name", true, "invalid characters"},
{"username_with_semicolon", "user;name", true, "invalid characters"},
{"username_with_equals", "user=name", true, "invalid characters"},
{"username_with_comma", "user,name", true, "invalid characters"},
{"username_with_plus", "user+name", true, "invalid characters"},
{"username_with_asterisk", "user*name", true, "invalid characters"},
{"username_with_question", "user?name", true, "invalid characters"},
{"username_with_angles", "user<name>", true, "invalid characters"},
{"reserved_dot", ".", true, "cannot be '.' or '..'"},
{"reserved_dotdot", "..", true, "cannot be '.' or '..'"},
{"username_ending_with_period", "user.", true, "cannot end with a period"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateUsername(tt.username)
if tt.wantErr {
assert.Error(t, err, "Should reject invalid username")
if tt.errMsg != "" {
assert.Contains(t, err.Error(), tt.errMsg, "Error message should contain expected text")
}
} else {
assert.NoError(t, err, "Should accept valid username")
}
})
}
}
// Test real-world integration scenarios with actual platform capabilities
func TestCheckPrivileges_RealWorldScenarios(t *testing.T) {
tests := []struct {
name string
feature string
featureSupportsUserSwitch bool
requestedUsername string
allowRoot bool
expectedBehaviorPattern string
}{
{"SSH_login_current_user", "SSH login", true, "", true, "should_allow_current_user"},
{"SFTP_current_user", "SFTP", true, "", true, "should_allow_current_user"},
{"port_forwarding_current_user", "port forwarding", false, "", true, "should_allow_current_user"},
{"SSH_login_root_not_allowed", "SSH login", true, "root", false, "should_deny_root"},
{"port_forwarding_different_user", "port forwarding", false, "differentuser", true, "should_deny_switching"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Mock privileged environment to ensure consistent test behavior across environments
cleanup := setupTestDependencies(
createTestUser("root", "0", "0", "/root"), // Running as root
nil,
runtime.GOOS,
0, // euid 0 (root)
map[string]*user.User{
"root": createTestUser("root", "0", "0", "/root"),
"differentuser": createTestUser("differentuser", "1000", "1000", "/home/differentuser"),
},
nil,
)
defer cleanup()
server := &Server{allowRootLogin: tt.allowRoot}
result := server.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: tt.requestedUsername,
FeatureSupportsUserSwitch: tt.featureSupportsUserSwitch,
FeatureName: tt.feature,
})
switch tt.expectedBehaviorPattern {
case "should_allow_current_user":
assert.True(t, result.Allowed, "Should allow current user access")
assert.False(t, result.RequiresUserSwitching, "Current user should not require switching")
case "should_deny_root":
assert.False(t, result.Allowed, "Should deny root when not allowed")
assert.Contains(t, result.Error.Error(), "root", "Should mention root in error")
case "should_deny_switching":
assert.False(t, result.Allowed, "Should deny when feature doesn't support switching")
assert.Contains(t, result.Error.Error(), "user switching not supported", "Should mention switching in error")
}
})
}
}
// Test with actual platform capabilities - no mocking
func TestCheckPrivileges_ActualPlatform(t *testing.T) {
// This test uses the REAL platform capabilities
server := &Server{allowRootLogin: true}
// Test current user access - should always work
result := server.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: "", // Current user
FeatureSupportsUserSwitch: true,
FeatureName: "SSH login",
})
assert.True(t, result.Allowed, "Current user should always be allowed")
assert.False(t, result.RequiresUserSwitching, "Current user should not require switching")
assert.NotNil(t, result.User, "Should return current user")
// Test user switching capability based on actual platform
actualIsPrivileged := isCurrentProcessPrivileged() // REAL check
actualOS := runtime.GOOS // REAL check
t.Logf("Platform capabilities: OS=%s, isPrivileged=%v, supportsUserSwitching=%v",
actualOS, actualIsPrivileged, actualIsPrivileged)
// Test requesting different user
result = server.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: "nonexistentuser",
FeatureSupportsUserSwitch: true,
FeatureName: "SSH login",
})
switch {
case actualOS == "windows":
// Windows supports user switching but should fail on nonexistent user
assert.False(t, result.Allowed, "Windows should deny nonexistent user")
assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed")
assert.Contains(t, result.Error.Error(), "not found",
"Should indicate user not found")
case !actualIsPrivileged:
// Non-privileged Unix processes should fallback to current user
assert.True(t, result.Allowed, "Non-privileged Unix process should fallback to current user")
assert.False(t, result.RequiresUserSwitching, "Fallback means no switching actually happens")
assert.True(t, result.UsedFallback, "Should indicate fallback was used")
assert.NotNil(t, result.User, "Should return current user")
default:
// Privileged Unix processes should attempt user lookup
assert.False(t, result.Allowed, "Should fail due to nonexistent user")
assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed")
assert.Contains(t, result.Error.Error(), "nonexistentuser",
"Should indicate user not found")
}
}
// Test platform detection logic with dependency injection
func TestPlatformLogic_DependencyInjection(t *testing.T) {
tests := []struct {
name string
os string
euid int
currentUser *user.User
expectedIsProcessPrivileged bool
expectedSupportsUserSwitching bool
}{
{
name: "linux_root_process",
os: "linux",
euid: 0,
currentUser: createTestUser("root", "0", "0", "/root"),
expectedIsProcessPrivileged: true,
expectedSupportsUserSwitching: true,
},
{
name: "linux_non_root_process",
os: "linux",
euid: 1000,
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
expectedIsProcessPrivileged: false,
expectedSupportsUserSwitching: false,
},
{
name: "windows_admin_process",
os: "windows",
euid: 1000, // euid ignored on Windows
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
expectedIsProcessPrivileged: true,
expectedSupportsUserSwitching: true, // Windows supports user switching when privileged
},
{
name: "windows_regular_process",
os: "windows",
euid: 1000, // euid ignored on Windows
currentUser: createTestUser("alice", "1001", "1001", "C:\\Users\\alice"),
expectedIsProcessPrivileged: false,
expectedSupportsUserSwitching: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Inject platform dependencies and test REAL logic
cleanup := setupTestDependencies(tt.currentUser, nil, tt.os, tt.euid, nil, nil)
defer cleanup()
// Test the actual functions with injected dependencies
actualIsPrivileged := isCurrentProcessPrivileged()
actualSupportsUserSwitching := actualIsPrivileged
assert.Equal(t, tt.expectedIsProcessPrivileged, actualIsPrivileged,
"isCurrentProcessPrivileged() result mismatch")
assert.Equal(t, tt.expectedSupportsUserSwitching, actualSupportsUserSwitching,
"supportsUserSwitching() result mismatch")
t.Logf("Platform: %s, EUID: %d, User: %s", tt.os, tt.euid, tt.currentUser.Username)
t.Logf("Results: isPrivileged=%v, supportsUserSwitching=%v",
actualIsPrivileged, actualSupportsUserSwitching)
})
}
}
func TestCheckPrivileges_WindowsElevatedUserSwitching(t *testing.T) {
// Test Windows elevated user switching scenarios with simplified privilege logic
tests := []struct {
name string
currentUser *user.User
requestedUsername string
allowRoot bool
expectedAllowed bool
expectedErrorContains string
}{
{
name: "windows_admin_can_switch_to_alice",
currentUser: createTestUser("administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\\\Users\\\\Administrator"),
requestedUsername: "alice",
allowRoot: true,
expectedAllowed: true,
},
{
name: "windows_non_admin_can_try_switch",
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\\\Users\\\\alice"),
requestedUsername: "bob",
allowRoot: true,
expectedAllowed: true, // Privilege check allows it, OS will reject during execution
},
{
name: "windows_system_can_switch_to_alice",
currentUser: createTestUser("SYSTEM", "S-1-5-18", "S-1-5-18", "C:\\\\Windows\\\\system32\\\\config\\\\systemprofile"),
requestedUsername: "alice",
allowRoot: true,
expectedAllowed: true,
},
{
name: "windows_admin_root_not_allowed",
currentUser: createTestUser("administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\\\Users\\\\Administrator"),
requestedUsername: "root",
allowRoot: false,
expectedAllowed: false,
expectedErrorContains: "privileged user login is disabled",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Setup test dependencies with Windows OS and specified privileges
lookupUsers := map[string]*user.User{
tt.requestedUsername: createTestUser(tt.requestedUsername, "1002", "1002", "C:\\\\Users\\\\"+tt.requestedUsername),
}
cleanup := setupTestDependencies(tt.currentUser, nil, "windows", 1000, lookupUsers, nil)
defer cleanup()
server := &Server{allowRootLogin: tt.allowRoot}
result := server.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: tt.requestedUsername,
FeatureSupportsUserSwitch: true,
FeatureName: "SSH login",
})
assert.Equal(t, tt.expectedAllowed, result.Allowed,
"Privilege check result should match expected for %s", tt.name)
if !tt.expectedAllowed && tt.expectedErrorContains != "" {
assert.NotNil(t, result.Error, "Should have error when not allowed")
assert.Contains(t, result.Error.Error(), tt.expectedErrorContains,
"Error should contain expected message")
}
if tt.expectedAllowed && tt.requestedUsername != "" && tt.currentUser.Username != tt.requestedUsername {
assert.True(t, result.RequiresUserSwitching, "Should require user switching for different user")
}
})
}
}

View File

@@ -0,0 +1,8 @@
//go:build js
package server
// enableUserSwitching is not supported on JS/WASM
func enableUserSwitching() error {
return errNotSupported
}

View File

@@ -0,0 +1,233 @@
//go:build unix
package server
import (
"errors"
"fmt"
"net"
"net/netip"
"os"
"os/exec"
"os/user"
"regexp"
"runtime"
"strconv"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
// POSIX portable filename character set regex: [a-zA-Z0-9._-]
// First character cannot be hyphen (POSIX requirement)
var posixUsernameRegex = regexp.MustCompile(`^[a-zA-Z0-9._][a-zA-Z0-9._-]*$`)
// validateUsername validates that a username conforms to POSIX standards with security considerations
func validateUsername(username string) error {
if username == "" {
return errors.New("username cannot be empty")
}
// POSIX allows up to 256 characters, but practical limit is 32 for compatibility
if len(username) > 32 {
return errors.New("username too long (max 32 characters)")
}
if !posixUsernameRegex.MatchString(username) {
return errors.New("username contains invalid characters (must match POSIX portable filename character set)")
}
if username == "." || username == ".." {
return fmt.Errorf("username cannot be '.' or '..'")
}
// Warn if username is fully numeric (can cause issues with UID/username ambiguity)
if isFullyNumeric(username) {
log.Warnf("fully numeric username '%s' may cause issues with some commands", username)
}
return nil
}
// isFullyNumeric checks if username contains only digits
func isFullyNumeric(username string) bool {
for _, char := range username {
if char < '0' || char > '9' {
return false
}
}
return true
}
// createPtyLoginCommand creates a Pty command using login for privileged processes
func (s *Server) createPtyLoginCommand(localUser *user.User, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) {
loginPath, args, err := s.getLoginCmd(localUser.Username, session.RemoteAddr())
if err != nil {
return nil, fmt.Errorf("get login command: %w", err)
}
execCmd := exec.CommandContext(session.Context(), loginPath, args...)
execCmd.Dir = localUser.HomeDir
execCmd.Env = s.preparePtyEnv(localUser, ptyReq, session)
return execCmd, nil
}
// getLoginCmd returns the login command and args for privileged Pty user switching
func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []string, error) {
loginPath, err := exec.LookPath("login")
if err != nil {
return "", nil, fmt.Errorf("login command not available: %w", err)
}
addrPort, err := netip.ParseAddrPort(remoteAddr.String())
if err != nil {
return "", nil, fmt.Errorf("parse remote address: %w", err)
}
switch runtime.GOOS {
case "linux":
// Special handling for Arch Linux without /etc/pam.d/remote
if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") {
return loginPath, []string{"-f", username, "-p"}, nil
}
return loginPath, []string{"-f", username, "-h", addrPort.Addr().String(), "-p"}, nil
case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly":
return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil
default:
return "", nil, fmt.Errorf("unsupported Unix platform for login command: %s", runtime.GOOS)
}
}
// fileExists checks if a file exists (helper for login command logic)
func (s *Server) fileExists(path string) bool {
_, err := os.Stat(path)
return err == nil
}
// parseUserCredentials extracts numeric UID, GID, and supplementary groups
func (s *Server) parseUserCredentials(localUser *user.User) (uint32, uint32, []uint32, error) {
uid64, err := strconv.ParseUint(localUser.Uid, 10, 32)
if err != nil {
return 0, 0, nil, fmt.Errorf("invalid UID %s: %w", localUser.Uid, err)
}
uid := uint32(uid64)
gid64, err := strconv.ParseUint(localUser.Gid, 10, 32)
if err != nil {
return 0, 0, nil, fmt.Errorf("invalid GID %s: %w", localUser.Gid, err)
}
gid := uint32(gid64)
groups, err := s.getSupplementaryGroups(localUser.Username)
if err != nil {
log.Warnf("failed to get supplementary groups for user %s: %v", localUser.Username, err)
groups = []uint32{gid}
}
return uid, gid, groups, nil
}
// getSupplementaryGroups retrieves supplementary group IDs for a user
func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) {
u, err := user.Lookup(username)
if err != nil {
return nil, fmt.Errorf("lookup user %s: %w", username, err)
}
groupIDStrings, err := u.GroupIds()
if err != nil {
return nil, fmt.Errorf("get group IDs for user %s: %w", username, err)
}
groups := make([]uint32, len(groupIDStrings))
for i, gidStr := range groupIDStrings {
gid64, err := strconv.ParseUint(gidStr, 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid group ID %s for user %s: %w", gidStr, username, err)
}
groups[i] = uint32(gid64)
}
return groups, nil
}
// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping.
// Returns the command and a cleanup function (no-op on Unix).
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
if err := validateUsername(localUser.Username); err != nil {
return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
}
uid, gid, groups, err := s.parseUserCredentials(localUser)
if err != nil {
return nil, nil, fmt.Errorf("parse user credentials: %w", err)
}
privilegeDropper := NewPrivilegeDropper()
config := ExecutorConfig{
UID: uid,
GID: gid,
Groups: groups,
WorkingDir: localUser.HomeDir,
Shell: getUserShell(localUser.Uid),
Command: session.RawCommand(),
PTY: hasPty,
}
cmd, err := privilegeDropper.CreateExecutorCommand(session.Context(), config)
return cmd, func() {}, err
}
// enableUserSwitching is a no-op on Unix systems
func enableUserSwitching() error {
return nil
}
// createPtyCommand creates the exec.Cmd for Pty execution respecting privilege check results
func (s *Server) createPtyCommand(privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) {
localUser := privilegeResult.User
if localUser == nil {
return nil, errors.New("no user in privilege result")
}
if privilegeResult.UsedFallback {
return s.createDirectPtyCommand(session, localUser, ptyReq), nil
}
return s.createPtyLoginCommand(localUser, ptyReq, session)
}
// createDirectPtyCommand creates a direct Pty command without privilege dropping
func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.User, ptyReq ssh.Pty) *exec.Cmd {
log.Debugf("creating direct Pty command for user %s (no user switching needed)", localUser.Username)
shell := getUserShell(localUser.Uid)
args := s.getShellCommandArgs(shell, session.RawCommand())
cmd := exec.CommandContext(session.Context(), args[0], args[1:]...)
cmd.Dir = localUser.HomeDir
cmd.Env = s.preparePtyEnv(localUser, ptyReq, session)
return cmd
}
// preparePtyEnv prepares environment variables for Pty execution
func (s *Server) preparePtyEnv(localUser *user.User, ptyReq ssh.Pty, session ssh.Session) []string {
termType := ptyReq.Term
if termType == "" {
termType = "xterm-256color"
}
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
env = append(env, fmt.Sprintf("TERM=%s", termType))
for _, v := range session.Environ() {
if acceptEnv(v) {
env = append(env, v)
}
}
return env
}

View File

@@ -0,0 +1,274 @@
//go:build windows
package server
import (
"errors"
"fmt"
"os/exec"
"os/user"
"strings"
"unsafe"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
// validateUsername validates Windows usernames according to SAM Account Name rules
func validateUsername(username string) error {
if username == "" {
return fmt.Errorf("username cannot be empty")
}
usernameToValidate := extractUsernameFromDomain(username)
if err := validateUsernameLength(usernameToValidate); err != nil {
return err
}
if err := validateUsernameCharacters(usernameToValidate); err != nil {
return err
}
if err := validateUsernameFormat(usernameToValidate); err != nil {
return err
}
return nil
}
// extractUsernameFromDomain extracts the username part from domain\username or username@domain format
func extractUsernameFromDomain(username string) string {
if idx := strings.LastIndex(username, `\`); idx != -1 {
return username[idx+1:]
}
if idx := strings.Index(username, "@"); idx != -1 {
return username[:idx]
}
return username
}
// validateUsernameLength checks if username length is within Windows limits
func validateUsernameLength(username string) error {
if len(username) > 20 {
return fmt.Errorf("username too long (max 20 characters for Windows)")
}
return nil
}
// validateUsernameCharacters checks for invalid characters in Windows usernames
func validateUsernameCharacters(username string) error {
invalidChars := []rune{'"', '/', '[', ']', ':', ';', '|', '=', ',', '+', '*', '?', '<', '>', ' ', '`', '&', '\n'}
for _, char := range username {
for _, invalid := range invalidChars {
if char == invalid {
return fmt.Errorf("username contains invalid characters")
}
}
if char < 32 || char == 127 {
return fmt.Errorf("username contains control characters")
}
}
return nil
}
// validateUsernameFormat checks for invalid username formats and patterns
func validateUsernameFormat(username string) error {
if username == "." || username == ".." {
return fmt.Errorf("username cannot be '.' or '..'")
}
if strings.HasSuffix(username, ".") {
return fmt.Errorf("username cannot end with a period")
}
return nil
}
// createExecutorCommand creates a command using Windows executor for privilege dropping.
// Returns the command and a cleanup function that must be called after starting the process.
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
username, _ := s.parseUsername(localUser.Username)
if err := validateUsername(username); err != nil {
return nil, nil, fmt.Errorf("invalid username %q: %w", username, err)
}
return s.createUserSwitchCommand(localUser, session, hasPty)
}
// createUserSwitchCommand creates a command with Windows user switching.
// Returns the command and a cleanup function that must be called after starting the process.
func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, func(), error) {
username, domain := s.parseUsername(localUser.Username)
shell := getUserShell(localUser.Uid)
rawCmd := session.RawCommand()
var command string
if rawCmd != "" {
command = rawCmd
}
config := WindowsExecutorConfig{
Username: username,
Domain: domain,
WorkingDir: localUser.HomeDir,
Shell: shell,
Command: command,
Interactive: interactive || (rawCmd == ""),
}
dropper := NewPrivilegeDropper()
cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config)
if err != nil {
return nil, nil, err
}
cleanup := func() {
if token != 0 {
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
log.Debugf("close primary token: %v", err)
}
}
}
return cmd, cleanup, nil
}
// parseUsername extracts username and domain from a Windows username
func (s *Server) parseUsername(fullUsername string) (username, domain string) {
// Handle DOMAIN\username format
if idx := strings.LastIndex(fullUsername, `\`); idx != -1 {
domain = fullUsername[:idx]
username = fullUsername[idx+1:]
return username, domain
}
// Handle username@domain format
if username, domain, ok := strings.Cut(fullUsername, "@"); ok {
return username, domain
}
// Local user (no domain)
return fullUsername, "."
}
// hasPrivilege checks if the current process has a specific privilege
func hasPrivilege(token windows.Handle, privilegeName string) (bool, error) {
var luid windows.LUID
if err := windows.LookupPrivilegeValue(nil, windows.StringToUTF16Ptr(privilegeName), &luid); err != nil {
return false, fmt.Errorf("lookup privilege value: %w", err)
}
var returnLength uint32
err := windows.GetTokenInformation(
windows.Token(token),
windows.TokenPrivileges,
nil, // null buffer to get size
0,
&returnLength,
)
if err != nil && !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return false, fmt.Errorf("get token information size: %w", err)
}
buffer := make([]byte, returnLength)
err = windows.GetTokenInformation(
windows.Token(token),
windows.TokenPrivileges,
&buffer[0],
returnLength,
&returnLength,
)
if err != nil {
return false, fmt.Errorf("get token information: %w", err)
}
privileges := (*windows.Tokenprivileges)(unsafe.Pointer(&buffer[0]))
// Check if the privilege is present and enabled
for i := uint32(0); i < privileges.PrivilegeCount; i++ {
privilege := (*windows.LUIDAndAttributes)(unsafe.Pointer(
uintptr(unsafe.Pointer(&privileges.Privileges[0])) +
uintptr(i)*unsafe.Sizeof(windows.LUIDAndAttributes{}),
))
if privilege.Luid == luid {
return (privilege.Attributes & windows.SE_PRIVILEGE_ENABLED) != 0, nil
}
}
return false, nil
}
// enablePrivilege enables a specific privilege for the current process token
// This is required because privileges like SeAssignPrimaryTokenPrivilege are present
// but disabled by default, even for the SYSTEM account
func enablePrivilege(token windows.Handle, privilegeName string) error {
var luid windows.LUID
if err := windows.LookupPrivilegeValue(nil, windows.StringToUTF16Ptr(privilegeName), &luid); err != nil {
return fmt.Errorf("lookup privilege value for %s: %w", privilegeName, err)
}
privileges := windows.Tokenprivileges{
PrivilegeCount: 1,
Privileges: [1]windows.LUIDAndAttributes{
{
Luid: luid,
Attributes: windows.SE_PRIVILEGE_ENABLED,
},
},
}
err := windows.AdjustTokenPrivileges(
windows.Token(token),
false,
&privileges,
0,
nil,
nil,
)
if err != nil {
return fmt.Errorf("adjust token privileges for %s: %w", privilegeName, err)
}
hasPriv, err := hasPrivilege(token, privilegeName)
if err != nil {
return fmt.Errorf("verify privilege %s after enabling: %w", privilegeName, err)
}
if !hasPriv {
return fmt.Errorf("privilege %s could not be enabled (may not be granted to account)", privilegeName)
}
log.Debugf("Successfully enabled privilege %s for current process", privilegeName)
return nil
}
// enableUserSwitching enables required privileges for Windows user switching
func enableUserSwitching() error {
process := windows.CurrentProcess()
var token windows.Token
err := windows.OpenProcessToken(
process,
windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY,
&token,
)
if err != nil {
return fmt.Errorf("open process token: %w", err)
}
defer func() {
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
log.Debugf("Failed to close process token: %v", err)
}
}()
if err := enablePrivilege(windows.Handle(token), "SeAssignPrimaryTokenPrivilege"); err != nil {
return fmt.Errorf("enable SeAssignPrimaryTokenPrivilege: %w", err)
}
log.Infof("Windows user switching privileges enabled successfully")
return nil
}

View File

@@ -0,0 +1,487 @@
//go:build windows
package winpty
import (
"context"
"errors"
"fmt"
"io"
"strings"
"sync"
"syscall"
"unsafe"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
var (
ErrEmptyEnvironment = errors.New("empty environment")
)
const (
extendedStartupInfoPresent = 0x00080000
createUnicodeEnvironment = 0x00000400
procThreadAttributePseudoConsole = 0x00020016
PowerShellCommandFlag = "-Command"
errCloseInputRead = "close input read handle: %v"
errCloseConPtyCleanup = "close ConPty handle during cleanup"
)
// PtyConfig holds configuration for Pty execution.
type PtyConfig struct {
Shell string
Command string
Width int
Height int
WorkingDir string
}
// UserConfig holds user execution configuration.
type UserConfig struct {
Token windows.Handle
Environment []string
}
var (
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole")
procInitializeProcThreadAttributeList = kernel32.NewProc("InitializeProcThreadAttributeList")
procUpdateProcThreadAttribute = kernel32.NewProc("UpdateProcThreadAttribute")
procDeleteProcThreadAttributeList = kernel32.NewProc("DeleteProcThreadAttributeList")
)
// ExecutePtyWithUserToken executes a command with ConPty using user token.
func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error {
args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command)
commandLine := buildCommandLine(args)
config := ExecutionConfig{
Pty: ptyConfig,
User: userConfig,
Session: session,
Context: ctx,
}
return executeConPtyWithConfig(commandLine, config)
}
// ExecutionConfig holds all execution configuration.
type ExecutionConfig struct {
Pty PtyConfig
User UserConfig
Session ssh.Session
Context context.Context
}
// executeConPtyWithConfig creates ConPty and executes process with configuration.
func executeConPtyWithConfig(commandLine string, config ExecutionConfig) error {
ctx := config.Context
session := config.Session
width := config.Pty.Width
height := config.Pty.Height
userToken := config.User.Token
userEnv := config.User.Environment
workingDir := config.Pty.WorkingDir
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
if err != nil {
return fmt.Errorf("create ConPty pipes: %w", err)
}
hPty, err := createConPty(width, height, inputRead, outputWrite)
if err != nil {
return fmt.Errorf("create ConPty: %w", err)
}
primaryToken, err := duplicateToPrimaryToken(userToken)
if err != nil {
if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
log.Debugf(errCloseConPtyCleanup)
}
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
return fmt.Errorf("duplicate to primary token: %w", err)
}
defer func() {
if err := windows.CloseHandle(primaryToken); err != nil {
log.Debugf("close primary token: %v", err)
}
}()
siEx, err := setupConPtyStartupInfo(hPty)
if err != nil {
if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
log.Debugf(errCloseConPtyCleanup)
}
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
return fmt.Errorf("setup startup info: %w", err)
}
defer func() {
_, _, _ = procDeleteProcThreadAttributeList.Call(uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)))
}()
pi, err := createConPtyProcess(commandLine, primaryToken, userEnv, workingDir, siEx)
if err != nil {
if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
log.Debugf(errCloseConPtyCleanup)
}
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
return fmt.Errorf("create process as user with ConPty: %w", err)
}
defer closeProcessInfo(pi)
if err := windows.CloseHandle(inputRead); err != nil {
log.Debugf(errCloseInputRead, err)
}
if err := windows.CloseHandle(outputWrite); err != nil {
log.Debugf("close output write handle: %v", err)
}
return bridgeConPtyIO(ctx, hPty, inputWrite, outputRead, session, session, session, pi.Process)
}
// createConPtyPipes creates input/output pipes for ConPty.
func createConPtyPipes() (inputRead, inputWrite, outputRead, outputWrite windows.Handle, err error) {
if err := windows.CreatePipe(&inputRead, &inputWrite, nil, 0); err != nil {
return 0, 0, 0, 0, fmt.Errorf("create input pipe: %w", err)
}
if err := windows.CreatePipe(&outputRead, &outputWrite, nil, 0); err != nil {
if closeErr := windows.CloseHandle(inputRead); closeErr != nil {
log.Debugf(errCloseInputRead, closeErr)
}
if closeErr := windows.CloseHandle(inputWrite); closeErr != nil {
log.Debugf("close input write handle: %v", closeErr)
}
return 0, 0, 0, 0, fmt.Errorf("create output pipe: %w", err)
}
return inputRead, inputWrite, outputRead, outputWrite, nil
}
// createConPty creates a Windows ConPty with the specified size and pipe handles.
func createConPty(width, height int, inputRead, outputWrite windows.Handle) (windows.Handle, error) {
size := windows.Coord{X: int16(width), Y: int16(height)}
var hPty windows.Handle
if err := windows.CreatePseudoConsole(size, inputRead, outputWrite, 0, &hPty); err != nil {
return 0, fmt.Errorf("CreatePseudoConsole: %w", err)
}
return hPty, nil
}
// setupConPtyStartupInfo prepares the STARTUPINFOEX with ConPty attributes.
func setupConPtyStartupInfo(hPty windows.Handle) (*windows.StartupInfoEx, error) {
var siEx windows.StartupInfoEx
siEx.StartupInfo.Cb = uint32(unsafe.Sizeof(siEx))
var attrListSize uintptr
ret, _, _ := procInitializeProcThreadAttributeList.Call(0, 1, 0, uintptr(unsafe.Pointer(&attrListSize)))
if ret == 0 && attrListSize == 0 {
return nil, fmt.Errorf("get attribute list size")
}
attrListBytes := make([]byte, attrListSize)
siEx.ProcThreadAttributeList = (*windows.ProcThreadAttributeList)(unsafe.Pointer(&attrListBytes[0]))
ret, _, err := procInitializeProcThreadAttributeList.Call(
uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)),
1,
0,
uintptr(unsafe.Pointer(&attrListSize)),
)
if ret == 0 {
return nil, fmt.Errorf("initialize attribute list: %w", err)
}
ret, _, err = procUpdateProcThreadAttribute.Call(
uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)),
0,
procThreadAttributePseudoConsole,
uintptr(hPty),
unsafe.Sizeof(hPty),
0,
0,
)
if ret == 0 {
return nil, fmt.Errorf("update thread attribute: %w", err)
}
return &siEx, nil
}
// createConPtyProcess creates the actual process with ConPty.
func createConPtyProcess(commandLine string, userToken windows.Handle, userEnv []string, workingDir string, siEx *windows.StartupInfoEx) (*windows.ProcessInformation, error) {
var pi windows.ProcessInformation
creationFlags := uint32(extendedStartupInfoPresent | createUnicodeEnvironment)
commandLinePtr, err := windows.UTF16PtrFromString(commandLine)
if err != nil {
return nil, fmt.Errorf("convert command line to UTF16: %w", err)
}
envPtr, err := convertEnvironmentToUTF16(userEnv)
if err != nil {
return nil, err
}
var workingDirPtr *uint16
if workingDir != "" {
workingDirPtr, err = windows.UTF16PtrFromString(workingDir)
if err != nil {
return nil, fmt.Errorf("convert working directory to UTF16: %w", err)
}
}
siEx.StartupInfo.Flags |= windows.STARTF_USESTDHANDLES
siEx.StartupInfo.StdInput = windows.Handle(0)
siEx.StartupInfo.StdOutput = windows.Handle(0)
siEx.StartupInfo.StdErr = siEx.StartupInfo.StdOutput
if userToken != windows.InvalidHandle {
err = windows.CreateProcessAsUser(
windows.Token(userToken),
nil,
commandLinePtr,
nil,
nil,
true,
creationFlags,
envPtr,
workingDirPtr,
&siEx.StartupInfo,
&pi,
)
} else {
err = windows.CreateProcess(
nil,
commandLinePtr,
nil,
nil,
true,
creationFlags,
envPtr,
workingDirPtr,
&siEx.StartupInfo,
&pi,
)
}
if err != nil {
return nil, fmt.Errorf("create process: %w", err)
}
return &pi, nil
}
// convertEnvironmentToUTF16 converts environment variables to Windows UTF16 format.
func convertEnvironmentToUTF16(userEnv []string) (*uint16, error) {
if len(userEnv) == 0 {
// Return nil pointer for empty environment - Windows API will inherit parent environment
return nil, nil //nolint:nilnil // Intentional nil,nil for empty environment
}
var envUTF16 []uint16
for _, envVar := range userEnv {
if envVar != "" {
utf16Str, err := windows.UTF16FromString(envVar)
if err != nil {
log.Debugf("skipping invalid environment variable: %s (error: %v)", envVar, err)
continue
}
envUTF16 = append(envUTF16, utf16Str[:len(utf16Str)-1]...)
envUTF16 = append(envUTF16, 0)
}
}
envUTF16 = append(envUTF16, 0)
if len(envUTF16) > 0 {
return &envUTF16[0], nil
}
// Return nil pointer when no valid environment variables found
return nil, nil //nolint:nilnil // Intentional nil,nil for empty environment
}
// duplicateToPrimaryToken converts an impersonation token to a primary token.
func duplicateToPrimaryToken(token windows.Handle) (windows.Handle, error) {
var primaryToken windows.Handle
if err := windows.DuplicateTokenEx(
windows.Token(token),
windows.TOKEN_ALL_ACCESS,
nil,
windows.SecurityImpersonation,
windows.TokenPrimary,
(*windows.Token)(&primaryToken),
); err != nil {
return 0, fmt.Errorf("duplicate token: %w", err)
}
return primaryToken, nil
}
// SessionExiter provides the Exit method for reporting process exit status.
type SessionExiter interface {
Exit(code int) error
}
// bridgeConPtyIO handles I/O bridging between ConPty and readers/writers.
func bridgeConPtyIO(ctx context.Context, hPty, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer, session SessionExiter, process windows.Handle) error {
if err := ctx.Err(); err != nil {
return err
}
var wg sync.WaitGroup
startIOBridging(ctx, &wg, inputWrite, outputRead, reader, writer)
processErr := waitForProcess(ctx, process)
if processErr != nil {
return processErr
}
var exitCode uint32
if err := windows.GetExitCodeProcess(process, &exitCode); err != nil {
log.Debugf("get exit code: %v", err)
} else {
if err := session.Exit(int(exitCode)); err != nil {
log.Debugf("report exit code: %v", err)
}
}
// Clean up in the original order after process completes
if err := reader.Close(); err != nil {
log.Debugf("close reader: %v", err)
}
ret, _, err := procClosePseudoConsole.Call(uintptr(hPty))
if ret == 0 {
log.Debugf("close ConPty handle: %v", err)
}
wg.Wait()
if err := windows.CloseHandle(outputRead); err != nil {
log.Debugf("close output read handle: %v", err)
}
return nil
}
// startIOBridging starts the I/O bridging goroutines.
func startIOBridging(ctx context.Context, wg *sync.WaitGroup, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer) {
wg.Add(2)
// Input: reader (SSH session) -> inputWrite (ConPty)
go func() {
defer wg.Done()
defer func() {
if err := windows.CloseHandle(inputWrite); err != nil {
log.Debugf("close input write handle in goroutine: %v", err)
}
}()
if _, err := io.Copy(&windowsHandleWriter{handle: inputWrite}, reader); err != nil {
log.Debugf("input copy ended with error: %v", err)
}
}()
// Output: outputRead (ConPty) -> writer (SSH session)
go func() {
defer wg.Done()
if _, err := io.Copy(writer, &windowsHandleReader{handle: outputRead}); err != nil {
log.Debugf("output copy ended with error: %v", err)
}
}()
}
// waitForProcess waits for process completion with context cancellation.
func waitForProcess(ctx context.Context, process windows.Handle) error {
if _, err := windows.WaitForSingleObject(process, windows.INFINITE); err != nil {
return fmt.Errorf("wait for process %d: %w", process, err)
}
return nil
}
// buildShellArgs builds shell arguments for ConPty execution.
func buildShellArgs(shell, command string) []string {
if command != "" {
return []string{shell, PowerShellCommandFlag, command}
}
return []string{shell}
}
// buildCommandLine builds a Windows command line from arguments using proper escaping.
func buildCommandLine(args []string) string {
if len(args) == 0 {
return ""
}
var result strings.Builder
for i, arg := range args {
if i > 0 {
result.WriteString(" ")
}
result.WriteString(syscall.EscapeArg(arg))
}
return result.String()
}
// closeHandles closes multiple Windows handles.
func closeHandles(handles ...windows.Handle) {
for _, handle := range handles {
if handle != windows.InvalidHandle {
if err := windows.CloseHandle(handle); err != nil {
log.Debugf("close handle: %v", err)
}
}
}
}
// closeProcessInfo closes process and thread handles.
func closeProcessInfo(pi *windows.ProcessInformation) {
if pi != nil {
if err := windows.CloseHandle(pi.Process); err != nil {
log.Debugf("close process handle: %v", err)
}
if err := windows.CloseHandle(pi.Thread); err != nil {
log.Debugf("close thread handle: %v", err)
}
}
}
// windowsHandleReader wraps a Windows handle for reading.
type windowsHandleReader struct {
handle windows.Handle
}
func (r *windowsHandleReader) Read(p []byte) (n int, err error) {
var bytesRead uint32
if err := windows.ReadFile(r.handle, p, &bytesRead, nil); err != nil {
return 0, err
}
return int(bytesRead), nil
}
func (r *windowsHandleReader) Close() error {
return windows.CloseHandle(r.handle)
}
// windowsHandleWriter wraps a Windows handle for writing.
type windowsHandleWriter struct {
handle windows.Handle
}
func (w *windowsHandleWriter) Write(p []byte) (n int, err error) {
var bytesWritten uint32
if err := windows.WriteFile(w.handle, p, &bytesWritten, nil); err != nil {
return 0, err
}
return int(bytesWritten), nil
}
func (w *windowsHandleWriter) Close() error {
return windows.CloseHandle(w.handle)
}

View File

@@ -0,0 +1,290 @@
//go:build windows
package winpty
import (
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/windows"
)
func TestBuildShellArgs(t *testing.T) {
tests := []struct {
name string
shell string
command string
expected []string
}{
{
name: "Shell with command",
shell: "powershell.exe",
command: "Get-Process",
expected: []string{"powershell.exe", "-Command", "Get-Process"},
},
{
name: "CMD with command",
shell: "cmd.exe",
command: "dir",
expected: []string{"cmd.exe", "-Command", "dir"},
},
{
name: "Shell interactive",
shell: "powershell.exe",
command: "",
expected: []string{"powershell.exe"},
},
{
name: "CMD interactive",
shell: "cmd.exe",
command: "",
expected: []string{"cmd.exe"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildShellArgs(tt.shell, tt.command)
assert.Equal(t, tt.expected, result)
})
}
}
func TestBuildCommandLine(t *testing.T) {
tests := []struct {
name string
args []string
expected string
}{
{
name: "Simple args",
args: []string{"cmd.exe", "/c", "echo"},
expected: "cmd.exe /c echo",
},
{
name: "Args with spaces",
args: []string{"Program Files\\app.exe", "arg with spaces"},
expected: `"Program Files\app.exe" "arg with spaces"`,
},
{
name: "Args with quotes",
args: []string{"cmd.exe", "/c", `echo "hello world"`},
expected: `cmd.exe /c "echo \"hello world\""`,
},
{
name: "PowerShell calling PowerShell",
args: []string{"powershell.exe", "-Command", `powershell.exe -Command "Get-Process | Where-Object {$_.Name -eq 'notepad'}"`},
expected: `powershell.exe -Command "powershell.exe -Command \"Get-Process | Where-Object {$_.Name -eq 'notepad'}\""`,
},
{
name: "Complex nested quotes",
args: []string{"cmd.exe", "/c", `echo "He said \"Hello\" to me"`},
expected: `cmd.exe /c "echo \"He said \\\"Hello\\\" to me\""`,
},
{
name: "Path with spaces and args",
args: []string{`C:\Program Files\MyApp\app.exe`, "--config", `C:\My Config\settings.json`},
expected: `"C:\Program Files\MyApp\app.exe" --config "C:\My Config\settings.json"`,
},
{
name: "Empty argument",
args: []string{"cmd.exe", "/c", "echo", ""},
expected: `cmd.exe /c echo ""`,
},
{
name: "Argument with backslashes",
args: []string{"robocopy", `C:\Source\`, `C:\Dest\`, "/E"},
expected: `robocopy C:\Source\ C:\Dest\ /E`,
},
{
name: "Empty args",
args: []string{},
expected: "",
},
{
name: "Single arg with space",
args: []string{"path with spaces"},
expected: `"path with spaces"`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildCommandLine(tt.args)
assert.Equal(t, tt.expected, result)
})
}
}
func TestCreateConPtyPipes(t *testing.T) {
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
require.NoError(t, err, "Should create ConPty pipes successfully")
// Verify all handles are valid
assert.NotEqual(t, windows.InvalidHandle, inputRead, "Input read handle should be valid")
assert.NotEqual(t, windows.InvalidHandle, inputWrite, "Input write handle should be valid")
assert.NotEqual(t, windows.InvalidHandle, outputRead, "Output read handle should be valid")
assert.NotEqual(t, windows.InvalidHandle, outputWrite, "Output write handle should be valid")
// Clean up handles
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
}
func TestCreateConPty(t *testing.T) {
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
require.NoError(t, err, "Should create ConPty pipes successfully")
defer closeHandles(inputRead, inputWrite, outputRead, outputWrite)
hPty, err := createConPty(80, 24, inputRead, outputWrite)
require.NoError(t, err, "Should create ConPty successfully")
assert.NotEqual(t, windows.InvalidHandle, hPty, "ConPty handle should be valid")
// Clean up ConPty
ret, _, _ := procClosePseudoConsole.Call(uintptr(hPty))
assert.NotEqual(t, uintptr(0), ret, "Should close ConPty successfully")
}
func TestConvertEnvironmentToUTF16(t *testing.T) {
tests := []struct {
name string
userEnv []string
hasError bool
}{
{
name: "Valid environment variables",
userEnv: []string{"PATH=C:\\Windows", "USER=testuser", "HOME=C:\\Users\\testuser"},
hasError: false,
},
{
name: "Empty environment",
userEnv: []string{},
hasError: false,
},
{
name: "Environment with empty strings",
userEnv: []string{"PATH=C:\\Windows", "", "USER=testuser"},
hasError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := convertEnvironmentToUTF16(tt.userEnv)
if tt.hasError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if len(tt.userEnv) == 0 {
assert.Nil(t, result, "Empty environment should return nil")
} else {
assert.NotNil(t, result, "Non-empty environment should return valid pointer")
}
}
})
}
}
func TestDuplicateToPrimaryToken(t *testing.T) {
if testing.Short() {
t.Skip("Skipping token tests in short mode")
}
// Get current process token for testing
var token windows.Token
err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_ALL_ACCESS, &token)
require.NoError(t, err, "Should open current process token")
defer func() {
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
t.Logf("Failed to close token: %v", err)
}
}()
primaryToken, err := duplicateToPrimaryToken(windows.Handle(token))
require.NoError(t, err, "Should duplicate token to primary")
assert.NotEqual(t, windows.InvalidHandle, primaryToken, "Primary token should be valid")
// Clean up
err = windows.CloseHandle(primaryToken)
assert.NoError(t, err, "Should close primary token")
}
func TestWindowsHandleReader(t *testing.T) {
// Create a pipe for testing
var readHandle, writeHandle windows.Handle
err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0)
require.NoError(t, err, "Should create pipe for testing")
defer closeHandles(readHandle, writeHandle)
// Write test data
testData := []byte("Hello, Windows Handle Reader!")
var bytesWritten uint32
err = windows.WriteFile(writeHandle, testData, &bytesWritten, nil)
require.NoError(t, err, "Should write test data")
require.Equal(t, uint32(len(testData)), bytesWritten, "Should write all test data")
// Close write handle to signal EOF
if err := windows.CloseHandle(writeHandle); err != nil {
t.Fatalf("Should close write handle: %v", err)
}
writeHandle = windows.InvalidHandle
// Test reading
reader := &windowsHandleReader{handle: readHandle}
buffer := make([]byte, len(testData))
n, err := reader.Read(buffer)
require.NoError(t, err, "Should read from handle")
assert.Equal(t, len(testData), n, "Should read expected number of bytes")
assert.Equal(t, testData, buffer, "Should read expected data")
}
func TestWindowsHandleWriter(t *testing.T) {
// Create a pipe for testing
var readHandle, writeHandle windows.Handle
err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0)
require.NoError(t, err, "Should create pipe for testing")
defer closeHandles(readHandle, writeHandle)
// Test writing
testData := []byte("Hello, Windows Handle Writer!")
writer := &windowsHandleWriter{handle: writeHandle}
n, err := writer.Write(testData)
require.NoError(t, err, "Should write to handle")
assert.Equal(t, len(testData), n, "Should write expected number of bytes")
// Close write handle
if err := windows.CloseHandle(writeHandle); err != nil {
t.Fatalf("Should close write handle: %v", err)
}
// Verify data was written by reading it back
buffer := make([]byte, len(testData))
var bytesRead uint32
err = windows.ReadFile(readHandle, buffer, &bytesRead, nil)
require.NoError(t, err, "Should read back written data")
assert.Equal(t, uint32(len(testData)), bytesRead, "Should read back expected number of bytes")
assert.Equal(t, testData, buffer, "Should read back expected data")
}
// BenchmarkConPtyCreation benchmarks ConPty creation performance
func BenchmarkConPtyCreation(b *testing.B) {
for i := 0; i < b.N; i++ {
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
if err != nil {
b.Fatal(err)
}
hPty, err := createConPty(80, 24, inputRead, outputWrite)
if err != nil {
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
b.Fatal(err)
}
// Clean up
if ret, _, err := procClosePseudoConsole.Call(uintptr(hPty)); ret == 0 {
log.Debugf("ClosePseudoConsole failed: %v", err)
}
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
}
}